Skip to main content

llm/providers/codex/
oauth.rs

1use crate::LlmError;
2use crate::oauth::BrowserOAuthHandler;
3use crate::oauth::OAuthError;
4use crate::oauth::OAuthHandler;
5use crate::oauth::credential_store::{OAuthCredential, OAuthCredentialStorage, OAuthCredentialStore};
6use base64::Engine;
7use base64::engine::general_purpose::URL_SAFE_NO_PAD;
8use oauth2::TokenResponse;
9use oauth2::basic::BasicClient;
10use oauth2::reqwest::redirect::Policy;
11use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
12use tokio::sync::Mutex;
13use url::Url;
14
15const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
16const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
17const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
18const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
19const SCOPE: &str = "openid profile email offline_access";
20
21/// Run the full Codex OAuth flow: open browser, capture callback, exchange token, save credentials.
22///
23/// This is designed to be called from `aether auth codex` CLI command.
24pub async fn perform_codex_oauth_flow() -> Result<(), LlmError> {
25    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
26    let state = generate_random_state();
27
28    let auth_url = Url::parse_with_params(
29        AUTHORIZE_URL,
30        &[
31            ("response_type", "code"),
32            ("client_id", CLIENT_ID),
33            ("redirect_uri", REDIRECT_URI),
34            ("scope", SCOPE),
35            ("code_challenge", pkce_challenge.as_str()),
36            ("code_challenge_method", "S256"),
37            ("state", &state),
38            ("id_token_add_organizations", "true"),
39            ("codex_cli_simplified_flow", "true"),
40            ("originator", "codex_cli_rs"),
41        ],
42    )
43    .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
44
45    // Port 1455 is hardcoded because the Codex API has a fixed redirect URI
46    // (http://localhost:1455/auth/callback) registered with OpenAI's OAuth server.
47    let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
48    let callback = handler.authorize(auth_url.as_str()).await?;
49
50    if callback.state != state {
51        return Err(OAuthError::StateMismatch.into());
52    }
53
54    let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
55        .set_auth_uri(
56            AuthUrl::new(AUTHORIZE_URL.to_string())
57                .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
58        )
59        .set_token_uri(
60            TokenUrl::new(TOKEN_URL.to_string())
61                .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
62        )
63        .set_redirect_uri(
64            RedirectUrl::new(REDIRECT_URI.to_string())
65                .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
66        );
67
68    let http_client = oauth2::reqwest::Client::builder()
69        .redirect(Policy::none())
70        .build()
71        .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))?;
72
73    let token_response = oauth_client
74        .exchange_code(AuthorizationCode::new(callback.code))
75        .set_pkce_verifier(pkce_verifier)
76        .request_async(&http_client)
77        .await
78        .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
79
80    let expires_at = token_response.expires_in().map(|duration| {
81        let now_ms = u64::try_from(
82            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
83        )
84        .unwrap_or(u64::MAX);
85        let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
86        now_ms.saturating_add(duration_ms)
87    });
88
89    let credential = OAuthCredential {
90        client_id: CLIENT_ID.to_string(),
91        access_token: token_response.access_token().secret().clone(),
92        refresh_token: token_response.refresh_token().map(|t| t.secret().clone()),
93        expires_at,
94    };
95
96    let store = OAuthCredentialStore::new(super::PROVIDER_ID);
97    store.save_credential(credential).await.map_err(|e| OAuthError::CredentialStore(e.to_string()))?;
98
99    Ok(())
100}
101
102/// Cached token with optional expiry.
103struct CachedToken {
104    access_token: String,
105    account_id: String,
106    /// Unix timestamp in milliseconds when the token expires
107    expires_at: Option<u64>,
108}
109
110impl CachedToken {
111    fn is_expired(&self) -> bool {
112        let Some(expires_at) = self.expires_at else {
113            return false;
114        };
115        let now_ms = u64::try_from(
116            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
117        )
118        .unwrap_or(u64::MAX);
119        now_ms >= expires_at
120    }
121}
122
123/// Manages OAuth tokens for the Codex backend API.
124///
125/// Generic over `OAuthCredentialStorage` so tests can inject an in-memory fake
126/// instead of hitting the OS keychain.
127pub struct CodexTokenManager<T: OAuthCredentialStorage> {
128    store: T,
129    server_id: String,
130    cached: Mutex<Option<CachedToken>>,
131}
132
133impl<T: OAuthCredentialStorage> CodexTokenManager<T> {
134    pub fn new(store: T, server_id: &str) -> Self {
135        Self { store, server_id: server_id.to_string(), cached: Mutex::new(None) }
136    }
137
138    /// Get a valid access token and account ID.
139    ///
140    /// Returns `(access_token, account_id)`. The account ID is extracted from
141    /// the JWT's `https://api.openai.com/auth` claim field `chatgpt_account_id`.
142    pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
143        // Check cache first — return if present and not expired
144        {
145            let guard = self.cached.lock().await;
146            if let Some(cached) = guard.as_ref()
147                && !cached.is_expired()
148            {
149                return Ok((cached.access_token.clone(), cached.account_id.clone()));
150            }
151        }
152
153        let credential = self
154            .store
155            .load_credential(&self.server_id)
156            .await
157            .map_err(|e| OAuthError::NoCredentials(e.to_string()))?
158            .ok_or_else(|| {
159                OAuthError::NoCredentials(
160                    "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
161                        .to_string(),
162                )
163            })?;
164
165        let account_id = extract_account_id(&credential.access_token)?;
166
167        let cached = CachedToken {
168            access_token: credential.access_token.clone(),
169            account_id: account_id.clone(),
170            expires_at: credential.expires_at,
171        };
172        *self.cached.lock().await = Some(cached);
173
174        Ok((credential.access_token, account_id))
175    }
176
177    /// Clear the cached token (e.g. after a 401 response)
178    pub async fn clear_cache(&self) {
179        *self.cached.lock().await = None;
180    }
181}
182
183/// Extract the account ID from a JWT access token.
184///
185/// The JWT payload contains a claim at `https://api.openai.com/auth`
186/// with a `chatgpt_account_id` field.
187pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
188    let parts: Vec<&str> = access_token.split('.').collect();
189    if parts.len() != 3 {
190        return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
191    }
192
193    let decoded = URL_SAFE_NO_PAD
194        .decode(parts[1])
195        .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
196
197    let payload: serde_json::Value = serde_json::from_slice(&decoded)
198        .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
199
200    let account_id = payload
201        .get("https://api.openai.com/auth")
202        .and_then(|auth| auth.get("chatgpt_account_id"))
203        .and_then(|v| v.as_str())
204        .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
205
206    Ok(account_id.to_string())
207}
208
209fn generate_random_state() -> String {
210    uuid::Uuid::new_v4().to_string()
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    /// Create a test JWT with a given payload
218    fn make_test_jwt(payload: &serde_json::Value) -> String {
219        let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
220        let payload_json = serde_json::to_string(payload).unwrap();
221        let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
222        format!("{header}.{payload_b64url}.fake_signature")
223    }
224
225    #[test]
226    fn extract_account_id_from_valid_jwt() {
227        let payload = serde_json::json!({
228            "sub": "user_123",
229            "https://api.openai.com/auth": {
230                "chatgpt_account_id": "acct_abc123"
231            }
232        });
233
234        let jwt = make_test_jwt(&payload);
235        let account_id = extract_account_id(&jwt).unwrap();
236        assert_eq!(account_id, "acct_abc123");
237    }
238
239    #[test]
240    fn extract_account_id_missing_claim() {
241        let payload = serde_json::json!({
242            "sub": "user_123"
243        });
244
245        let jwt = make_test_jwt(&payload);
246        let result = extract_account_id(&jwt);
247        assert!(result.is_err());
248        assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
249    }
250
251    #[test]
252    fn extract_account_id_invalid_jwt_format() {
253        let result = extract_account_id("not.a.valid.jwt.too.many.parts");
254        assert!(result.is_err());
255
256        let result = extract_account_id("toofewparts");
257        assert!(result.is_err());
258    }
259
260    #[test]
261    fn extract_account_id_invalid_base64() {
262        let result = extract_account_id("header.!!!invalid!!!.signature");
263        assert!(result.is_err());
264    }
265
266    #[test]
267    fn auth_url_is_well_formed() {
268        let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
269        let state = "test-state";
270
271        let auth_url = Url::parse_with_params(
272            AUTHORIZE_URL,
273            &[
274                ("response_type", "code"),
275                ("client_id", CLIENT_ID),
276                ("redirect_uri", REDIRECT_URI),
277                ("scope", SCOPE),
278                ("code_challenge", pkce_challenge.as_str()),
279                ("code_challenge_method", "S256"),
280                ("state", state),
281                ("id_token_add_organizations", "true"),
282                ("codex_cli_simplified_flow", "true"),
283                ("originator", "codex_cli_rs"),
284            ],
285        )
286        .unwrap();
287
288        let url_str = auth_url.as_str();
289        assert!(url_str.starts_with(AUTHORIZE_URL));
290        assert!(url_str.contains("client_id="));
291        assert!(url_str.contains("redirect_uri="));
292        assert!(url_str.contains("scope="));
293        assert!(url_str.contains("code_challenge="));
294        assert!(url_str.contains("state=test-state"));
295    }
296
297    #[test]
298    fn generate_random_state_is_valid_uuid() {
299        let state = generate_random_state();
300        assert!(!state.is_empty());
301        assert!(uuid::Uuid::parse_str(&state).is_ok());
302    }
303
304    #[test]
305    fn oauth_constants_are_valid() {
306        assert!(AUTHORIZE_URL.starts_with("https://"));
307        assert!(TOKEN_URL.starts_with("https://"));
308        assert!(REDIRECT_URI.starts_with("http://localhost:"));
309        assert!(SCOPE.contains("openid"));
310    }
311
312    #[test]
313    fn cached_token_not_expired_when_no_expiry() {
314        let token = CachedToken { access_token: "tok".to_string(), account_id: "acct".to_string(), expires_at: None };
315        assert!(!token.is_expired());
316    }
317
318    #[test]
319    fn cached_token_not_expired_when_future() {
320        let future_ms =
321            u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis())
322                .unwrap()
323                + 3_600_000; // 1 hour from now
324        let token = CachedToken {
325            access_token: "tok".to_string(),
326            account_id: "acct".to_string(),
327            expires_at: Some(future_ms),
328        };
329        assert!(!token.is_expired());
330    }
331
332    #[test]
333    fn cached_token_expired_when_past() {
334        let token = CachedToken {
335            access_token: "tok".to_string(),
336            account_id: "acct".to_string(),
337            expires_at: Some(1000), // way in the past
338        };
339        assert!(token.is_expired());
340    }
341}