Skip to main content

llm/providers/codex/
oauth.rs

1use crate::LlmError;
2use crate::oauth::BrowserOAuthHandler;
3use crate::oauth::OAuthError;
4use crate::oauth::OAuthHandler;
5use crate::oauth::credential_store::{OAuthCredential, OAuthCredentialStorage, OAuthCredentialStore};
6use base64::Engine;
7use base64::engine::general_purpose::URL_SAFE_NO_PAD;
8use oauth2::TokenResponse;
9use oauth2::basic::BasicClient;
10use oauth2::reqwest::redirect::Policy;
11use oauth2::{AuthUrl, AuthorizationCode, ClientId, PkceCodeChallenge, RedirectUrl, TokenUrl};
12use tokio::sync::Mutex;
13use url::Url;
14
15const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
16const AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
17const TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
18const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
19const SCOPE: &str = "openid profile email offline_access";
20
21/// Run the full Codex OAuth flow: open browser, capture callback, exchange token, save credentials.
22///
23/// This is designed to be called from `aether auth codex` CLI command.
24pub async fn perform_codex_oauth_flow() -> Result<(), LlmError> {
25    let store = OAuthCredentialStore::with_platform_store()?;
26    perform_codex_oauth_flow_with_store(&store).await
27}
28
29pub async fn perform_codex_oauth_flow_with_store(store: &impl OAuthCredentialStorage) -> Result<(), LlmError> {
30    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
31    let state = generate_random_state();
32
33    let auth_url = Url::parse_with_params(
34        AUTHORIZE_URL,
35        &[
36            ("response_type", "code"),
37            ("client_id", CLIENT_ID),
38            ("redirect_uri", REDIRECT_URI),
39            ("scope", SCOPE),
40            ("code_challenge", pkce_challenge.as_str()),
41            ("code_challenge_method", "S256"),
42            ("state", &state),
43            ("id_token_add_organizations", "true"),
44            ("codex_cli_simplified_flow", "true"),
45            ("originator", "codex_cli_rs"),
46        ],
47    )
48    .map_err(|e| OAuthError::TokenExchange(format!("Failed to build auth URL: {e}")))?;
49
50    // Port 1455 is hardcoded because the Codex API has a fixed redirect URI
51    // (http://localhost:1455/auth/callback) registered with OpenAI's OAuth server.
52    let handler = BrowserOAuthHandler::with_redirect_uri(REDIRECT_URI, 1455)?;
53    let callback = handler.authorize(auth_url.as_str()).await?;
54
55    if callback.state != state {
56        return Err(OAuthError::StateMismatch.into());
57    }
58
59    let oauth_client = BasicClient::new(ClientId::new(CLIENT_ID.to_string()))
60        .set_auth_uri(
61            AuthUrl::new(AUTHORIZE_URL.to_string())
62                .map_err(|e| OAuthError::TokenExchange(format!("invalid auth URL: {e}")))?,
63        )
64        .set_token_uri(
65            TokenUrl::new(TOKEN_URL.to_string())
66                .map_err(|e| OAuthError::TokenExchange(format!("invalid token URL: {e}")))?,
67        )
68        .set_redirect_uri(
69            RedirectUrl::new(REDIRECT_URI.to_string())
70                .map_err(|e| OAuthError::TokenExchange(format!("invalid redirect URI: {e}")))?,
71        );
72
73    let http_client = oauth2::reqwest::Client::builder()
74        .redirect(Policy::none())
75        .build()
76        .map_err(|e| OAuthError::TokenExchange(format!("failed to build HTTP client: {e}")))?;
77
78    let token_response = oauth_client
79        .exchange_code(AuthorizationCode::new(callback.code))
80        .set_pkce_verifier(pkce_verifier)
81        .request_async(&http_client)
82        .await
83        .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
84
85    let expires_at = token_response.expires_in().map(|duration| {
86        let now_ms = u64::try_from(
87            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
88        )
89        .unwrap_or(u64::MAX);
90        let duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
91        now_ms.saturating_add(duration_ms)
92    });
93
94    let credential = OAuthCredential {
95        client_id: CLIENT_ID.to_string(),
96        access_token: token_response.access_token().secret().clone(),
97        refresh_token: token_response.refresh_token().map(|t| t.secret().clone()),
98        expires_at,
99    };
100
101    store.save_credential(super::PROVIDER_ID, credential).await?;
102
103    Ok(())
104}
105
106/// Cached token with optional expiry.
107struct CachedToken {
108    access_token: String,
109    account_id: String,
110    /// Unix timestamp in milliseconds when the token expires
111    expires_at: Option<u64>,
112}
113
114impl CachedToken {
115    fn is_expired(&self) -> bool {
116        let Some(expires_at) = self.expires_at else {
117            return false;
118        };
119        let now_ms = u64::try_from(
120            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_millis(),
121        )
122        .unwrap_or(u64::MAX);
123        now_ms >= expires_at
124    }
125}
126
127/// Manages OAuth tokens for the Codex backend API.
128///
129/// Generic over `OAuthCredentialStorage` so tests can inject an in-memory fake
130/// instead of hitting the OS keychain.
131pub struct CodexTokenManager<T: OAuthCredentialStorage> {
132    store: T,
133    credential_key: String,
134    cached: Mutex<Option<CachedToken>>,
135}
136
137impl<T: OAuthCredentialStorage> CodexTokenManager<T> {
138    pub fn new(store: T, credential_key: &str) -> Self {
139        Self { store, credential_key: credential_key.to_string(), cached: Mutex::new(None) }
140    }
141
142    /// Get a valid access token and account ID.
143    ///
144    /// Returns `(access_token, account_id)`. The account ID is extracted from
145    /// the JWT's `https://api.openai.com/auth` claim field `chatgpt_account_id`.
146    pub async fn get_valid_token(&self) -> Result<(String, String), LlmError> {
147        // Check cache first — return if present and not expired
148        {
149            let guard = self.cached.lock().await;
150            if let Some(cached) = guard.as_ref()
151                && !cached.is_expired()
152            {
153                return Ok((cached.access_token.clone(), cached.account_id.clone()));
154            }
155        }
156
157        let credential = self
158            .store
159            .load_credential(&self.credential_key)
160            .await
161            .map_err(|e| OAuthError::NoCredentials(e.to_string()))?
162            .ok_or_else(|| {
163                OAuthError::NoCredentials(
164                    "No Codex OAuth credentials found. Run `aether` and select a codex model to trigger OAuth login."
165                        .to_string(),
166                )
167            })?;
168
169        let account_id = extract_account_id(&credential.access_token)?;
170
171        let cached = CachedToken {
172            access_token: credential.access_token.clone(),
173            account_id: account_id.clone(),
174            expires_at: credential.expires_at,
175        };
176        *self.cached.lock().await = Some(cached);
177
178        Ok((credential.access_token, account_id))
179    }
180
181    /// Clear the cached token (e.g. after a 401 response)
182    pub async fn clear_cache(&self) {
183        *self.cached.lock().await = None;
184    }
185}
186
187/// Extract the account ID from a JWT access token.
188///
189/// The JWT payload contains a claim at `https://api.openai.com/auth`
190/// with a `chatgpt_account_id` field.
191pub fn extract_account_id(access_token: &str) -> Result<String, LlmError> {
192    let parts: Vec<&str> = access_token.split('.').collect();
193    if parts.len() != 3 {
194        return Err(OAuthError::InvalidJwt("expected 3 dot-separated parts".to_string()).into());
195    }
196
197    let decoded = URL_SAFE_NO_PAD
198        .decode(parts[1])
199        .map_err(|e| OAuthError::InvalidJwt(format!("failed to decode payload: {e}")))?;
200
201    let payload: serde_json::Value = serde_json::from_slice(&decoded)
202        .map_err(|e| OAuthError::InvalidJwt(format!("failed to parse payload: {e}")))?;
203
204    let account_id = payload
205        .get("https://api.openai.com/auth")
206        .and_then(|auth| auth.get("chatgpt_account_id"))
207        .and_then(|v| v.as_str())
208        .ok_or_else(|| OAuthError::InvalidJwt("missing chatgpt_account_id in token".to_string()))?;
209
210    Ok(account_id.to_string())
211}
212
213fn generate_random_state() -> String {
214    uuid::Uuid::new_v4().to_string()
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    /// Create a test JWT with a given payload
222    fn make_test_jwt(payload: &serde_json::Value) -> String {
223        let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
224        let payload_json = serde_json::to_string(payload).unwrap();
225        let payload_b64url = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
226        format!("{header}.{payload_b64url}.fake_signature")
227    }
228
229    #[test]
230    fn extract_account_id_from_valid_jwt() {
231        let payload = serde_json::json!({
232            "sub": "user_123",
233            "https://api.openai.com/auth": {
234                "chatgpt_account_id": "acct_abc123"
235            }
236        });
237
238        let jwt = make_test_jwt(&payload);
239        let account_id = extract_account_id(&jwt).unwrap();
240        assert_eq!(account_id, "acct_abc123");
241    }
242
243    #[test]
244    fn extract_account_id_missing_claim() {
245        let payload = serde_json::json!({
246            "sub": "user_123"
247        });
248
249        let jwt = make_test_jwt(&payload);
250        let result = extract_account_id(&jwt);
251        assert!(result.is_err());
252        assert!(result.unwrap_err().to_string().contains("chatgpt_account_id"));
253    }
254
255    #[test]
256    fn extract_account_id_invalid_jwt_format() {
257        let result = extract_account_id("not.a.valid.jwt.too.many.parts");
258        assert!(result.is_err());
259
260        let result = extract_account_id("toofewparts");
261        assert!(result.is_err());
262    }
263
264    #[test]
265    fn extract_account_id_invalid_base64() {
266        let result = extract_account_id("header.!!!invalid!!!.signature");
267        assert!(result.is_err());
268    }
269
270    #[test]
271    fn auth_url_is_well_formed() {
272        let (pkce_challenge, _) = PkceCodeChallenge::new_random_sha256();
273        let state = "test-state";
274
275        let auth_url = Url::parse_with_params(
276            AUTHORIZE_URL,
277            &[
278                ("response_type", "code"),
279                ("client_id", CLIENT_ID),
280                ("redirect_uri", REDIRECT_URI),
281                ("scope", SCOPE),
282                ("code_challenge", pkce_challenge.as_str()),
283                ("code_challenge_method", "S256"),
284                ("state", state),
285                ("id_token_add_organizations", "true"),
286                ("codex_cli_simplified_flow", "true"),
287                ("originator", "codex_cli_rs"),
288            ],
289        )
290        .unwrap();
291
292        let url_str = auth_url.as_str();
293        assert!(url_str.starts_with(AUTHORIZE_URL));
294        assert!(url_str.contains("client_id="));
295        assert!(url_str.contains("redirect_uri="));
296        assert!(url_str.contains("scope="));
297        assert!(url_str.contains("code_challenge="));
298        assert!(url_str.contains("state=test-state"));
299    }
300
301    #[test]
302    fn generate_random_state_is_valid_uuid() {
303        let state = generate_random_state();
304        assert!(!state.is_empty());
305        assert!(uuid::Uuid::parse_str(&state).is_ok());
306    }
307
308    #[test]
309    fn oauth_constants_are_valid() {
310        assert!(AUTHORIZE_URL.starts_with("https://"));
311        assert!(TOKEN_URL.starts_with("https://"));
312        assert!(REDIRECT_URI.starts_with("http://localhost:"));
313        assert!(SCOPE.contains("openid"));
314    }
315
316    #[test]
317    fn cached_token_not_expired_when_no_expiry() {
318        let token = CachedToken { access_token: "tok".to_string(), account_id: "acct".to_string(), expires_at: None };
319        assert!(!token.is_expired());
320    }
321
322    #[test]
323    fn cached_token_not_expired_when_future() {
324        let future_ms =
325            u64::try_from(std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis())
326                .unwrap()
327                + 3_600_000; // 1 hour from now
328        let token = CachedToken {
329            access_token: "tok".to_string(),
330            account_id: "acct".to_string(),
331            expires_at: Some(future_ms),
332        };
333        assert!(!token.is_expired());
334    }
335
336    #[test]
337    fn cached_token_expired_when_past() {
338        let token = CachedToken {
339            access_token: "tok".to_string(),
340            account_id: "acct".to_string(),
341            expires_at: Some(1000), // way in the past
342        };
343        assert!(token.is_expired());
344    }
345}