Skip to main content

llm/providers/codex/
oauth.rs

1use crate::LlmError;
2use aether_auth::{BrowserOAuthHandler, OAuthCredential, OAuthCredentialStorage, OAuthError, OAuthHandler};
3use base64::Engine;
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use oauth2::TokenResponse;
6use oauth2::basic::BasicClient;
7use oauth2::reqwest::redirect::Policy;
8use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use url::Url;
12
13const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
14const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
15const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
16const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
17const SCOPE: &str = "openid profile email offline_access";
18
19/// Run the full Codex OAuth flow: open browser, capture callback, exchange token, save credentials.
20///
21pub async fn perform_codex_oauth_flow(store: &dyn OAuthCredentialStorage) -> Result<(), LlmError> {
22    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
23    let state = generate_random_state();
24
25    let auth_url = Url::parse_with_params(
26        AUTHORIZE_URL,
27        &[
28            ("response_type", "code"),
29            ("client_id", CLIENT_ID),
30            ("redirect_uri", REDIRECT_URI),
31            ("scope", SCOPE),
32            ("code_challenge", pkce_challenge.as_str()),
33            ("code_challenge_method", "S256"),
34            ("state", &state),
35            ("id_token_add_organizations", "true"),
36            ("codex_cli_simplified_flow", "true"),
37            ("originator", "codex_cli_rs"),
38        ],
39    )
40    .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
41
42    // Port 1455 is hardcoded because the Codex API has a fixed redirect URI
43    // (http://localhost:1455/auth/callback) registered with OpenAI's OAuth server.
44    let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
45    let callback = handler.authorize(auth_url.as_str()).await?;
46
47    if callback.state != state {
48        return Err(OAuthError::StateMismatch.into());
49    }
50
51    let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
52        .set_auth_uri(
53            AuthUrl::new(AUTHORIZE_URL.to_string())
54                .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
55        )
56        .set_token_uri(
57            TokenUrl::new(TOKEN_URL.to_string())
58                .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
59        )
60        .set_redirect_uri(
61            RedirectUrl::new(REDIRECT_URI.to_string())
62                .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
63        );
64
65    let http_client = oauth2::reqwest::Client::builder()
66        .redirect(Policy::none())
67        .build()
68        .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))?;
69
70    let token_response = oauth_client
71        .exchange_code(AuthorizationCode::new(callback.code))
72        .set_pkce_verifier(pkce_verifier)
73        .request_async(&http_client)
74        .await
75        .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
76
77    let expires_at = token_response.expires_in().map(|duration| {
78        let now_ms = u64::try_from(
79            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
80        )
81        .unwrap_or(u64::MAX);
82        let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
83        now_ms.saturating_add(duration_ms)
84    });
85
86    let credential = OAuthCredential {
87        client_id: CLIENT_ID.to_string(),
88        access_token: token_response.access_token().secret().clone(),
89        refresh_token: token_response.refresh_token().map(|t| t.secret().clone()),
90        expires_at,
91    };
92
93    store.save_credential(super::PROVIDER_ID, credential).await?;
94
95    Ok(())
96}
97
98/// Cached token with optional expiry.
99struct CachedToken {
100    access_token: String,
101    account_id: String,
102    /// Unix timestamp in milliseconds when the token expires
103    expires_at: Option<u64>,
104}
105
106impl CachedToken {
107    fn is_expired(&self) -> bool {
108        let Some(expires_at) = self.expires_at else {
109            return false;
110        };
111        let now_ms = u64::try_from(
112            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
113        )
114        .unwrap_or(u64::MAX);
115        now_ms >= expires_at
116    }
117}
118
119/// Manages OAuth tokens for the Codex backend API.
120///
121/// Holds an `Arc<dyn OAuthCredentialStorage>` so callers can swap in keyring-backed,
122/// file-backed, or in-memory stores without changing this type.
123pub struct CodexTokenManager {
124    store: Arc<dyn OAuthCredentialStorage>,
125    credential_key: String,
126    cached: Mutex<Option<CachedToken>>,
127}
128
129impl CodexTokenManager {
130    pub fn new(store: Arc<dyn OAuthCredentialStorage>, credential_key: &str) -> Self {
131        Self { store, credential_key: credential_key.to_string(), cached: Mutex::new(None) }
132    }
133
134    /// Get a valid access token and account ID.
135    ///
136    /// Returns `(access_token, account_id)`. The account ID is extracted from
137    /// the JWT's `https://api.openai.com/auth` claim field `chatgpt_account_id`.
138    pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
139        // Check cache first — return if present and not expired
140        {
141            let guard = self.cached.lock().await;
142            if let Some(cached) = guard.as_ref()
143                && !cached.is_expired()
144            {
145                return Ok((cached.access_token.clone(), cached.account_id.clone()));
146            }
147        }
148
149        let credential = self
150            .store
151            .load_credential(&self.credential_key)
152            .await
153            .map_err(|e| OAuthError::NoCredentials(e.to_string()))?
154            .ok_or_else(|| {
155                OAuthError::NoCredentials(
156                    "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
157                        .to_string(),
158                )
159            })?;
160
161        let account_id = extract_account_id(&credential.access_token)?;
162
163        let cached = CachedToken {
164            access_token: credential.access_token.clone(),
165            account_id: account_id.clone(),
166            expires_at: credential.expires_at,
167        };
168        *self.cached.lock().await = Some(cached);
169
170        Ok((credential.access_token, account_id))
171    }
172
173    /// Clear the cached token (e.g. after a 401 response)
174    pub async fn clear_cache(&self) {
175        *self.cached.lock().await = None;
176    }
177}
178
179/// Extract the account ID from a JWT access token.
180///
181/// The JWT payload contains a claim at `https://api.openai.com/auth`
182/// with a `chatgpt_account_id` field.
183pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
184    let parts: Vec<&str> = access_token.split('.').collect();
185    if parts.len() != 3 {
186        return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
187    }
188
189    let decoded = URL_SAFE_NO_PAD
190        .decode(parts[1])
191        .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
192
193    let payload: serde_json::Value = serde_json::from_slice(&decoded)
194        .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
195
196    let account_id = payload
197        .get("https://api.openai.com/auth")
198        .and_then(|auth| auth.get("chatgpt_account_id"))
199        .and_then(|v| v.as_str())
200        .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
201
202    Ok(account_id.to_string())
203}
204
205fn generate_random_state() -> String {
206    uuid::Uuid::new_v4().to_string()
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    /// Create a test JWT with a given payload
214    fn make_test_jwt(payload: &serde_json::Value) -> String {
215        let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
216        let payload_json = serde_json::to_string(payload).unwrap();
217        let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
218        format!("{header}.{payload_b64url}.fake_signature")
219    }
220
221    #[test]
222    fn extract_account_id_from_valid_jwt() {
223        let payload = serde_json::json!({
224            "sub": "user_123",
225            "https://api.openai.com/auth": {
226                "chatgpt_account_id": "acct_abc123"
227            }
228        });
229
230        let jwt = make_test_jwt(&payload);
231        let account_id = extract_account_id(&jwt).unwrap();
232        assert_eq!(account_id, "acct_abc123");
233    }
234
235    #[test]
236    fn extract_account_id_missing_claim() {
237        let payload = serde_json::json!({
238            "sub": "user_123"
239        });
240
241        let jwt = make_test_jwt(&payload);
242        let result = extract_account_id(&jwt);
243        assert!(result.is_err());
244        assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
245    }
246
247    #[test]
248    fn extract_account_id_invalid_jwt_format() {
249        let result = extract_account_id("not.a.valid.jwt.too.many.parts");
250        assert!(result.is_err());
251
252        let result = extract_account_id("toofewparts");
253        assert!(result.is_err());
254    }
255
256    #[test]
257    fn extract_account_id_invalid_base64() {
258        let result = extract_account_id("header.!!!invalid!!!.signature");
259        assert!(result.is_err());
260    }
261
262    #[test]
263    fn auth_url_is_well_formed() {
264        let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
265        let state = "test-state";
266
267        let auth_url = Url::parse_with_params(
268            AUTHORIZE_URL,
269            &[
270                ("response_type", "code"),
271                ("client_id", CLIENT_ID),
272                ("redirect_uri", REDIRECT_URI),
273                ("scope", SCOPE),
274                ("code_challenge", pkce_challenge.as_str()),
275                ("code_challenge_method", "S256"),
276                ("state", state),
277                ("id_token_add_organizations", "true"),
278                ("codex_cli_simplified_flow", "true"),
279                ("originator", "codex_cli_rs"),
280            ],
281        )
282        .unwrap();
283
284        let url_str = auth_url.as_str();
285        assert!(url_str.starts_with(AUTHORIZE_URL));
286        assert!(url_str.contains("client_id="));
287        assert!(url_str.contains("redirect_uri="));
288        assert!(url_str.contains("scope="));
289        assert!(url_str.contains("code_challenge="));
290        assert!(url_str.contains("state=test-state"));
291    }
292
293    #[test]
294    fn generate_random_state_is_valid_uuid() {
295        let state = generate_random_state();
296        assert!(!state.is_empty());
297        assert!(uuid::Uuid::parse_str(&state).is_ok());
298    }
299
300    #[test]
301    fn oauth_constants_are_valid() {
302        assert!(AUTHORIZE_URL.starts_with("https://"));
303        assert!(TOKEN_URL.starts_with("https://"));
304        assert!(REDIRECT_URI.starts_with("http://localhost:"));
305        assert!(SCOPE.contains("openid"));
306    }
307
308    #[test]
309    fn cached_token_not_expired_when_no_expiry() {
310        let token = CachedToken { access_token: "tok".to_string(), account_id: "acct".to_string(), expires_at: None };
311        assert!(!token.is_expired());
312    }
313
314    #[test]
315    fn cached_token_not_expired_when_future() {
316        let future_ms =
317            u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis())
318                .unwrap()
319                + 3_600_000; // 1 hour from now
320        let token = CachedToken {
321            access_token: "tok".to_string(),
322            account_id: "acct".to_string(),
323            expires_at: Some(future_ms),
324        };
325        assert!(!token.is_expired());
326    }
327
328    #[test]
329    fn cached_token_expired_when_past() {
330        let token = CachedToken {
331            access_token: "tok".to_string(),
332            account_id: "acct".to_string(),
333            expires_at: Some(1000), // way in the past
334        };
335        assert!(token.is_expired());
336    }
337}