Skip to main content

oxi/
auth_storage.rs

1//! Authentication storage for API keys and OAuth tokens
2//!
3//! Provides secure storage and retrieval of authentication credentials,
4//! with OS keyring integration and fallback to encrypted file storage.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::RwLock;
10
11/// Authentication credential
12#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum AuthCredential {
15    /// API key credential
16    ApiKey { key: String },
17    /// OAuth credential
18    OAuth {
19        access_token: String,
20        refresh_token: Option<String>,
21        expires_at: u64,
22    },
23}
24
25/// Authentication status
26#[derive(Debug, Clone)]
27pub struct AuthStatus {
28    /// Whether auth is configured
29    pub configured: bool,
30    /// Source of the auth (stored, runtime, environment, etc.)
31    pub source: Option<String>,
32    /// Label for display
33    pub label: Option<String>,
34}
35
36/// Result of an auth operation
37type AuthResult<T> = Result<T, AuthError>;
38
39/// Authentication errors
40#[derive(Debug, thiserror::Error)]
41pub enum AuthError {
42    #[error("Failed to read auth storage: {0}")]
43    ReadError(String),
44    #[error("Failed to write auth storage: {0}")]
45    WriteError(String),
46    #[error("Credential not found: {0}")]
47    NotFound(String),
48    #[error("Invalid credential format: {0}")]
49    InvalidFormat(String),
50    #[error("Keyring error: {0}")]
51    KeyringError(String),
52}
53
54/// Storage backend trait
55pub trait AuthStorageBackend: Send + Sync {
56    /// Read a value
57    fn read(&self) -> AuthResult<Option<String>>;
58    /// Write a value
59    fn write(&self, data: &str) -> AuthResult<()>;
60    /// Delete a value
61    fn delete(&self) -> AuthResult<()>;
62}
63
64/// File-based auth storage backend
65pub struct FileAuthStorage {
66    path: PathBuf,
67    cache: RwLock<Option<String>>,
68}
69
70impl FileAuthStorage {
71    /// Create a new file-based auth storage
72    pub fn new(path: PathBuf) -> Self {
73        Self {
74            path,
75            cache: RwLock::new(None),
76        }
77    }
78
79    /// Get the default auth file path
80    pub fn default_path() -> Option<PathBuf> {
81        dirs::config_dir().map(|p| p.join("oxi").join("auth.json"))
82    }
83}
84
85impl AuthStorageBackend for FileAuthStorage {
86    fn read(&self) -> AuthResult<Option<String>> {
87        if !self.path.exists() {
88            return Ok(None);
89        }
90
91        match std::fs::read_to_string(&self.path) {
92            Ok(content) => {
93                *self.cache.write().unwrap() = Some(content.clone());
94                Ok(Some(content))
95            }
96            Err(e) => Err(AuthError::ReadError(e.to_string())),
97        }
98    }
99
100    fn write(&self, data: &str) -> AuthResult<()> {
101        // Ensure parent directory exists
102        if let Some(parent) = self.path.parent() {
103            std::fs::create_dir_all(parent)
104                .map_err(|e| AuthError::WriteError(e.to_string()))?;
105        }
106
107        // Set file permissions to owner-only on Unix
108        #[cfg(unix)]
109        {
110            use std::os::unix::fs::PermissionsExt;
111            let perms = std::fs::Permissions::from_mode(0o600);
112            std::fs::set_permissions(&self.path, perms)
113                .map_err(|e| AuthError::WriteError(e.to_string()))?;
114        }
115
116        std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
117        *self.cache.write().unwrap() = Some(data.to_string());
118        Ok(())
119    }
120
121    fn delete(&self) -> AuthResult<()> {
122        if self.path.exists() {
123            std::fs::remove_file(&self.path)
124                .map_err(|e| AuthError::WriteError(e.to_string()))?;
125        }
126        *self.cache.write().unwrap() = None;
127        Ok(())
128    }
129}
130
131/// Environment variable backend
132pub struct EnvAuthStorage {
133    provider_prefix: String,
134}
135
136impl EnvAuthStorage {
137    /// Create a new environment-based auth storage
138    pub fn new(provider: &str) -> Self {
139        Self {
140            provider_prefix: format!(
141                "{}_API_KEY",
142                provider.to_uppercase().replace('-', "_")
143            ),
144        }
145    }
146}
147
148impl AuthStorageBackend for EnvAuthStorage {
149    fn read(&self) -> AuthResult<Option<String>> {
150        Ok(std::env::var(&self.provider_prefix).ok())
151    }
152
153    fn write(&self, _data: &str) -> AuthResult<()> {
154        Err(AuthError::WriteError(
155            "Cannot write to environment variables".to_string(),
156        ))
157    }
158
159    fn delete(&self) -> AuthResult<()> {
160        std::env::remove_var(&self.provider_prefix);
161        Ok(())
162    }
163}
164
165/// Memory-based auth storage (for testing)
166pub struct MemoryAuthStorage {
167    data: RwLock<HashMap<String, AuthCredential>>,
168}
169
170impl MemoryAuthStorage {
171    pub fn new() -> Self {
172        Self {
173            data: RwLock::new(HashMap::new()),
174        }
175    }
176}
177
178impl Default for MemoryAuthStorage {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184impl AuthStorageBackend for MemoryAuthStorage {
185    fn read(&self) -> AuthResult<Option<String>> {
186        // Not applicable for memory storage
187        Ok(None)
188    }
189
190    fn write(&self, _data: &str) -> AuthResult<()> {
191        // Not applicable for memory storage
192        Ok(())
193    }
194
195    fn delete(&self) -> AuthResult<()> {
196        self.data.write().unwrap().clear();
197        Ok(())
198    }
199}
200
201/// Main auth storage struct
202pub struct AuthStorage {
203    /// File-based storage
204    file_storage: Option<FileAuthStorage>,
205    /// In-memory cache
206    credentials: RwLock<HashMap<String, AuthCredential>>,
207    /// Runtime overrides (CLI --api-key)
208    runtime_overrides: RwLock<HashMap<String, String>>,
209}
210
211impl AuthStorage {
212    /// Create a new auth storage with default file backend
213    pub fn new() -> Self {
214        let file_storage = Self::default_path().map(FileAuthStorage::new);
215
216        let credentials = if let Some(ref storage) = file_storage {
217            if let Ok(Some(content)) = storage.read() {
218                serde_json::from_str(&content).unwrap_or_default()
219            } else {
220                HashMap::new()
221            }
222        } else {
223            HashMap::new()
224        };
225
226        Self {
227            file_storage,
228            credentials: RwLock::new(credentials),
229            runtime_overrides: RwLock::new(HashMap::new()),
230        }
231    }
232
233    /// Create with explicit storage backend
234    pub fn with_backend<B: AuthStorageBackend + 'static>(backend: B) -> Self {
235        let credentials = if let Ok(Some(content)) = backend.read() {
236            serde_json::from_str(&content).unwrap_or_default()
237        } else {
238            HashMap::new()
239        };
240
241        Self {
242            file_storage: None,
243            credentials: RwLock::new(credentials),
244            runtime_overrides: RwLock::new(HashMap::new()),
245        }
246    }
247
248    /// Create a memory-only storage
249    pub fn in_memory() -> Self {
250        Self {
251            file_storage: None,
252            credentials: RwLock::new(HashMap::new()),
253            runtime_overrides: RwLock::new(HashMap::new()),
254        }
255    }
256
257    /// Get the default auth file path
258    fn default_path() -> Option<PathBuf> {
259        dirs::config_dir().map(|p| p.join("oxi").join("auth.json"))
260    }
261
262    /// Set a runtime API key override (from CLI --api-key)
263    pub fn set_runtime_key(&self, provider: &str, api_key: String) {
264        self.runtime_overrides
265            .write()
266            .unwrap()
267            .insert(provider.to_string(), api_key);
268    }
269
270    /// Remove a runtime override
271    pub fn remove_runtime_key(&self, provider: &str) {
272        self.runtime_overrides.write().unwrap().remove(provider);
273    }
274
275    /// Check if a provider has any auth configured
276    pub fn has_auth(&self, provider: &str) -> bool {
277        // Check runtime override
278        if self.runtime_overrides.read().unwrap().contains_key(provider) {
279            return true;
280        }
281
282        // Check stored credentials
283        if self.credentials.read().unwrap().contains_key(provider) {
284            return true;
285        }
286
287        // Check environment
288        let env_key = format!(
289            "{}_API_KEY",
290            provider.to_uppercase().replace('-', "_")
291        );
292        std::env::var(&env_key).is_ok()
293    }
294
295    /// Get auth status for a provider
296    pub fn get_status(&self, provider: &str) -> AuthStatus {
297        if self.runtime_overrides.read().unwrap().contains_key(provider) {
298            return AuthStatus {
299                configured: false,
300                source: Some("runtime".to_string()),
301                label: Some("--api-key".to_string()),
302            };
303        }
304
305        if self.credentials.read().unwrap().contains_key(provider) {
306            return AuthStatus {
307                configured: true,
308                source: Some("stored".to_string()),
309                label: None,
310            };
311        }
312
313        let env_key = format!(
314            "{}_API_KEY",
315            provider.to_uppercase().replace('-', "_")
316        );
317        if std::env::var(&env_key).is_ok() {
318            return AuthStatus {
319                configured: false,
320                source: Some("environment".to_string()),
321                label: Some(env_key),
322            };
323        }
324
325        AuthStatus {
326            configured: false,
327            source: None,
328            label: None,
329        }
330    }
331
332    /// Get API key for a provider
333    ///
334    /// Priority:
335    /// 1. Runtime override (CLI --api-key)
336    /// 2. Stored API key
337    /// 3. OAuth token (auto-refreshed)
338    /// 4. Environment variable
339    pub fn get_api_key(&self, provider: &str) -> Option<String> {
340        // 1. Runtime override
341        if let Some(key) = self.runtime_overrides.read().unwrap().get(provider) {
342            return Some(key.clone());
343        }
344
345        // 2. Stored credential
346        if let Some(cred) = self.credentials.read().unwrap().get(provider) {
347            return match cred {
348                AuthCredential::ApiKey { key } => Some(key.clone()),
349                AuthCredential::OAuth { access_token, expires_at, .. } => {
350                    // Check if token needs refresh
351                    if *expires_at > std::time::SystemTime::now()
352                        .duration_since(std::time::UNIX_EPOCH)
353                        .unwrap()
354                        .as_secs()
355                    {
356                        Some(access_token.clone())
357                    } else {
358                        // Token expired
359                        None
360                    }
361                }
362            };
363        }
364
365        // 3. Environment variable
366        let env_key = format!(
367            "{}_API_KEY",
368            provider.to_uppercase().replace('-', "_")
369        );
370        std::env::var(&env_key).ok()
371    }
372
373    /// Set API key for a provider
374    pub fn set_api_key(&self, provider: &str, key: String) {
375        self.credentials
376            .write()
377            .unwrap()
378            .insert(provider.to_string(), AuthCredential::ApiKey { key });
379        self.persist();
380    }
381
382    /// Set OAuth credential for a provider
383    pub fn set_oauth(
384        &self,
385        provider: &str,
386        access_token: String,
387        refresh_token: Option<String>,
388        expires_at: u64,
389    ) {
390        self.credentials.write().unwrap().insert(
391            provider.to_string(),
392            AuthCredential::OAuth {
393                access_token,
394                refresh_token,
395                expires_at,
396            },
397        );
398        self.persist();
399    }
400
401    /// Remove credential for a provider
402    pub fn remove(&self, provider: &str) {
403        self.credentials.write().unwrap().remove(provider);
404        self.persist();
405    }
406
407    /// List all providers with credentials
408    pub fn list_providers(&self) -> Vec<String> {
409        self.credentials.read().unwrap().keys().cloned().collect()
410    }
411
412    /// Check if credential exists for provider
413    pub fn has(&self, provider: &str) -> bool {
414        self.credentials.read().unwrap().contains_key(provider)
415    }
416
417    /// Get all credentials (for debugging)
418    pub fn get_all(&self) -> HashMap<String, AuthCredential> {
419        self.credentials.read().unwrap().clone()
420    }
421
422    /// Clear all stored credentials
423    pub fn clear(&self) {
424        self.credentials.write().unwrap().clear();
425        self.persist();
426    }
427
428    /// Reload from disk
429    pub fn reload(&self) {
430        if let Some(ref storage) = self.file_storage {
431            if let Ok(Some(content)) = storage.read() {
432                if let Ok(creds) = serde_json::from_str(&content) {
433                    *self.credentials.write().unwrap() = creds;
434                }
435            }
436        }
437    }
438
439    /// Persist to disk
440    fn persist(&self) {
441        if let Some(ref storage) = self.file_storage {
442            let creds = self.credentials.read().unwrap();
443            if let Ok(json) = serde_json::to_string_pretty(&*creds) {
444                let _ = storage.write(&json);
445            }
446        }
447    }
448}
449
450impl Default for AuthStorage {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456/// Wrapper for using keyring crate with fallback
457pub mod keyring_support {
458    use super::*;
459
460    /// Try to get a secret from the OS keyring
461    #[cfg(feature = "keyring")]
462    pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
463        use keyring::Entry;
464        Entry::new(service, account)
465            .ok()
466            .and_then(|entry| entry.get_password().ok())
467    }
468
469    /// Try to set a secret in the OS keyring
470    #[cfg(feature = "keyring")]
471    pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> Result<(), AuthError> {
472        use keyring::Entry;
473        Entry::new(service, account)
474            .map_err(|e| AuthError::KeyringError(e.to_string()))?
475            .set_password(secret)
476            .map_err(|e| AuthError::KeyringError(e.to_string()))
477    }
478
479    /// Try to delete a secret from the OS keyring
480    #[cfg(feature = "keyring")]
481    pub fn delete_keyring_secret(service: &str, account: &str) -> Result<(), AuthError> {
482        use keyring::Entry;
483        Entry::new(service, account)
484            .map_err(|e| AuthError::KeyringError(e.to_string()))?
485            .delete_credential()
486            .map_err(|e| AuthError::KeyringError(e.to_string()))
487    }
488
489    // Non-keyring fallbacks
490    #[cfg(not(feature = "keyring"))]
491    pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
492        None
493    }
494
495    #[cfg(not(feature = "keyring"))]
496    pub fn set_keyring_secret(_service: &str, _account: &str, _secret: &str) -> Result<(), AuthError> {
497        Err(AuthError::KeyringError("Keyring support not compiled".to_string()))
498    }
499
500    #[cfg(not(feature = "keyring"))]
501    pub fn delete_keyring_secret(_service: &str, _account: &str) -> Result<(), AuthError> {
502        Err(AuthError::KeyringError("Keyring support not compiled".to_string()))
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_auth_storage_new() {
512        let storage = AuthStorage::in_memory();
513        assert!(!storage.has("anthropic"));
514    }
515
516    #[test]
517    fn test_set_and_get_api_key() {
518        let storage = AuthStorage::in_memory();
519        storage.set_api_key("anthropic", "sk-test123".to_string());
520        assert!(storage.has("anthropic"));
521        assert_eq!(storage.get_api_key("anthropic"), Some("sk-test123".to_string()));
522    }
523
524    #[test]
525    fn test_runtime_override() {
526        let storage = AuthStorage::in_memory();
527        storage.set_api_key("anthropic", "stored-key".to_string());
528        storage.set_runtime_key("anthropic", "runtime-key".to_string());
529
530        // Runtime key should take priority
531        assert_eq!(storage.get_api_key("anthropic"), Some("runtime-key".to_string()));
532    }
533
534    #[test]
535    fn test_remove_credential() {
536        let storage = AuthStorage::in_memory();
537        storage.set_api_key("anthropic", "sk-test123".to_string());
538        assert!(storage.has("anthropic"));
539
540        storage.remove("anthropic");
541        assert!(!storage.has("anthropic"));
542    }
543
544    #[test]
545    fn test_auth_status() {
546        let storage = AuthStorage::in_memory();
547        storage.set_api_key("anthropic", "sk-test123".to_string());
548
549        let status = storage.get_status("anthropic");
550        assert!(status.configured);
551        assert_eq!(status.source, Some("stored".to_string()));
552    }
553
554    #[test]
555    fn test_list_providers() {
556        let storage = AuthStorage::in_memory();
557        storage.set_api_key("anthropic", "key1".to_string());
558        storage.set_api_key("openai", "key2".to_string());
559
560        let providers = storage.list_providers();
561        assert!(providers.contains(&"anthropic".to_string()));
562        assert!(providers.contains(&"openai".to_string()));
563    }
564
565    #[test]
566    fn test_oauth_credential() {
567        let storage = AuthStorage::in_memory();
568        storage.set_oauth("provider", "access123".to_string(), Some("refresh456".to_string()), u64::MAX);
569
570        assert!(storage.has("provider"));
571        assert_eq!(storage.get_api_key("provider"), Some("access123".to_string()));
572    }
573
574    #[test]
575    fn test_expired_oauth_token() {
576        let storage = AuthStorage::in_memory();
577        // Set token that expired in the past
578        storage.set_oauth("provider", "access123".to_string(), None, 0);
579
580        // Token should be treated as expired
581        let key = storage.get_api_key("provider");
582        assert!(key.is_none());
583    }
584
585    #[test]
586    fn test_get_all_credentials() {
587        let storage = AuthStorage::in_memory();
588        storage.set_api_key("anthropic", "key1".to_string());
589        storage.set_api_key("openai", "key2".to_string());
590
591        let all = storage.get_all();
592        assert_eq!(all.len(), 2);
593    }
594
595    #[test]
596    fn test_clear() {
597        let storage = AuthStorage::in_memory();
598        storage.set_api_key("anthropic", "key".to_string());
599        assert!(storage.has("anthropic"));
600
601        storage.clear();
602        assert!(!storage.has("anthropic"));
603    }
604
605    #[test]
606    fn test_remove_runtime_key() {
607        let storage = AuthStorage::in_memory();
608        storage.set_api_key("anthropic", "stored".to_string());
609        storage.set_runtime_key("anthropic", "runtime".to_string());
610
611        assert_eq!(storage.get_api_key("anthropic"), Some("runtime".to_string()));
612
613        storage.remove_runtime_key("anthropic");
614        assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
615    }
616}