1use crate::LlmError;
2use crate::oauth::BrowserOAuthHandler;
3use crate::oauth::OAuthError;
4use crate::oauth::OAuthHandler;
5use crate::oauth::credential_store::{OAuthCredential, OAuthCredentialStorage, OAuthCredentialStore};
6use base64::Engine;
7use base64::engine::general_purpose::URL_SAFE_NO_PAD;
8use oauth2::TokenResponse;
9use oauth2::basic::BasicClient;
10use oauth2::reqwest::redirect::Policy;
11use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
12use tokio::sync::Mutex;
13use url::Url;
14
15const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
16const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
17const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
18const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
19const SCOPE: &str = "openid profile email offline_access";
20
21pub async fn perform_codex_oauth_flow() -> Result<(), LlmError> {
25 let store = OAuthCredentialStore::with_platform_store()?;
26 perform_codex_oauth_flow_with_store(&store).await
27}
28
29pub async fn perform_codex_oauth_flow_with_store(store: &impl OAuthCredentialStorage) -> Result<(), LlmError> {
30 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
31 let state = generate_random_state();
32
33 let auth_url = Url::parse_with_params(
34 AUTHORIZE_URL,
35 &[
36 ("response_type", "code"),
37 ("client_id", CLIENT_ID),
38 ("redirect_uri", REDIRECT_URI),
39 ("scope", SCOPE),
40 ("code_challenge", pkce_challenge.as_str()),
41 ("code_challenge_method", "S256"),
42 ("state", &state),
43 ("id_token_add_organizations", "true"),
44 ("codex_cli_simplified_flow", "true"),
45 ("originator", "codex_cli_rs"),
46 ],
47 )
48 .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
49
50 let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
53 let callback = handler.authorize(auth_url.as_str()).await?;
54
55 if callback.state != state {
56 return Err(OAuthError::StateMismatch.into());
57 }
58
59 let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
60 .set_auth_uri(
61 AuthUrl::new(AUTHORIZE_URL.to_string())
62 .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
63 )
64 .set_token_uri(
65 TokenUrl::new(TOKEN_URL.to_string())
66 .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
67 )
68 .set_redirect_uri(
69 RedirectUrl::new(REDIRECT_URI.to_string())
70 .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
71 );
72
73 let http_client = oauth2::reqwest::Client::builder()
74 .redirect(Policy::none())
75 .build()
76 .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))?;
77
78 let token_response = oauth_client
79 .exchange_code(AuthorizationCode::new(callback.code))
80 .set_pkce_verifier(pkce_verifier)
81 .request_async(&http_client)
82 .await
83 .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
84
85 let expires_at = token_response.expires_in().map(|duration| {
86 let now_ms = u64::try_from(
87 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
88 )
89 .unwrap_or(u64::MAX);
90 let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
91 now_ms.saturating_add(duration_ms)
92 });
93
94 let credential = OAuthCredential {
95 client_id: CLIENT_ID.to_string(),
96 access_token: token_response.access_token().secret().clone(),
97 refresh_token: token_response.refresh_token().map(|t| t.secret().clone()),
98 expires_at,
99 };
100
101 store.save_credential(super::PROVIDER_ID, credential).await?;
102
103 Ok(())
104}
105
106struct CachedToken {
108 access_token: String,
109 account_id: String,
110 expires_at: Option<u64>,
112}
113
114impl CachedToken {
115 fn is_expired(&self) -> bool {
116 let Some(expires_at) = self.expires_at else {
117 return false;
118 };
119 let now_ms = u64::try_from(
120 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
121 )
122 .unwrap_or(u64::MAX);
123 now_ms >= expires_at
124 }
125}
126
127pub struct CodexTokenManager<T: OAuthCredentialStorage> {
132 store: T,
133 credential_key: String,
134 cached: Mutex<Option<CachedToken>>,
135}
136
137impl<T: OAuthCredentialStorage> CodexTokenManager<T> {
138 pub fn new(store: T, credential_key: &str) -> Self {
139 Self { store, credential_key: credential_key.to_string(), cached: Mutex::new(None) }
140 }
141
142 pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
147 {
149 let guard = self.cached.lock().await;
150 if let Some(cached) = guard.as_ref()
151 && !cached.is_expired()
152 {
153 return Ok((cached.access_token.clone(), cached.account_id.clone()));
154 }
155 }
156
157 let credential = self
158 .store
159 .load_credential(&self.credential_key)
160 .await
161 .map_err(|e| OAuthError::NoCredentials(e.to_string()))?
162 .ok_or_else(|| {
163 OAuthError::NoCredentials(
164 "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
165 .to_string(),
166 )
167 })?;
168
169 let account_id = extract_account_id(&credential.access_token)?;
170
171 let cached = CachedToken {
172 access_token: credential.access_token.clone(),
173 account_id: account_id.clone(),
174 expires_at: credential.expires_at,
175 };
176 *self.cached.lock().await = Some(cached);
177
178 Ok((credential.access_token, account_id))
179 }
180
181 pub async fn clear_cache(&self) {
183 *self.cached.lock().await = None;
184 }
185}
186
187pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
192 let parts: Vec<&str> = access_token.split('.').collect();
193 if parts.len() != 3 {
194 return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
195 }
196
197 let decoded = URL_SAFE_NO_PAD
198 .decode(parts[1])
199 .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
200
201 let payload: serde_json::Value = serde_json::from_slice(&decoded)
202 .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
203
204 let account_id = payload
205 .get("https://api.openai.com/auth")
206 .and_then(|auth| auth.get("chatgpt_account_id"))
207 .and_then(|v| v.as_str())
208 .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
209
210 Ok(account_id.to_string())
211}
212
213fn generate_random_state() -> String {
214 uuid::Uuid::new_v4().to_string()
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 fn make_test_jwt(payload: &serde_json::Value) -> String {
223 let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
224 let payload_json = serde_json::to_string(payload).unwrap();
225 let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
226 format!("{header}.{payload_b64url}.fake_signature")
227 }
228
229 #[test]
230 fn extract_account_id_from_valid_jwt() {
231 let payload = serde_json::json!({
232 "sub": "user_123",
233 "https://api.openai.com/auth": {
234 "chatgpt_account_id": "acct_abc123"
235 }
236 });
237
238 let jwt = make_test_jwt(&payload);
239 let account_id = extract_account_id(&jwt).unwrap();
240 assert_eq!(account_id, "acct_abc123");
241 }
242
243 #[test]
244 fn extract_account_id_missing_claim() {
245 let payload = serde_json::json!({
246 "sub": "user_123"
247 });
248
249 let jwt = make_test_jwt(&payload);
250 let result = extract_account_id(&jwt);
251 assert!(result.is_err());
252 assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
253 }
254
255 #[test]
256 fn extract_account_id_invalid_jwt_format() {
257 let result = extract_account_id("not.a.valid.jwt.too.many.parts");
258 assert!(result.is_err());
259
260 let result = extract_account_id("toofewparts");
261 assert!(result.is_err());
262 }
263
264 #[test]
265 fn extract_account_id_invalid_base64() {
266 let result = extract_account_id("header.!!!invalid!!!.signature");
267 assert!(result.is_err());
268 }
269
270 #[test]
271 fn auth_url_is_well_formed() {
272 let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
273 let state = "test-state";
274
275 let auth_url = Url::parse_with_params(
276 AUTHORIZE_URL,
277 &[
278 ("response_type", "code"),
279 ("client_id", CLIENT_ID),
280 ("redirect_uri", REDIRECT_URI),
281 ("scope", SCOPE),
282 ("code_challenge", pkce_challenge.as_str()),
283 ("code_challenge_method", "S256"),
284 ("state", state),
285 ("id_token_add_organizations", "true"),
286 ("codex_cli_simplified_flow", "true"),
287 ("originator", "codex_cli_rs"),
288 ],
289 )
290 .unwrap();
291
292 let url_str = auth_url.as_str();
293 assert!(url_str.starts_with(AUTHORIZE_URL));
294 assert!(url_str.contains("client_id="));
295 assert!(url_str.contains("redirect_uri="));
296 assert!(url_str.contains("scope="));
297 assert!(url_str.contains("code_challenge="));
298 assert!(url_str.contains("state=test-state"));
299 }
300
301 #[test]
302 fn generate_random_state_is_valid_uuid() {
303 let state = generate_random_state();
304 assert!(!state.is_empty());
305 assert!(uuid::Uuid::parse_str(&state).is_ok());
306 }
307
308 #[test]
309 fn oauth_constants_are_valid() {
310 assert!(AUTHORIZE_URL.starts_with("https://"));
311 assert!(TOKEN_URL.starts_with("https://"));
312 assert!(REDIRECT_URI.starts_with("http://localhost:"));
313 assert!(SCOPE.contains("openid"));
314 }
315
316 #[test]
317 fn cached_token_not_expired_when_no_expiry() {
318 let token = CachedToken { access_token: "tok".to_string(), account_id: "acct".to_string(), expires_at: None };
319 assert!(!token.is_expired());
320 }
321
322 #[test]
323 fn cached_token_not_expired_when_future() {
324 let future_ms =
325 u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis())
326 .unwrap()
327 + 3_600_000; let token = CachedToken {
329 access_token: "tok".to_string(),
330 account_id: "acct".to_string(),
331 expires_at: Some(future_ms),
332 };
333 assert!(!token.is_expired());
334 }
335
336 #[test]
337 fn cached_token_expired_when_past() {
338 let token = CachedToken {
339 access_token: "tok".to_string(),
340 account_id: "acct".to_string(),
341 expires_at: Some(1000), };
343 assert!(token.is_expired());
344 }
345}