openai-auth 1.0.0

OpenAI/ChatGPT OAuth 2.0 authentication with PKCE - sync and async APIs
Documentation
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use serde::{Deserialize, Serialize};

use crate::{OpenAIAuthError, Result, types::SessionData};

/// OpenAI-specific auth claims within JWT
#[derive(Debug, Serialize, Deserialize)]
struct OpenAIAuth {
    #[serde(rename = "chatgpt_account_id")]
    chatgpt_account_id: Option<String>,
    #[serde(rename = "organization_id")]
    organization_id: Option<String>,
    #[serde(rename = "project_id")]
    project_id: Option<String>,
    #[serde(rename = "completed_platform_onboarding")]
    completed_platform_onboarding: Option<bool>,
    #[serde(rename = "is_org_owner")]
    is_org_owner: Option<bool>,
    #[serde(rename = "chatgpt_plan_type")]
    chatgpt_plan_type: Option<String>,
}

/// JWT claims structure
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
    #[serde(rename = "https://api.openai.com/auth")]
    openai_auth: Option<OpenAIAuth>,
}

/// Decode a JWT token without verification
///
/// This extracts claims from a JWT without verifying its signature.
/// We trust tokens from OpenAI's OAuth flow.
fn decode_jwt_claims(token: &str) -> Result<Claims> {
    let mut validation = Validation::new(Algorithm::RS256);
    validation.insecure_disable_signature_validation();
    validation.validate_exp = false;

    let token_data = decode::<Claims>(token, &DecodingKey::from_secret(&[]), &validation)?;
    Ok(token_data.claims)
}

/// Extract ChatGPT account ID from access token JWT
///
/// This function decodes the JWT without verifying the signature (since we
/// already trust the token from the OAuth flow) and extracts the account ID
/// from the custom claims.
///
/// # Arguments
///
/// * `token` - The JWT access token from OpenAI
///
/// # Returns
///
/// The ChatGPT account ID as a string
///
/// # Errors
///
/// Returns an error if:
/// - The JWT is malformed
/// - The required claim is missing
pub fn extract_account_id(token: &str) -> Result<String> {
    let claims = decode_jwt_claims(token)?;
    claims
        .openai_auth
        .and_then(|auth| auth.chatgpt_account_id)
        .ok_or_else(|| OpenAIAuthError::MissingJwtClaim("chatgpt_account_id".to_string()))
}

/// Extract session data from id_token and access_token
///
/// Combines information from both tokens to create a complete session data object.
///
/// # Arguments
///
/// * `id_token` - The id_token from OAuth flow
/// * `access_token` - The access_token from OAuth flow
///
/// # Returns
///
/// SessionData containing organization, project, and user information
pub fn extract_session_data(id_token: &str, access_token: &str) -> Result<SessionData> {
    let id_claims = decode_jwt_claims(id_token)?;
    let access_claims = decode_jwt_claims(access_token)?;

    let id_auth = id_claims.openai_auth.unwrap_or(OpenAIAuth {
        chatgpt_account_id: None,
        organization_id: None,
        project_id: None,
        completed_platform_onboarding: None,
        is_org_owner: None,
        chatgpt_plan_type: None,
    });

    let access_auth = access_claims.openai_auth.unwrap_or(OpenAIAuth {
        chatgpt_account_id: None,
        organization_id: None,
        project_id: None,
        completed_platform_onboarding: None,
        is_org_owner: None,
        chatgpt_plan_type: None,
    });

    Ok(SessionData {
        organization_id: id_auth.organization_id,
        project_id: id_auth.project_id,
        completed_platform_onboarding: id_auth.completed_platform_onboarding.unwrap_or(false),
        is_org_owner: id_auth.is_org_owner.unwrap_or(false),
        chatgpt_plan_type: access_auth.chatgpt_plan_type,
        chatgpt_account_id: access_auth.chatgpt_account_id,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_account_id_missing_claim() {
        // A token without the required claim should return an error
        let token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.invalid";
        let result = extract_account_id(token);
        assert!(result.is_err());
    }
}