1use crate::auth::oauth_common::{OAuthProfileState, OAuthTokenManager};
2use crate::auth::profiles::{
3 profile_id, ProviderAuthKind, ProviderAuthProfile, ProviderAuthProfilesData, ProviderTokenSet,
4};
5use crate::auth::store::ProviderAuthStore;
6use anyhow::{Context, Result};
7use base64::Engine;
8use chrono::Utc;
9use serde::Deserialize;
10use std::path::Path;
11use std::time::Duration;
12
13const DEFAULT_PROFILE_NAME: &str = "default";
14const OPENAI_CODEX_PROVIDER: &str = "openai-codex";
15const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
16const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
17const OPENAI_REFRESH_SKEW_SECS: u64 = 90;
18
19#[derive(Clone)]
20pub struct ProviderAuthService {
21 store: ProviderAuthStore,
22 client: reqwest::Client,
23}
24
25#[derive(Debug, Deserialize)]
26struct OpenAiTokenResponse {
27 access_token: String,
28 #[serde(default)]
29 refresh_token: Option<String>,
30 #[serde(default)]
31 id_token: Option<String>,
32 #[serde(default)]
33 expires_in: Option<i64>,
34 #[serde(default)]
35 token_type: Option<String>,
36 #[serde(default)]
37 scope: Option<String>,
38}
39
40impl ProviderAuthService {
41 pub fn new(config_dir: &Path) -> Self {
42 Self {
43 store: ProviderAuthStore::new(config_dir),
44 client: reqwest::Client::new(),
45 }
46 }
47
48 pub fn store(&self) -> &ProviderAuthStore {
49 &self.store
50 }
51
52 pub async fn load_profiles(&self) -> Result<ProviderAuthProfilesData> {
53 self.store.load().await
54 }
55
56 pub async fn store_openai_codex_tokens(
57 &self,
58 profile_name: &str,
59 token_set: ProviderTokenSet,
60 ) -> Result<ProviderAuthProfile> {
61 self.store_oauth_profile(
62 OPENAI_CODEX_PROVIDER,
63 profile_name,
64 OAuthProfileState {
65 account_id: extract_account_id_from_jwt(&token_set.access_token),
66 token_set,
67 metadata: Default::default(),
68 },
69 true,
70 )
71 .await
72 }
73
74 pub async fn store_oauth_profile(
75 &self,
76 provider: &str,
77 profile_name: &str,
78 state: OAuthProfileState,
79 set_active: bool,
80 ) -> Result<ProviderAuthProfile> {
81 let mut profile =
82 ProviderAuthProfile::new_oauth(provider, profile_name, state.token_set.clone());
83 profile.account_id = state.account_id;
84 profile.metadata = state.metadata;
85 self.store
86 .upsert_profile(profile.clone(), set_active)
87 .await?;
88 Ok(profile)
89 }
90
91 pub async fn get_profile(
92 &self,
93 provider: &str,
94 profile_override: Option<&str>,
95 ) -> Result<Option<ProviderAuthProfile>> {
96 let data = self.store.load().await?;
97 let Some(id) = select_profile_id(&data, provider, profile_override) else {
98 return Ok(None);
99 };
100 Ok(data.profiles.get(&id).cloned())
101 }
102
103 pub async fn get_active_profile(&self, provider: &str) -> Result<Option<ProviderAuthProfile>> {
104 self.get_profile(provider, None).await
105 }
106
107 pub async fn get_provider_bearer_token(
108 &self,
109 provider: &str,
110 profile_override: Option<&str>,
111 ) -> Result<Option<String>> {
112 let profile = self.get_profile(provider, profile_override).await?;
113 let Some(profile) = profile else {
114 return Ok(None);
115 };
116 Ok(match profile.kind {
117 ProviderAuthKind::OAuth => profile.token_set.and_then(|token_set| {
118 (!token_set.access_token.trim().is_empty()).then_some(token_set.access_token)
119 }),
120 ProviderAuthKind::Token => profile
121 .token
122 .and_then(|token| (!token.trim().is_empty()).then_some(token)),
123 })
124 }
125
126 pub async fn get_valid_openai_codex_access_token(
127 &self,
128 profile_override: Option<&str>,
129 ) -> Result<Option<String>> {
130 let profile = self
131 .get_profile(OPENAI_CODEX_PROVIDER, profile_override)
132 .await?;
133 let Some(profile) = profile else {
134 return Ok(None);
135 };
136 let Some(token_set) = profile.token_set.clone() else {
137 anyhow::bail!("OpenAI Codex auth profile is missing OAuth token set");
138 };
139
140 if !token_set.is_expiring_within(Duration::from_secs(OPENAI_REFRESH_SKEW_SECS)) {
141 return Ok(Some(token_set.access_token));
142 }
143
144 let Some(refresh_token) = token_set.refresh_token.clone() else {
145 return Ok(Some(token_set.access_token));
146 };
147 let refreshed = self
148 .refresh_openai_codex_tokens_with_refresh_token(&refresh_token)
149 .await?;
150 let access_token = refreshed.access_token.clone();
151 let account_id = extract_account_id_from_jwt(&access_token);
152 let profile_id = profile.id.clone();
153 self.store
154 .update_profile(&profile_id, |profile| {
155 profile.token_set = Some(refreshed.clone());
156 profile.account_id = account_id.clone();
157 Ok(())
158 })
159 .await?;
160 Ok(Some(access_token))
161 }
162
163 pub async fn refresh_openai_codex_tokens(
164 &self,
165 profile_override: Option<&str>,
166 ) -> Result<ProviderAuthProfile> {
167 self.refresh_oauth_profile(
168 OPENAI_CODEX_PROVIDER,
169 profile_override,
170 &OpenAiCodexTokenManager {
171 client: self.client.clone(),
172 },
173 )
174 .await
175 }
176
177 pub async fn refresh_oauth_profile(
178 &self,
179 provider: &str,
180 profile_override: Option<&str>,
181 token_manager: &dyn OAuthTokenManager,
182 ) -> Result<ProviderAuthProfile> {
183 let profile = self
184 .get_profile(provider, profile_override)
185 .await?
186 .ok_or_else(|| {
187 anyhow::anyhow!("OAuth auth profile not found for provider '{provider}'")
188 })?;
189 let refresh_token = profile
190 .token_set
191 .as_ref()
192 .and_then(|tokens| tokens.refresh_token.clone())
193 .ok_or_else(|| {
194 anyhow::anyhow!(
195 "OAuth auth profile for provider '{provider}' does not contain a refresh token"
196 )
197 })?;
198 let refreshed = token_manager.refresh_oauth_state(&refresh_token).await?;
199 let account_id = refreshed
200 .account_id
201 .clone()
202 .or_else(|| token_manager.extract_account_id(&refreshed.token_set.access_token));
203 let metadata = refreshed.metadata.clone();
204 let token_set = refreshed.token_set.clone();
205 self.store
206 .update_profile(&profile.id, |existing| {
207 existing.token_set = Some(token_set.clone());
208 existing.account_id = account_id.clone();
209 existing.metadata = metadata.clone();
210 Ok(())
211 })
212 .await
213 }
214
215 pub async fn set_active_profile(&self, provider: &str, profile_name: &str) -> Result<String> {
216 let requested_id = if profile_name.contains(':') {
217 profile_name.to_string()
218 } else {
219 profile_id(provider, profile_name)
220 };
221 self.store
222 .set_active_profile(provider, &requested_id)
223 .await?;
224 Ok(requested_id)
225 }
226
227 pub async fn remove_profile(&self, provider: &str, profile_name: &str) -> Result<bool> {
228 let requested_id = if profile_name.contains(':') {
229 profile_name.to_string()
230 } else {
231 profile_id(provider, profile_name)
232 };
233 self.store.remove_profile(&requested_id).await
234 }
235
236 async fn refresh_openai_codex_tokens_with_refresh_token(
237 &self,
238 refresh_token: &str,
239 ) -> Result<ProviderTokenSet> {
240 OpenAiCodexTokenManager {
241 client: self.client.clone(),
242 }
243 .refresh_oauth_state(refresh_token)
244 .await
245 .map(|state| state.token_set)
246 }
247}
248
249struct OpenAiCodexTokenManager {
250 client: reqwest::Client,
251}
252
253#[async_trait::async_trait]
254impl OAuthTokenManager for OpenAiCodexTokenManager {
255 async fn refresh_oauth_state(&self, refresh_token: &str) -> Result<OAuthProfileState> {
256 let response = self
257 .client
258 .post(OPENAI_OAUTH_TOKEN_URL)
259 .form(&[
260 ("grant_type", "refresh_token"),
261 ("refresh_token", refresh_token),
262 ("client_id", OPENAI_OAUTH_CLIENT_ID),
263 ])
264 .send()
265 .await
266 .context("Failed to refresh OpenAI Codex OAuth token")?;
267 if !response.status().is_success() {
268 let status = response.status();
269 let body = response.text().await.unwrap_or_default();
270 anyhow::bail!("OpenAI Codex token refresh failed ({status}): {body}");
271 }
272 let parsed: OpenAiTokenResponse = response
273 .json()
274 .await
275 .context("Failed to parse OpenAI Codex token refresh response")?;
276 let expires_at = parsed
277 .expires_in
278 .map(|seconds| Utc::now() + chrono::Duration::seconds(seconds));
279 Ok(ProviderTokenSet {
280 access_token: parsed.access_token,
281 refresh_token: parsed
282 .refresh_token
283 .or_else(|| Some(refresh_token.to_string())),
284 id_token: parsed.id_token,
285 expires_at,
286 token_type: parsed.token_type,
287 scope: parsed.scope,
288 })
289 .map(|token_set| OAuthProfileState {
290 account_id: extract_account_id_from_jwt(&token_set.access_token),
291 token_set,
292 metadata: Default::default(),
293 })
294 }
295
296 fn extract_account_id(&self, access_token: &str) -> Option<String> {
297 extract_account_id_from_jwt(access_token)
298 }
299}
300
301fn select_profile_id(
302 data: &ProviderAuthProfilesData,
303 provider: &str,
304 profile_override: Option<&str>,
305) -> Option<String> {
306 profile_override
307 .map(|value| {
308 if value.contains(':') {
309 value.to_string()
310 } else {
311 profile_id(provider, value)
312 }
313 })
314 .or_else(|| data.active_profiles.get(provider).cloned())
315 .or_else(|| {
316 let fallback = profile_id(provider, DEFAULT_PROFILE_NAME);
317 data.profiles.contains_key(&fallback).then_some(fallback)
318 })
319}
320
321pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
322 let payload = token.split('.').nth(1)?;
323 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
324 .decode(payload)
325 .ok()
326 .or_else(|| {
327 base64::engine::general_purpose::URL_SAFE
328 .decode(payload)
329 .ok()
330 })?;
331 let json: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
332 for key in ["https://api.openai.com/auth", "org_id", "account_id", "sub"] {
333 if let Some(value) = json.get(key).and_then(|value| value.as_str()) {
334 return Some(value.to_string());
335 }
336 }
337 None
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use std::collections::BTreeMap;
344 use tempfile::tempdir;
345
346 struct FakeTokenManager;
347
348 #[async_trait::async_trait]
349 impl OAuthTokenManager for FakeTokenManager {
350 async fn refresh_oauth_state(&self, refresh_token: &str) -> Result<OAuthProfileState> {
351 Ok(OAuthProfileState {
352 token_set: ProviderTokenSet {
353 access_token: "refreshed-access".into(),
354 refresh_token: Some(refresh_token.to_string()),
355 id_token: None,
356 expires_at: None,
357 token_type: Some("Bearer".into()),
358 scope: Some("openid".into()),
359 },
360 account_id: Some("acct-1".into()),
361 metadata: BTreeMap::from([(
362 "api_base".to_string(),
363 "https://oauth.example/v1".to_string(),
364 )]),
365 })
366 }
367
368 fn extract_account_id(&self, access_token: &str) -> Option<String> {
369 Some(access_token.to_string())
370 }
371 }
372
373 #[tokio::test]
374 async fn store_and_get_bearer_token() {
375 let dir = tempdir().unwrap();
376 let service = ProviderAuthService::new(dir.path());
377 service
378 .store_openai_codex_tokens(
379 "default",
380 ProviderTokenSet {
381 access_token: "access".into(),
382 refresh_token: Some("refresh".into()),
383 id_token: None,
384 expires_at: None,
385 token_type: Some("Bearer".into()),
386 scope: Some("openid".into()),
387 },
388 )
389 .await
390 .unwrap();
391 assert_eq!(
392 service
393 .get_provider_bearer_token("openai-codex", None)
394 .await
395 .unwrap()
396 .as_deref(),
397 Some("access")
398 );
399 }
400
401 #[tokio::test]
402 async fn set_active_profile_uses_profile_name() {
403 let dir = tempdir().unwrap();
404 let service = ProviderAuthService::new(dir.path());
405 service
406 .store_openai_codex_tokens(
407 "work",
408 ProviderTokenSet {
409 access_token: "access".into(),
410 refresh_token: Some("refresh".into()),
411 id_token: None,
412 expires_at: None,
413 token_type: None,
414 scope: None,
415 },
416 )
417 .await
418 .unwrap();
419 let selected = service
420 .set_active_profile("openai-codex", "work")
421 .await
422 .unwrap();
423 assert_eq!(selected, "openai-codex:work");
424 }
425
426 #[tokio::test]
427 async fn refresh_oauth_profile_updates_metadata_for_generic_provider() {
428 let dir = tempdir().unwrap();
429 let service = ProviderAuthService::new(dir.path());
430 service
431 .store_oauth_profile(
432 "qwen-login",
433 "default",
434 OAuthProfileState {
435 token_set: ProviderTokenSet {
436 access_token: "access".into(),
437 refresh_token: Some("refresh".into()),
438 id_token: None,
439 expires_at: None,
440 token_type: Some("Bearer".into()),
441 scope: Some("openid".into()),
442 },
443 account_id: None,
444 metadata: BTreeMap::new(),
445 },
446 true,
447 )
448 .await
449 .unwrap();
450
451 let updated = service
452 .refresh_oauth_profile("qwen-login", None, &FakeTokenManager)
453 .await
454 .unwrap();
455
456 assert_eq!(
457 updated.token_set.as_ref().unwrap().access_token,
458 "refreshed-access"
459 );
460 assert_eq!(updated.account_id.as_deref(), Some("acct-1"));
461 assert_eq!(
462 updated.metadata.get("api_base").map(String::as_str),
463 Some("https://oauth.example/v1")
464 );
465 }
466}