stakpak_shared/
auth_manager.rs

1//! Authentication manager for storing and retrieving provider credentials
2//!
3//! This module manages provider credentials stored in `auth.toml` in the config directory.
4//! Credentials are organized by profile and provider, with support for a shared "all" profile
5//! that serves as a fallback for all other profiles.
6//!
7//! # File Structure
8//!
9//! ```toml
10//! # Shared across all profiles
11//! [all.anthropic]
12//! type = "oauth"
13//! access = "eyJ..."
14//! refresh = "eyJ..."
15//! expires = 1735600000000
16//!
17//! # Profile-specific override
18//! [work.anthropic]
19//! type = "api"
20//! key = "sk-ant-..."
21//! ```
22
23use crate::models::auth::ProviderAuth;
24use crate::oauth::error::{OAuthError, OAuthResult};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::{Path, PathBuf};
28
29/// The name of the auth configuration file
30const AUTH_FILE_NAME: &str = "auth.toml";
31
32/// Special profile name that provides defaults for all profiles
33const ALL_PROFILE: &str = "all";
34
35/// Structure of the auth.toml file
36#[derive(Debug, Clone, Serialize, Deserialize, Default)]
37pub struct AuthFile {
38    /// Profile-scoped credentials: profile_name -> provider_name -> auth
39    #[serde(flatten)]
40    pub profiles: HashMap<String, HashMap<String, ProviderAuth>>,
41}
42
43/// Manages provider credentials stored in auth.toml
44#[derive(Debug, Clone)]
45pub struct AuthManager {
46    /// Path to the auth.toml file
47    auth_path: PathBuf,
48    /// Loaded auth file contents
49    auth_file: AuthFile,
50}
51
52impl AuthManager {
53    /// Load auth manager for the given config directory
54    pub fn new(config_dir: &Path) -> OAuthResult<Self> {
55        let auth_path = config_dir.join(AUTH_FILE_NAME);
56        let auth_file = if auth_path.exists() {
57            let content = std::fs::read_to_string(&auth_path)?;
58            toml::from_str(&content)?
59        } else {
60            AuthFile::default()
61        };
62
63        Ok(Self {
64            auth_path,
65            auth_file,
66        })
67    }
68
69    /// Load auth manager from the default Stakpak config directory (~/.stakpak/)
70    pub fn from_default_dir() -> OAuthResult<Self> {
71        let config_dir = get_default_config_dir()?;
72        Self::new(&config_dir)
73    }
74
75    /// Get credentials for a provider, respecting profile inheritance
76    ///
77    /// Resolution order:
78    /// 1. `[{profile}.{provider}]` - profile-specific
79    /// 2. `[all.{provider}]` - shared fallback
80    pub fn get(&self, profile: &str, provider: &str) -> Option<&ProviderAuth> {
81        // First, check profile-specific credentials
82        if let Some(providers) = self.auth_file.profiles.get(profile)
83            && let Some(auth) = providers.get(provider)
84        {
85            return Some(auth);
86        }
87
88        // Fall back to "all" profile
89        if profile != ALL_PROFILE
90            && let Some(providers) = self.auth_file.profiles.get(ALL_PROFILE)
91            && let Some(auth) = providers.get(provider)
92        {
93            return Some(auth);
94        }
95
96        None
97    }
98
99    /// Set credentials for a provider in a specific profile
100    pub fn set(&mut self, profile: &str, provider: &str, auth: ProviderAuth) -> OAuthResult<()> {
101        self.auth_file
102            .profiles
103            .entry(profile.to_string())
104            .or_default()
105            .insert(provider.to_string(), auth);
106
107        self.save()
108    }
109
110    /// Remove credentials for a provider from a specific profile
111    pub fn remove(&mut self, profile: &str, provider: &str) -> OAuthResult<bool> {
112        let removed = if let Some(providers) = self.auth_file.profiles.get_mut(profile) {
113            let removed = providers.remove(provider).is_some();
114            // Clean up empty profile entries
115            if providers.is_empty() {
116                self.auth_file.profiles.remove(profile);
117            }
118            removed
119        } else {
120            false
121        };
122
123        if removed {
124            self.save()?;
125        }
126
127        Ok(removed)
128    }
129
130    /// List all credentials
131    pub fn list(&self) -> &HashMap<String, HashMap<String, ProviderAuth>> {
132        &self.auth_file.profiles
133    }
134
135    /// Get all credentials for a specific profile (including inherited from "all")
136    pub fn list_for_profile(&self, profile: &str) -> HashMap<String, &ProviderAuth> {
137        let mut result = HashMap::new();
138
139        // Start with "all" profile credentials
140        if let Some(all_providers) = self.auth_file.profiles.get(ALL_PROFILE) {
141            for (provider, auth) in all_providers {
142                result.insert(provider.clone(), auth);
143            }
144        }
145
146        // Override with profile-specific credentials
147        if profile != ALL_PROFILE
148            && let Some(profile_providers) = self.auth_file.profiles.get(profile)
149        {
150            for (provider, auth) in profile_providers {
151                result.insert(provider.clone(), auth);
152            }
153        }
154
155        result
156    }
157
158    /// Check if any credentials are configured
159    pub fn has_credentials(&self) -> bool {
160        self.auth_file
161            .profiles
162            .values()
163            .any(|providers| !providers.is_empty())
164    }
165
166    /// Get the path to the auth file
167    pub fn auth_path(&self) -> &Path {
168        &self.auth_path
169    }
170
171    /// Update OAuth tokens for a provider (used during token refresh)
172    pub fn update_oauth_tokens(
173        &mut self,
174        profile: &str,
175        provider: &str,
176        access: &str,
177        refresh: &str,
178        expires: i64,
179    ) -> OAuthResult<()> {
180        let auth = ProviderAuth::oauth(access, refresh, expires);
181        self.set(profile, provider, auth)
182    }
183
184    /// Save changes to disk
185    fn save(&self) -> OAuthResult<()> {
186        // Ensure parent directory exists
187        if let Some(parent) = self.auth_path.parent() {
188            std::fs::create_dir_all(parent)?;
189        }
190
191        let content = toml::to_string_pretty(&self.auth_file)?;
192
193        // Write to a temp file first, then rename for atomicity
194        let temp_path = self.auth_path.with_extension("toml.tmp");
195        std::fs::write(&temp_path, &content)?;
196
197        // Set file permissions to 0600 (owner read/write only) on Unix
198        #[cfg(unix)]
199        {
200            use std::os::unix::fs::PermissionsExt;
201            let permissions = std::fs::Permissions::from_mode(0o600);
202            std::fs::set_permissions(&temp_path, permissions)?;
203        }
204
205        // Atomic rename
206        std::fs::rename(&temp_path, &self.auth_path)?;
207
208        Ok(())
209    }
210}
211
212/// Get the default Stakpak config directory
213pub fn get_default_config_dir() -> OAuthResult<PathBuf> {
214    let home = dirs::home_dir().ok_or_else(|| {
215        OAuthError::IoError(std::io::Error::new(
216            std::io::ErrorKind::NotFound,
217            "Could not determine home directory",
218        ))
219    })?;
220
221    Ok(home.join(".stakpak"))
222}
223
224/// Get the auth file path for a given config directory
225pub fn get_auth_file_path(config_dir: &Path) -> PathBuf {
226    config_dir.join(AUTH_FILE_NAME)
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use tempfile::TempDir;
233
234    fn create_test_auth_manager() -> (AuthManager, TempDir) {
235        let temp_dir = TempDir::new().unwrap();
236        let manager = AuthManager::new(temp_dir.path()).unwrap();
237        (manager, temp_dir)
238    }
239
240    #[test]
241    fn test_new_empty() {
242        let (manager, _temp) = create_test_auth_manager();
243        assert!(!manager.has_credentials());
244        assert!(manager.list().is_empty());
245    }
246
247    #[test]
248    fn test_set_and_get() {
249        let (mut manager, _temp) = create_test_auth_manager();
250
251        let auth = ProviderAuth::api_key("sk-test-key");
252        manager.set("default", "anthropic", auth.clone()).unwrap();
253
254        let retrieved = manager.get("default", "anthropic");
255        assert!(retrieved.is_some());
256        assert_eq!(retrieved.unwrap(), &auth);
257    }
258
259    #[test]
260    fn test_profile_inheritance() {
261        let (mut manager, _temp) = create_test_auth_manager();
262
263        // Set in "all" profile
264        let all_auth = ProviderAuth::api_key("sk-all-key");
265        manager.set("all", "anthropic", all_auth.clone()).unwrap();
266
267        // Should be accessible from any profile
268        assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
269        assert_eq!(manager.get("work", "anthropic"), Some(&all_auth));
270        assert_eq!(manager.get("all", "anthropic"), Some(&all_auth));
271    }
272
273    #[test]
274    fn test_profile_override() {
275        let (mut manager, _temp) = create_test_auth_manager();
276
277        // Set in "all" profile
278        let all_auth = ProviderAuth::api_key("sk-all-key");
279        manager.set("all", "anthropic", all_auth.clone()).unwrap();
280
281        // Override in "work" profile
282        let work_auth = ProviderAuth::api_key("sk-work-key");
283        manager.set("work", "anthropic", work_auth.clone()).unwrap();
284
285        // "work" should get its own key
286        assert_eq!(manager.get("work", "anthropic"), Some(&work_auth));
287
288        // "default" should still get the "all" key
289        assert_eq!(manager.get("default", "anthropic"), Some(&all_auth));
290    }
291
292    #[test]
293    fn test_remove() {
294        let (mut manager, _temp) = create_test_auth_manager();
295
296        let auth = ProviderAuth::api_key("sk-test-key");
297        manager.set("default", "anthropic", auth).unwrap();
298
299        assert!(manager.get("default", "anthropic").is_some());
300
301        let removed = manager.remove("default", "anthropic").unwrap();
302        assert!(removed);
303
304        assert!(manager.get("default", "anthropic").is_none());
305    }
306
307    #[test]
308    fn test_remove_nonexistent() {
309        let (mut manager, _temp) = create_test_auth_manager();
310
311        let removed = manager.remove("default", "anthropic").unwrap();
312        assert!(!removed);
313    }
314
315    #[test]
316    fn test_list_for_profile() {
317        let (mut manager, _temp) = create_test_auth_manager();
318
319        let all_anthropic = ProviderAuth::api_key("sk-all-anthropic");
320        let all_openai = ProviderAuth::api_key("sk-all-openai");
321        let work_anthropic = ProviderAuth::api_key("sk-work-anthropic");
322
323        manager
324            .set("all", "anthropic", all_anthropic.clone())
325            .unwrap();
326        manager.set("all", "openai", all_openai.clone()).unwrap();
327        manager
328            .set("work", "anthropic", work_anthropic.clone())
329            .unwrap();
330
331        let work_creds = manager.list_for_profile("work");
332        assert_eq!(work_creds.len(), 2);
333        assert_eq!(work_creds.get("anthropic"), Some(&&work_anthropic));
334        assert_eq!(work_creds.get("openai"), Some(&&all_openai));
335
336        let default_creds = manager.list_for_profile("default");
337        assert_eq!(default_creds.len(), 2);
338        assert_eq!(default_creds.get("anthropic"), Some(&&all_anthropic));
339        assert_eq!(default_creds.get("openai"), Some(&&all_openai));
340    }
341
342    #[test]
343    fn test_persistence() {
344        let temp_dir = TempDir::new().unwrap();
345
346        // Create and save credentials
347        {
348            let mut manager = AuthManager::new(temp_dir.path()).unwrap();
349            let auth = ProviderAuth::api_key("sk-test-key");
350            manager.set("default", "anthropic", auth).unwrap();
351        }
352
353        // Load and verify
354        {
355            let manager = AuthManager::new(temp_dir.path()).unwrap();
356            let retrieved = manager.get("default", "anthropic");
357            assert!(retrieved.is_some());
358            assert_eq!(retrieved.unwrap().api_key_value(), Some("sk-test-key"));
359        }
360    }
361
362    #[test]
363    fn test_oauth_tokens() {
364        let (mut manager, _temp) = create_test_auth_manager();
365
366        let expires = chrono::Utc::now().timestamp_millis() + 3600000;
367        let auth = ProviderAuth::oauth("access-token", "refresh-token", expires);
368        manager.set("default", "anthropic", auth).unwrap();
369
370        let retrieved = manager.get("default", "anthropic").unwrap();
371        assert!(retrieved.is_oauth());
372        assert_eq!(retrieved.access_token(), Some("access-token"));
373        assert_eq!(retrieved.refresh_token(), Some("refresh-token"));
374    }
375
376    #[test]
377    fn test_update_oauth_tokens() {
378        let (mut manager, _temp) = create_test_auth_manager();
379
380        // Initial set
381        manager
382            .set(
383                "default",
384                "anthropic",
385                ProviderAuth::oauth("old-access", "old-refresh", 0),
386            )
387            .unwrap();
388
389        // Update tokens
390        let new_expires = chrono::Utc::now().timestamp_millis() + 3600000;
391        manager
392            .update_oauth_tokens(
393                "default",
394                "anthropic",
395                "new-access",
396                "new-refresh",
397                new_expires,
398            )
399            .unwrap();
400
401        let retrieved = manager.get("default", "anthropic").unwrap();
402        assert_eq!(retrieved.access_token(), Some("new-access"));
403        assert_eq!(retrieved.refresh_token(), Some("new-refresh"));
404    }
405
406    #[cfg(unix)]
407    #[test]
408    fn test_file_permissions() {
409        use std::os::unix::fs::PermissionsExt;
410
411        let temp_dir = TempDir::new().unwrap();
412        let mut manager = AuthManager::new(temp_dir.path()).unwrap();
413
414        let auth = ProviderAuth::api_key("sk-test-key");
415        manager.set("default", "anthropic", auth).unwrap();
416
417        let metadata = std::fs::metadata(manager.auth_path()).unwrap();
418        let mode = metadata.permissions().mode();
419
420        // Check that file is readable/writable only by owner (0600)
421        assert_eq!(mode & 0o777, 0o600);
422    }
423}