1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::{Arc, OnceLock};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum AuthCredential {
21 ApiKey {
23 key: String,
25 },
26 OAuth {
28 access_token: String,
30 refresh_token: Option<String>,
32 expires_at: u64,
34 #[serde(default)]
36 scopes: Option<String>,
37 #[serde(default)]
39 provider_data: Option<serde_json::Value>,
40 },
41 Session {
43 token: String,
45 #[serde(default)]
47 expires_at: u64,
48 #[serde(default)]
50 metadata: Option<serde_json::Value>,
51 },
52}
53
54impl AuthCredential {
55 pub fn is_expired(&self) -> bool {
57 match self {
58 AuthCredential::OAuth { expires_at, .. } => {
59 let now = now_secs();
60 *expires_at < now
61 }
62 AuthCredential::Session { expires_at, .. } => {
63 if *expires_at == 0 {
64 return false; }
66 *expires_at <= now_secs()
67 }
68 AuthCredential::ApiKey { .. } => false,
69 }
70 }
71
72 pub fn needs_refresh(&self) -> bool {
74 match self {
75 AuthCredential::OAuth {
76 expires_at,
77 refresh_token,
78 ..
79 } => {
80 let now = now_secs();
81 refresh_token.is_some() && *expires_at <= now + 60
82 }
83 AuthCredential::Session { .. } => false,
84 AuthCredential::ApiKey { .. } => false,
85 }
86 }
87
88 pub fn access_token(&self) -> Option<&str> {
90 match self {
91 AuthCredential::OAuth { access_token, .. } if !self.is_expired() => Some(access_token),
92 AuthCredential::Session { token, .. } if !self.is_expired() => Some(token),
93 _ => None,
94 }
95 }
96
97 pub fn type_name(&self) -> &'static str {
99 match self {
100 AuthCredential::ApiKey { .. } => "api_key",
101 AuthCredential::OAuth { .. } => "oauth",
102 AuthCredential::Session { .. } => "session",
103 }
104 }
105
106 pub fn validate(&self) -> Result<(), CredentialValidationError> {
108 match self {
109 AuthCredential::ApiKey { key } => {
110 if key.is_empty() {
111 return Err(CredentialValidationError::EmptyField("key".to_string()));
112 }
113 if key == "your-api-key-here" || key == "xxx" {
115 return Err(CredentialValidationError::PlaceholderValue(key.clone()));
116 }
117 Ok(())
118 }
119 AuthCredential::OAuth {
120 access_token,
121 expires_at,
122 ..
123 } => {
124 if access_token.is_empty() {
125 return Err(CredentialValidationError::EmptyField(
126 "access_token".to_string(),
127 ));
128 }
129 if *expires_at == 0 {
130 return Err(CredentialValidationError::InvalidExpiry);
131 }
132 Ok(())
133 }
134 AuthCredential::Session { token, .. } => {
135 if token.is_empty() {
136 return Err(CredentialValidationError::EmptyField("token".to_string()));
137 }
138 Ok(())
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, thiserror::Error)]
146pub enum CredentialValidationError {
147 #[error("Field '{0}' must not be empty")]
148 EmptyField(String),
150 #[error("Placeholder value detected: '{0}'")]
151 PlaceholderValue(String),
153 #[error("Invalid expiry timestamp")]
154 InvalidExpiry,
156}
157
158#[derive(Debug, Clone)]
164pub struct AuthStatus {
165 pub configured: bool,
167 pub source: Option<String>,
169 pub label: Option<String>,
171}
172
173impl std::fmt::Display for AuthStatus {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match (&self.source, &self.label) {
176 (Some(source), Some(label)) => write!(f, "{} ({})", source, label),
177 (Some(source), None) => write!(f, "{}", source),
178 (None, Some(label)) => write!(f, "{}", label),
179 (None, None) => write!(f, "not configured"),
180 }
181 }
182}
183
184pub type AuthResult<T> = Result<T, AuthError>;
190
191#[derive(Debug, Clone, thiserror::Error)]
193pub enum AuthError {
194 #[error("Failed to read auth storage: {0}")]
195 ReadError(String),
197 #[error("Failed to write auth storage: {0}")]
198 WriteError(String),
200 #[error("Credential not found: {0}")]
201 NotFound(String),
203 #[error("Invalid credential format: {0}")]
204 InvalidFormat(String),
206 #[error("Keyring error: {0}")]
207 KeyringError(String),
209 #[error("Credential validation failed: {0}")]
210 ValidationFailed(String),
212}
213
214pub trait AuthStorageBackend: Send + Sync {
220 fn read(&self) -> AuthResult<Option<String>>;
222 fn write(&self, data: &str) -> AuthResult<()>;
224 fn delete(&self) -> AuthResult<()>;
226}
227
228pub struct FileAuthStorage {
234 path: PathBuf,
235 cache: RwLock<Option<String>>,
236}
237
238impl FileAuthStorage {
239 pub fn new(path: PathBuf) -> Self {
241 Self {
242 path,
243 cache: RwLock::new(None),
244 }
245 }
246
247 pub fn default_path() -> Option<PathBuf> {
249 dirs::home_dir().map(|p| p.join(".oxi").join("auth.json"))
250 }
251
252 pub fn path(&self) -> &PathBuf {
254 &self.path
255 }
256}
257
258impl AuthStorageBackend for FileAuthStorage {
259 fn read(&self) -> AuthResult<Option<String>> {
260 if !self.path.exists() {
261 return Ok(None);
262 }
263
264 match std::fs::read_to_string(&self.path) {
265 Ok(content) => {
266 *self.cache.write() = Some(content.clone());
267 Ok(Some(content))
268 }
269 Err(e) => Err(AuthError::ReadError(e.to_string())),
270 }
271 }
272
273 fn write(&self, data: &str) -> AuthResult<()> {
274 if let Some(parent) = self.path.parent() {
276 std::fs::create_dir_all(parent).map_err(|e| AuthError::WriteError(e.to_string()))?;
277
278 #[cfg(unix)]
279 {
280 use std::os::unix::fs::PermissionsExt;
281 let perms = std::fs::Permissions::from_mode(0o700);
282 let _ = std::fs::set_permissions(parent, perms);
283 }
284 }
285
286 std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
288
289 #[cfg(unix)]
291 {
292 use std::os::unix::fs::PermissionsExt;
293 let perms = std::fs::Permissions::from_mode(0o600);
294 std::fs::set_permissions(&self.path, perms)
295 .map_err(|e| AuthError::WriteError(e.to_string()))?;
296 }
297
298 *self.cache.write() = Some(data.to_string());
299 Ok(())
300 }
301
302 fn delete(&self) -> AuthResult<()> {
303 if self.path.exists() {
304 std::fs::remove_file(&self.path).map_err(|e| AuthError::WriteError(e.to_string()))?;
305 }
306 *self.cache.write() = None;
307 Ok(())
308 }
309}
310
311pub struct MemoryAuthStorage {
317 data: RwLock<HashMap<String, AuthCredential>>,
318}
319
320impl MemoryAuthStorage {
321 pub fn new() -> Self {
323 Self {
324 data: RwLock::new(HashMap::new()),
325 }
326 }
327}
328
329impl Default for MemoryAuthStorage {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl AuthStorageBackend for MemoryAuthStorage {
336 fn read(&self) -> AuthResult<Option<String>> {
337 Ok(None)
339 }
340
341 fn write(&self, _data: &str) -> AuthResult<()> {
342 Ok(())
343 }
344
345 fn delete(&self) -> AuthResult<()> {
346 self.data.write().clear();
347 Ok(())
348 }
349}
350
351pub trait FallbackResolver: Send + Sync {
357 fn resolve(&self, provider: &str) -> Option<String>;
359}
360
361pub struct FnFallbackResolver {
363 #[allow(clippy::type_complexity)]
364 f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>,
365}
366
367impl FnFallbackResolver {
368 #[allow(clippy::type_complexity)]
370 pub fn new(f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>) -> Self {
371 Self { f }
372 }
373}
374
375impl FallbackResolver for FnFallbackResolver {
376 fn resolve(&self, provider: &str) -> Option<String> {
377 (self.f)(provider)
378 }
379}
380
381pub struct EnvVarFallbackResolver;
386
387impl FallbackResolver for EnvVarFallbackResolver {
388 fn resolve(&self, provider: &str) -> Option<String> {
389 let builtin = oxi_ai::get_builtin_provider(provider)?;
391 let key = builtin.env_key;
392
393 if let Ok(val) = std::env::var(key) {
395 if !val.is_empty() {
396 return Some(val);
397 }
398 }
399
400 for extra in builtin.extra_env_keys {
402 if let Ok(val) = std::env::var(extra) {
403 if !val.is_empty() {
404 return Some(val);
405 }
406 }
407 }
408
409 None
410 }
411}
412
413pub struct AuthStorage {
427 file_storage: Option<Arc<dyn AuthStorageBackend>>,
429 credentials: RwLock<HashMap<String, AuthCredential>>,
431 runtime_overrides: RwLock<HashMap<String, String>>,
433 fallback_resolver: RwLock<Option<Arc<dyn FallbackResolver>>>,
435 errors: RwLock<Vec<AuthError>>,
437 load_error: RwLock<Option<AuthError>>,
439 plaintext_warned: OnceLock<()>,
441}
442
443impl AuthStorage {
444 pub fn new() -> Self {
446 let file_storage = FileAuthStorage::default_path()
447 .map(|p| Arc::new(FileAuthStorage::new(p)) as Arc<dyn AuthStorageBackend>);
448
449 let credentials = if let Some(ref storage) = file_storage {
450 match storage.read() {
451 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
452 _ => HashMap::new(),
453 }
454 } else {
455 HashMap::new()
456 };
457
458 Self {
459 file_storage,
460 credentials: RwLock::new(credentials),
461 runtime_overrides: RwLock::new(HashMap::new()),
462 fallback_resolver: RwLock::new(None),
463 errors: RwLock::new(Vec::new()),
464 load_error: RwLock::new(None),
465 plaintext_warned: OnceLock::new(),
466 }
467 }
468
469 pub fn with_backend(backend: impl AuthStorageBackend + 'static) -> Self {
471 let credentials = match backend.read() {
472 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
473 _ => HashMap::new(),
474 };
475
476 Self {
477 file_storage: Some(Arc::new(backend)),
478 credentials: RwLock::new(credentials),
479 runtime_overrides: RwLock::new(HashMap::new()),
480 fallback_resolver: RwLock::new(None),
481 errors: RwLock::new(Vec::new()),
482 load_error: RwLock::new(None),
483 plaintext_warned: OnceLock::new(),
484 }
485 }
486
487 pub fn in_memory() -> Self {
489 Self {
490 file_storage: None,
491 credentials: RwLock::new(HashMap::new()),
492 runtime_overrides: RwLock::new(HashMap::new()),
493 fallback_resolver: RwLock::new(None),
494 errors: RwLock::new(Vec::new()),
495 load_error: RwLock::new(None),
496 plaintext_warned: OnceLock::new(),
497 }
498 }
499
500 pub fn default_path() -> Option<PathBuf> {
502 FileAuthStorage::default_path()
503 }
504
505 pub fn set_runtime_key(&self, provider: &str, api_key: String) {
511 self.runtime_overrides
512 .write()
513 .insert(provider.to_string(), api_key);
514 }
515
516 pub fn remove_runtime_key(&self, provider: &str) {
518 self.runtime_overrides.write().remove(provider);
519 }
520
521 pub fn set_fallback_resolver(&self, resolver: Arc<dyn FallbackResolver>) {
528 *self.fallback_resolver.write() = Some(resolver);
529 }
530
531 pub fn clear_fallback_resolver(&self) {
533 *self.fallback_resolver.write() = None;
534 }
535
536 pub fn has_auth(&self, provider: &str) -> bool {
542 if self.runtime_overrides.read().contains_key(provider) {
543 return true;
544 }
545 if self.credentials.read().contains_key(provider) {
546 return true;
547 }
548 if let Some(ref resolver) = *self.fallback_resolver.read() {
549 if resolver.resolve(provider).is_some() {
550 return true;
551 }
552 }
553 false
554 }
555
556 pub fn get_status(&self, provider: &str) -> AuthStatus {
558 if self.runtime_overrides.read().contains_key(provider) {
559 return AuthStatus {
560 configured: false,
561 source: Some("runtime".to_string()),
562 label: Some("--api-key".to_string()),
563 };
564 }
565
566 if let Some(cred) = self.credentials.read().get(provider) {
567 return AuthStatus {
568 configured: true,
569 source: Some("stored".to_string()),
570 label: Some(cred.type_name().to_string()),
571 };
572 }
573
574 if let Some(ref resolver) = *self.fallback_resolver.read() {
575 if resolver.resolve(provider).is_some() {
576 return AuthStatus {
577 configured: false,
578 source: Some("fallback".to_string()),
579 label: Some("custom provider config".to_string()),
580 };
581 }
582 }
583
584 AuthStatus {
585 configured: false,
586 source: None,
587 label: None,
588 }
589 }
590
591 pub fn get_api_key(&self, provider: &str) -> Option<String> {
600 self.get_api_key_with_options(provider, true)
601 }
602
603 pub fn get_api_key_with_options(
605 &self,
606 provider: &str,
607 include_fallback: bool,
608 ) -> Option<String> {
609 if let Some(key) = self.runtime_overrides.read().get(provider) {
611 return Some(key.clone());
612 }
613
614 if let Some(cred) = self.credentials.read().get(provider) {
616 return match cred {
617 AuthCredential::ApiKey { key } => Some(key.clone()),
618 AuthCredential::OAuth {
619 access_token,
620 expires_at,
621 ..
622 } => {
623 if *expires_at > now_secs() {
624 Some(access_token.clone())
625 } else {
626 None
628 }
629 }
630 AuthCredential::Session {
631 token, expires_at, ..
632 } => {
633 if *expires_at == 0 || *expires_at > now_secs() {
634 Some(token.clone())
635 } else {
636 None
637 }
638 }
639 };
640 }
641
642 if let Some(builtin) = oxi_ai::register_builtins::get_builtin_provider(provider) {
647 let env_key = builtin.env_key;
648 let credentials = self.credentials.read();
649 for other in oxi_ai::register_builtins::get_builtin_providers() {
650 if other.name == provider {
651 continue; }
653 if other.env_key == env_key {
654 if let Some(cred) = credentials.get(other.name) {
655 return match cred {
656 AuthCredential::ApiKey { key } => Some(key.clone()),
657 AuthCredential::OAuth {
658 access_token, expires_at, ..
659 } => {
660 if *expires_at > now_secs() {
661 Some(access_token.clone())
662 } else {
663 None
664 }
665 }
666 AuthCredential::Session {
667 token, expires_at, ..
668 } => {
669 if *expires_at == 0 || *expires_at > now_secs() {
670 Some(token.clone())
671 } else {
672 None
673 }
674 }
675 };
676 }
677 }
678 }
679 }
680
681 if include_fallback {
683 if let Some(ref resolver) = *self.fallback_resolver.read() {
684 return resolver.resolve(provider);
685 }
686 }
687
688 None
689 }
690
691 pub fn set_api_key(&self, provider: &str, key: String) {
697 self.credentials
698 .write()
699 .insert(provider.to_string(), AuthCredential::ApiKey { key });
700 if let Err(e) = self.persist() {
701 tracing::warn!("Failed to persist API key for '{}': {}", provider, e);
702 }
703 }
704
705 pub fn set_oauth(
707 &self,
708 provider: &str,
709 access_token: String,
710 refresh_token: Option<String>,
711 expires_at: u64,
712 ) {
713 self.set_oauth_full(
714 provider,
715 access_token,
716 refresh_token,
717 expires_at,
718 None,
719 None,
720 );
721 }
722
723 pub fn set_oauth_full(
725 &self,
726 provider: &str,
727 access_token: String,
728 refresh_token: Option<String>,
729 expires_at: u64,
730 scopes: Option<String>,
731 provider_data: Option<serde_json::Value>,
732 ) {
733 self.credentials.write().insert(
734 provider.to_string(),
735 AuthCredential::OAuth {
736 access_token,
737 refresh_token,
738 expires_at,
739 scopes,
740 provider_data,
741 },
742 );
743 if let Err(e) = self.persist() {
744 tracing::warn!("Failed to persist OAuth token for '{}': {}", provider, e);
745 }
746 }
747
748 pub fn set_session(
750 &self,
751 provider: &str,
752 token: String,
753 expires_at: u64,
754 metadata: Option<serde_json::Value>,
755 ) {
756 self.credentials.write().insert(
757 provider.to_string(),
758 AuthCredential::Session {
759 token,
760 expires_at,
761 metadata,
762 },
763 );
764 if let Err(e) = self.persist() {
765 tracing::warn!("Failed to persist session for '{}': {}", provider, e);
766 }
767 }
768
769 pub fn update_oauth_tokens(
771 &self,
772 provider: &str,
773 new_access_token: String,
774 new_refresh_token: Option<String>,
775 new_expires_at: u64,
776 ) -> AuthResult<()> {
777 let mut creds = self.credentials.write();
778 let cred = creds
779 .get_mut(provider)
780 .ok_or_else(|| AuthError::NotFound(provider.to_string()))?;
781
782 match cred {
783 AuthCredential::OAuth {
784 access_token,
785 refresh_token,
786 expires_at,
787 ..
788 } => {
789 *access_token = new_access_token;
790 *refresh_token = new_refresh_token;
791 *expires_at = new_expires_at;
792 }
793 _ => {
794 return Err(AuthError::InvalidFormat(format!(
795 "Provider '{}' does not have OAuth credentials",
796 provider
797 )));
798 }
799 }
800
801 drop(creds);
802 if let Err(e) = self.persist() {
803 tracing::warn!(
804 "Failed to persist OAuth token update for '{}': {}",
805 provider,
806 e
807 );
808 }
809 Ok(())
810 }
811
812 pub fn get(&self, provider: &str) -> Option<AuthCredential> {
818 self.credentials.read().get(provider).cloned()
819 }
820
821 pub fn get_oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
823 self.credentials.read().get(provider).cloned()
824 }
825
826 pub fn has_oauth_with_refresh(&self, provider: &str) -> bool {
828 if let Some(cred) = self.credentials.read().get(provider) {
829 matches!(
830 cred,
831 AuthCredential::OAuth {
832 refresh_token: Some(_),
833 ..
834 }
835 )
836 } else {
837 false
838 }
839 }
840
841 pub fn set(&self, provider: &str, credential: AuthCredential) {
847 self.credentials
848 .write()
849 .insert(provider.to_string(), credential);
850 if let Err(e) = self.persist() {
851 tracing::warn!("Failed to persist credential for '{}': {}", provider, e);
852 }
853 }
854
855 pub fn remove(&self, provider: &str) {
857 self.credentials.write().remove(provider);
858 if let Err(e) = self.persist() {
859 tracing::warn!("Failed to persist after removing '{}': {}", provider, e);
860 }
861 }
862
863 pub fn list_providers(&self) -> Vec<String> {
865 self.credentials.read().keys().cloned().collect()
866 }
867
868 pub fn has(&self, provider: &str) -> bool {
870 self.credentials.read().contains_key(provider)
871 }
872
873 pub fn get_all(&self) -> HashMap<String, AuthCredential> {
875 self.credentials.read().clone()
876 }
877
878 pub fn clear(&self) {
880 self.credentials.write().clear();
881 if let Err(e) = self.persist() {
882 tracing::warn!("Failed to persist after clearing credentials: {}", e);
883 }
884 }
885
886 pub fn reload(&self) {
892 if let Some(ref storage) = self.file_storage {
893 match storage.read() {
894 Ok(Some(content)) => {
895 if let Ok(creds) = serde_json::from_str(&content) {
896 *self.credentials.write() = creds;
897 }
898 *self.load_error.write() = None;
899 }
900 Ok(None) => {
901 self.credentials.write().clear();
902 *self.load_error.write() = None;
903 }
904 Err(e) => {
905 *self.load_error.write() = Some(e);
906 self.record_error(AuthError::ReadError(
907 "Failed to reload auth storage".to_string(),
908 ));
909 }
910 }
911 }
912 }
913
914 fn persist(&self) -> Result<(), String> {
916 if let Some(ref storage) = self.file_storage {
917 let creds = self.credentials.read();
918 if let Ok(json) = serde_json::to_string_pretty(&*creds) {
919 #[cfg(not(feature = "keyring"))]
921 {
922 self.plaintext_warned.get_or_init(|| {
923 tracing::warn!(
924 "Auth credentials are stored in plaintext. \
925 Enable the 'keyring' feature for secure OS-level storage."
926 );
927 });
928 }
929
930 if let Err(e) = storage.write(&json) {
931 tracing::error!("Failed to persist auth storage: {}", e);
932 self.record_error(e);
933 return Err("persist failed".to_string());
934 }
935 }
936 }
937 Ok(())
938 }
939
940 fn record_error(&self, error: AuthError) {
946 self.errors.write().push(error);
947 }
948
949 pub fn drain_errors(&self) -> Vec<AuthError> {
951 let mut errors = self.errors.write();
952 std::mem::take(&mut *errors)
953 }
954
955 pub fn load_error(&self) -> Option<AuthError> {
957 self.load_error.read().clone()
958 }
959
960 pub fn validate_all(&self) -> Vec<(String, CredentialValidationError)> {
966 let creds = self.credentials.read();
967 let mut results = Vec::new();
968 for (provider, cred) in creds.iter() {
969 if let Err(e) = cred.validate() {
970 results.push((provider.clone(), e));
971 }
972 }
973 results
974 }
975
976 pub fn validate(&self, provider: &str) -> Result<(), CredentialValidationError> {
978 let creds = self.credentials.read();
979 let cred = creds.get(provider).ok_or_else(|| {
980 CredentialValidationError::EmptyField(format!(
981 "no credential for provider '{}'",
982 provider
983 ))
984 })?;
985 cred.validate()
986 }
987
988 pub fn configured_providers(&self) -> Vec<String> {
994 let mut providers: Vec<String> = self.credentials.read().keys().cloned().collect();
995 providers.sort();
996 providers
997 }
998
999 pub fn has_multiple_providers(&self) -> bool {
1001 self.credentials.read().len() > 1
1002 }
1003
1004 pub fn primary_provider(&self) -> Option<String> {
1006 let creds = self.credentials.read();
1007 creds.keys().next().cloned()
1008 }
1009
1010 pub fn migrate_provider(&self, from: &str, to: &str) -> AuthResult<()> {
1012 let mut creds = self.credentials.write();
1013 let cred = creds
1014 .remove(from)
1015 .ok_or_else(|| AuthError::NotFound(from.to_string()))?;
1016 creds.insert(to.to_string(), cred);
1017 drop(creds);
1018 let _ = self.persist();
1019 Ok(())
1020 }
1021}
1022
1023impl Default for AuthStorage {
1024 fn default() -> Self {
1025 Self::new()
1026 }
1027}
1028
1029fn now_secs() -> u64 {
1034 std::time::SystemTime::now()
1035 .duration_since(std::time::UNIX_EPOCH)
1036 .map(|d| d.as_secs())
1037 .unwrap_or(0)
1038}
1039
1040#[allow(unexpected_cfgs)]
1046pub mod keyring_support {
1047 use super::*;
1048
1049 #[cfg(feature = "keyring")]
1051 pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
1052 use keyring::Entry;
1053 Entry::new(service, account)
1054 .ok()
1055 .and_then(|entry| entry.get_password().ok())
1056 }
1057
1058 #[cfg(feature = "keyring")]
1060 pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> AuthResult<()> {
1061 use keyring::Entry;
1062 Entry::new(service, account)
1063 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1064 .set_password(secret)
1065 .map_err(|e| AuthError::KeyringError(e.to_string()))
1066 }
1067
1068 #[cfg(feature = "keyring")]
1070 pub fn delete_keyring_secret(service: &str, account: &str) -> AuthResult<()> {
1071 use keyring::Entry;
1072 Entry::new(service, account)
1073 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1074 .delete_credential()
1075 .map_err(|e| AuthError::KeyringError(e.to_string()))
1076 }
1077
1078 #[cfg(not(feature = "keyring"))]
1080 pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
1084 None
1085 }
1086
1087 #[cfg(not(feature = "keyring"))]
1088 pub fn set_keyring_secret(_service: &str, _account: &str, _secret: &str) -> AuthResult<()> {
1092 Err(AuthError::KeyringError(
1093 "Keyring support not compiled".to_string(),
1094 ))
1095 }
1096
1097 #[cfg(not(feature = "keyring"))]
1098 pub fn delete_keyring_secret(_service: &str, _account: &str) -> AuthResult<()> {
1102 Err(AuthError::KeyringError(
1103 "Keyring support not compiled".to_string(),
1104 ))
1105 }
1106}
1107
1108pub fn shared_auth_storage() -> Arc<AuthStorage> {
1118 static STORAGE: OnceLock<Arc<AuthStorage>> = OnceLock::new();
1119 STORAGE
1120 .get_or_init(|| {
1121 let storage = Arc::new(AuthStorage::new());
1122 storage.set_fallback_resolver(Arc::new(EnvVarFallbackResolver));
1125 storage
1126 })
1127 .clone()
1128}
1129
1130#[cfg(test)]
1135mod tests {
1136 use super::*;
1137
1138 #[test]
1139 fn test_auth_storage_new() {
1140 let storage = AuthStorage::in_memory();
1141 assert!(!storage.has("anthropic"));
1142 }
1143
1144 #[test]
1145 fn test_set_and_get_api_key() {
1146 let storage = AuthStorage::in_memory();
1147 storage.set_api_key("anthropic", "sk-test123".to_string());
1148 assert!(storage.has("anthropic"));
1149 assert_eq!(
1150 storage.get_api_key("anthropic"),
1151 Some("sk-test123".to_string())
1152 );
1153 }
1154
1155 #[test]
1156 fn test_runtime_override() {
1157 let storage = AuthStorage::in_memory();
1158 storage.set_api_key("anthropic", "stored-key".to_string());
1159 storage.set_runtime_key("anthropic", "runtime-key".to_string());
1160
1161 assert_eq!(
1163 storage.get_api_key("anthropic"),
1164 Some("runtime-key".to_string())
1165 );
1166 }
1167
1168 #[test]
1169 fn test_remove_credential() {
1170 let storage = AuthStorage::in_memory();
1171 storage.set_api_key("anthropic", "sk-test123".to_string());
1172 assert!(storage.has("anthropic"));
1173
1174 storage.remove("anthropic");
1175 assert!(!storage.has("anthropic"));
1176 }
1177
1178 #[test]
1179 fn test_auth_status() {
1180 let storage = AuthStorage::in_memory();
1181 storage.set_api_key("anthropic", "sk-test123".to_string());
1182
1183 let status = storage.get_status("anthropic");
1184 assert!(status.configured);
1185 assert_eq!(status.source, Some("stored".to_string()));
1186 assert_eq!(status.label, Some("api_key".to_string()));
1187 }
1188
1189 #[test]
1190 fn test_auth_status_display() {
1191 let status = AuthStatus {
1192 configured: true,
1193 source: Some("stored".to_string()),
1194 label: Some("api_key".to_string()),
1195 };
1196 let display = format!("{}", status);
1197 assert_eq!(display, "stored (api_key)");
1198
1199 let no_config = AuthStatus {
1200 configured: false,
1201 source: None,
1202 label: None,
1203 };
1204 assert_eq!(format!("{}", no_config), "not configured");
1205 }
1206
1207 #[test]
1208 fn test_list_providers() {
1209 let storage = AuthStorage::in_memory();
1210 storage.set_api_key("anthropic", "key1".to_string());
1211 storage.set_api_key("openai", "key2".to_string());
1212
1213 let providers = storage.list_providers();
1214 assert!(providers.contains(&"anthropic".to_string()));
1215 assert!(providers.contains(&"openai".to_string()));
1216 }
1217
1218 #[test]
1219 fn test_oauth_credential() {
1220 let storage = AuthStorage::in_memory();
1221 storage.set_oauth(
1222 "provider",
1223 "access123".to_string(),
1224 Some("refresh456".to_string()),
1225 u64::MAX,
1226 );
1227
1228 assert!(storage.has("provider"));
1229 assert_eq!(
1230 storage.get_api_key("provider"),
1231 Some("access123".to_string())
1232 );
1233 }
1234
1235 #[test]
1236 fn test_expired_oauth_token() {
1237 let storage = AuthStorage::in_memory();
1238 storage.set_oauth("provider", "access123".to_string(), None, 0);
1240
1241 let key = storage.get_api_key("provider");
1243 assert!(key.is_none());
1244 }
1245
1246 #[test]
1247 fn test_get_all_credentials() {
1248 let storage = AuthStorage::in_memory();
1249 storage.set_api_key("anthropic", "key1".to_string());
1250 storage.set_api_key("openai", "key2".to_string());
1251
1252 let all = storage.get_all();
1253 assert_eq!(all.len(), 2);
1254 }
1255
1256 #[test]
1257 fn test_clear() {
1258 let storage = AuthStorage::in_memory();
1259 storage.set_api_key("anthropic", "key".to_string());
1260 assert!(storage.has("anthropic"));
1261
1262 storage.clear();
1263 assert!(!storage.has("anthropic"));
1264 }
1265
1266 #[test]
1267 fn test_remove_runtime_key() {
1268 let storage = AuthStorage::in_memory();
1269 storage.set_api_key("anthropic", "stored".to_string());
1270 storage.set_runtime_key("anthropic", "runtime".to_string());
1271
1272 assert_eq!(
1273 storage.get_api_key("anthropic"),
1274 Some("runtime".to_string())
1275 );
1276
1277 storage.remove_runtime_key("anthropic");
1278 assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
1279 }
1280
1281 #[test]
1282 fn test_auth_credential_is_expired() {
1283 let api_key_cred = AuthCredential::ApiKey {
1285 key: "test".to_string(),
1286 };
1287 assert!(!api_key_cred.is_expired());
1288
1289 let future_time = now_secs() + 3600;
1291 let oauth_cred = AuthCredential::OAuth {
1292 access_token: "token".to_string(),
1293 refresh_token: Some("refresh".to_string()),
1294 expires_at: future_time,
1295 scopes: None,
1296 provider_data: None,
1297 };
1298 assert!(!oauth_cred.is_expired());
1299
1300 let oauth_cred_expired = AuthCredential::OAuth {
1302 access_token: "token".to_string(),
1303 refresh_token: Some("refresh".to_string()),
1304 expires_at: 0,
1305 scopes: None,
1306 provider_data: None,
1307 };
1308 assert!(oauth_cred_expired.is_expired());
1309 }
1310
1311 #[test]
1312 fn test_auth_credential_needs_refresh() {
1313 let future_time = now_secs() + 120; let oauth_cred = AuthCredential::OAuth {
1317 access_token: "token".to_string(),
1318 refresh_token: Some("refresh".to_string()),
1319 expires_at: future_time,
1320 scopes: None,
1321 provider_data: None,
1322 };
1323 assert!(!oauth_cred.needs_refresh());
1324
1325 let soon = now_secs() + 30;
1327 let oauth_soon = AuthCredential::OAuth {
1328 access_token: "token".to_string(),
1329 refresh_token: Some("refresh".to_string()),
1330 expires_at: soon,
1331 scopes: None,
1332 provider_data: None,
1333 };
1334 assert!(oauth_soon.needs_refresh());
1335
1336 let no_refresh = AuthCredential::OAuth {
1338 access_token: "token".to_string(),
1339 refresh_token: None,
1340 expires_at: future_time,
1341 scopes: None,
1342 provider_data: None,
1343 };
1344 assert!(!no_refresh.needs_refresh());
1345
1346 let api_key_cred = AuthCredential::ApiKey {
1348 key: "test".to_string(),
1349 };
1350 assert!(!api_key_cred.needs_refresh());
1351 }
1352
1353 #[test]
1354 fn test_auth_credential_access_token() {
1355 let future_time = now_secs() + 3600;
1356
1357 let oauth_cred = AuthCredential::OAuth {
1358 access_token: "valid_token".to_string(),
1359 refresh_token: Some("refresh".to_string()),
1360 expires_at: future_time,
1361 scopes: None,
1362 provider_data: None,
1363 };
1364 assert_eq!(oauth_cred.access_token(), Some("valid_token"));
1365
1366 let expired_cred = AuthCredential::OAuth {
1368 access_token: "expired_token".to_string(),
1369 refresh_token: Some("refresh".to_string()),
1370 expires_at: 0,
1371 scopes: None,
1372 provider_data: None,
1373 };
1374 assert!(expired_cred.access_token().is_none());
1375
1376 let api_key_cred = AuthCredential::ApiKey {
1378 key: "api_key_token".to_string(),
1379 };
1380 assert!(api_key_cred.access_token().is_none());
1381 }
1382
1383 #[test]
1384 fn test_get_oauth_credential() {
1385 let storage = AuthStorage::in_memory();
1386 storage.set_oauth(
1387 "provider",
1388 "access".to_string(),
1389 Some("refresh".to_string()),
1390 u64::MAX,
1391 );
1392
1393 let cred = storage.get_oauth_credential("provider");
1394 assert!(cred.is_some());
1395 assert!(matches!(cred.unwrap(), AuthCredential::OAuth { .. }));
1396 }
1397
1398 #[test]
1399 fn test_has_oauth_with_refresh() {
1400 let storage = AuthStorage::in_memory();
1401
1402 storage.set_oauth(
1404 "with_refresh",
1405 "access".to_string(),
1406 Some("refresh".to_string()),
1407 u64::MAX,
1408 );
1409 assert!(storage.has_oauth_with_refresh("with_refresh"));
1410
1411 storage.set_oauth("without_refresh", "access".to_string(), None, u64::MAX);
1413 assert!(!storage.has_oauth_with_refresh("without_refresh"));
1414
1415 storage.set_api_key("apikey_provider", "key".to_string());
1417 assert!(!storage.has_oauth_with_refresh("apikey_provider"));
1418 }
1419
1420 #[test]
1421 fn test_set_oauth_full() {
1422 let storage = AuthStorage::in_memory();
1423 storage.set_oauth_full(
1424 "provider",
1425 "access_token".to_string(),
1426 Some("refresh_token".to_string()),
1427 3600,
1428 Some("read write".to_string()),
1429 Some(serde_json::json!({"extra": "data"})),
1430 );
1431
1432 let cred = storage.get_oauth_credential("provider");
1433 assert!(cred.is_some());
1434 if let AuthCredential::OAuth {
1435 scopes,
1436 provider_data,
1437 ..
1438 } = cred.unwrap()
1439 {
1440 assert_eq!(scopes, Some("read write".to_string()));
1441 assert!(provider_data.is_some());
1442 } else {
1443 panic!("Expected OAuth credential");
1444 }
1445 }
1446
1447 #[test]
1448 fn test_session_token() {
1449 let storage = AuthStorage::in_memory();
1450 storage.set_session(
1451 "browser",
1452 "session-token-123".to_string(),
1453 0, Some(serde_json::json!({"user": "test"})),
1455 );
1456
1457 assert!(storage.has("browser"));
1458 assert_eq!(
1459 storage.get_api_key("browser"),
1460 Some("session-token-123".to_string())
1461 );
1462
1463 let cred = storage.get("browser").unwrap();
1464 assert!(matches!(cred, AuthCredential::Session { .. }));
1465 assert!(cred.access_token().is_some());
1466 }
1467
1468 #[test]
1469 fn test_session_token_expired() {
1470 let storage = AuthStorage::in_memory();
1471 storage.set_session("browser", "session-token".to_string(), 1, None);
1472
1473 assert!(storage.get_api_key("browser").is_none());
1475 }
1476
1477 #[test]
1478 fn test_credential_validation() {
1479 let valid = AuthCredential::ApiKey {
1481 key: "sk-valid".to_string(),
1482 };
1483 assert!(valid.validate().is_ok());
1484
1485 let empty = AuthCredential::ApiKey {
1487 key: "".to_string(),
1488 };
1489 assert!(empty.validate().is_err());
1490
1491 let placeholder = AuthCredential::ApiKey {
1493 key: "your-api-key-here".to_string(),
1494 };
1495 assert!(placeholder.validate().is_err());
1496
1497 let valid_oauth = AuthCredential::OAuth {
1499 access_token: "token".to_string(),
1500 refresh_token: None,
1501 expires_at: now_secs() + 3600,
1502 scopes: None,
1503 provider_data: None,
1504 };
1505 assert!(valid_oauth.validate().is_ok());
1506
1507 let invalid_oauth = AuthCredential::OAuth {
1509 access_token: "".to_string(),
1510 refresh_token: None,
1511 expires_at: 1000,
1512 scopes: None,
1513 provider_data: None,
1514 };
1515 assert!(invalid_oauth.validate().is_err());
1516 }
1517
1518 #[test]
1519 fn test_validate_all() {
1520 let storage = AuthStorage::in_memory();
1521 storage.set_api_key("valid", "sk-good".to_string());
1522 storage.set_api_key("empty", "".to_string());
1523
1524 let errors = storage.validate_all();
1525 assert_eq!(errors.len(), 1);
1526 assert_eq!(errors[0].0, "empty");
1527 }
1528
1529 #[test]
1530 fn test_update_oauth_tokens() {
1531 let storage = AuthStorage::in_memory();
1532 storage.set_oauth(
1533 "provider",
1534 "old-access".to_string(),
1535 Some("old-refresh".to_string()),
1536 now_secs() + 3600,
1537 );
1538
1539 storage
1540 .update_oauth_tokens(
1541 "provider",
1542 "new-access".to_string(),
1543 Some("new-refresh".to_string()),
1544 now_secs() + 7200,
1545 )
1546 .unwrap();
1547
1548 let key = storage.get_api_key("provider");
1549 assert_eq!(key, Some("new-access".to_string()));
1550 }
1551
1552 #[test]
1553 fn test_update_oauth_tokens_wrong_type() {
1554 let storage = AuthStorage::in_memory();
1555 storage.set_api_key("provider", "key".to_string());
1556
1557 let result = storage.update_oauth_tokens(
1558 "provider",
1559 "new-access".to_string(),
1560 None,
1561 now_secs() + 3600,
1562 );
1563 assert!(result.is_err());
1564 }
1565
1566 #[test]
1567 fn test_migrate_provider() {
1568 let storage = AuthStorage::in_memory();
1569 storage.set_api_key("old-provider", "key123".to_string());
1570 storage
1571 .migrate_provider("old-provider", "new-provider")
1572 .unwrap();
1573
1574 assert!(!storage.has("old-provider"));
1575 assert!(storage.has("new-provider"));
1576 assert_eq!(
1577 storage.get_api_key("new-provider"),
1578 Some("key123".to_string())
1579 );
1580 }
1581
1582 #[test]
1583 fn test_migrate_provider_not_found() {
1584 let storage = AuthStorage::in_memory();
1585 let result = storage.migrate_provider("nonexistent", "target");
1586 assert!(result.is_err());
1587 }
1588
1589 #[test]
1590 fn test_error_draining() {
1591 let storage = AuthStorage::in_memory();
1592 let errors = storage.drain_errors();
1593 assert!(errors.is_empty());
1594 }
1595
1596 #[test]
1597 fn test_fallback_resolver() {
1598 let storage = AuthStorage::in_memory();
1599 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|provider| {
1600 if provider == "custom" {
1601 Some("custom-key-from-config".to_string())
1602 } else {
1603 None
1604 }
1605 }))));
1606
1607 assert_eq!(
1608 storage.get_api_key("custom"),
1609 Some("custom-key-from-config".to_string())
1610 );
1611 assert!(storage.get_api_key("unknown").is_none());
1612
1613 storage.clear_fallback_resolver();
1615 assert!(storage.get_api_key("custom").is_none());
1616 }
1617
1618 #[test]
1619 fn test_get_api_key_with_options() {
1620 let storage = AuthStorage::in_memory();
1621 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|_| {
1622 Some("fallback-key".to_string())
1623 }))));
1624
1625 assert_eq!(
1627 storage.get_api_key_with_options("test", true),
1628 Some("fallback-key".to_string())
1629 );
1630
1631 assert!(storage.get_api_key_with_options("test", false).is_none());
1633 }
1634
1635 #[test]
1636 fn test_configured_providers() {
1637 let storage = AuthStorage::in_memory();
1638 storage.set_api_key("openai", "key".to_string());
1639 storage.set_api_key("anthropic", "key".to_string());
1640
1641 let providers = storage.configured_providers();
1642 assert!(providers.len() >= 2);
1643 let mut sorted = providers.clone();
1645 sorted.sort();
1646 assert_eq!(providers, sorted);
1647 }
1648
1649 #[test]
1650 fn test_has_multiple_providers() {
1651 let storage = AuthStorage::in_memory();
1652 assert!(!storage.has_multiple_providers());
1653
1654 storage.set_api_key("openai", "key1".to_string());
1655 assert!(!storage.has_multiple_providers());
1656
1657 storage.set_api_key("anthropic", "key2".to_string());
1658 assert!(storage.has_multiple_providers());
1659 }
1660
1661 #[test]
1662 fn test_set_and_get_credential() {
1663 let storage = AuthStorage::in_memory();
1664 let cred = AuthCredential::Session {
1665 token: "abc".to_string(),
1666 expires_at: 0,
1667 metadata: None,
1668 };
1669 storage.set("custom", cred);
1670 let retrieved = storage.get("custom");
1671 assert!(retrieved.is_some());
1672 assert!(matches!(retrieved.unwrap(), AuthCredential::Session { .. }));
1673 }
1674
1675 #[test]
1676 fn test_credential_type_name() {
1677 assert_eq!(
1678 AuthCredential::ApiKey {
1679 key: "k".to_string()
1680 }
1681 .type_name(),
1682 "api_key"
1683 );
1684 assert_eq!(
1685 AuthCredential::OAuth {
1686 access_token: "t".to_string(),
1687 refresh_token: None,
1688 expires_at: 0,
1689 scopes: None,
1690 provider_data: None,
1691 }
1692 .type_name(),
1693 "oauth"
1694 );
1695 assert_eq!(
1696 AuthCredential::Session {
1697 token: "t".to_string(),
1698 expires_at: 0,
1699 metadata: None,
1700 }
1701 .type_name(),
1702 "session"
1703 );
1704 }
1705}