agent_diva_core/auth/
store.rs1use crate::auth::profiles::{ProviderAuthProfile, ProviderAuthProfilesData};
2use anyhow::{Context, Result};
3use chrono::Utc;
4use std::path::{Path, PathBuf};
5use std::time::Duration;
6use tokio::fs;
7use tokio::time::{sleep, Instant};
8
9const AUTH_DIR: &str = "data/auth";
10const AUTH_FILENAME: &str = "profiles.json";
11const LOCK_FILENAME: &str = "profiles.lock";
12const LOCK_RETRY_MS: u64 = 50;
13const LOCK_TIMEOUT_MS: u64 = 10_000;
14
15#[derive(Debug, Clone)]
16pub struct ProviderAuthStore {
17 path: PathBuf,
18 lock_path: PathBuf,
19}
20
21impl ProviderAuthStore {
22 pub fn new(config_dir: &Path) -> Self {
23 let auth_dir = config_dir.join(AUTH_DIR);
24 Self {
25 path: auth_dir.join(AUTH_FILENAME),
26 lock_path: auth_dir.join(LOCK_FILENAME),
27 }
28 }
29
30 pub fn path(&self) -> &Path {
31 &self.path
32 }
33
34 pub async fn load(&self) -> Result<ProviderAuthProfilesData> {
35 let _lock = self.acquire_lock().await?;
36 self.load_locked().await
37 }
38
39 pub async fn upsert_profile(
40 &self,
41 mut profile: ProviderAuthProfile,
42 set_active: bool,
43 ) -> Result<()> {
44 let _lock = self.acquire_lock().await?;
45 let mut data = self.load_locked().await?;
46
47 profile.updated_at = Utc::now();
48 if let Some(existing) = data.profiles.get(&profile.id) {
49 profile.created_at = existing.created_at;
50 }
51
52 if set_active {
53 data.active_profiles
54 .insert(profile.provider.clone(), profile.id.clone());
55 }
56 data.profiles.insert(profile.id.clone(), profile);
57 data.updated_at = Utc::now();
58 self.save_locked(&data).await
59 }
60
61 pub async fn remove_profile(&self, profile_id: &str) -> Result<bool> {
62 let _lock = self.acquire_lock().await?;
63 let mut data = self.load_locked().await?;
64 let removed = data.profiles.remove(profile_id).is_some();
65 if removed {
66 data.active_profiles.retain(|_, id| id != profile_id);
67 data.updated_at = Utc::now();
68 self.save_locked(&data).await?;
69 }
70 Ok(removed)
71 }
72
73 pub async fn set_active_profile(&self, provider: &str, profile_id: &str) -> Result<()> {
74 let _lock = self.acquire_lock().await?;
75 let mut data = self.load_locked().await?;
76 if !data.profiles.contains_key(profile_id) {
77 anyhow::bail!("Auth profile not found: {profile_id}");
78 }
79 data.active_profiles
80 .insert(provider.to_string(), profile_id.to_string());
81 data.updated_at = Utc::now();
82 self.save_locked(&data).await
83 }
84
85 pub async fn clear_active_profile(&self, provider: &str) -> Result<()> {
86 let _lock = self.acquire_lock().await?;
87 let mut data = self.load_locked().await?;
88 data.active_profiles.remove(provider);
89 data.updated_at = Utc::now();
90 self.save_locked(&data).await
91 }
92
93 pub async fn update_profile<F>(
94 &self,
95 profile_id: &str,
96 mut updater: F,
97 ) -> Result<ProviderAuthProfile>
98 where
99 F: FnMut(&mut ProviderAuthProfile) -> Result<()>,
100 {
101 let _lock = self.acquire_lock().await?;
102 let mut data = self.load_locked().await?;
103 let profile = data
104 .profiles
105 .get_mut(profile_id)
106 .ok_or_else(|| anyhow::anyhow!("Auth profile not found: {profile_id}"))?;
107 updater(profile)?;
108 profile.updated_at = Utc::now();
109 let updated = profile.clone();
110 data.updated_at = Utc::now();
111 self.save_locked(&data).await?;
112 Ok(updated)
113 }
114
115 async fn load_locked(&self) -> Result<ProviderAuthProfilesData> {
116 if !self.path.exists() {
117 return Ok(ProviderAuthProfilesData::default());
118 }
119 let raw = fs::read_to_string(&self.path)
120 .await
121 .with_context(|| format!("Failed to read auth store {}", self.path.display()))?;
122 let data = serde_json::from_str(&raw)
123 .with_context(|| format!("Failed to parse auth store {}", self.path.display()))?;
124 Ok(data)
125 }
126
127 async fn save_locked(&self, data: &ProviderAuthProfilesData) -> Result<()> {
128 if let Some(parent) = self.path.parent() {
129 fs::create_dir_all(parent).await.with_context(|| {
130 format!("Failed to create auth store directory {}", parent.display())
131 })?;
132 }
133 let temp_path = self.path.with_extension("json.tmp");
134 let payload = serde_json::to_vec_pretty(data)?;
135 fs::write(&temp_path, payload)
136 .await
137 .with_context(|| format!("Failed to write auth temp file {}", temp_path.display()))?;
138 fs::rename(&temp_path, &self.path).await.with_context(|| {
139 format!(
140 "Failed to move auth temp file {} into {}",
141 temp_path.display(),
142 self.path.display()
143 )
144 })?;
145 Ok(())
146 }
147
148 async fn acquire_lock(&self) -> Result<LockGuard> {
149 if let Some(parent) = self.lock_path.parent() {
150 fs::create_dir_all(parent).await.with_context(|| {
151 format!("Failed to create auth lock directory {}", parent.display())
152 })?;
153 }
154
155 let deadline = Instant::now() + Duration::from_millis(LOCK_TIMEOUT_MS);
156 loop {
157 match fs::OpenOptions::new()
158 .write(true)
159 .create_new(true)
160 .open(&self.lock_path)
161 .await
162 {
163 Ok(_) => return Ok(LockGuard(self.lock_path.clone())),
164 Err(err) if err.kind() == std::io::ErrorKind::AlreadyExists => {
165 if Instant::now() >= deadline {
166 anyhow::bail!("Timed out acquiring auth store lock");
167 }
168 sleep(Duration::from_millis(LOCK_RETRY_MS)).await;
169 }
170 Err(err) => {
171 return Err(err).with_context(|| {
172 format!("Failed to open auth lock {}", self.lock_path.display())
173 })
174 }
175 }
176 }
177 }
178}
179
180struct LockGuard(PathBuf);
181
182impl Drop for LockGuard {
183 fn drop(&mut self) {
184 let _ = std::fs::remove_file(&self.0);
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::auth::profiles::{
192 profile_id, ProviderAuthKind, ProviderAuthProfile, ProviderTokenSet,
193 };
194 use tempfile::tempdir;
195
196 fn oauth_profile() -> ProviderAuthProfile {
197 ProviderAuthProfile::new_oauth(
198 "openai-codex",
199 "default",
200 ProviderTokenSet {
201 access_token: "access".into(),
202 refresh_token: Some("refresh".into()),
203 id_token: None,
204 expires_at: None,
205 token_type: Some("Bearer".into()),
206 scope: Some("openid".into()),
207 },
208 )
209 }
210
211 #[tokio::test]
212 async fn upsert_load_remove_profile_roundtrip() {
213 let dir = tempdir().unwrap();
214 let store = ProviderAuthStore::new(dir.path());
215 store.upsert_profile(oauth_profile(), true).await.unwrap();
216
217 let loaded = store.load().await.unwrap();
218 assert_eq!(
219 loaded.active_profiles.get("openai-codex").unwrap(),
220 "openai-codex:default"
221 );
222 assert!(loaded.profiles.contains_key("openai-codex:default"));
223
224 let removed = store.remove_profile("openai-codex:default").await.unwrap();
225 assert!(removed);
226 let loaded = store.load().await.unwrap();
227 assert!(!loaded.profiles.contains_key("openai-codex:default"));
228 }
229
230 #[tokio::test]
231 async fn set_and_clear_active_profile() {
232 let dir = tempdir().unwrap();
233 let store = ProviderAuthStore::new(dir.path());
234 store.upsert_profile(oauth_profile(), false).await.unwrap();
235 store
236 .set_active_profile("openai-codex", &profile_id("openai-codex", "default"))
237 .await
238 .unwrap();
239 assert_eq!(
240 store
241 .load()
242 .await
243 .unwrap()
244 .active_profiles
245 .get("openai-codex")
246 .cloned(),
247 Some("openai-codex:default".into())
248 );
249 store.clear_active_profile("openai-codex").await.unwrap();
250 assert!(!store
251 .load()
252 .await
253 .unwrap()
254 .active_profiles
255 .contains_key("openai-codex"));
256 }
257
258 #[tokio::test]
259 async fn update_profile_changes_token() {
260 let dir = tempdir().unwrap();
261 let store = ProviderAuthStore::new(dir.path());
262 store.upsert_profile(oauth_profile(), true).await.unwrap();
263 let updated = store
264 .update_profile("openai-codex:default", |profile| {
265 profile.kind = ProviderAuthKind::Token;
266 profile.token_set = None;
267 profile.token = Some("plain".into());
268 Ok(())
269 })
270 .await
271 .unwrap();
272 assert_eq!(updated.token.as_deref(), Some("plain"));
273 }
274
275 #[tokio::test]
276 async fn damaged_file_returns_error() {
277 let dir = tempdir().unwrap();
278 let store = ProviderAuthStore::new(dir.path());
279 std::fs::create_dir_all(store.path().parent().unwrap()).unwrap();
280 std::fs::write(store.path(), "{not-json").unwrap();
281 let err = store.load().await.unwrap_err().to_string();
282 assert!(err.contains("Failed to parse auth store"));
283 }
284}