1use crate::LlmError;
2use aether_auth::{
3 BrowserOAuthHandler, OAuthCredential, OAuthCredentialStorage, OAuthError, OAuthHandler, oauth_http_client,
4};
5use base64::Engine;
6use base64::engine::general_purpose::URL_SAFE_NO_PAD;
7use oauth2::basic::BasicClient;
8use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use url::Url;
12
13const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
14const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
15const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
16const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
17const SCOPE: &str = "openid profile email offline_access";
18
19pub async fn perform_codex_oauth_flow(store: &dyn OAuthCredentialStorage) -> Result<(), LlmError> {
22 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
23 let state = generate_random_state();
24
25 let auth_url = Url::parse_with_params(
26 AUTHORIZE_URL,
27 &[
28 ("response_type", "code"),
29 ("client_id", CLIENT_ID),
30 ("redirect_uri", REDIRECT_URI),
31 ("scope", SCOPE),
32 ("code_challenge", pkce_challenge.as_str()),
33 ("code_challenge_method", "S256"),
34 ("state", &state),
35 ("id_token_add_organizations", "true"),
36 ("codex_cli_simplified_flow", "true"),
37 ("originator", "codex_cli_rs"),
38 ],
39 )
40 .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
41
42 let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
45 let callback = handler.authorize(auth_url.as_str()).await?;
46
47 if callback.state != state {
48 return Err(OAuthError::StateMismatch.into());
49 }
50
51 let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
52 .set_auth_uri(
53 AuthUrl::new(AUTHORIZE_URL.to_string())
54 .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
55 )
56 .set_token_uri(
57 TokenUrl::new(TOKEN_URL.to_string())
58 .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
59 )
60 .set_redirect_uri(
61 RedirectUrl::new(REDIRECT_URI.to_string())
62 .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
63 );
64
65 let http_client = oauth_http_client()?;
66
67 let token_response = oauth_client
68 .exchange_code(AuthorizationCode::new(callback.code))
69 .set_pkce_verifier(pkce_verifier)
70 .request_async(&http_client)
71 .await
72 .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
73
74 let credential = OAuthCredential::from_token_response(CLIENT_ID.to_string(), &token_response);
75 store.save_credential(super::PROVIDER_ID, credential).await?;
76
77 Ok(())
78}
79
80struct CachedToken {
82 credential: OAuthCredential,
83 account_id: String,
84}
85
86pub struct CodexTokenManager {
91 store: Arc<dyn OAuthCredentialStorage>,
92 credential_key: String,
93 token_url: TokenUrl,
94 cached: Mutex<Option<CachedToken>>,
95}
96
97impl CodexTokenManager {
98 pub fn new(store: Arc<dyn OAuthCredentialStorage>, credential_key: &str) -> Self {
99 Self::new_with_token_url(
100 store,
101 credential_key,
102 TokenUrl::new(TOKEN_URL.to_string()).expect("hardcoded Codex token URL is valid"),
103 )
104 }
105
106 fn new_with_token_url(store: Arc<dyn OAuthCredentialStorage>, credential_key: &str, token_url: TokenUrl) -> Self {
107 Self { store, credential_key: credential_key.to_string(), token_url, cached: Mutex::new(None) }
108 }
109
110 pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
115 let mut cache = self.cached.lock().await;
116 if let Some(cached) = cache.as_ref()
117 && !cached.credential.needs_refresh()
118 {
119 return Ok((cached.credential.access_token.clone(), cached.account_id.clone()));
120 }
121
122 let credential = self.load_or_refresh().await?;
123 let account_id = extract_account_id(&credential.access_token)?;
124 let access_token = credential.access_token.clone();
125 *cache = Some(CachedToken { credential, account_id: account_id.clone() });
126 Ok((access_token, account_id))
127 }
128
129 async fn load_or_refresh(&self) -> Result<OAuthCredential, LlmError> {
130 let stored = self.store.load_credential(&self.credential_key).await?.ok_or_else(|| {
131 OAuthError::NoCredentials(
132 "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
133 .to_string(),
134 )
135 })?;
136
137 if !stored.needs_refresh() {
138 return Ok(stored);
139 }
140
141 let refreshed = stored.refresh(&self.token_url).await?;
142 self.store.save_credential(&self.credential_key, refreshed.clone()).await?;
143 Ok(refreshed)
144 }
145
146 pub async fn clear_cache(&self) {
148 *self.cached.lock().await = None;
149 }
150}
151
152pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
157 let parts: Vec<&str> = access_token.split('.').collect();
158 if parts.len() != 3 {
159 return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
160 }
161
162 let decoded = URL_SAFE_NO_PAD
163 .decode(parts[1])
164 .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
165
166 let payload: serde_json::Value = serde_json::from_slice(&decoded)
167 .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
168
169 let account_id = payload
170 .get("https://api.openai.com/auth")
171 .and_then(|auth| auth.get("chatgpt_account_id"))
172 .and_then(|v| v.as_str())
173 .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
174
175 Ok(account_id.to_string())
176}
177
178fn generate_random_state() -> String {
179 uuid::Uuid::new_v4().to_string()
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use aether_auth::{FakeOAuthCredentialStore, OAuthCredential};
186 use axum::Router;
187 use axum::body::{Body, to_bytes};
188 use axum::extract::State;
189 use axum::http::{HeaderMap, Method, Request, StatusCode};
190 use axum::response::IntoResponse;
191 use axum::routing::post;
192 use std::collections::HashMap;
193 use tokio::net::TcpListener;
194 use tokio::sync::{Mutex as TokioMutex, oneshot};
195
196 fn make_test_jwt(payload: &serde_json::Value) -> String {
198 let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
199 let payload_json = serde_json::to_string(payload).unwrap();
200 let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
201 format!("{header}.{payload_b64url}.fake_signature")
202 }
203
204 #[test]
205 fn extract_account_id_from_valid_jwt() {
206 let payload = serde_json::json!({
207 "sub": "user_123",
208 "https://api.openai.com/auth": {
209 "chatgpt_account_id": "acct_abc123"
210 }
211 });
212
213 let jwt = make_test_jwt(&payload);
214 let account_id = extract_account_id(&jwt).unwrap();
215 assert_eq!(account_id, "acct_abc123");
216 }
217
218 #[test]
219 fn extract_account_id_missing_claim() {
220 let payload = serde_json::json!({
221 "sub": "user_123"
222 });
223
224 let jwt = make_test_jwt(&payload);
225 let result = extract_account_id(&jwt);
226 assert!(result.is_err());
227 assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
228 }
229
230 #[test]
231 fn extract_account_id_invalid_jwt_format() {
232 let result = extract_account_id("not.a.valid.jwt.too.many.parts");
233 assert!(result.is_err());
234
235 let result = extract_account_id("toofewparts");
236 assert!(result.is_err());
237 }
238
239 #[test]
240 fn extract_account_id_invalid_base64() {
241 let result = extract_account_id("header.!!!invalid!!!.signature");
242 assert!(result.is_err());
243 }
244
245 #[test]
246 fn auth_url_is_well_formed() {
247 let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
248 let state = "test-state";
249
250 let auth_url = Url::parse_with_params(
251 AUTHORIZE_URL,
252 &[
253 ("response_type", "code"),
254 ("client_id", CLIENT_ID),
255 ("redirect_uri", REDIRECT_URI),
256 ("scope", SCOPE),
257 ("code_challenge", pkce_challenge.as_str()),
258 ("code_challenge_method", "S256"),
259 ("state", state),
260 ("id_token_add_organizations", "true"),
261 ("codex_cli_simplified_flow", "true"),
262 ("originator", "codex_cli_rs"),
263 ],
264 )
265 .unwrap();
266
267 let url_str = auth_url.as_str();
268 assert!(url_str.starts_with(AUTHORIZE_URL));
269 assert!(url_str.contains("client_id="));
270 assert!(url_str.contains("redirect_uri="));
271 assert!(url_str.contains("scope="));
272 assert!(url_str.contains("code_challenge="));
273 assert!(url_str.contains("state=test-state"));
274 }
275
276 #[test]
277 fn generate_random_state_is_valid_uuid() {
278 let state = generate_random_state();
279 assert!(!state.is_empty());
280 assert!(uuid::Uuid::parse_str(&state).is_ok());
281 }
282
283 #[test]
284 fn oauth_constants_are_valid() {
285 assert!(AUTHORIZE_URL.starts_with("https://"));
286 assert!(TOKEN_URL.starts_with("https://"));
287 assert!(REDIRECT_URI.starts_with("http://localhost:"));
288 assert!(SCOPE.contains("openid"));
289 }
290
291 #[tokio::test]
292 async fn codex_token_manager_refreshes_expired_credential() {
293 let new_access_token = test_jwt_for_account("acct_new");
294 let endpoint = FakeTokenEndpoint::start(TokenEndpointResponse::success(&new_access_token, None)).await;
295 let store = Arc::new(
296 FakeOAuthCredentialStore::new()
297 .with_credential("codex", expired_credential("old-access", Some("refresh-old"))),
298 );
299
300 let manager = CodexTokenManager::new_with_token_url(store.clone(), "codex", endpoint.url.clone());
301 let (access_token, account_id) = manager.get_valid_token().await.unwrap();
302 let request = endpoint.request.await.expect("token endpoint request");
303 let saved = store.load_credential("codex").await.unwrap().unwrap();
304
305 assert_eq!(access_token, new_access_token);
306 assert_eq!(account_id, "acct_new");
307 assert_eq!(saved.access_token, new_access_token);
308 assert_eq!(saved.refresh_token.as_deref(), Some("refresh-old"));
309 assert_eq!(request.method, Method::POST);
310 assert_eq!(request.path, "/oauth/token");
311 assert_eq!(request.form.get("grant_type").map(String::as_str), Some("refresh_token"));
312 assert_eq!(request.form.get("refresh_token").map(String::as_str), Some("refresh-old"));
313 assert_eq!(request.form.get("client_id").map(String::as_str), Some(CLIENT_ID));
314 assert!(request.headers.get("accept").is_some());
315 }
316
317 #[tokio::test]
318 async fn codex_token_manager_saves_rotated_refresh_token() {
319 let new_access_token = test_jwt_for_account("acct_new");
320 let endpoint =
321 FakeTokenEndpoint::start(TokenEndpointResponse::success(&new_access_token, Some("refresh-new"))).await;
322 let store = Arc::new(
323 FakeOAuthCredentialStore::new()
324 .with_credential("codex", expired_credential("old-access", Some("refresh-old"))),
325 );
326 let manager = CodexTokenManager::new_with_token_url(store.clone(), "codex", endpoint.url.clone());
327 manager.get_valid_token().await.unwrap();
328 let saved = store.load_credential("codex").await.unwrap().unwrap();
329
330 assert_eq!(saved.access_token, new_access_token);
331 assert_eq!(saved.refresh_token.as_deref(), Some("refresh-new"));
332 }
333
334 #[tokio::test]
335 async fn codex_token_manager_uses_unexpired_credential_without_refresh() {
336 let access_token = test_jwt_for_account("acct_existing");
337 let store = Arc::new(FakeOAuthCredentialStore::new().with_credential(
338 "codex",
339 OAuthCredential {
340 client_id: CLIENT_ID.to_string(),
341 access_token: access_token.clone(),
342 refresh_token: Some("refresh-old".to_string()),
343 expires_at: Some(u64::MAX),
344 granted_scopes: Vec::new(),
345 },
346 ));
347
348 let manager = CodexTokenManager::new_with_token_url(
349 store,
350 "codex",
351 TokenUrl::new("http://127.0.0.1:9/oauth/token".to_string()).unwrap(),
352 );
353
354 let (returned_token, account_id) = manager.get_valid_token().await.unwrap();
355 assert_eq!(returned_token, access_token);
356 assert_eq!(account_id, "acct_existing");
357 }
358
359 #[tokio::test]
360 async fn codex_token_manager_errors_when_credential_is_missing() {
361 let store = Arc::new(FakeOAuthCredentialStore::new());
362 let manager = CodexTokenManager::new_with_token_url(
363 store,
364 "codex",
365 TokenUrl::new("http://127.0.0.1:9/oauth/token".to_string()).unwrap(),
366 );
367
368 let error = manager.get_valid_token().await.unwrap_err();
369 assert!(error.to_string().contains("No Codex OAuth credentials found"));
370 assert!(error.to_string().contains("select a codex model"));
371 }
372
373 #[tokio::test]
374 async fn codex_token_manager_errors_when_expired_without_refresh_token() {
375 let original = expired_credential("old-access", None);
376 let store = Arc::new(FakeOAuthCredentialStore::new().with_credential("codex", original.clone()));
377 let manager = CodexTokenManager::new_with_token_url(
378 store.clone(),
379 "codex",
380 TokenUrl::new("http://127.0.0.1:9/oauth/token".to_string()).unwrap(),
381 );
382
383 let error = manager.get_valid_token().await.unwrap_err();
384 let saved = store.load_credential("codex").await.unwrap().unwrap();
385
386 assert!(error.to_string().contains("Re-run OAuth login"));
387 assert_eq!(saved.access_token, original.access_token);
388 assert_eq!(saved.refresh_token, original.refresh_token);
389 }
390
391 #[tokio::test]
392 async fn codex_token_manager_does_not_overwrite_credential_when_refresh_fails() {
393 let endpoint = FakeTokenEndpoint::start(TokenEndpointResponse::failure()).await;
394 let original = expired_credential("old-access", Some("refresh-old"));
395 let store = Arc::new(FakeOAuthCredentialStore::new().with_credential("codex", original.clone()));
396 let manager = CodexTokenManager::new_with_token_url(store.clone(), "codex", endpoint.url.clone());
397
398 let result = manager.get_valid_token().await;
399 let saved = store.load_credential("codex").await.unwrap().unwrap();
400
401 assert!(result.is_err());
402 assert_eq!(saved.access_token, original.access_token);
403 assert_eq!(saved.refresh_token, original.refresh_token);
404 }
405
406 struct FakeTokenEndpoint {
407 url: TokenUrl,
408 request: oneshot::Receiver<CapturedTokenRequest>,
409 }
410
411 struct CapturedTokenRequest {
412 method: Method,
413 path: String,
414 headers: HeaderMap,
415 form: HashMap<String, String>,
416 }
417
418 #[derive(Clone)]
419 struct FakeTokenState {
420 response: TokenEndpointResponse,
421 request_tx: Arc<TokioMutex<Option<oneshot::Sender<CapturedTokenRequest>>>>,
422 shutdown_tx: Arc<TokioMutex<Option<oneshot::Sender<()>>>>,
423 }
424
425 #[derive(Clone)]
426 struct TokenEndpointResponse {
427 status: StatusCode,
428 body: serde_json::Value,
429 }
430
431 impl TokenEndpointResponse {
432 fn success(access_token: &str, refresh_token: Option<&str>) -> Self {
433 let mut body = serde_json::json!({
434 "access_token": access_token,
435 "token_type": "Bearer",
436 "expires_in": 3600
437 });
438 if let Some(refresh_token) = refresh_token {
439 body["refresh_token"] = serde_json::Value::String(refresh_token.to_string());
440 }
441 Self { status: StatusCode::OK, body }
442 }
443
444 fn failure() -> Self {
445 Self { status: StatusCode::BAD_REQUEST, body: serde_json::json!({ "error": "invalid_grant" }) }
446 }
447 }
448
449 impl FakeTokenEndpoint {
450 async fn start(response: TokenEndpointResponse) -> Self {
451 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind fake token endpoint");
452 let url = TokenUrl::new(format!(
453 "http://{}/oauth/token",
454 listener.local_addr().expect("fake token endpoint address")
455 ))
456 .expect("fake token endpoint URL is valid");
457 let (request_tx, request) = oneshot::channel();
458 let (shutdown_tx, shutdown) = oneshot::channel();
459 let state = FakeTokenState {
460 response,
461 request_tx: Arc::new(TokioMutex::new(Some(request_tx))),
462 shutdown_tx: Arc::new(TokioMutex::new(Some(shutdown_tx))),
463 };
464 let app = Router::new().route("/oauth/token", post(capture_token_request)).with_state(state);
465 tokio::spawn(async move {
466 axum::serve(listener, app)
467 .with_graceful_shutdown(async {
468 let _ = shutdown.await;
469 })
470 .await
471 .expect("serve fake token endpoint");
472 });
473 Self { url, request }
474 }
475 }
476
477 async fn capture_token_request(State(state): State<FakeTokenState>, request: Request<Body>) -> impl IntoResponse {
478 let (parts, body) = request.into_parts();
479 let body = to_bytes(body, usize::MAX).await.expect("read token request body");
480 let form = url::form_urlencoded::parse(&body).into_owned().collect();
481 if let Some(tx) = state.request_tx.lock().await.take() {
482 let _ = tx.send(CapturedTokenRequest {
483 method: parts.method,
484 path: parts.uri.path().to_string(),
485 headers: parts.headers,
486 form,
487 });
488 }
489 if let Some(tx) = state.shutdown_tx.lock().await.take() {
490 let _ = tx.send(());
491 }
492 (state.response.status, axum::Json(state.response.body))
493 }
494
495 fn expired_credential(access_token: &str, refresh_token: Option<&str>) -> OAuthCredential {
496 OAuthCredential {
497 client_id: CLIENT_ID.to_string(),
498 access_token: access_token.to_string(),
499 refresh_token: refresh_token.map(str::to_string),
500 expires_at: Some(0),
501 granted_scopes: Vec::new(),
502 }
503 }
504
505 fn test_jwt_for_account(account_id: &str) -> String {
506 make_test_jwt(&serde_json::json!({
507 "https://api.openai.com/auth": {
508 "chatgpt_account_id": account_id
509 }
510 }))
511 }
512}