1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::RwLock;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum AuthCredential {
15 ApiKey { key: String },
17 OAuth {
19 access_token: String,
20 refresh_token: Option<String>,
21 expires_at: u64,
22 },
23}
24
25#[derive(Debug, Clone)]
27pub struct AuthStatus {
28 pub configured: bool,
30 pub source: Option<String>,
32 pub label: Option<String>,
34}
35
36type AuthResult<T> = Result<T, AuthError>;
38
39#[derive(Debug, thiserror::Error)]
41pub enum AuthError {
42 #[error("Failed to read auth storage: {0}")]
43 ReadError(String),
44 #[error("Failed to write auth storage: {0}")]
45 WriteError(String),
46 #[error("Credential not found: {0}")]
47 NotFound(String),
48 #[error("Invalid credential format: {0}")]
49 InvalidFormat(String),
50 #[error("Keyring error: {0}")]
51 KeyringError(String),
52}
53
54pub trait AuthStorageBackend: Send + Sync {
56 fn read(&self) -> AuthResult<Option<String>>;
58 fn write(&self, data: &str) -> AuthResult<()>;
60 fn delete(&self) -> AuthResult<()>;
62}
63
64pub struct FileAuthStorage {
66 path: PathBuf,
67 cache: RwLock<Option<String>>,
68}
69
70impl FileAuthStorage {
71 pub fn new(path: PathBuf) -> Self {
73 Self {
74 path,
75 cache: RwLock::new(None),
76 }
77 }
78
79 pub fn default_path() -> Option<PathBuf> {
81 dirs::config_dir().map(|p| p.join("oxi").join("auth.json"))
82 }
83}
84
85impl AuthStorageBackend for FileAuthStorage {
86 fn read(&self) -> AuthResult<Option<String>> {
87 if !self.path.exists() {
88 return Ok(None);
89 }
90
91 match std::fs::read_to_string(&self.path) {
92 Ok(content) => {
93 *self.cache.write().unwrap() = Some(content.clone());
94 Ok(Some(content))
95 }
96 Err(e) => Err(AuthError::ReadError(e.to_string())),
97 }
98 }
99
100 fn write(&self, data: &str) -> AuthResult<()> {
101 if let Some(parent) = self.path.parent() {
103 std::fs::create_dir_all(parent)
104 .map_err(|e| AuthError::WriteError(e.to_string()))?;
105 }
106
107 #[cfg(unix)]
109 {
110 use std::os::unix::fs::PermissionsExt;
111 let perms = std::fs::Permissions::from_mode(0o600);
112 std::fs::set_permissions(&self.path, perms)
113 .map_err(|e| AuthError::WriteError(e.to_string()))?;
114 }
115
116 std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
117 *self.cache.write().unwrap() = Some(data.to_string());
118 Ok(())
119 }
120
121 fn delete(&self) -> AuthResult<()> {
122 if self.path.exists() {
123 std::fs::remove_file(&self.path)
124 .map_err(|e| AuthError::WriteError(e.to_string()))?;
125 }
126 *self.cache.write().unwrap() = None;
127 Ok(())
128 }
129}
130
131pub struct EnvAuthStorage {
133 provider_prefix: String,
134}
135
136impl EnvAuthStorage {
137 pub fn new(provider: &str) -> Self {
139 Self {
140 provider_prefix: format!(
141 "{}_API_KEY",
142 provider.to_uppercase().replace('-', "_")
143 ),
144 }
145 }
146}
147
148impl AuthStorageBackend for EnvAuthStorage {
149 fn read(&self) -> AuthResult<Option<String>> {
150 Ok(std::env::var(&self.provider_prefix).ok())
151 }
152
153 fn write(&self, _data: &str) -> AuthResult<()> {
154 Err(AuthError::WriteError(
155 "Cannot write to environment variables".to_string(),
156 ))
157 }
158
159 fn delete(&self) -> AuthResult<()> {
160 std::env::remove_var(&self.provider_prefix);
161 Ok(())
162 }
163}
164
165pub struct MemoryAuthStorage {
167 data: RwLock<HashMap<String, AuthCredential>>,
168}
169
170impl MemoryAuthStorage {
171 pub fn new() -> Self {
172 Self {
173 data: RwLock::new(HashMap::new()),
174 }
175 }
176}
177
178impl Default for MemoryAuthStorage {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184impl AuthStorageBackend for MemoryAuthStorage {
185 fn read(&self) -> AuthResult<Option<String>> {
186 Ok(None)
188 }
189
190 fn write(&self, _data: &str) -> AuthResult<()> {
191 Ok(())
193 }
194
195 fn delete(&self) -> AuthResult<()> {
196 self.data.write().unwrap().clear();
197 Ok(())
198 }
199}
200
201pub struct AuthStorage {
203 file_storage: Option<FileAuthStorage>,
205 credentials: RwLock<HashMap<String, AuthCredential>>,
207 runtime_overrides: RwLock<HashMap<String, String>>,
209}
210
211impl AuthStorage {
212 pub fn new() -> Self {
214 let file_storage = Self::default_path().map(FileAuthStorage::new);
215
216 let credentials = if let Some(ref storage) = file_storage {
217 if let Ok(Some(content)) = storage.read() {
218 serde_json::from_str(&content).unwrap_or_default()
219 } else {
220 HashMap::new()
221 }
222 } else {
223 HashMap::new()
224 };
225
226 Self {
227 file_storage,
228 credentials: RwLock::new(credentials),
229 runtime_overrides: RwLock::new(HashMap::new()),
230 }
231 }
232
233 pub fn with_backend<B: AuthStorageBackend + 'static>(backend: B) -> Self {
235 let credentials = if let Ok(Some(content)) = backend.read() {
236 serde_json::from_str(&content).unwrap_or_default()
237 } else {
238 HashMap::new()
239 };
240
241 Self {
242 file_storage: None,
243 credentials: RwLock::new(credentials),
244 runtime_overrides: RwLock::new(HashMap::new()),
245 }
246 }
247
248 pub fn in_memory() -> Self {
250 Self {
251 file_storage: None,
252 credentials: RwLock::new(HashMap::new()),
253 runtime_overrides: RwLock::new(HashMap::new()),
254 }
255 }
256
257 fn default_path() -> Option<PathBuf> {
259 dirs::config_dir().map(|p| p.join("oxi").join("auth.json"))
260 }
261
262 pub fn set_runtime_key(&self, provider: &str, api_key: String) {
264 self.runtime_overrides
265 .write()
266 .unwrap()
267 .insert(provider.to_string(), api_key);
268 }
269
270 pub fn remove_runtime_key(&self, provider: &str) {
272 self.runtime_overrides.write().unwrap().remove(provider);
273 }
274
275 pub fn has_auth(&self, provider: &str) -> bool {
277 if self.runtime_overrides.read().unwrap().contains_key(provider) {
279 return true;
280 }
281
282 if self.credentials.read().unwrap().contains_key(provider) {
284 return true;
285 }
286
287 let env_key = format!(
289 "{}_API_KEY",
290 provider.to_uppercase().replace('-', "_")
291 );
292 std::env::var(&env_key).is_ok()
293 }
294
295 pub fn get_status(&self, provider: &str) -> AuthStatus {
297 if self.runtime_overrides.read().unwrap().contains_key(provider) {
298 return AuthStatus {
299 configured: false,
300 source: Some("runtime".to_string()),
301 label: Some("--api-key".to_string()),
302 };
303 }
304
305 if self.credentials.read().unwrap().contains_key(provider) {
306 return AuthStatus {
307 configured: true,
308 source: Some("stored".to_string()),
309 label: None,
310 };
311 }
312
313 let env_key = format!(
314 "{}_API_KEY",
315 provider.to_uppercase().replace('-', "_")
316 );
317 if std::env::var(&env_key).is_ok() {
318 return AuthStatus {
319 configured: false,
320 source: Some("environment".to_string()),
321 label: Some(env_key),
322 };
323 }
324
325 AuthStatus {
326 configured: false,
327 source: None,
328 label: None,
329 }
330 }
331
332 pub fn get_api_key(&self, provider: &str) -> Option<String> {
340 if let Some(key) = self.runtime_overrides.read().unwrap().get(provider) {
342 return Some(key.clone());
343 }
344
345 if let Some(cred) = self.credentials.read().unwrap().get(provider) {
347 return match cred {
348 AuthCredential::ApiKey { key } => Some(key.clone()),
349 AuthCredential::OAuth { access_token, expires_at, .. } => {
350 if *expires_at > std::time::SystemTime::now()
352 .duration_since(std::time::UNIX_EPOCH)
353 .unwrap()
354 .as_secs()
355 {
356 Some(access_token.clone())
357 } else {
358 None
360 }
361 }
362 };
363 }
364
365 let env_key = format!(
367 "{}_API_KEY",
368 provider.to_uppercase().replace('-', "_")
369 );
370 std::env::var(&env_key).ok()
371 }
372
373 pub fn set_api_key(&self, provider: &str, key: String) {
375 self.credentials
376 .write()
377 .unwrap()
378 .insert(provider.to_string(), AuthCredential::ApiKey { key });
379 self.persist();
380 }
381
382 pub fn set_oauth(
384 &self,
385 provider: &str,
386 access_token: String,
387 refresh_token: Option<String>,
388 expires_at: u64,
389 ) {
390 self.credentials.write().unwrap().insert(
391 provider.to_string(),
392 AuthCredential::OAuth {
393 access_token,
394 refresh_token,
395 expires_at,
396 },
397 );
398 self.persist();
399 }
400
401 pub fn remove(&self, provider: &str) {
403 self.credentials.write().unwrap().remove(provider);
404 self.persist();
405 }
406
407 pub fn list_providers(&self) -> Vec<String> {
409 self.credentials.read().unwrap().keys().cloned().collect()
410 }
411
412 pub fn has(&self, provider: &str) -> bool {
414 self.credentials.read().unwrap().contains_key(provider)
415 }
416
417 pub fn get_all(&self) -> HashMap<String, AuthCredential> {
419 self.credentials.read().unwrap().clone()
420 }
421
422 pub fn clear(&self) {
424 self.credentials.write().unwrap().clear();
425 self.persist();
426 }
427
428 pub fn reload(&self) {
430 if let Some(ref storage) = self.file_storage {
431 if let Ok(Some(content)) = storage.read() {
432 if let Ok(creds) = serde_json::from_str(&content) {
433 *self.credentials.write().unwrap() = creds;
434 }
435 }
436 }
437 }
438
439 fn persist(&self) {
441 if let Some(ref storage) = self.file_storage {
442 let creds = self.credentials.read().unwrap();
443 if let Ok(json) = serde_json::to_string_pretty(&*creds) {
444 let _ = storage.write(&json);
445 }
446 }
447 }
448}
449
450impl Default for AuthStorage {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456pub mod keyring_support {
458 use super::*;
459
460 #[cfg(feature = "keyring")]
462 pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
463 use keyring::Entry;
464 Entry::new(service, account)
465 .ok()
466 .and_then(|entry| entry.get_password().ok())
467 }
468
469 #[cfg(feature = "keyring")]
471 pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> Result<(), AuthError> {
472 use keyring::Entry;
473 Entry::new(service, account)
474 .map_err(|e| AuthError::KeyringError(e.to_string()))?
475 .set_password(secret)
476 .map_err(|e| AuthError::KeyringError(e.to_string()))
477 }
478
479 #[cfg(feature = "keyring")]
481 pub fn delete_keyring_secret(service: &str, account: &str) -> Result<(), AuthError> {
482 use keyring::Entry;
483 Entry::new(service, account)
484 .map_err(|e| AuthError::KeyringError(e.to_string()))?
485 .delete_credential()
486 .map_err(|e| AuthError::KeyringError(e.to_string()))
487 }
488
489 #[cfg(not(feature = "keyring"))]
491 pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
492 None
493 }
494
495 #[cfg(not(feature = "keyring"))]
496 pub fn set_keyring_secret(_service: &str, _account: &str, _secret: &str) -> Result<(), AuthError> {
497 Err(AuthError::KeyringError("Keyring support not compiled".to_string()))
498 }
499
500 #[cfg(not(feature = "keyring"))]
501 pub fn delete_keyring_secret(_service: &str, _account: &str) -> Result<(), AuthError> {
502 Err(AuthError::KeyringError("Keyring support not compiled".to_string()))
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_auth_storage_new() {
512 let storage = AuthStorage::in_memory();
513 assert!(!storage.has("anthropic"));
514 }
515
516 #[test]
517 fn test_set_and_get_api_key() {
518 let storage = AuthStorage::in_memory();
519 storage.set_api_key("anthropic", "sk-test123".to_string());
520 assert!(storage.has("anthropic"));
521 assert_eq!(storage.get_api_key("anthropic"), Some("sk-test123".to_string()));
522 }
523
524 #[test]
525 fn test_runtime_override() {
526 let storage = AuthStorage::in_memory();
527 storage.set_api_key("anthropic", "stored-key".to_string());
528 storage.set_runtime_key("anthropic", "runtime-key".to_string());
529
530 assert_eq!(storage.get_api_key("anthropic"), Some("runtime-key".to_string()));
532 }
533
534 #[test]
535 fn test_remove_credential() {
536 let storage = AuthStorage::in_memory();
537 storage.set_api_key("anthropic", "sk-test123".to_string());
538 assert!(storage.has("anthropic"));
539
540 storage.remove("anthropic");
541 assert!(!storage.has("anthropic"));
542 }
543
544 #[test]
545 fn test_auth_status() {
546 let storage = AuthStorage::in_memory();
547 storage.set_api_key("anthropic", "sk-test123".to_string());
548
549 let status = storage.get_status("anthropic");
550 assert!(status.configured);
551 assert_eq!(status.source, Some("stored".to_string()));
552 }
553
554 #[test]
555 fn test_list_providers() {
556 let storage = AuthStorage::in_memory();
557 storage.set_api_key("anthropic", "key1".to_string());
558 storage.set_api_key("openai", "key2".to_string());
559
560 let providers = storage.list_providers();
561 assert!(providers.contains(&"anthropic".to_string()));
562 assert!(providers.contains(&"openai".to_string()));
563 }
564
565 #[test]
566 fn test_oauth_credential() {
567 let storage = AuthStorage::in_memory();
568 storage.set_oauth("provider", "access123".to_string(), Some("refresh456".to_string()), u64::MAX);
569
570 assert!(storage.has("provider"));
571 assert_eq!(storage.get_api_key("provider"), Some("access123".to_string()));
572 }
573
574 #[test]
575 fn test_expired_oauth_token() {
576 let storage = AuthStorage::in_memory();
577 storage.set_oauth("provider", "access123".to_string(), None, 0);
579
580 let key = storage.get_api_key("provider");
582 assert!(key.is_none());
583 }
584
585 #[test]
586 fn test_get_all_credentials() {
587 let storage = AuthStorage::in_memory();
588 storage.set_api_key("anthropic", "key1".to_string());
589 storage.set_api_key("openai", "key2".to_string());
590
591 let all = storage.get_all();
592 assert_eq!(all.len(), 2);
593 }
594
595 #[test]
596 fn test_clear() {
597 let storage = AuthStorage::in_memory();
598 storage.set_api_key("anthropic", "key".to_string());
599 assert!(storage.has("anthropic"));
600
601 storage.clear();
602 assert!(!storage.has("anthropic"));
603 }
604
605 #[test]
606 fn test_remove_runtime_key() {
607 let storage = AuthStorage::in_memory();
608 storage.set_api_key("anthropic", "stored".to_string());
609 storage.set_runtime_key("anthropic", "runtime".to_string());
610
611 assert_eq!(storage.get_api_key("anthropic"), Some("runtime".to_string()));
612
613 storage.remove_runtime_key("anthropic");
614 assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
615 }
616}