Skip to main content

agent_diva_core/auth/
store.rs

1use crate::auth::profiles::{ProviderAuthProfile, ProviderAuthProfilesData};
2use anyhow::{Context, Result};
3use chrono::Utc;
4use std::path::{Path, PathBuf};
5use std::time::Duration;
6use tokio::fs;
7use tokio::time::{sleep, Instant};
8
9const AUTH_DIR: &str = "data/auth";
10const AUTH_FILENAME: &str = "profiles.json";
11const LOCK_FILENAME: &str = "profiles.lock";
12const LOCK_RETRY_MS: u64 = 50;
13const LOCK_TIMEOUT_MS: u64 = 10_000;
14
15#[derive(Debug, Clone)]
16pub struct ProviderAuthStore {
17    path: PathBuf,
18    lock_path: PathBuf,
19}
20
21impl ProviderAuthStore {
22    pub fn new(config_dir: &Path) -> Self {
23        let auth_dir = config_dir.join(AUTH_DIR);
24        Self {
25            path: auth_dir.join(AUTH_FILENAME),
26            lock_path: auth_dir.join(LOCK_FILENAME),
27        }
28    }
29
30    pub fn path(&self) -> &Path {
31        &self.path
32    }
33
34    pub async fn load(&self) -> Result<ProviderAuthProfilesData> {
35        let _lock = self.acquire_lock().await?;
36        self.load_locked().await
37    }
38
39    pub async fn upsert_profile(
40        &self,
41        mut profile: ProviderAuthProfile,
42        set_active: bool,
43    ) -> Result<()> {
44        let _lock = self.acquire_lock().await?;
45        let mut data = self.load_locked().await?;
46
47        profile.updated_at = Utc::now();
48        if let Some(existing) = data.profiles.get(&profile.id) {
49            profile.created_at = existing.created_at;
50        }
51
52        if set_active {
53            data.active_profiles
54                .insert(profile.provider.clone(), profile.id.clone());
55        }
56        data.profiles.insert(profile.id.clone(), profile);
57        data.updated_at = Utc::now();
58        self.save_locked(&data).await
59    }
60
61    pub async fn remove_profile(&self, profile_id: &str) -> Result<bool> {
62        let _lock = self.acquire_lock().await?;
63        let mut data = self.load_locked().await?;
64        let removed = data.profiles.remove(profile_id).is_some();
65        if removed {
66            data.active_profiles.retain(|_, id| id != profile_id);
67            data.updated_at = Utc::now();
68            self.save_locked(&data).await?;
69        }
70        Ok(removed)
71    }
72
73    pub async fn set_active_profile(&self, provider: &str, profile_id: &str) -> Result<()> {
74        let _lock = self.acquire_lock().await?;
75        let mut data = self.load_locked().await?;
76        if !data.profiles.contains_key(profile_id) {
77            anyhow::bail!("Auth profile not found: {profile_id}");
78        }
79        data.active_profiles
80            .insert(provider.to_string(), profile_id.to_string());
81        data.updated_at = Utc::now();
82        self.save_locked(&data).await
83    }
84
85    pub async fn clear_active_profile(&self, provider: &str) -> Result<()> {
86        let _lock = self.acquire_lock().await?;
87        let mut data = self.load_locked().await?;
88        data.active_profiles.remove(provider);
89        data.updated_at = Utc::now();
90        self.save_locked(&data).await
91    }
92
93    pub async fn update_profile<F>(
94        &self,
95        profile_id: &str,
96        mut updater: F,
97    ) -> Result<ProviderAuthProfile>
98    where
99        F: FnMut(&mut ProviderAuthProfile) -> Result<()>,
100    {
101        let _lock = self.acquire_lock().await?;
102        let mut data = self.load_locked().await?;
103        let profile = data
104            .profiles
105            .get_mut(profile_id)
106            .ok_or_else(|| anyhow::anyhow!("Auth profile not found: {profile_id}"))?;
107        updater(profile)?;
108        profile.updated_at = Utc::now();
109        let updated = profile.clone();
110        data.updated_at = Utc::now();
111        self.save_locked(&data).await?;
112        Ok(updated)
113    }
114
115    async fn load_locked(&self) -> Result<ProviderAuthProfilesData> {
116        if !self.path.exists() {
117            return Ok(ProviderAuthProfilesData::default());
118        }
119        let raw = fs::read_to_string(&self.path)
120            .await
121            .with_context(|| format!("Failed to read auth store {}", self.path.display()))?;
122        let data = serde_json::from_str(&raw)
123            .with_context(|| format!("Failed to parse auth store {}", self.path.display()))?;
124        Ok(data)
125    }
126
127    async fn save_locked(&self, data: &ProviderAuthProfilesData) -> Result<()> {
128        if let Some(parent) = self.path.parent() {
129            fs::create_dir_all(parent).await.with_context(|| {
130                format!("Failed to create auth store directory {}", parent.display())
131            })?;
132        }
133        let temp_path = self.path.with_extension("json.tmp");
134        let payload = serde_json::to_vec_pretty(data)?;
135        fs::write(&temp_path, payload)
136            .await
137            .with_context(|| format!("Failed to write auth temp file {}", temp_path.display()))?;
138        fs::rename(&temp_path, &self.path).await.with_context(|| {
139            format!(
140                "Failed to move auth temp file {} into {}",
141                temp_path.display(),
142                self.path.display()
143            )
144        })?;
145        Ok(())
146    }
147
148    async fn acquire_lock(&self) -> Result<LockGuard> {
149        if let Some(parent) = self.lock_path.parent() {
150            fs::create_dir_all(parent).await.with_context(|| {
151                format!("Failed to create auth lock directory {}", parent.display())
152            })?;
153        }
154
155        let deadline = Instant::now() + Duration::from_millis(LOCK_TIMEOUT_MS);
156        loop {
157            match fs::OpenOptions::new()
158                .write(true)
159                .create_new(true)
160                .open(&self.lock_path)
161                .await
162            {
163                Ok(_) => return Ok(LockGuard(self.lock_path.clone())),
164                Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
165                    if Instant::now() >= deadline {
166                        anyhow::bail!("Timed out acquiring auth store lock");
167                    }
168                    sleep(Duration::from_millis(LOCK_RETRY_MS)).await;
169                }
170                Err(err) => {
171                    return Err(err).with_context(|| {
172                        format!("Failed to open auth lock {}", self.lock_path.display())
173                    })
174                }
175            }
176        }
177    }
178}
179
180struct LockGuard(PathBuf);
181
182impl Drop for LockGuard {
183    fn drop(&mut self) {
184        let _ = std::fs::remove_file(&self.0);
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::auth::profiles::{
192        profile_id, ProviderAuthKind, ProviderAuthProfile, ProviderTokenSet,
193    };
194    use tempfile::tempdir;
195
196    fn oauth_profile() -> ProviderAuthProfile {
197        ProviderAuthProfile::new_oauth(
198            "openai-codex",
199            "default",
200            ProviderTokenSet {
201                access_token: "access".into(),
202                refresh_token: Some("refresh".into()),
203                id_token: None,
204                expires_at: None,
205                token_type: Some("Bearer".into()),
206                scope: Some("openid".into()),
207            },
208        )
209    }
210
211    #[tokio::test]
212    async fn upsert_load_remove_profile_roundtrip() {
213        let dir = tempdir().unwrap();
214        let store = ProviderAuthStore::new(dir.path());
215        store.upsert_profile(oauth_profile(), true).await.unwrap();
216
217        let loaded = store.load().await.unwrap();
218        assert_eq!(
219            loaded.active_profiles.get("openai-codex").unwrap(),
220            "openai-codex:default"
221        );
222        assert!(loaded.profiles.contains_key("openai-codex:default"));
223
224        let removed = store.remove_profile("openai-codex:default").await.unwrap();
225        assert!(removed);
226        let loaded = store.load().await.unwrap();
227        assert!(!loaded.profiles.contains_key("openai-codex:default"));
228    }
229
230    #[tokio::test]
231    async fn set_and_clear_active_profile() {
232        let dir = tempdir().unwrap();
233        let store = ProviderAuthStore::new(dir.path());
234        store.upsert_profile(oauth_profile(), false).await.unwrap();
235        store
236            .set_active_profile("openai-codex", &profile_id("openai-codex", "default"))
237            .await
238            .unwrap();
239        assert_eq!(
240            store
241                .load()
242                .await
243                .unwrap()
244                .active_profiles
245                .get("openai-codex")
246                .cloned(),
247            Some("openai-codex:default".into())
248        );
249        store.clear_active_profile("openai-codex").await.unwrap();
250        assert!(!store
251            .load()
252            .await
253            .unwrap()
254            .active_profiles
255            .contains_key("openai-codex"));
256    }
257
258    #[tokio::test]
259    async fn update_profile_changes_token() {
260        let dir = tempdir().unwrap();
261        let store = ProviderAuthStore::new(dir.path());
262        store.upsert_profile(oauth_profile(), true).await.unwrap();
263        let updated = store
264            .update_profile("openai-codex:default", |profile| {
265                profile.kind = ProviderAuthKind::Token;
266                profile.token_set = None;
267                profile.token = Some("plain".into());
268                Ok(())
269            })
270            .await
271            .unwrap();
272        assert_eq!(updated.token.as_deref(), Some("plain"));
273    }
274
275    #[tokio::test]
276    async fn damaged_file_returns_error() {
277        let dir = tempdir().unwrap();
278        let store = ProviderAuthStore::new(dir.path());
279        std::fs::create_dir_all(store.path().parent().unwrap()).unwrap();
280        std::fs::write(store.path(), "{not-json").unwrap();
281        let err = store.load().await.unwrap_err().to_string();
282        assert!(err.contains("Failed to parse auth store"));
283    }
284}