1use crate::LlmError;
2use crate::oauth::BrowserOAuthHandler;
3use crate::oauth::OAuthError;
4use crate::oauth::OAuthHandler;
5use crate::oauth::credential_store::{OAuthCredential, OAuthCredentialStorage, OAuthCredentialStore};
6use base64::Engine;
7use base64::engine::general_purpose::URL_SAFE_NO_PAD;
8use oauth2::TokenResponse;
9use oauth2::basic::BasicClient;
10use oauth2::reqwest::redirect::Policy;
11use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
12use tokio::sync::Mutex;
13use url::Url;
14
15const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
16const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
17const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
18const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
19const SCOPE: &str = "openid profile email offline_access";
20
21pub async fn perform_codex_oauth_flow() -> Result<(), LlmError> {
25 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
26 let state = generate_random_state();
27
28 let auth_url = Url::parse_with_params(
29 AUTHORIZE_URL,
30 &[
31 ("response_type", "code"),
32 ("client_id", CLIENT_ID),
33 ("redirect_uri", REDIRECT_URI),
34 ("scope", SCOPE),
35 ("code_challenge", pkce_challenge.as_str()),
36 ("code_challenge_method", "S256"),
37 ("state", &state),
38 ("id_token_add_organizations", "true"),
39 ("codex_cli_simplified_flow", "true"),
40 ("originator", "codex_cli_rs"),
41 ],
42 )
43 .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
44
45 let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
48 let callback = handler.authorize(auth_url.as_str()).await?;
49
50 if callback.state != state {
51 return Err(OAuthError::StateMismatch.into());
52 }
53
54 let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
55 .set_auth_uri(
56 AuthUrl::new(AUTHORIZE_URL.to_string())
57 .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
58 )
59 .set_token_uri(
60 TokenUrl::new(TOKEN_URL.to_string())
61 .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
62 )
63 .set_redirect_uri(
64 RedirectUrl::new(REDIRECT_URI.to_string())
65 .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
66 );
67
68 let http_client = oauth2::reqwest::Client::builder()
69 .redirect(Policy::none())
70 .build()
71 .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))?;
72
73 let token_response = oauth_client
74 .exchange_code(AuthorizationCode::new(callback.code))
75 .set_pkce_verifier(pkce_verifier)
76 .request_async(&http_client)
77 .await
78 .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
79
80 let expires_at = token_response.expires_in().map(|duration| {
81 let now_ms = u64::try_from(
82 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
83 )
84 .unwrap_or(u64::MAX);
85 let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
86 now_ms.saturating_add(duration_ms)
87 });
88
89 let credential = OAuthCredential {
90 client_id: CLIENT_ID.to_string(),
91 access_token: token_response.access_token().secret().clone(),
92 refresh_token: token_response.refresh_token().map(|t| t.secret().clone()),
93 expires_at,
94 };
95
96 let store = OAuthCredentialStore::new(super::PROVIDER_ID);
97 store.save_credential(credential).await.map_err(|e| OAuthError::CredentialStore(e.to_string()))?;
98
99 Ok(())
100}
101
102struct CachedToken {
104 access_token: String,
105 account_id: String,
106 expires_at: Option<u64>,
108}
109
110impl CachedToken {
111 fn is_expired(&self) -> bool {
112 let Some(expires_at) = self.expires_at else {
113 return false;
114 };
115 let now_ms = u64::try_from(
116 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
117 )
118 .unwrap_or(u64::MAX);
119 now_ms >= expires_at
120 }
121}
122
123pub struct CodexTokenManager<T: OAuthCredentialStorage> {
128 store: T,
129 server_id: String,
130 cached: Mutex<Option<CachedToken>>,
131}
132
133impl<T: OAuthCredentialStorage> CodexTokenManager<T> {
134 pub fn new(store: T, server_id: &str) -> Self {
135 Self { store, server_id: server_id.to_string(), cached: Mutex::new(None) }
136 }
137
138 pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
143 {
145 let guard = self.cached.lock().await;
146 if let Some(cached) = guard.as_ref()
147 && !cached.is_expired()
148 {
149 return Ok((cached.access_token.clone(), cached.account_id.clone()));
150 }
151 }
152
153 let credential = self
154 .store
155 .load_credential(&self.server_id)
156 .await
157 .map_err(|e| OAuthError::NoCredentials(e.to_string()))?
158 .ok_or_else(|| {
159 OAuthError::NoCredentials(
160 "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
161 .to_string(),
162 )
163 })?;
164
165 let account_id = extract_account_id(&credential.access_token)?;
166
167 let cached = CachedToken {
168 access_token: credential.access_token.clone(),
169 account_id: account_id.clone(),
170 expires_at: credential.expires_at,
171 };
172 *self.cached.lock().await = Some(cached);
173
174 Ok((credential.access_token, account_id))
175 }
176
177 pub async fn clear_cache(&self) {
179 *self.cached.lock().await = None;
180 }
181}
182
183pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
188 let parts: Vec<&str> = access_token.split('.').collect();
189 if parts.len() != 3 {
190 return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
191 }
192
193 let decoded = URL_SAFE_NO_PAD
194 .decode(parts[1])
195 .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
196
197 let payload: serde_json::Value = serde_json::from_slice(&decoded)
198 .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
199
200 let account_id = payload
201 .get("https://api.openai.com/auth")
202 .and_then(|auth| auth.get("chatgpt_account_id"))
203 .and_then(|v| v.as_str())
204 .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
205
206 Ok(account_id.to_string())
207}
208
209fn generate_random_state() -> String {
210 uuid::Uuid::new_v4().to_string()
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 fn make_test_jwt(payload: &serde_json::Value) -> String {
219 let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
220 let payload_json = serde_json::to_string(payload).unwrap();
221 let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
222 format!("{header}.{payload_b64url}.fake_signature")
223 }
224
225 #[test]
226 fn extract_account_id_from_valid_jwt() {
227 let payload = serde_json::json!({
228 "sub": "user_123",
229 "https://api.openai.com/auth": {
230 "chatgpt_account_id": "acct_abc123"
231 }
232 });
233
234 let jwt = make_test_jwt(&payload);
235 let account_id = extract_account_id(&jwt).unwrap();
236 assert_eq!(account_id, "acct_abc123");
237 }
238
239 #[test]
240 fn extract_account_id_missing_claim() {
241 let payload = serde_json::json!({
242 "sub": "user_123"
243 });
244
245 let jwt = make_test_jwt(&payload);
246 let result = extract_account_id(&jwt);
247 assert!(result.is_err());
248 assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
249 }
250
251 #[test]
252 fn extract_account_id_invalid_jwt_format() {
253 let result = extract_account_id("not.a.valid.jwt.too.many.parts");
254 assert!(result.is_err());
255
256 let result = extract_account_id("toofewparts");
257 assert!(result.is_err());
258 }
259
260 #[test]
261 fn extract_account_id_invalid_base64() {
262 let result = extract_account_id("header.!!!invalid!!!.signature");
263 assert!(result.is_err());
264 }
265
266 #[test]
267 fn auth_url_is_well_formed() {
268 let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
269 let state = "test-state";
270
271 let auth_url = Url::parse_with_params(
272 AUTHORIZE_URL,
273 &[
274 ("response_type", "code"),
275 ("client_id", CLIENT_ID),
276 ("redirect_uri", REDIRECT_URI),
277 ("scope", SCOPE),
278 ("code_challenge", pkce_challenge.as_str()),
279 ("code_challenge_method", "S256"),
280 ("state", state),
281 ("id_token_add_organizations", "true"),
282 ("codex_cli_simplified_flow", "true"),
283 ("originator", "codex_cli_rs"),
284 ],
285 )
286 .unwrap();
287
288 let url_str = auth_url.as_str();
289 assert!(url_str.starts_with(AUTHORIZE_URL));
290 assert!(url_str.contains("client_id="));
291 assert!(url_str.contains("redirect_uri="));
292 assert!(url_str.contains("scope="));
293 assert!(url_str.contains("code_challenge="));
294 assert!(url_str.contains("state=test-state"));
295 }
296
297 #[test]
298 fn generate_random_state_is_valid_uuid() {
299 let state = generate_random_state();
300 assert!(!state.is_empty());
301 assert!(uuid::Uuid::parse_str(&state).is_ok());
302 }
303
304 #[test]
305 fn oauth_constants_are_valid() {
306 assert!(AUTHORIZE_URL.starts_with("https://"));
307 assert!(TOKEN_URL.starts_with("https://"));
308 assert!(REDIRECT_URI.starts_with("http://localhost:"));
309 assert!(SCOPE.contains("openid"));
310 }
311
312 #[test]
313 fn cached_token_not_expired_when_no_expiry() {
314 let token = CachedToken { access_token: "tok".to_string(), account_id: "acct".to_string(), expires_at: None };
315 assert!(!token.is_expired());
316 }
317
318 #[test]
319 fn cached_token_not_expired_when_future() {
320 let future_ms =
321 u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis())
322 .unwrap()
323 + 3_600_000; let token = CachedToken {
325 access_token: "tok".to_string(),
326 account_id: "acct".to_string(),
327 expires_at: Some(future_ms),
328 };
329 assert!(!token.is_expired());
330 }
331
332 #[test]
333 fn cached_token_expired_when_past() {
334 let token = CachedToken {
335 access_token: "tok".to_string(),
336 account_id: "acct".to_string(),
337 expires_at: Some(1000), };
339 assert!(token.is_expired());
340 }
341}