Skip to main content

agent_diva_core/auth/
service.rs

1use crate::auth::oauth_common::{OAuthProfileState, OAuthTokenManager};
2use crate::auth::profiles::{
3    profile_id, ProviderAuthKind, ProviderAuthProfile, ProviderAuthProfilesData, ProviderTokenSet,
4};
5use crate::auth::store::ProviderAuthStore;
6use anyhow::{Context, Result};
7use base64::Engine;
8use chrono::Utc;
9use serde::Deserialize;
10use std::path::Path;
11use std::time::Duration;
12
13const DEFAULT_PROFILE_NAME: &str = "default";
14const OPENAI_CODEX_PROVIDER: &str = "openai-codex";
15const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
16const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
17const OPENAI_REFRESH_SKEW_SECS: u64 = 90;
18
19#[derive(Clone)]
20pub struct ProviderAuthService {
21    store: ProviderAuthStore,
22    client: reqwest::Client,
23}
24
25#[derive(Debug, Deserialize)]
26struct OpenAiTokenResponse {
27    access_token: String,
28    #[serde(default)]
29    refresh_token: Option<String>,
30    #[serde(default)]
31    id_token: Option<String>,
32    #[serde(default)]
33    expires_in: Option<i64>,
34    #[serde(default)]
35    token_type: Option<String>,
36    #[serde(default)]
37    scope: Option<String>,
38}
39
40impl ProviderAuthService {
41    pub fn new(config_dir: &Path) -> Self {
42        Self {
43            store: ProviderAuthStore::new(config_dir),
44            client: reqwest::Client::new(),
45        }
46    }
47
48    pub fn store(&self) -> &ProviderAuthStore {
49        &self.store
50    }
51
52    pub async fn load_profiles(&self) -> Result<ProviderAuthProfilesData> {
53        self.store.load().await
54    }
55
56    pub async fn store_openai_codex_tokens(
57        &self,
58        profile_name: &str,
59        token_set: ProviderTokenSet,
60    ) -> Result<ProviderAuthProfile> {
61        self.store_oauth_profile(
62            OPENAI_CODEX_PROVIDER,
63            profile_name,
64            OAuthProfileState {
65                account_id: extract_account_id_from_jwt(&token_set.access_token),
66                token_set,
67                metadata: Default::default(),
68            },
69            true,
70        )
71        .await
72    }
73
74    pub async fn store_oauth_profile(
75        &self,
76        provider: &str,
77        profile_name: &str,
78        state: OAuthProfileState,
79        set_active: bool,
80    ) -> Result<ProviderAuthProfile> {
81        let mut profile =
82            ProviderAuthProfile::new_oauth(provider, profile_name, state.token_set.clone());
83        profile.account_id = state.account_id;
84        profile.metadata = state.metadata;
85        self.store
86            .upsert_profile(profile.clone(), set_active)
87            .await?;
88        Ok(profile)
89    }
90
91    pub async fn get_profile(
92        &self,
93        provider: &str,
94        profile_override: Option<&str>,
95    ) -> Result<Option<ProviderAuthProfile>> {
96        let data = self.store.load().await?;
97        let Some(id) = select_profile_id(&data, provider, profile_override) else {
98            return Ok(None);
99        };
100        Ok(data.profiles.get(&id).cloned())
101    }
102
103    pub async fn get_active_profile(&self, provider: &str) -> Result<Option<ProviderAuthProfile>> {
104        self.get_profile(provider, None).await
105    }
106
107    pub async fn get_provider_bearer_token(
108        &self,
109        provider: &str,
110        profile_override: Option<&str>,
111    ) -> Result<Option<String>> {
112        let profile = self.get_profile(provider, profile_override).await?;
113        let Some(profile) = profile else {
114            return Ok(None);
115        };
116        Ok(match profile.kind {
117            ProviderAuthKind::OAuth => profile.token_set.and_then(|token_set| {
118                (!token_set.access_token.trim().is_empty()).then_some(token_set.access_token)
119            }),
120            ProviderAuthKind::Token => profile
121                .token
122                .and_then(|token| (!token.trim().is_empty()).then_some(token)),
123        })
124    }
125
126    pub async fn get_valid_openai_codex_access_token(
127        &self,
128        profile_override: Option<&str>,
129    ) -> Result<Option<String>> {
130        let profile = self
131            .get_profile(OPENAI_CODEX_PROVIDER, profile_override)
132            .await?;
133        let Some(profile) = profile else {
134            return Ok(None);
135        };
136        let Some(token_set) = profile.token_set.clone() else {
137            anyhow::bail!("OpenAI Codex auth profile is missing OAuth token set");
138        };
139
140        if !token_set.is_expiring_within(Duration::from_secs(OPENAI_REFRESH_SKEW_SECS)) {
141            return Ok(Some(token_set.access_token));
142        }
143
144        let Some(refresh_token) = token_set.refresh_token.clone() else {
145            return Ok(Some(token_set.access_token));
146        };
147        let refreshed = self
148            .refresh_openai_codex_tokens_with_refresh_token(&refresh_token)
149            .await?;
150        let access_token = refreshed.access_token.clone();
151        let account_id = extract_account_id_from_jwt(&access_token);
152        let profile_id = profile.id.clone();
153        self.store
154            .update_profile(&profile_id, |profile| {
155                profile.token_set = Some(refreshed.clone());
156                profile.account_id = account_id.clone();
157                Ok(())
158            })
159            .await?;
160        Ok(Some(access_token))
161    }
162
163    pub async fn refresh_openai_codex_tokens(
164        &self,
165        profile_override: Option<&str>,
166    ) -> Result<ProviderAuthProfile> {
167        self.refresh_oauth_profile(
168            OPENAI_CODEX_PROVIDER,
169            profile_override,
170            &OpenAiCodexTokenManager {
171                client: self.client.clone(),
172            },
173        )
174        .await
175    }
176
177    pub async fn refresh_oauth_profile(
178        &self,
179        provider: &str,
180        profile_override: Option<&str>,
181        token_manager: &dyn OAuthTokenManager,
182    ) -> Result<ProviderAuthProfile> {
183        let profile = self
184            .get_profile(provider, profile_override)
185            .await?
186            .ok_or_else(|| {
187                anyhow::anyhow!("OAuth auth profile not found for provider '{provider}'")
188            })?;
189        let refresh_token = profile
190            .token_set
191            .as_ref()
192            .and_then(|tokens| tokens.refresh_token.clone())
193            .ok_or_else(|| {
194                anyhow::anyhow!(
195                    "OAuth auth profile for provider '{provider}' does not contain a refresh token"
196                )
197            })?;
198        let refreshed = token_manager.refresh_oauth_state(&refresh_token).await?;
199        let account_id = refreshed
200            .account_id
201            .clone()
202            .or_else(|| token_manager.extract_account_id(&refreshed.token_set.access_token));
203        let metadata = refreshed.metadata.clone();
204        let token_set = refreshed.token_set.clone();
205        self.store
206            .update_profile(&profile.id, |existing| {
207                existing.token_set = Some(token_set.clone());
208                existing.account_id = account_id.clone();
209                existing.metadata = metadata.clone();
210                Ok(())
211            })
212            .await
213    }
214
215    pub async fn set_active_profile(&self, provider: &str, profile_name: &str) -> Result<String> {
216        let requested_id = if profile_name.contains(':') {
217            profile_name.to_string()
218        } else {
219            profile_id(provider, profile_name)
220        };
221        self.store
222            .set_active_profile(provider, &requested_id)
223            .await?;
224        Ok(requested_id)
225    }
226
227    pub async fn remove_profile(&self, provider: &str, profile_name: &str) -> Result<bool> {
228        let requested_id = if profile_name.contains(':') {
229            profile_name.to_string()
230        } else {
231            profile_id(provider, profile_name)
232        };
233        self.store.remove_profile(&requested_id).await
234    }
235
236    async fn refresh_openai_codex_tokens_with_refresh_token(
237        &self,
238        refresh_token: &str,
239    ) -> Result<ProviderTokenSet> {
240        OpenAiCodexTokenManager {
241            client: self.client.clone(),
242        }
243        .refresh_oauth_state(refresh_token)
244        .await
245        .map(|state| state.token_set)
246    }
247}
248
249struct OpenAiCodexTokenManager {
250    client: reqwest::Client,
251}
252
253#[async_trait::async_trait]
254impl OAuthTokenManager for OpenAiCodexTokenManager {
255    async fn refresh_oauth_state(&self, refresh_token: &str) -> Result<OAuthProfileState> {
256        let response = self
257            .client
258            .post(OPENAI_OAUTH_TOKEN_URL)
259            .form(&[
260                ("grant_type", "refresh_token"),
261                ("refresh_token", refresh_token),
262                ("client_id", OPENAI_OAUTH_CLIENT_ID),
263            ])
264            .send()
265            .await
266            .context("Failed to refresh OpenAI Codex OAuth token")?;
267        if !response.status().is_success() {
268            let status = response.status();
269            let body = response.text().await.unwrap_or_default();
270            anyhow::bail!("OpenAI Codex token refresh failed ({status}): {body}");
271        }
272        let parsed: OpenAiTokenResponse = response
273            .json()
274            .await
275            .context("Failed to parse OpenAI Codex token refresh response")?;
276        let expires_at = parsed
277            .expires_in
278            .map(|seconds| Utc::now() + chrono::Duration::seconds(seconds));
279        Ok(ProviderTokenSet {
280            access_token: parsed.access_token,
281            refresh_token: parsed
282                .refresh_token
283                .or_else(|| Some(refresh_token.to_string())),
284            id_token: parsed.id_token,
285            expires_at,
286            token_type: parsed.token_type,
287            scope: parsed.scope,
288        })
289        .map(|token_set| OAuthProfileState {
290            account_id: extract_account_id_from_jwt(&token_set.access_token),
291            token_set,
292            metadata: Default::default(),
293        })
294    }
295
296    fn extract_account_id(&self, access_token: &str) -> Option<String> {
297        extract_account_id_from_jwt(access_token)
298    }
299}
300
301fn select_profile_id(
302    data: &ProviderAuthProfilesData,
303    provider: &str,
304    profile_override: Option<&str>,
305) -> Option<String> {
306    profile_override
307        .map(|value| {
308            if value.contains(':') {
309                value.to_string()
310            } else {
311                profile_id(provider, value)
312            }
313        })
314        .or_else(|| data.active_profiles.get(provider).cloned())
315        .or_else(|| {
316            let fallback = profile_id(provider, DEFAULT_PROFILE_NAME);
317            data.profiles.contains_key(&fallback).then_some(fallback)
318        })
319}
320
321pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
322    let payload = token.split('.').nth(1)?;
323    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
324        .decode(payload)
325        .ok()
326        .or_else(|| {
327            base64::engine::general_purpose::URL_SAFE
328                .decode(payload)
329                .ok()
330        })?;
331    let json: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
332    for key in ["https://api.openai.com/auth", "org_id", "account_id", "sub"] {
333        if let Some(value) = json.get(key).and_then(|value| value.as_str()) {
334            return Some(value.to_string());
335        }
336    }
337    None
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use std::collections::BTreeMap;
344    use tempfile::tempdir;
345
346    struct FakeTokenManager;
347
348    #[async_trait::async_trait]
349    impl OAuthTokenManager for FakeTokenManager {
350        async fn refresh_oauth_state(&self, refresh_token: &str) -> Result<OAuthProfileState> {
351            Ok(OAuthProfileState {
352                token_set: ProviderTokenSet {
353                    access_token: "refreshed-access".into(),
354                    refresh_token: Some(refresh_token.to_string()),
355                    id_token: None,
356                    expires_at: None,
357                    token_type: Some("Bearer".into()),
358                    scope: Some("openid".into()),
359                },
360                account_id: Some("acct-1".into()),
361                metadata: BTreeMap::from([(
362                    "api_base".to_string(),
363                    "https://oauth.example/v1".to_string(),
364                )]),
365            })
366        }
367
368        fn extract_account_id(&self, access_token: &str) -> Option<String> {
369            Some(access_token.to_string())
370        }
371    }
372
373    #[tokio::test]
374    async fn store_and_get_bearer_token() {
375        let dir = tempdir().unwrap();
376        let service = ProviderAuthService::new(dir.path());
377        service
378            .store_openai_codex_tokens(
379                "default",
380                ProviderTokenSet {
381                    access_token: "access".into(),
382                    refresh_token: Some("refresh".into()),
383                    id_token: None,
384                    expires_at: None,
385                    token_type: Some("Bearer".into()),
386                    scope: Some("openid".into()),
387                },
388            )
389            .await
390            .unwrap();
391        assert_eq!(
392            service
393                .get_provider_bearer_token("openai-codex", None)
394                .await
395                .unwrap()
396                .as_deref(),
397            Some("access")
398        );
399    }
400
401    #[tokio::test]
402    async fn set_active_profile_uses_profile_name() {
403        let dir = tempdir().unwrap();
404        let service = ProviderAuthService::new(dir.path());
405        service
406            .store_openai_codex_tokens(
407                "work",
408                ProviderTokenSet {
409                    access_token: "access".into(),
410                    refresh_token: Some("refresh".into()),
411                    id_token: None,
412                    expires_at: None,
413                    token_type: None,
414                    scope: None,
415                },
416            )
417            .await
418            .unwrap();
419        let selected = service
420            .set_active_profile("openai-codex", "work")
421            .await
422            .unwrap();
423        assert_eq!(selected, "openai-codex:work");
424    }
425
426    #[tokio::test]
427    async fn refresh_oauth_profile_updates_metadata_for_generic_provider() {
428        let dir = tempdir().unwrap();
429        let service = ProviderAuthService::new(dir.path());
430        service
431            .store_oauth_profile(
432                "qwen-login",
433                "default",
434                OAuthProfileState {
435                    token_set: ProviderTokenSet {
436                        access_token: "access".into(),
437                        refresh_token: Some("refresh".into()),
438                        id_token: None,
439                        expires_at: None,
440                        token_type: Some("Bearer".into()),
441                        scope: Some("openid".into()),
442                    },
443                    account_id: None,
444                    metadata: BTreeMap::new(),
445                },
446                true,
447            )
448            .await
449            .unwrap();
450
451        let updated = service
452            .refresh_oauth_profile("qwen-login", None, &FakeTokenManager)
453            .await
454            .unwrap();
455
456        assert_eq!(
457            updated.token_set.as_ref().unwrap().access_token,
458            "refreshed-access"
459        );
460        assert_eq!(updated.account_id.as_deref(), Some("acct-1"));
461        assert_eq!(
462            updated.metadata.get("api_base").map(String::as_str),
463            Some("https://oauth.example/v1")
464        );
465    }
466}