1use crate::LlmError;
2use aether_auth::{BrowserOAuthHandler, OAuthCredential, OAuthCredentialStorage, OAuthError, OAuthHandler};
3use base64::Engine;
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use oauth2::TokenResponse;
6use oauth2::basic::BasicClient;
7use oauth2::reqwest::redirect::Policy;
8use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use url::Url;
12
13const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
14const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
15const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
16const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
17const SCOPE: &str = "openid profile email offline_access";
18
19pub async fn perform_codex_oauth_flow(store: &dyn OAuthCredentialStorage) -> Result<(), LlmError> {
22 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
23 let state = generate_random_state();
24
25 let auth_url = Url::parse_with_params(
26 AUTHORIZE_URL,
27 &[
28 ("response_type", "code"),
29 ("client_id", CLIENT_ID),
30 ("redirect_uri", REDIRECT_URI),
31 ("scope", SCOPE),
32 ("code_challenge", pkce_challenge.as_str()),
33 ("code_challenge_method", "S256"),
34 ("state", &state),
35 ("id_token_add_organizations", "true"),
36 ("codex_cli_simplified_flow", "true"),
37 ("originator", "codex_cli_rs"),
38 ],
39 )
40 .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
41
42 let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
45 let callback = handler.authorize(auth_url.as_str()).await?;
46
47 if callback.state != state {
48 return Err(OAuthError::StateMismatch.into());
49 }
50
51 let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
52 .set_auth_uri(
53 AuthUrl::new(AUTHORIZE_URL.to_string())
54 .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
55 )
56 .set_token_uri(
57 TokenUrl::new(TOKEN_URL.to_string())
58 .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
59 )
60 .set_redirect_uri(
61 RedirectUrl::new(REDIRECT_URI.to_string())
62 .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
63 );
64
65 let http_client = oauth2::reqwest::Client::builder()
66 .redirect(Policy::none())
67 .build()
68 .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))?;
69
70 let token_response = oauth_client
71 .exchange_code(AuthorizationCode::new(callback.code))
72 .set_pkce_verifier(pkce_verifier)
73 .request_async(&http_client)
74 .await
75 .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
76
77 let expires_at = token_response.expires_in().map(|duration| {
78 let now_ms = u64::try_from(
79 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
80 )
81 .unwrap_or(u64::MAX);
82 let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
83 now_ms.saturating_add(duration_ms)
84 });
85
86 let credential = OAuthCredential {
87 client_id: CLIENT_ID.to_string(),
88 access_token: token_response.access_token().secret().clone(),
89 refresh_token: token_response.refresh_token().map(|t| t.secret().clone()),
90 expires_at,
91 };
92
93 store.save_credential(super::PROVIDER_ID, credential).await?;
94
95 Ok(())
96}
97
98struct CachedToken {
100 access_token: String,
101 account_id: String,
102 expires_at: Option<u64>,
104}
105
106impl CachedToken {
107 fn is_expired(&self) -> bool {
108 let Some(expires_at) = self.expires_at else {
109 return false;
110 };
111 let now_ms = u64::try_from(
112 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
113 )
114 .unwrap_or(u64::MAX);
115 now_ms >= expires_at
116 }
117}
118
119pub struct CodexTokenManager {
124 store: Arc<dyn OAuthCredentialStorage>,
125 credential_key: String,
126 cached: Mutex<Option<CachedToken>>,
127}
128
129impl CodexTokenManager {
130 pub fn new(store: Arc<dyn OAuthCredentialStorage>, credential_key: &str) -> Self {
131 Self { store, credential_key: credential_key.to_string(), cached: Mutex::new(None) }
132 }
133
134 pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
139 {
141 let guard = self.cached.lock().await;
142 if let Some(cached) = guard.as_ref()
143 && !cached.is_expired()
144 {
145 return Ok((cached.access_token.clone(), cached.account_id.clone()));
146 }
147 }
148
149 let credential = self
150 .store
151 .load_credential(&self.credential_key)
152 .await
153 .map_err(|e| OAuthError::NoCredentials(e.to_string()))?
154 .ok_or_else(|| {
155 OAuthError::NoCredentials(
156 "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
157 .to_string(),
158 )
159 })?;
160
161 let account_id = extract_account_id(&credential.access_token)?;
162
163 let cached = CachedToken {
164 access_token: credential.access_token.clone(),
165 account_id: account_id.clone(),
166 expires_at: credential.expires_at,
167 };
168 *self.cached.lock().await = Some(cached);
169
170 Ok((credential.access_token, account_id))
171 }
172
173 pub async fn clear_cache(&self) {
175 *self.cached.lock().await = None;
176 }
177}
178
179pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
184 let parts: Vec<&str> = access_token.split('.').collect();
185 if parts.len() != 3 {
186 return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
187 }
188
189 let decoded = URL_SAFE_NO_PAD
190 .decode(parts[1])
191 .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
192
193 let payload: serde_json::Value = serde_json::from_slice(&decoded)
194 .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
195
196 let account_id = payload
197 .get("https://api.openai.com/auth")
198 .and_then(|auth| auth.get("chatgpt_account_id"))
199 .and_then(|v| v.as_str())
200 .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
201
202 Ok(account_id.to_string())
203}
204
205fn generate_random_state() -> String {
206 uuid::Uuid::new_v4().to_string()
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 fn make_test_jwt(payload: &serde_json::Value) -> String {
215 let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
216 let payload_json = serde_json::to_string(payload).unwrap();
217 let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
218 format!("{header}.{payload_b64url}.fake_signature")
219 }
220
221 #[test]
222 fn extract_account_id_from_valid_jwt() {
223 let payload = serde_json::json!({
224 "sub": "user_123",
225 "https://api.openai.com/auth": {
226 "chatgpt_account_id": "acct_abc123"
227 }
228 });
229
230 let jwt = make_test_jwt(&payload);
231 let account_id = extract_account_id(&jwt).unwrap();
232 assert_eq!(account_id, "acct_abc123");
233 }
234
235 #[test]
236 fn extract_account_id_missing_claim() {
237 let payload = serde_json::json!({
238 "sub": "user_123"
239 });
240
241 let jwt = make_test_jwt(&payload);
242 let result = extract_account_id(&jwt);
243 assert!(result.is_err());
244 assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
245 }
246
247 #[test]
248 fn extract_account_id_invalid_jwt_format() {
249 let result = extract_account_id("not.a.valid.jwt.too.many.parts");
250 assert!(result.is_err());
251
252 let result = extract_account_id("toofewparts");
253 assert!(result.is_err());
254 }
255
256 #[test]
257 fn extract_account_id_invalid_base64() {
258 let result = extract_account_id("header.!!!invalid!!!.signature");
259 assert!(result.is_err());
260 }
261
262 #[test]
263 fn auth_url_is_well_formed() {
264 let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
265 let state = "test-state";
266
267 let auth_url = Url::parse_with_params(
268 AUTHORIZE_URL,
269 &[
270 ("response_type", "code"),
271 ("client_id", CLIENT_ID),
272 ("redirect_uri", REDIRECT_URI),
273 ("scope", SCOPE),
274 ("code_challenge", pkce_challenge.as_str()),
275 ("code_challenge_method", "S256"),
276 ("state", state),
277 ("id_token_add_organizations", "true"),
278 ("codex_cli_simplified_flow", "true"),
279 ("originator", "codex_cli_rs"),
280 ],
281 )
282 .unwrap();
283
284 let url_str = auth_url.as_str();
285 assert!(url_str.starts_with(AUTHORIZE_URL));
286 assert!(url_str.contains("client_id="));
287 assert!(url_str.contains("redirect_uri="));
288 assert!(url_str.contains("scope="));
289 assert!(url_str.contains("code_challenge="));
290 assert!(url_str.contains("state=test-state"));
291 }
292
293 #[test]
294 fn generate_random_state_is_valid_uuid() {
295 let state = generate_random_state();
296 assert!(!state.is_empty());
297 assert!(uuid::Uuid::parse_str(&state).is_ok());
298 }
299
300 #[test]
301 fn oauth_constants_are_valid() {
302 assert!(AUTHORIZE_URL.starts_with("https://"));
303 assert!(TOKEN_URL.starts_with("https://"));
304 assert!(REDIRECT_URI.starts_with("http://localhost:"));
305 assert!(SCOPE.contains("openid"));
306 }
307
308 #[test]
309 fn cached_token_not_expired_when_no_expiry() {
310 let token = CachedToken { access_token: "tok".to_string(), account_id: "acct".to_string(), expires_at: None };
311 assert!(!token.is_expired());
312 }
313
314 #[test]
315 fn cached_token_not_expired_when_future() {
316 let future_ms =
317 u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis())
318 .unwrap()
319 + 3_600_000; let token = CachedToken {
321 access_token: "tok".to_string(),
322 account_id: "acct".to_string(),
323 expires_at: Some(future_ms),
324 };
325 assert!(!token.is_expired());
326 }
327
328 #[test]
329 fn cached_token_expired_when_past() {
330 let token = CachedToken {
331 access_token: "tok".to_string(),
332 account_id: "acct".to_string(),
333 expires_at: Some(1000), };
335 assert!(token.is_expired());
336 }
337}