1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::{Arc, OnceLock};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum AuthCredential {
21 ApiKey {
23 key: String,
25 },
26 OAuth {
28 access_token: String,
30 refresh_token: Option<String>,
32 expires_at: u64,
34 #[serde(default)]
36 scopes: Option<String>,
37 #[serde(default)]
39 provider_data: Option<serde_json::Value>,
40 },
41 Session {
43 token: String,
45 #[serde(default)]
47 expires_at: u64,
48 #[serde(default)]
50 metadata: Option<serde_json::Value>,
51 },
52}
53
54impl AuthCredential {
55 pub fn is_expired(&self) -> bool {
57 match self {
58 AuthCredential::OAuth { expires_at, .. } => {
59 let now = now_secs();
60 *expires_at < now
61 }
62 AuthCredential::Session { expires_at, .. } => {
63 if *expires_at == 0 {
64 return false; }
66 *expires_at <= now_secs()
67 }
68 AuthCredential::ApiKey { .. } => false,
69 }
70 }
71
72 pub fn needs_refresh(&self) -> bool {
74 match self {
75 AuthCredential::OAuth {
76 expires_at,
77 refresh_token,
78 ..
79 } => {
80 let now = now_secs();
81 refresh_token.is_some() && *expires_at <= now + 60
82 }
83 AuthCredential::Session { .. } => false,
84 AuthCredential::ApiKey { .. } => false,
85 }
86 }
87
88 pub fn access_token(&self) -> Option<&str> {
90 match self {
91 AuthCredential::OAuth { access_token, .. } if !self.is_expired() => Some(access_token),
92 AuthCredential::Session { token, .. } if !self.is_expired() => Some(token),
93 _ => None,
94 }
95 }
96
97 pub fn type_name(&self) -> &'static str {
99 match self {
100 AuthCredential::ApiKey { .. } => "api_key",
101 AuthCredential::OAuth { .. } => "oauth",
102 AuthCredential::Session { .. } => "session",
103 }
104 }
105
106 pub fn validate(&self) -> Result<(), CredentialValidationError> {
108 match self {
109 AuthCredential::ApiKey { key } => {
110 if key.is_empty() {
111 return Err(CredentialValidationError::EmptyField("key".to_string()));
112 }
113 if key == "your-api-key-here" || key == "xxx" {
115 return Err(CredentialValidationError::PlaceholderValue(key.clone()));
116 }
117 Ok(())
118 }
119 AuthCredential::OAuth {
120 access_token,
121 expires_at,
122 ..
123 } => {
124 if access_token.is_empty() {
125 return Err(CredentialValidationError::EmptyField(
126 "access_token".to_string(),
127 ));
128 }
129 if *expires_at == 0 {
130 return Err(CredentialValidationError::InvalidExpiry);
131 }
132 Ok(())
133 }
134 AuthCredential::Session { token, .. } => {
135 if token.is_empty() {
136 return Err(CredentialValidationError::EmptyField("token".to_string()));
137 }
138 Ok(())
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, thiserror::Error)]
146pub enum CredentialValidationError {
147 #[error("Field '{0}' must not be empty")]
148 EmptyField(String),
150 #[error("Placeholder value detected: '{0}'")]
151 PlaceholderValue(String),
153 #[error("Invalid expiry timestamp")]
154 InvalidExpiry,
156}
157
158#[derive(Debug, Clone)]
164pub struct AuthStatus {
165 pub configured: bool,
167 pub source: Option<String>,
169 pub label: Option<String>,
171}
172
173impl std::fmt::Display for AuthStatus {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match (&self.source, &self.label) {
176 (Some(source), Some(label)) => write!(f, "{} ({})", source, label),
177 (Some(source), None) => write!(f, "{}", source),
178 (None, Some(label)) => write!(f, "{}", label),
179 (None, None) => write!(f, "not configured"),
180 }
181 }
182}
183
184pub type AuthResult<T> = Result<T, AuthError>;
190
191#[derive(Debug, Clone, thiserror::Error)]
193pub enum AuthError {
194 #[error("Failed to read auth storage: {0}")]
195 ReadError(String),
197 #[error("Failed to write auth storage: {0}")]
198 WriteError(String),
200 #[error("Credential not found: {0}")]
201 NotFound(String),
203 #[error("Invalid credential format: {0}")]
204 InvalidFormat(String),
206 #[error("Keyring error: {0}")]
207 KeyringError(String),
209 #[error("Credential validation failed: {0}")]
210 ValidationFailed(String),
212}
213
214pub trait AuthStorageBackend: Send + Sync {
220 fn read(&self) -> AuthResult<Option<String>>;
222 fn write(&self, data: &str) -> AuthResult<()>;
224 fn delete(&self) -> AuthResult<()>;
226}
227
228pub struct FileAuthStorage {
234 path: PathBuf,
235 cache: RwLock<Option<String>>,
236}
237
238impl FileAuthStorage {
239 pub fn new(path: PathBuf) -> Self {
241 Self {
242 path,
243 cache: RwLock::new(None),
244 }
245 }
246
247 pub fn default_path() -> Option<PathBuf> {
249 dirs::home_dir().map(|p| p.join(".oxi").join("auth.json"))
250 }
251
252 pub fn path(&self) -> &PathBuf {
254 &self.path
255 }
256}
257
258impl AuthStorageBackend for FileAuthStorage {
259 fn read(&self) -> AuthResult<Option<String>> {
260 if !self.path.exists() {
261 return Ok(None);
262 }
263
264 match std::fs::read_to_string(&self.path) {
265 Ok(content) => {
266 *self.cache.write() = Some(content.clone());
267 Ok(Some(content))
268 }
269 Err(e) => Err(AuthError::ReadError(e.to_string())),
270 }
271 }
272
273 fn write(&self, data: &str) -> AuthResult<()> {
274 if let Some(parent) = self.path.parent() {
276 std::fs::create_dir_all(parent).map_err(|e| AuthError::WriteError(e.to_string()))?;
277
278 #[cfg(unix)]
279 {
280 use std::os::unix::fs::PermissionsExt;
281 let perms = std::fs::Permissions::from_mode(0o700);
282 let _ = std::fs::set_permissions(parent, perms);
283 }
284 }
285
286 std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
288
289 #[cfg(unix)]
291 {
292 use std::os::unix::fs::PermissionsExt;
293 let perms = std::fs::Permissions::from_mode(0o600);
294 std::fs::set_permissions(&self.path, perms)
295 .map_err(|e| AuthError::WriteError(e.to_string()))?;
296 }
297
298 *self.cache.write() = Some(data.to_string());
299 Ok(())
300 }
301
302 fn delete(&self) -> AuthResult<()> {
303 if self.path.exists() {
304 std::fs::remove_file(&self.path).map_err(|e| AuthError::WriteError(e.to_string()))?;
305 }
306 *self.cache.write() = None;
307 Ok(())
308 }
309}
310
311pub struct MemoryAuthStorage {
317 data: RwLock<HashMap<String, AuthCredential>>,
318}
319
320impl MemoryAuthStorage {
321 pub fn new() -> Self {
323 Self {
324 data: RwLock::new(HashMap::new()),
325 }
326 }
327}
328
329impl Default for MemoryAuthStorage {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl AuthStorageBackend for MemoryAuthStorage {
336 fn read(&self) -> AuthResult<Option<String>> {
337 Ok(None)
339 }
340
341 fn write(&self, _data: &str) -> AuthResult<()> {
342 Ok(())
343 }
344
345 fn delete(&self) -> AuthResult<()> {
346 self.data.write().clear();
347 Ok(())
348 }
349}
350
351pub trait FallbackResolver: Send + Sync {
357 fn resolve(&self, provider: &str) -> Option<String>;
359}
360
361pub struct FnFallbackResolver {
363 #[allow(clippy::type_complexity)]
364 f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>,
365}
366
367impl FnFallbackResolver {
368 #[allow(clippy::type_complexity)]
370 pub fn new(f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>) -> Self {
371 Self { f }
372 }
373}
374
375impl FallbackResolver for FnFallbackResolver {
376 fn resolve(&self, provider: &str) -> Option<String> {
377 (self.f)(provider)
378 }
379}
380
381pub struct EnvVarFallbackResolver;
386
387impl FallbackResolver for EnvVarFallbackResolver {
388 fn resolve(&self, provider: &str) -> Option<String> {
389 let builtin = oxi_ai::get_builtin_provider(provider)?;
391 let key = builtin.env_key;
392
393 if let Ok(val) = std::env::var(key)
395 && !val.is_empty()
396 {
397 return Some(val);
398 }
399
400 for extra in builtin.extra_env_keys {
402 if let Ok(val) = std::env::var(extra)
403 && !val.is_empty()
404 {
405 return Some(val);
406 }
407 }
408
409 None
410 }
411}
412
413pub struct AuthStorage {
427 file_storage: Option<Arc<dyn AuthStorageBackend>>,
429 credentials: RwLock<HashMap<String, AuthCredential>>,
431 runtime_overrides: RwLock<HashMap<String, String>>,
433 fallback_resolver: RwLock<Option<Arc<dyn FallbackResolver>>>,
435 errors: RwLock<Vec<AuthError>>,
437 load_error: RwLock<Option<AuthError>>,
439 plaintext_warned: OnceLock<()>,
441}
442
443impl AuthStorage {
444 pub fn new() -> Self {
446 let file_storage = FileAuthStorage::default_path()
447 .map(|p| Arc::new(FileAuthStorage::new(p)) as Arc<dyn AuthStorageBackend>);
448
449 let credentials = if let Some(ref storage) = file_storage {
450 match storage.read() {
451 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
452 _ => HashMap::new(),
453 }
454 } else {
455 HashMap::new()
456 };
457
458 Self {
459 file_storage,
460 credentials: RwLock::new(credentials),
461 runtime_overrides: RwLock::new(HashMap::new()),
462 fallback_resolver: RwLock::new(None),
463 errors: RwLock::new(Vec::new()),
464 load_error: RwLock::new(None),
465 plaintext_warned: OnceLock::new(),
466 }
467 }
468
469 pub fn with_backend(backend: impl AuthStorageBackend + 'static) -> Self {
471 let credentials = match backend.read() {
472 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
473 _ => HashMap::new(),
474 };
475
476 Self {
477 file_storage: Some(Arc::new(backend)),
478 credentials: RwLock::new(credentials),
479 runtime_overrides: RwLock::new(HashMap::new()),
480 fallback_resolver: RwLock::new(None),
481 errors: RwLock::new(Vec::new()),
482 load_error: RwLock::new(None),
483 plaintext_warned: OnceLock::new(),
484 }
485 }
486
487 pub fn in_memory() -> Self {
489 Self {
490 file_storage: None,
491 credentials: RwLock::new(HashMap::new()),
492 runtime_overrides: RwLock::new(HashMap::new()),
493 fallback_resolver: RwLock::new(None),
494 errors: RwLock::new(Vec::new()),
495 load_error: RwLock::new(None),
496 plaintext_warned: OnceLock::new(),
497 }
498 }
499
500 pub fn default_path() -> Option<PathBuf> {
502 FileAuthStorage::default_path()
503 }
504
505 pub fn set_runtime_key(&self, provider: &str, api_key: String) {
511 self.runtime_overrides
512 .write()
513 .insert(provider.to_string(), api_key);
514 }
515
516 pub fn remove_runtime_key(&self, provider: &str) {
518 self.runtime_overrides.write().remove(provider);
519 }
520
521 pub fn set_fallback_resolver(&self, resolver: Arc<dyn FallbackResolver>) {
528 *self.fallback_resolver.write() = Some(resolver);
529 }
530
531 pub fn clear_fallback_resolver(&self) {
533 *self.fallback_resolver.write() = None;
534 }
535
536 pub fn has_auth(&self, provider: &str) -> bool {
542 if self.runtime_overrides.read().contains_key(provider) {
543 return true;
544 }
545 if self.credentials.read().contains_key(provider) {
546 return true;
547 }
548 if let Some(ref resolver) = *self.fallback_resolver.read()
549 && resolver.resolve(provider).is_some()
550 {
551 return true;
552 }
553 false
554 }
555
556 pub fn get_status(&self, provider: &str) -> AuthStatus {
558 if self.runtime_overrides.read().contains_key(provider) {
559 return AuthStatus {
560 configured: false,
561 source: Some("runtime".to_string()),
562 label: Some("--api-key".to_string()),
563 };
564 }
565
566 if let Some(cred) = self.credentials.read().get(provider) {
567 return AuthStatus {
568 configured: true,
569 source: Some("stored".to_string()),
570 label: Some(cred.type_name().to_string()),
571 };
572 }
573
574 if let Some(ref resolver) = *self.fallback_resolver.read()
575 && resolver.resolve(provider).is_some()
576 {
577 return AuthStatus {
578 configured: false,
579 source: Some("fallback".to_string()),
580 label: Some("custom provider config".to_string()),
581 };
582 }
583
584 AuthStatus {
585 configured: false,
586 source: None,
587 label: None,
588 }
589 }
590
591 pub fn get_api_key(&self, provider: &str) -> Option<String> {
600 self.get_api_key_with_options(provider, true)
601 }
602
603 pub fn get_api_key_with_options(
605 &self,
606 provider: &str,
607 include_fallback: bool,
608 ) -> Option<String> {
609 if let Some(key) = self.runtime_overrides.read().get(provider) {
611 return Some(key.clone());
612 }
613
614 if let Some(cred) = self.credentials.read().get(provider) {
616 return match cred {
617 AuthCredential::ApiKey { key } => Some(key.clone()),
618 AuthCredential::OAuth {
619 access_token,
620 expires_at,
621 ..
622 } => {
623 if *expires_at > now_secs() {
624 Some(access_token.clone())
625 } else {
626 None
628 }
629 }
630 AuthCredential::Session {
631 token, expires_at, ..
632 } => {
633 if *expires_at == 0 || *expires_at > now_secs() {
634 Some(token.clone())
635 } else {
636 None
637 }
638 }
639 };
640 }
641
642 if let Some(builtin) = oxi_ai::register_builtins::get_builtin_provider(provider) {
647 let env_key = builtin.env_key;
648 let credentials = self.credentials.read();
649 for other in oxi_ai::register_builtins::get_builtin_providers() {
650 if other.name == provider {
651 continue; }
653 if other.env_key == env_key
654 && let Some(cred) = credentials.get(other.name)
655 {
656 return match cred {
657 AuthCredential::ApiKey { key } => Some(key.clone()),
658 AuthCredential::OAuth {
659 access_token,
660 expires_at,
661 ..
662 } => {
663 if *expires_at > now_secs() {
664 Some(access_token.clone())
665 } else {
666 None
667 }
668 }
669 AuthCredential::Session {
670 token, expires_at, ..
671 } => {
672 if *expires_at == 0 || *expires_at > now_secs() {
673 Some(token.clone())
674 } else {
675 None
676 }
677 }
678 };
679 }
680 }
681 }
682
683 if include_fallback && let Some(ref resolver) = *self.fallback_resolver.read() {
685 return resolver.resolve(provider);
686 }
687
688 None
689 }
690
691 pub fn set_api_key(&self, provider: &str, key: String) {
697 self.credentials
698 .write()
699 .insert(provider.to_string(), AuthCredential::ApiKey { key });
700 if let Err(e) = self.persist() {
701 tracing::warn!("Failed to persist API key for '{}': {}", provider, e);
702 }
703 }
704
705 pub fn set_oauth(
707 &self,
708 provider: &str,
709 access_token: String,
710 refresh_token: Option<String>,
711 expires_at: u64,
712 ) {
713 self.set_oauth_full(
714 provider,
715 access_token,
716 refresh_token,
717 expires_at,
718 None,
719 None,
720 );
721 }
722
723 pub fn set_oauth_full(
725 &self,
726 provider: &str,
727 access_token: String,
728 refresh_token: Option<String>,
729 expires_at: u64,
730 scopes: Option<String>,
731 provider_data: Option<serde_json::Value>,
732 ) {
733 self.credentials.write().insert(
734 provider.to_string(),
735 AuthCredential::OAuth {
736 access_token,
737 refresh_token,
738 expires_at,
739 scopes,
740 provider_data,
741 },
742 );
743 if let Err(e) = self.persist() {
744 tracing::warn!("Failed to persist OAuth token for '{}': {}", provider, e);
745 }
746 }
747
748 pub fn set_session(
750 &self,
751 provider: &str,
752 token: String,
753 expires_at: u64,
754 metadata: Option<serde_json::Value>,
755 ) {
756 self.credentials.write().insert(
757 provider.to_string(),
758 AuthCredential::Session {
759 token,
760 expires_at,
761 metadata,
762 },
763 );
764 if let Err(e) = self.persist() {
765 tracing::warn!("Failed to persist session for '{}': {}", provider, e);
766 }
767 }
768
769 pub fn update_oauth_tokens(
771 &self,
772 provider: &str,
773 new_access_token: String,
774 new_refresh_token: Option<String>,
775 new_expires_at: u64,
776 ) -> AuthResult<()> {
777 let mut creds = self.credentials.write();
778 let cred = creds
779 .get_mut(provider)
780 .ok_or_else(|| AuthError::NotFound(provider.to_string()))?;
781
782 match cred {
783 AuthCredential::OAuth {
784 access_token,
785 refresh_token,
786 expires_at,
787 ..
788 } => {
789 *access_token = new_access_token;
790 *refresh_token = new_refresh_token;
791 *expires_at = new_expires_at;
792 }
793 _ => {
794 return Err(AuthError::InvalidFormat(format!(
795 "Provider '{}' does not have OAuth credentials",
796 provider
797 )));
798 }
799 }
800
801 drop(creds);
802 if let Err(e) = self.persist() {
803 tracing::warn!(
804 "Failed to persist OAuth token update for '{}': {}",
805 provider,
806 e
807 );
808 }
809 Ok(())
810 }
811
812 pub fn get(&self, provider: &str) -> Option<AuthCredential> {
818 self.credentials.read().get(provider).cloned()
819 }
820
821 pub fn get_oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
823 self.credentials.read().get(provider).cloned()
824 }
825
826 pub fn has_oauth_with_refresh(&self, provider: &str) -> bool {
828 if let Some(cred) = self.credentials.read().get(provider) {
829 matches!(
830 cred,
831 AuthCredential::OAuth {
832 refresh_token: Some(_),
833 ..
834 }
835 )
836 } else {
837 false
838 }
839 }
840
841 pub fn set(&self, provider: &str, credential: AuthCredential) {
847 self.credentials
848 .write()
849 .insert(provider.to_string(), credential);
850 if let Err(e) = self.persist() {
851 tracing::warn!("Failed to persist credential for '{}': {}", provider, e);
852 }
853 }
854
855 pub fn remove(&self, provider: &str) {
857 self.credentials.write().remove(provider);
858 if let Err(e) = self.persist() {
859 tracing::warn!("Failed to persist after removing '{}': {}", provider, e);
860 }
861 }
862
863 pub fn list_providers(&self) -> Vec<String> {
865 self.credentials.read().keys().cloned().collect()
866 }
867
868 pub fn has(&self, provider: &str) -> bool {
870 self.credentials.read().contains_key(provider)
871 }
872
873 pub fn get_all(&self) -> HashMap<String, AuthCredential> {
875 self.credentials.read().clone()
876 }
877
878 pub fn clear(&self) {
880 self.credentials.write().clear();
881 if let Err(e) = self.persist() {
882 tracing::warn!("Failed to persist after clearing credentials: {}", e);
883 }
884 }
885
886 pub fn reload(&self) {
892 if let Some(ref storage) = self.file_storage {
893 match storage.read() {
894 Ok(Some(content)) => {
895 if let Ok(creds) = serde_json::from_str(&content) {
896 *self.credentials.write() = creds;
897 }
898 *self.load_error.write() = None;
899 }
900 Ok(None) => {
901 self.credentials.write().clear();
902 *self.load_error.write() = None;
903 }
904 Err(e) => {
905 *self.load_error.write() = Some(e);
906 self.record_error(AuthError::ReadError(
907 "Failed to reload auth storage".to_string(),
908 ));
909 }
910 }
911 }
912 }
913
914 #[allow(unexpected_cfgs)]
916 fn persist(&self) -> Result<(), String> {
917 if let Some(ref storage) = self.file_storage {
918 let creds = self.credentials.read();
919 if let Ok(json) = serde_json::to_string_pretty(&*creds) {
920 #[cfg(not(feature = "keyring"))]
922 {
923 self.plaintext_warned.get_or_init(|| {
924 tracing::warn!(
925 "Auth credentials are stored in plaintext. \
926 Enable the 'keyring' feature for secure OS-level storage."
927 );
928 });
929 }
930
931 if let Err(e) = storage.write(&json) {
932 tracing::error!("Failed to persist auth storage: {}", e);
933 self.record_error(e);
934 return Err("persist failed".to_string());
935 }
936 }
937 }
938 Ok(())
939 }
940
941 fn record_error(&self, error: AuthError) {
947 self.errors.write().push(error);
948 }
949
950 pub fn drain_errors(&self) -> Vec<AuthError> {
952 let mut errors = self.errors.write();
953 std::mem::take(&mut *errors)
954 }
955
956 pub fn load_error(&self) -> Option<AuthError> {
958 self.load_error.read().clone()
959 }
960
961 pub fn validate_all(&self) -> Vec<(String, CredentialValidationError)> {
967 let creds = self.credentials.read();
968 let mut results = Vec::new();
969 for (provider, cred) in creds.iter() {
970 if let Err(e) = cred.validate() {
971 results.push((provider.clone(), e));
972 }
973 }
974 results
975 }
976
977 pub fn validate(&self, provider: &str) -> Result<(), CredentialValidationError> {
979 let creds = self.credentials.read();
980 let cred = creds.get(provider).ok_or_else(|| {
981 CredentialValidationError::EmptyField(format!(
982 "no credential for provider '{}'",
983 provider
984 ))
985 })?;
986 cred.validate()
987 }
988
989 pub fn configured_providers(&self) -> Vec<String> {
995 let mut providers: Vec<String> = self.credentials.read().keys().cloned().collect();
996 providers.sort();
997 providers
998 }
999
1000 pub fn has_multiple_providers(&self) -> bool {
1002 self.credentials.read().len() > 1
1003 }
1004
1005 pub fn primary_provider(&self) -> Option<String> {
1007 let creds = self.credentials.read();
1008 creds.keys().next().cloned()
1009 }
1010
1011 pub fn migrate_provider(&self, from: &str, to: &str) -> AuthResult<()> {
1013 let mut creds = self.credentials.write();
1014 let cred = creds
1015 .remove(from)
1016 .ok_or_else(|| AuthError::NotFound(from.to_string()))?;
1017 creds.insert(to.to_string(), cred);
1018 drop(creds);
1019 let _ = self.persist();
1020 Ok(())
1021 }
1022}
1023
1024impl Default for AuthStorage {
1025 fn default() -> Self {
1026 Self::new()
1027 }
1028}
1029
1030fn now_secs() -> u64 {
1035 std::time::SystemTime::now()
1036 .duration_since(std::time::UNIX_EPOCH)
1037 .map(|d| d.as_secs())
1038 .unwrap_or(0)
1039}
1040
1041#[allow(unexpected_cfgs)]
1047pub mod keyring_support {
1048 use super::*;
1049
1050 #[cfg(feature = "keyring")]
1052 pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
1053 use keyring::Entry;
1054 Entry::new(service, account)
1055 .ok()
1056 .and_then(|entry| entry.get_password().ok())
1057 }
1058
1059 #[cfg(feature = "keyring")]
1061 pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> AuthResult<()> {
1062 use keyring::Entry;
1063 Entry::new(service, account)
1064 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1065 .set_password(secret)
1066 .map_err(|e| AuthError::KeyringError(e.to_string()))
1067 }
1068
1069 #[cfg(feature = "keyring")]
1071 pub fn delete_keyring_secret(service: &str, account: &str) -> AuthResult<()> {
1072 use keyring::Entry;
1073 Entry::new(service, account)
1074 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1075 .delete_credential()
1076 .map_err(|e| AuthError::KeyringError(e.to_string()))
1077 }
1078
1079 #[cfg(not(feature = "keyring"))]
1081 pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
1085 None
1086 }
1087
1088 #[cfg(not(feature = "keyring"))]
1089 pub fn set_keyring_secret(_service: &str, _account: &str, _secret: &str) -> AuthResult<()> {
1093 Err(AuthError::KeyringError(
1094 "Keyring support not compiled".to_string(),
1095 ))
1096 }
1097
1098 #[cfg(not(feature = "keyring"))]
1099 pub fn delete_keyring_secret(_service: &str, _account: &str) -> AuthResult<()> {
1103 Err(AuthError::KeyringError(
1104 "Keyring support not compiled".to_string(),
1105 ))
1106 }
1107}
1108
1109pub fn shared_auth_storage() -> Arc<AuthStorage> {
1119 static STORAGE: OnceLock<Arc<AuthStorage>> = OnceLock::new();
1120 STORAGE
1121 .get_or_init(|| {
1122 let storage = Arc::new(AuthStorage::new());
1123 storage.set_fallback_resolver(Arc::new(EnvVarFallbackResolver));
1126 storage
1127 })
1128 .clone()
1129}
1130
1131#[cfg(test)]
1136mod tests {
1137 use super::*;
1138
1139 #[test]
1140 fn test_auth_storage_new() {
1141 let storage = AuthStorage::in_memory();
1142 assert!(!storage.has("anthropic"));
1143 }
1144
1145 #[test]
1146 fn test_set_and_get_api_key() {
1147 let storage = AuthStorage::in_memory();
1148 storage.set_api_key("anthropic", "sk-test123".to_string());
1149 assert!(storage.has("anthropic"));
1150 assert_eq!(
1151 storage.get_api_key("anthropic"),
1152 Some("sk-test123".to_string())
1153 );
1154 }
1155
1156 #[test]
1157 fn test_runtime_override() {
1158 let storage = AuthStorage::in_memory();
1159 storage.set_api_key("anthropic", "stored-key".to_string());
1160 storage.set_runtime_key("anthropic", "runtime-key".to_string());
1161
1162 assert_eq!(
1164 storage.get_api_key("anthropic"),
1165 Some("runtime-key".to_string())
1166 );
1167 }
1168
1169 #[test]
1170 fn test_remove_credential() {
1171 let storage = AuthStorage::in_memory();
1172 storage.set_api_key("anthropic", "sk-test123".to_string());
1173 assert!(storage.has("anthropic"));
1174
1175 storage.remove("anthropic");
1176 assert!(!storage.has("anthropic"));
1177 }
1178
1179 #[test]
1180 fn test_auth_status() {
1181 let storage = AuthStorage::in_memory();
1182 storage.set_api_key("anthropic", "sk-test123".to_string());
1183
1184 let status = storage.get_status("anthropic");
1185 assert!(status.configured);
1186 assert_eq!(status.source, Some("stored".to_string()));
1187 assert_eq!(status.label, Some("api_key".to_string()));
1188 }
1189
1190 #[test]
1191 fn test_auth_status_display() {
1192 let status = AuthStatus {
1193 configured: true,
1194 source: Some("stored".to_string()),
1195 label: Some("api_key".to_string()),
1196 };
1197 let display = format!("{}", status);
1198 assert_eq!(display, "stored (api_key)");
1199
1200 let no_config = AuthStatus {
1201 configured: false,
1202 source: None,
1203 label: None,
1204 };
1205 assert_eq!(format!("{}", no_config), "not configured");
1206 }
1207
1208 #[test]
1209 fn test_list_providers() {
1210 let storage = AuthStorage::in_memory();
1211 storage.set_api_key("anthropic", "key1".to_string());
1212 storage.set_api_key("openai", "key2".to_string());
1213
1214 let providers = storage.list_providers();
1215 assert!(providers.contains(&"anthropic".to_string()));
1216 assert!(providers.contains(&"openai".to_string()));
1217 }
1218
1219 #[test]
1220 fn test_oauth_credential() {
1221 let storage = AuthStorage::in_memory();
1222 storage.set_oauth(
1223 "provider",
1224 "access123".to_string(),
1225 Some("refresh456".to_string()),
1226 u64::MAX,
1227 );
1228
1229 assert!(storage.has("provider"));
1230 assert_eq!(
1231 storage.get_api_key("provider"),
1232 Some("access123".to_string())
1233 );
1234 }
1235
1236 #[test]
1237 fn test_expired_oauth_token() {
1238 let storage = AuthStorage::in_memory();
1239 storage.set_oauth("provider", "access123".to_string(), None, 0);
1241
1242 let key = storage.get_api_key("provider");
1244 assert!(key.is_none());
1245 }
1246
1247 #[test]
1248 fn test_get_all_credentials() {
1249 let storage = AuthStorage::in_memory();
1250 storage.set_api_key("anthropic", "key1".to_string());
1251 storage.set_api_key("openai", "key2".to_string());
1252
1253 let all = storage.get_all();
1254 assert_eq!(all.len(), 2);
1255 }
1256
1257 #[test]
1258 fn test_clear() {
1259 let storage = AuthStorage::in_memory();
1260 storage.set_api_key("anthropic", "key".to_string());
1261 assert!(storage.has("anthropic"));
1262
1263 storage.clear();
1264 assert!(!storage.has("anthropic"));
1265 }
1266
1267 #[test]
1268 fn test_remove_runtime_key() {
1269 let storage = AuthStorage::in_memory();
1270 storage.set_api_key("anthropic", "stored".to_string());
1271 storage.set_runtime_key("anthropic", "runtime".to_string());
1272
1273 assert_eq!(
1274 storage.get_api_key("anthropic"),
1275 Some("runtime".to_string())
1276 );
1277
1278 storage.remove_runtime_key("anthropic");
1279 assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
1280 }
1281
1282 #[test]
1283 fn test_auth_credential_is_expired() {
1284 let api_key_cred = AuthCredential::ApiKey {
1286 key: "test".to_string(),
1287 };
1288 assert!(!api_key_cred.is_expired());
1289
1290 let future_time = now_secs() + 3600;
1292 let oauth_cred = AuthCredential::OAuth {
1293 access_token: "token".to_string(),
1294 refresh_token: Some("refresh".to_string()),
1295 expires_at: future_time,
1296 scopes: None,
1297 provider_data: None,
1298 };
1299 assert!(!oauth_cred.is_expired());
1300
1301 let oauth_cred_expired = AuthCredential::OAuth {
1303 access_token: "token".to_string(),
1304 refresh_token: Some("refresh".to_string()),
1305 expires_at: 0,
1306 scopes: None,
1307 provider_data: None,
1308 };
1309 assert!(oauth_cred_expired.is_expired());
1310 }
1311
1312 #[test]
1313 fn test_auth_credential_needs_refresh() {
1314 let future_time = now_secs() + 120; let oauth_cred = AuthCredential::OAuth {
1318 access_token: "token".to_string(),
1319 refresh_token: Some("refresh".to_string()),
1320 expires_at: future_time,
1321 scopes: None,
1322 provider_data: None,
1323 };
1324 assert!(!oauth_cred.needs_refresh());
1325
1326 let soon = now_secs() + 30;
1328 let oauth_soon = AuthCredential::OAuth {
1329 access_token: "token".to_string(),
1330 refresh_token: Some("refresh".to_string()),
1331 expires_at: soon,
1332 scopes: None,
1333 provider_data: None,
1334 };
1335 assert!(oauth_soon.needs_refresh());
1336
1337 let no_refresh = AuthCredential::OAuth {
1339 access_token: "token".to_string(),
1340 refresh_token: None,
1341 expires_at: future_time,
1342 scopes: None,
1343 provider_data: None,
1344 };
1345 assert!(!no_refresh.needs_refresh());
1346
1347 let api_key_cred = AuthCredential::ApiKey {
1349 key: "test".to_string(),
1350 };
1351 assert!(!api_key_cred.needs_refresh());
1352 }
1353
1354 #[test]
1355 fn test_auth_credential_access_token() {
1356 let future_time = now_secs() + 3600;
1357
1358 let oauth_cred = AuthCredential::OAuth {
1359 access_token: "valid_token".to_string(),
1360 refresh_token: Some("refresh".to_string()),
1361 expires_at: future_time,
1362 scopes: None,
1363 provider_data: None,
1364 };
1365 assert_eq!(oauth_cred.access_token(), Some("valid_token"));
1366
1367 let expired_cred = AuthCredential::OAuth {
1369 access_token: "expired_token".to_string(),
1370 refresh_token: Some("refresh".to_string()),
1371 expires_at: 0,
1372 scopes: None,
1373 provider_data: None,
1374 };
1375 assert!(expired_cred.access_token().is_none());
1376
1377 let api_key_cred = AuthCredential::ApiKey {
1379 key: "api_key_token".to_string(),
1380 };
1381 assert!(api_key_cred.access_token().is_none());
1382 }
1383
1384 #[test]
1385 fn test_get_oauth_credential() {
1386 let storage = AuthStorage::in_memory();
1387 storage.set_oauth(
1388 "provider",
1389 "access".to_string(),
1390 Some("refresh".to_string()),
1391 u64::MAX,
1392 );
1393
1394 let cred = storage.get_oauth_credential("provider");
1395 assert!(cred.is_some());
1396 assert!(matches!(cred.unwrap(), AuthCredential::OAuth { .. }));
1397 }
1398
1399 #[test]
1400 fn test_has_oauth_with_refresh() {
1401 let storage = AuthStorage::in_memory();
1402
1403 storage.set_oauth(
1405 "with_refresh",
1406 "access".to_string(),
1407 Some("refresh".to_string()),
1408 u64::MAX,
1409 );
1410 assert!(storage.has_oauth_with_refresh("with_refresh"));
1411
1412 storage.set_oauth("without_refresh", "access".to_string(), None, u64::MAX);
1414 assert!(!storage.has_oauth_with_refresh("without_refresh"));
1415
1416 storage.set_api_key("apikey_provider", "key".to_string());
1418 assert!(!storage.has_oauth_with_refresh("apikey_provider"));
1419 }
1420
1421 #[test]
1422 fn test_set_oauth_full() {
1423 let storage = AuthStorage::in_memory();
1424 storage.set_oauth_full(
1425 "provider",
1426 "access_token".to_string(),
1427 Some("refresh_token".to_string()),
1428 3600,
1429 Some("read write".to_string()),
1430 Some(serde_json::json!({"extra": "data"})),
1431 );
1432
1433 let cred = storage.get_oauth_credential("provider");
1434 assert!(cred.is_some());
1435 if let AuthCredential::OAuth {
1436 scopes,
1437 provider_data,
1438 ..
1439 } = cred.unwrap()
1440 {
1441 assert_eq!(scopes, Some("read write".to_string()));
1442 assert!(provider_data.is_some());
1443 } else {
1444 panic!("Expected OAuth credential");
1445 }
1446 }
1447
1448 #[test]
1449 fn test_session_token() {
1450 let storage = AuthStorage::in_memory();
1451 storage.set_session(
1452 "browser",
1453 "session-token-123".to_string(),
1454 0, Some(serde_json::json!({"user": "test"})),
1456 );
1457
1458 assert!(storage.has("browser"));
1459 assert_eq!(
1460 storage.get_api_key("browser"),
1461 Some("session-token-123".to_string())
1462 );
1463
1464 let cred = storage.get("browser").unwrap();
1465 assert!(matches!(cred, AuthCredential::Session { .. }));
1466 assert!(cred.access_token().is_some());
1467 }
1468
1469 #[test]
1470 fn test_session_token_expired() {
1471 let storage = AuthStorage::in_memory();
1472 storage.set_session("browser", "session-token".to_string(), 1, None);
1473
1474 assert!(storage.get_api_key("browser").is_none());
1476 }
1477
1478 #[test]
1479 fn test_credential_validation() {
1480 let valid = AuthCredential::ApiKey {
1482 key: "sk-valid".to_string(),
1483 };
1484 assert!(valid.validate().is_ok());
1485
1486 let empty = AuthCredential::ApiKey {
1488 key: "".to_string(),
1489 };
1490 assert!(empty.validate().is_err());
1491
1492 let placeholder = AuthCredential::ApiKey {
1494 key: "your-api-key-here".to_string(),
1495 };
1496 assert!(placeholder.validate().is_err());
1497
1498 let valid_oauth = AuthCredential::OAuth {
1500 access_token: "token".to_string(),
1501 refresh_token: None,
1502 expires_at: now_secs() + 3600,
1503 scopes: None,
1504 provider_data: None,
1505 };
1506 assert!(valid_oauth.validate().is_ok());
1507
1508 let invalid_oauth = AuthCredential::OAuth {
1510 access_token: "".to_string(),
1511 refresh_token: None,
1512 expires_at: 1000,
1513 scopes: None,
1514 provider_data: None,
1515 };
1516 assert!(invalid_oauth.validate().is_err());
1517 }
1518
1519 #[test]
1520 fn test_validate_all() {
1521 let storage = AuthStorage::in_memory();
1522 storage.set_api_key("valid", "sk-good".to_string());
1523 storage.set_api_key("empty", "".to_string());
1524
1525 let errors = storage.validate_all();
1526 assert_eq!(errors.len(), 1);
1527 assert_eq!(errors[0].0, "empty");
1528 }
1529
1530 #[test]
1531 fn test_update_oauth_tokens() {
1532 let storage = AuthStorage::in_memory();
1533 storage.set_oauth(
1534 "provider",
1535 "old-access".to_string(),
1536 Some("old-refresh".to_string()),
1537 now_secs() + 3600,
1538 );
1539
1540 storage
1541 .update_oauth_tokens(
1542 "provider",
1543 "new-access".to_string(),
1544 Some("new-refresh".to_string()),
1545 now_secs() + 7200,
1546 )
1547 .unwrap();
1548
1549 let key = storage.get_api_key("provider");
1550 assert_eq!(key, Some("new-access".to_string()));
1551 }
1552
1553 #[test]
1554 fn test_update_oauth_tokens_wrong_type() {
1555 let storage = AuthStorage::in_memory();
1556 storage.set_api_key("provider", "key".to_string());
1557
1558 let result = storage.update_oauth_tokens(
1559 "provider",
1560 "new-access".to_string(),
1561 None,
1562 now_secs() + 3600,
1563 );
1564 assert!(result.is_err());
1565 }
1566
1567 #[test]
1568 fn test_migrate_provider() {
1569 let storage = AuthStorage::in_memory();
1570 storage.set_api_key("old-provider", "key123".to_string());
1571 storage
1572 .migrate_provider("old-provider", "new-provider")
1573 .unwrap();
1574
1575 assert!(!storage.has("old-provider"));
1576 assert!(storage.has("new-provider"));
1577 assert_eq!(
1578 storage.get_api_key("new-provider"),
1579 Some("key123".to_string())
1580 );
1581 }
1582
1583 #[test]
1584 fn test_migrate_provider_not_found() {
1585 let storage = AuthStorage::in_memory();
1586 let result = storage.migrate_provider("nonexistent", "target");
1587 assert!(result.is_err());
1588 }
1589
1590 #[test]
1591 fn test_error_draining() {
1592 let storage = AuthStorage::in_memory();
1593 let errors = storage.drain_errors();
1594 assert!(errors.is_empty());
1595 }
1596
1597 #[test]
1598 fn test_fallback_resolver() {
1599 let storage = AuthStorage::in_memory();
1600 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|provider| {
1601 if provider == "custom" {
1602 Some("custom-key-from-config".to_string())
1603 } else {
1604 None
1605 }
1606 }))));
1607
1608 assert_eq!(
1609 storage.get_api_key("custom"),
1610 Some("custom-key-from-config".to_string())
1611 );
1612 assert!(storage.get_api_key("unknown").is_none());
1613
1614 storage.clear_fallback_resolver();
1616 assert!(storage.get_api_key("custom").is_none());
1617 }
1618
1619 #[test]
1620 fn test_get_api_key_with_options() {
1621 let storage = AuthStorage::in_memory();
1622 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|_| {
1623 Some("fallback-key".to_string())
1624 }))));
1625
1626 assert_eq!(
1628 storage.get_api_key_with_options("test", true),
1629 Some("fallback-key".to_string())
1630 );
1631
1632 assert!(storage.get_api_key_with_options("test", false).is_none());
1634 }
1635
1636 #[test]
1637 fn test_configured_providers() {
1638 let storage = AuthStorage::in_memory();
1639 storage.set_api_key("openai", "key".to_string());
1640 storage.set_api_key("anthropic", "key".to_string());
1641
1642 let providers = storage.configured_providers();
1643 assert!(providers.len() >= 2);
1644 let mut sorted = providers.clone();
1646 sorted.sort();
1647 assert_eq!(providers, sorted);
1648 }
1649
1650 #[test]
1651 fn test_has_multiple_providers() {
1652 let storage = AuthStorage::in_memory();
1653 assert!(!storage.has_multiple_providers());
1654
1655 storage.set_api_key("openai", "key1".to_string());
1656 assert!(!storage.has_multiple_providers());
1657
1658 storage.set_api_key("anthropic", "key2".to_string());
1659 assert!(storage.has_multiple_providers());
1660 }
1661
1662 #[test]
1663 fn test_set_and_get_credential() {
1664 let storage = AuthStorage::in_memory();
1665 let cred = AuthCredential::Session {
1666 token: "abc".to_string(),
1667 expires_at: 0,
1668 metadata: None,
1669 };
1670 storage.set("custom", cred);
1671 let retrieved = storage.get("custom");
1672 assert!(retrieved.is_some());
1673 assert!(matches!(retrieved.unwrap(), AuthCredential::Session { .. }));
1674 }
1675
1676 #[test]
1677 fn test_credential_type_name() {
1678 assert_eq!(
1679 AuthCredential::ApiKey {
1680 key: "k".to_string()
1681 }
1682 .type_name(),
1683 "api_key"
1684 );
1685 assert_eq!(
1686 AuthCredential::OAuth {
1687 access_token: "t".to_string(),
1688 refresh_token: None,
1689 expires_at: 0,
1690 scopes: None,
1691 provider_data: None,
1692 }
1693 .type_name(),
1694 "oauth"
1695 );
1696 assert_eq!(
1697 AuthCredential::Session {
1698 token: "t".to_string(),
1699 expires_at: 0,
1700 metadata: None,
1701 }
1702 .type_name(),
1703 "session"
1704 );
1705 }
1706}