1use crate::models::auth::ProviderAuth;
24use crate::oauth::error::{OAuthError, OAuthResult};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28
29const AUTH_FILE_NAME: &str = "auth.toml";
31
32const ALL_PROFILE: &str = "all";
34
35#[derive(Debug, Clone, Serialize, Deserialize, Default)]
37pub struct AuthFile {
38 #[serde(flatten)]
40 pub profiles: HashMap<String, HashMap<String, ProviderAuth>>,
41}
42
43#[derive(Debug, Clone)]
45pub struct AuthManager {
46 auth_path: PathBuf,
48 auth_file: AuthFile,
50}
51
52impl AuthManager {
53 pub fn new(config_dir: &Path) -> OAuthResult<Self> {
55 let auth_path = config_dir.join(AUTH_FILE_NAME);
56 let auth_file = if auth_path.exists() {
57 let content = std::fs::read_to_string(&auth_path)?;
58 toml::from_str(&content)?
59 } else {
60 AuthFile::default()
61 };
62
63 Ok(Self {
64 auth_path,
65 auth_file,
66 })
67 }
68
69 pub fn from_default_dir() -> OAuthResult<Self> {
71 let config_dir = get_default_config_dir()?;
72 Self::new(&config_dir)
73 }
74
75 pub fn get(&self, profile: &str, provider: &str) -> Option<&ProviderAuth> {
81 if let Some(providers) = self.auth_file.profiles.get(profile)
83 && let Some(auth) = providers.get(provider)
84 {
85 return Some(auth);
86 }
87
88 if profile != ALL_PROFILE
90 && let Some(providers) = self.auth_file.profiles.get(ALL_PROFILE)
91 && let Some(auth) = providers.get(provider)
92 {
93 return Some(auth);
94 }
95
96 None
97 }
98
99 pub fn set(&mut self, profile: &str, provider: &str, auth: ProviderAuth) -> OAuthResult<()> {
101 self.auth_file
102 .profiles
103 .entry(profile.to_string())
104 .or_default()
105 .insert(provider.to_string(), auth);
106
107 self.save()
108 }
109
110 pub fn remove(&mut self, profile: &str, provider: &str) -> OAuthResult<bool> {
112 let removed = if let Some(providers) = self.auth_file.profiles.get_mut(profile) {
113 let removed = providers.remove(provider).is_some();
114 if providers.is_empty() {
116 self.auth_file.profiles.remove(profile);
117 }
118 removed
119 } else {
120 false
121 };
122
123 if removed {
124 self.save()?;
125 }
126
127 Ok(removed)
128 }
129
130 pub fn list(&self) -> &HashMap<String, HashMap<String, ProviderAuth>> {
132 &self.auth_file.profiles
133 }
134
135 pub fn list_for_profile(&self, profile: &str) -> HashMap<String, &ProviderAuth> {
137 let mut result = HashMap::new();
138
139 if let Some(all_providers) = self.auth_file.profiles.get(ALL_PROFILE) {
141 for (provider, auth) in all_providers {
142 result.insert(provider.clone(), auth);
143 }
144 }
145
146 if profile != ALL_PROFILE
148 && let Some(profile_providers) = self.auth_file.profiles.get(profile)
149 {
150 for (provider, auth) in profile_providers {
151 result.insert(provider.clone(), auth);
152 }
153 }
154
155 result
156 }
157
158 pub fn has_credentials(&self) -> bool {
160 self.auth_file
161 .profiles
162 .values()
163 .any(|providers| !providers.is_empty())
164 }
165
166 pub fn auth_path(&self) -> &Path {
168 &self.auth_path
169 }
170
171 pub fn update_oauth_tokens(
173 &mut self,
174 profile: &str,
175 provider: &str,
176 access: &str,
177 refresh: &str,
178 expires: i64,
179 ) -> OAuthResult<()> {
180 let auth = ProviderAuth::oauth(access, refresh, expires);
181 self.set(profile, provider, auth)
182 }
183
184 fn save(&self) -> OAuthResult<()> {
186 if let Some(parent) = self.auth_path.parent() {
188 std::fs::create_dir_all(parent)?;
189 }
190
191 let content = toml::to_string_pretty(&self.auth_file)?;
192
193 let temp_path = self.auth_path.with_extension("toml.tmp");
195 std::fs::write(&temp_path, &content)?;
196
197 #[cfg(unix)]
199 {
200 use std::os::unix::fs::PermissionsExt;
201 let permissions = std::fs::Permissions::from_mode(0o600);
202 std::fs::set_permissions(&temp_path, permissions)?;
203 }
204
205 std::fs::rename(&temp_path, &self.auth_path)?;
207
208 Ok(())
209 }
210}
211
212pub fn get_default_config_dir() -> OAuthResult<PathBuf> {
214 let home = dirs::home_dir().ok_or_else(|| {
215 OAuthError::IoError(std::io::Error::new(
216 std::io::ErrorKind::NotFound,
217 "Could not determine home directory",
218 ))
219 })?;
220
221 Ok(home.join(".stakpak"))
222}
223
224pub fn get_auth_file_path(config_dir: &Path) -> PathBuf {
226 config_dir.join(AUTH_FILE_NAME)
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use tempfile::TempDir;
233
234 fn create_test_auth_manager() -> (AuthManager, TempDir) {
235 let temp_dir = TempDir::new().unwrap();
236 let manager = AuthManager::new(temp_dir.path()).unwrap();
237 (manager, temp_dir)
238 }
239
240 #[test]
241 fn test_new_empty() {
242 let (manager, _temp) = create_test_auth_manager();
243 assert!(!manager.has_credentials());
244 assert!(manager.list().is_empty());
245 }
246
247 #[test]
248 fn test_set_and_get() {
249 let (mut manager, _temp) = create_test_auth_manager();
250
251 let auth = ProviderAuth::api_key("sk-test-key");
252 manager.set("default", "anthropic", auth.clone()).unwrap();
253
254 let retrieved = manager.get("default", "anthropic");
255 assert!(retrieved.is_some());
256 assert_eq!(retrieved.unwrap(), &auth);
257 }
258
259 #[test]
260 fn test_profile_inheritance() {
261 let (mut manager, _temp) = create_test_auth_manager();
262
263 let all_auth = ProviderAuth::api_key("sk-all-key");
265 manager.set("all", "anthropic", all_auth.clone()).unwrap();
266
267 assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
269 assert_eq!(manager.get("work", "anthropic"), Some(&all_auth));
270 assert_eq!(manager.get("all", "anthropic"), Some(&all_auth));
271 }
272
273 #[test]
274 fn test_profile_override() {
275 let (mut manager, _temp) = create_test_auth_manager();
276
277 let all_auth = ProviderAuth::api_key("sk-all-key");
279 manager.set("all", "anthropic", all_auth.clone()).unwrap();
280
281 let work_auth = ProviderAuth::api_key("sk-work-key");
283 manager.set("work", "anthropic", work_auth.clone()).unwrap();
284
285 assert_eq!(manager.get("work", "anthropic"), Some(&work_auth));
287
288 assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
290 }
291
292 #[test]
293 fn test_remove() {
294 let (mut manager, _temp) = create_test_auth_manager();
295
296 let auth = ProviderAuth::api_key("sk-test-key");
297 manager.set("default", "anthropic", auth).unwrap();
298
299 assert!(manager.get("default", "anthropic").is_some());
300
301 let removed = manager.remove("default", "anthropic").unwrap();
302 assert!(removed);
303
304 assert!(manager.get("default", "anthropic").is_none());
305 }
306
307 #[test]
308 fn test_remove_nonexistent() {
309 let (mut manager, _temp) = create_test_auth_manager();
310
311 let removed = manager.remove("default", "anthropic").unwrap();
312 assert!(!removed);
313 }
314
315 #[test]
316 fn test_list_for_profile() {
317 let (mut manager, _temp) = create_test_auth_manager();
318
319 let all_anthropic = ProviderAuth::api_key("sk-all-anthropic");
320 let all_openai = ProviderAuth::api_key("sk-all-openai");
321 let work_anthropic = ProviderAuth::api_key("sk-work-anthropic");
322
323 manager
324 .set("all", "anthropic", all_anthropic.clone())
325 .unwrap();
326 manager.set("all", "openai", all_openai.clone()).unwrap();
327 manager
328 .set("work", "anthropic", work_anthropic.clone())
329 .unwrap();
330
331 let work_creds = manager.list_for_profile("work");
332 assert_eq!(work_creds.len(), 2);
333 assert_eq!(work_creds.get("anthropic"), Some(&&work_anthropic));
334 assert_eq!(work_creds.get("openai"), Some(&&all_openai));
335
336 let default_creds = manager.list_for_profile("default");
337 assert_eq!(default_creds.len(), 2);
338 assert_eq!(default_creds.get("anthropic"), Some(&&all_anthropic));
339 assert_eq!(default_creds.get("openai"), Some(&&all_openai));
340 }
341
342 #[test]
343 fn test_persistence() {
344 let temp_dir = TempDir::new().unwrap();
345
346 {
348 let mut manager = AuthManager::new(temp_dir.path()).unwrap();
349 let auth = ProviderAuth::api_key("sk-test-key");
350 manager.set("default", "anthropic", auth).unwrap();
351 }
352
353 {
355 let manager = AuthManager::new(temp_dir.path()).unwrap();
356 let retrieved = manager.get("default", "anthropic");
357 assert!(retrieved.is_some());
358 assert_eq!(retrieved.unwrap().api_key_value(), Some("sk-test-key"));
359 }
360 }
361
362 #[test]
363 fn test_oauth_tokens() {
364 let (mut manager, _temp) = create_test_auth_manager();
365
366 let expires = chrono::Utc::now().timestamp_millis() + 3600000;
367 let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
368 manager.set("default", "anthropic", auth).unwrap();
369
370 let retrieved = manager.get("default", "anthropic").unwrap();
371 assert!(retrieved.is_oauth());
372 assert_eq!(retrieved.access_token(), Some("access-token"));
373 assert_eq!(retrieved.refresh_token(), Some("refresh-token"));
374 }
375
376 #[test]
377 fn test_update_oauth_tokens() {
378 let (mut manager, _temp) = create_test_auth_manager();
379
380 manager
382 .set(
383 "default",
384 "anthropic",
385 ProviderAuth::oauth("old-access", "old-refresh", 0),
386 )
387 .unwrap();
388
389 let new_expires = chrono::Utc::now().timestamp_millis() + 3600000;
391 manager
392 .update_oauth_tokens(
393 "default",
394 "anthropic",
395 "new-access",
396 "new-refresh",
397 new_expires,
398 )
399 .unwrap();
400
401 let retrieved = manager.get("default", "anthropic").unwrap();
402 assert_eq!(retrieved.access_token(), Some("new-access"));
403 assert_eq!(retrieved.refresh_token(), Some("new-refresh"));
404 }
405
406 #[cfg(unix)]
407 #[test]
408 fn test_file_permissions() {
409 use std::os::unix::fs::PermissionsExt;
410
411 let temp_dir = TempDir::new().unwrap();
412 let mut manager = AuthManager::new(temp_dir.path()).unwrap();
413
414 let auth = ProviderAuth::api_key("sk-test-key");
415 manager.set("default", "anthropic", auth).unwrap();
416
417 let metadata = std::fs::metadata(manager.auth_path()).unwrap();
418 let mode = metadata.permissions().mode();
419
420 assert_eq!(mode & 0o777, 0o600);
422 }
423}