Skip to main content

llm/providers/codex/
oauth.rs

1use crate::LlmError;
2use aether_auth::{
3    BrowserOAuthHandler, OAuthCredential, OAuthCredentialStorage, OAuthError, OAuthHandler, oauth_http_client,
4};
5use base64::Engine;
6use base64::engine::general_purpose::URL_SAFE_NO_PAD;
7use oauth2::basic::BasicClient;
8use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use url::Url;
12
13const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
14const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
15const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
16const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
17const SCOPE: &str = "openid profile email offline_access";
18
19/// Run the full Codex OAuth flow: open browser, capture callback, exchange token, save credentials.
20///
21pub async fn perform_codex_oauth_flow(store: &dyn OAuthCredentialStorage) -> Result<(), LlmError> {
22    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
23    let state = generate_random_state();
24
25    let auth_url = Url::parse_with_params(
26        AUTHORIZE_URL,
27        &[
28            ("response_type", "code"),
29            ("client_id", CLIENT_ID),
30            ("redirect_uri", REDIRECT_URI),
31            ("scope", SCOPE),
32            ("code_challenge", pkce_challenge.as_str()),
33            ("code_challenge_method", "S256"),
34            ("state", &state),
35            ("id_token_add_organizations", "true"),
36            ("codex_cli_simplified_flow", "true"),
37            ("originator", "codex_cli_rs"),
38        ],
39    )
40    .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
41
42    // Port 1455 is hardcoded because the Codex API has a fixed redirect URI
43    // (http://localhost:1455/auth/callback) registered with OpenAI's OAuth server.
44    let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
45    let callback = handler.authorize(auth_url.as_str()).await?;
46
47    if callback.state != state {
48        return Err(OAuthError::StateMismatch.into());
49    }
50
51    let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
52        .set_auth_uri(
53            AuthUrl::new(AUTHORIZE_URL.to_string())
54                .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
55        )
56        .set_token_uri(
57            TokenUrl::new(TOKEN_URL.to_string())
58                .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
59        )
60        .set_redirect_uri(
61            RedirectUrl::new(REDIRECT_URI.to_string())
62                .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
63        );
64
65    let http_client = oauth_http_client()?;
66
67    let token_response = oauth_client
68        .exchange_code(AuthorizationCode::new(callback.code))
69        .set_pkce_verifier(pkce_verifier)
70        .request_async(&http_client)
71        .await
72        .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
73
74    let credential = OAuthCredential::from_token_response(CLIENT_ID.to_string(), &token_response);
75    store.save_credential(super::PROVIDER_ID, credential).await?;
76
77    Ok(())
78}
79
80/// In-memory cache of the most recently validated credential and its derived account ID.
81struct CachedToken {
82    credential: OAuthCredential,
83    account_id: String,
84}
85
86/// Manages OAuth tokens for the Codex backend API.
87///
88/// Holds an `Arc<dyn OAuthCredentialStorage>` so callers can swap in keyring-backed,
89/// file-backed, or in-memory stores without changing this type.
90pub struct CodexTokenManager {
91    store: Arc<dyn OAuthCredentialStorage>,
92    credential_key: String,
93    token_url: TokenUrl,
94    cached: Mutex<Option<CachedToken>>,
95}
96
97impl CodexTokenManager {
98    pub fn new(store: Arc<dyn OAuthCredentialStorage>, credential_key: &str) -> Self {
99        Self::new_with_token_url(
100            store,
101            credential_key,
102            TokenUrl::new(TOKEN_URL.to_string()).expect("hardcoded Codex token URL is valid"),
103        )
104    }
105
106    fn new_with_token_url(store: Arc<dyn OAuthCredentialStorage>, credential_key: &str, token_url: TokenUrl) -> Self {
107        Self { store, credential_key: credential_key.to_string(), token_url, cached: Mutex::new(None) }
108    }
109
110    /// Get a valid access token and account ID.
111    ///
112    /// Returns `(access_token, account_id)`. The account ID is extracted from
113    /// the JWT's `https://api.openai.com/auth` claim field `chatgpt_account_id`.
114    pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
115        let mut cache = self.cached.lock().await;
116        if let Some(cached) = cache.as_ref()
117            && !cached.credential.needs_refresh()
118        {
119            return Ok((cached.credential.access_token.clone(), cached.account_id.clone()));
120        }
121
122        let credential = self.load_or_refresh().await?;
123        let account_id = extract_account_id(&credential.access_token)?;
124        let access_token = credential.access_token.clone();
125        *cache = Some(CachedToken { credential, account_id: account_id.clone() });
126        Ok((access_token, account_id))
127    }
128
129    async fn load_or_refresh(&self) -> Result<OAuthCredential, LlmError> {
130        let stored = self.store.load_credential(&self.credential_key).await?.ok_or_else(|| {
131            OAuthError::NoCredentials(
132                "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
133                    .to_string(),
134            )
135        })?;
136
137        if !stored.needs_refresh() {
138            return Ok(stored);
139        }
140
141        let refreshed = stored.refresh(&self.token_url).await?;
142        self.store.save_credential(&self.credential_key, refreshed.clone()).await?;
143        Ok(refreshed)
144    }
145
146    /// Clear the cached token (e.g. after a 401 response)
147    pub async fn clear_cache(&self) {
148        *self.cached.lock().await = None;
149    }
150}
151
152/// Extract the account ID from a JWT access token.
153///
154/// The JWT payload contains a claim at `https://api.openai.com/auth`
155/// with a `chatgpt_account_id` field.
156pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
157    let parts: Vec<&str> = access_token.split('.').collect();
158    if parts.len() != 3 {
159        return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
160    }
161
162    let decoded = URL_SAFE_NO_PAD
163        .decode(parts[1])
164        .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
165
166    let payload: serde_json::Value = serde_json::from_slice(&decoded)
167        .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
168
169    let account_id = payload
170        .get("https://api.openai.com/auth")
171        .and_then(|auth| auth.get("chatgpt_account_id"))
172        .and_then(|v| v.as_str())
173        .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
174
175    Ok(account_id.to_string())
176}
177
178fn generate_random_state() -> String {
179    uuid::Uuid::new_v4().to_string()
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use aether_auth::{FakeOAuthCredentialStore, OAuthCredential};
186    use axum::Router;
187    use axum::body::{Body, to_bytes};
188    use axum::extract::State;
189    use axum::http::{HeaderMap, Method, Request, StatusCode};
190    use axum::response::IntoResponse;
191    use axum::routing::post;
192    use std::collections::HashMap;
193    use tokio::net::TcpListener;
194    use tokio::sync::{Mutex as TokioMutex, oneshot};
195
196    /// Create a test JWT with a given payload
197    fn make_test_jwt(payload: &serde_json::Value) -> String {
198        let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
199        let payload_json = serde_json::to_string(payload).unwrap();
200        let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
201        format!("{header}.{payload_b64url}.fake_signature")
202    }
203
204    #[test]
205    fn extract_account_id_from_valid_jwt() {
206        let payload = serde_json::json!({
207            "sub": "user_123",
208            "https://api.openai.com/auth": {
209                "chatgpt_account_id": "acct_abc123"
210            }
211        });
212
213        let jwt = make_test_jwt(&payload);
214        let account_id = extract_account_id(&jwt).unwrap();
215        assert_eq!(account_id, "acct_abc123");
216    }
217
218    #[test]
219    fn extract_account_id_missing_claim() {
220        let payload = serde_json::json!({
221            "sub": "user_123"
222        });
223
224        let jwt = make_test_jwt(&payload);
225        let result = extract_account_id(&jwt);
226        assert!(result.is_err());
227        assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
228    }
229
230    #[test]
231    fn extract_account_id_invalid_jwt_format() {
232        let result = extract_account_id("not.a.valid.jwt.too.many.parts");
233        assert!(result.is_err());
234
235        let result = extract_account_id("toofewparts");
236        assert!(result.is_err());
237    }
238
239    #[test]
240    fn extract_account_id_invalid_base64() {
241        let result = extract_account_id("header.!!!invalid!!!.signature");
242        assert!(result.is_err());
243    }
244
245    #[test]
246    fn auth_url_is_well_formed() {
247        let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
248        let state = "test-state";
249
250        let auth_url = Url::parse_with_params(
251            AUTHORIZE_URL,
252            &[
253                ("response_type", "code"),
254                ("client_id", CLIENT_ID),
255                ("redirect_uri", REDIRECT_URI),
256                ("scope", SCOPE),
257                ("code_challenge", pkce_challenge.as_str()),
258                ("code_challenge_method", "S256"),
259                ("state", state),
260                ("id_token_add_organizations", "true"),
261                ("codex_cli_simplified_flow", "true"),
262                ("originator", "codex_cli_rs"),
263            ],
264        )
265        .unwrap();
266
267        let url_str = auth_url.as_str();
268        assert!(url_str.starts_with(AUTHORIZE_URL));
269        assert!(url_str.contains("client_id="));
270        assert!(url_str.contains("redirect_uri="));
271        assert!(url_str.contains("scope="));
272        assert!(url_str.contains("code_challenge="));
273        assert!(url_str.contains("state=test-state"));
274    }
275
276    #[test]
277    fn generate_random_state_is_valid_uuid() {
278        let state = generate_random_state();
279        assert!(!state.is_empty());
280        assert!(uuid::Uuid::parse_str(&state).is_ok());
281    }
282
283    #[test]
284    fn oauth_constants_are_valid() {
285        assert!(AUTHORIZE_URL.starts_with("https://"));
286        assert!(TOKEN_URL.starts_with("https://"));
287        assert!(REDIRECT_URI.starts_with("http://localhost:"));
288        assert!(SCOPE.contains("openid"));
289    }
290
291    #[tokio::test]
292    async fn codex_token_manager_refreshes_expired_credential() {
293        let new_access_token = test_jwt_for_account("acct_new");
294        let endpoint = FakeTokenEndpoint::start(TokenEndpointResponse::success(&new_access_token, None)).await;
295        let store = Arc::new(
296            FakeOAuthCredentialStore::new()
297                .with_credential("codex", expired_credential("old-access", Some("refresh-old"))),
298        );
299
300        let manager = CodexTokenManager::new_with_token_url(store.clone(), "codex", endpoint.url.clone());
301        let (access_token, account_id) = manager.get_valid_token().await.unwrap();
302        let request = endpoint.request.await.expect("token endpoint request");
303        let saved = store.load_credential("codex").await.unwrap().unwrap();
304
305        assert_eq!(access_token, new_access_token);
306        assert_eq!(account_id, "acct_new");
307        assert_eq!(saved.access_token, new_access_token);
308        assert_eq!(saved.refresh_token.as_deref(), Some("refresh-old"));
309        assert_eq!(request.method, Method::POST);
310        assert_eq!(request.path, "/oauth/token");
311        assert_eq!(request.form.get("grant_type").map(String::as_str), Some("refresh_token"));
312        assert_eq!(request.form.get("refresh_token").map(String::as_str), Some("refresh-old"));
313        assert_eq!(request.form.get("client_id").map(String::as_str), Some(CLIENT_ID));
314        assert!(request.headers.get("accept").is_some());
315    }
316
317    #[tokio::test]
318    async fn codex_token_manager_saves_rotated_refresh_token() {
319        let new_access_token = test_jwt_for_account("acct_new");
320        let endpoint =
321            FakeTokenEndpoint::start(TokenEndpointResponse::success(&new_access_token, Some("refresh-new"))).await;
322        let store = Arc::new(
323            FakeOAuthCredentialStore::new()
324                .with_credential("codex", expired_credential("old-access", Some("refresh-old"))),
325        );
326        let manager = CodexTokenManager::new_with_token_url(store.clone(), "codex", endpoint.url.clone());
327        manager.get_valid_token().await.unwrap();
328        let saved = store.load_credential("codex").await.unwrap().unwrap();
329
330        assert_eq!(saved.access_token, new_access_token);
331        assert_eq!(saved.refresh_token.as_deref(), Some("refresh-new"));
332    }
333
334    #[tokio::test]
335    async fn codex_token_manager_uses_unexpired_credential_without_refresh() {
336        let access_token = test_jwt_for_account("acct_existing");
337        let store = Arc::new(FakeOAuthCredentialStore::new().with_credential(
338            "codex",
339            OAuthCredential {
340                client_id: CLIENT_ID.to_string(),
341                access_token: access_token.clone(),
342                refresh_token: Some("refresh-old".to_string()),
343                expires_at: Some(u64::MAX),
344                granted_scopes: Vec::new(),
345            },
346        ));
347
348        let manager = CodexTokenManager::new_with_token_url(
349            store,
350            "codex",
351            TokenUrl::new("http://127.0.0.1:9/oauth/token".to_string()).unwrap(),
352        );
353
354        let (returned_token, account_id) = manager.get_valid_token().await.unwrap();
355        assert_eq!(returned_token, access_token);
356        assert_eq!(account_id, "acct_existing");
357    }
358
359    #[tokio::test]
360    async fn codex_token_manager_errors_when_credential_is_missing() {
361        let store = Arc::new(FakeOAuthCredentialStore::new());
362        let manager = CodexTokenManager::new_with_token_url(
363            store,
364            "codex",
365            TokenUrl::new("http://127.0.0.1:9/oauth/token".to_string()).unwrap(),
366        );
367
368        let error = manager.get_valid_token().await.unwrap_err();
369        assert!(error.to_string().contains("No Codex OAuth credentials found"));
370        assert!(error.to_string().contains("select a codex model"));
371    }
372
373    #[tokio::test]
374    async fn codex_token_manager_errors_when_expired_without_refresh_token() {
375        let original = expired_credential("old-access", None);
376        let store = Arc::new(FakeOAuthCredentialStore::new().with_credential("codex", original.clone()));
377        let manager = CodexTokenManager::new_with_token_url(
378            store.clone(),
379            "codex",
380            TokenUrl::new("http://127.0.0.1:9/oauth/token".to_string()).unwrap(),
381        );
382
383        let error = manager.get_valid_token().await.unwrap_err();
384        let saved = store.load_credential("codex").await.unwrap().unwrap();
385
386        assert!(error.to_string().contains("Re-run OAuth login"));
387        assert_eq!(saved.access_token, original.access_token);
388        assert_eq!(saved.refresh_token, original.refresh_token);
389    }
390
391    #[tokio::test]
392    async fn codex_token_manager_does_not_overwrite_credential_when_refresh_fails() {
393        let endpoint = FakeTokenEndpoint::start(TokenEndpointResponse::failure()).await;
394        let original = expired_credential("old-access", Some("refresh-old"));
395        let store = Arc::new(FakeOAuthCredentialStore::new().with_credential("codex", original.clone()));
396        let manager = CodexTokenManager::new_with_token_url(store.clone(), "codex", endpoint.url.clone());
397
398        let result = manager.get_valid_token().await;
399        let saved = store.load_credential("codex").await.unwrap().unwrap();
400
401        assert!(result.is_err());
402        assert_eq!(saved.access_token, original.access_token);
403        assert_eq!(saved.refresh_token, original.refresh_token);
404    }
405
406    struct FakeTokenEndpoint {
407        url: TokenUrl,
408        request: oneshot::Receiver<CapturedTokenRequest>,
409    }
410
411    struct CapturedTokenRequest {
412        method: Method,
413        path: String,
414        headers: HeaderMap,
415        form: HashMap<String, String>,
416    }
417
418    #[derive(Clone)]
419    struct FakeTokenState {
420        response: TokenEndpointResponse,
421        request_tx: Arc<TokioMutex<Option<oneshot::Sender<CapturedTokenRequest>>>>,
422        shutdown_tx: Arc<TokioMutex<Option<oneshot::Sender<()>>>>,
423    }
424
425    #[derive(Clone)]
426    struct TokenEndpointResponse {
427        status: StatusCode,
428        body: serde_json::Value,
429    }
430
431    impl TokenEndpointResponse {
432        fn success(access_token: &str, refresh_token: Option<&str>) -> Self {
433            let mut body = serde_json::json!({
434                "access_token": access_token,
435                "token_type": "Bearer",
436                "expires_in": 3600
437            });
438            if let Some(refresh_token) = refresh_token {
439                body["refresh_token"] = serde_json::Value::String(refresh_token.to_string());
440            }
441            Self { status: StatusCode::OK, body }
442        }
443
444        fn failure() -> Self {
445            Self { status: StatusCode::BAD_REQUEST, body: serde_json::json!({ "error": "invalid_grant" }) }
446        }
447    }
448
449    impl FakeTokenEndpoint {
450        async fn start(response: TokenEndpointResponse) -> Self {
451            let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind fake token endpoint");
452            let url = TokenUrl::new(format!(
453                "http://{}/oauth/token",
454                listener.local_addr().expect("fake token endpoint address")
455            ))
456            .expect("fake token endpoint URL is valid");
457            let (request_tx, request) = oneshot::channel();
458            let (shutdown_tx, shutdown) = oneshot::channel();
459            let state = FakeTokenState {
460                response,
461                request_tx: Arc::new(TokioMutex::new(Some(request_tx))),
462                shutdown_tx: Arc::new(TokioMutex::new(Some(shutdown_tx))),
463            };
464            let app = Router::new().route("/oauth/token", post(capture_token_request)).with_state(state);
465            tokio::spawn(async move {
466                axum::serve(listener, app)
467                    .with_graceful_shutdown(async {
468                        let _ = shutdown.await;
469                    })
470                    .await
471                    .expect("serve fake token endpoint");
472            });
473            Self { url, request }
474        }
475    }
476
477    async fn capture_token_request(State(state): State<FakeTokenState>, request: Request<Body>) -> impl IntoResponse {
478        let (parts, body) = request.into_parts();
479        let body = to_bytes(body, usize::MAX).await.expect("read token request body");
480        let form = url::form_urlencoded::parse(&body).into_owned().collect();
481        if let Some(tx) = state.request_tx.lock().await.take() {
482            let _ = tx.send(CapturedTokenRequest {
483                method: parts.method,
484                path: parts.uri.path().to_string(),
485                headers: parts.headers,
486                form,
487            });
488        }
489        if let Some(tx) = state.shutdown_tx.lock().await.take() {
490            let _ = tx.send(());
491        }
492        (state.response.status, axum::Json(state.response.body))
493    }
494
495    fn expired_credential(access_token: &str, refresh_token: Option<&str>) -> OAuthCredential {
496        OAuthCredential {
497            client_id: CLIENT_ID.to_string(),
498            access_token: access_token.to_string(),
499            refresh_token: refresh_token.map(str::to_string),
500            expires_at: Some(0),
501            granted_scopes: Vec::new(),
502        }
503    }
504
505    fn test_jwt_for_account(account_id: &str) -> String {
506        make_test_jwt(&serde_json::json!({
507            "https://api.openai.com/auth": {
508                "chatgpt_account_id": account_id
509            }
510        }))
511    }
512}