1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::{Arc, OnceLock};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum AuthCredential {
21 ApiKey {
23 key: String,
25 },
26 OAuth {
28 access_token: String,
30 refresh_token: Option<String>,
32 expires_at: u64,
34 #[serde(default)]
36 scopes: Option<String>,
37 #[serde(default)]
39 provider_data: Option<serde_json::Value>,
40 },
41 Session {
43 token: String,
45 #[serde(default)]
47 expires_at: u64,
48 #[serde(default)]
50 metadata: Option<serde_json::Value>,
51 },
52}
53
54impl AuthCredential {
55 pub fn is_expired(&self) -> bool {
57 match self {
58 AuthCredential::OAuth { expires_at, .. } => {
59 let now = now_secs();
60 *expires_at < now
61 }
62 AuthCredential::Session { expires_at, .. } => {
63 if *expires_at == 0 {
64 return false; }
66 *expires_at <= now_secs()
67 }
68 AuthCredential::ApiKey { .. } => false,
69 }
70 }
71
72 pub fn needs_refresh(&self) -> bool {
74 match self {
75 AuthCredential::OAuth {
76 expires_at,
77 refresh_token,
78 ..
79 } => {
80 let now = now_secs();
81 refresh_token.is_some() && *expires_at <= now + 60
82 }
83 AuthCredential::Session { .. } => false,
84 AuthCredential::ApiKey { .. } => false,
85 }
86 }
87
88 pub fn access_token(&self) -> Option<&str> {
90 match self {
91 AuthCredential::OAuth { access_token, .. } if !self.is_expired() => Some(access_token),
92 AuthCredential::Session { token, .. } if !self.is_expired() => Some(token),
93 _ => None,
94 }
95 }
96
97 pub fn type_name(&self) -> &'static str {
99 match self {
100 AuthCredential::ApiKey { .. } => "api_key",
101 AuthCredential::OAuth { .. } => "oauth",
102 AuthCredential::Session { .. } => "session",
103 }
104 }
105
106 pub fn validate(&self) -> Result<(), CredentialValidationError> {
108 match self {
109 AuthCredential::ApiKey { key } => {
110 if key.is_empty() {
111 return Err(CredentialValidationError::EmptyField("key".to_string()));
112 }
113 if key == "your-api-key-here" || key == "xxx" {
115 return Err(CredentialValidationError::PlaceholderValue(key.clone()));
116 }
117 Ok(())
118 }
119 AuthCredential::OAuth {
120 access_token,
121 expires_at,
122 ..
123 } => {
124 if access_token.is_empty() {
125 return Err(CredentialValidationError::EmptyField(
126 "access_token".to_string(),
127 ));
128 }
129 if *expires_at == 0 {
130 return Err(CredentialValidationError::InvalidExpiry);
131 }
132 Ok(())
133 }
134 AuthCredential::Session { token, .. } => {
135 if token.is_empty() {
136 return Err(CredentialValidationError::EmptyField("token".to_string()));
137 }
138 Ok(())
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, thiserror::Error)]
146pub enum CredentialValidationError {
147 #[error("Field '{0}' must not be empty")]
148 EmptyField(String),
150 #[error("Placeholder value detected: '{0}'")]
151 PlaceholderValue(String),
153 #[error("Invalid expiry timestamp")]
154 InvalidExpiry,
156}
157
158#[derive(Debug, Clone)]
164pub struct AuthStatus {
165 pub configured: bool,
167 pub source: Option<String>,
169 pub label: Option<String>,
171}
172
173impl std::fmt::Display for AuthStatus {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match (&self.source, &self.label) {
176 (Some(source), Some(label)) => write!(f, "{} ({})", source, label),
177 (Some(source), None) => write!(f, "{}", source),
178 (None, Some(label)) => write!(f, "{}", label),
179 (None, None) => write!(f, "not configured"),
180 }
181 }
182}
183
184pub type AuthResult<T> = Result<T, AuthError>;
190
191#[derive(Debug, Clone, thiserror::Error)]
193pub enum AuthError {
194 #[error("Failed to read auth storage: {0}")]
195 ReadError(String),
197 #[error("Failed to write auth storage: {0}")]
198 WriteError(String),
200 #[error("Credential not found: {0}")]
201 NotFound(String),
203 #[error("Invalid credential format: {0}")]
204 InvalidFormat(String),
206 #[error("Keyring error: {0}")]
207 KeyringError(String),
209 #[error("Credential validation failed: {0}")]
210 ValidationFailed(String),
212}
213
214pub trait AuthStorageBackend: Send + Sync {
220 fn read(&self) -> AuthResult<Option<String>>;
222 fn write(&self, data: &str) -> AuthResult<()>;
224 fn delete(&self) -> AuthResult<()>;
226}
227
228pub struct FileAuthStorage {
234 path: PathBuf,
235 cache: RwLock<Option<String>>,
236}
237
238impl FileAuthStorage {
239 pub fn new(path: PathBuf) -> Self {
241 Self {
242 path,
243 cache: RwLock::new(None),
244 }
245 }
246
247 pub fn default_path() -> Option<PathBuf> {
249 dirs::home_dir().map(|p| p.join(".oxi").join("auth.json"))
250 }
251
252 pub fn path(&self) -> &PathBuf {
254 &self.path
255 }
256}
257
258impl AuthStorageBackend for FileAuthStorage {
259 fn read(&self) -> AuthResult<Option<String>> {
260 if !self.path.exists() {
261 return Ok(None);
262 }
263
264 match std::fs::read_to_string(&self.path) {
265 Ok(content) => {
266 *self.cache.write() = Some(content.clone());
267 Ok(Some(content))
268 }
269 Err(e) => Err(AuthError::ReadError(e.to_string())),
270 }
271 }
272
273 fn write(&self, data: &str) -> AuthResult<()> {
274 if let Some(parent) = self.path.parent() {
276 std::fs::create_dir_all(parent).map_err(|e| AuthError::WriteError(e.to_string()))?;
277
278 #[cfg(unix)]
279 {
280 use std::os::unix::fs::PermissionsExt;
281 let perms = std::fs::Permissions::from_mode(0o700);
282 let _ = std::fs::set_permissions(parent, perms);
283 }
284 }
285
286 std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
288
289 #[cfg(unix)]
291 {
292 use std::os::unix::fs::PermissionsExt;
293 let perms = std::fs::Permissions::from_mode(0o600);
294 std::fs::set_permissions(&self.path, perms)
295 .map_err(|e| AuthError::WriteError(e.to_string()))?;
296 }
297
298 *self.cache.write() = Some(data.to_string());
299 Ok(())
300 }
301
302 fn delete(&self) -> AuthResult<()> {
303 if self.path.exists() {
304 std::fs::remove_file(&self.path).map_err(|e| AuthError::WriteError(e.to_string()))?;
305 }
306 *self.cache.write() = None;
307 Ok(())
308 }
309}
310
311pub struct MemoryAuthStorage {
317 data: RwLock<HashMap<String, AuthCredential>>,
318}
319
320impl MemoryAuthStorage {
321 pub fn new() -> Self {
323 Self {
324 data: RwLock::new(HashMap::new()),
325 }
326 }
327}
328
329impl Default for MemoryAuthStorage {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl AuthStorageBackend for MemoryAuthStorage {
336 fn read(&self) -> AuthResult<Option<String>> {
337 Ok(None)
339 }
340
341 fn write(&self, _data: &str) -> AuthResult<()> {
342 Ok(())
343 }
344
345 fn delete(&self) -> AuthResult<()> {
346 self.data.write().clear();
347 Ok(())
348 }
349}
350
351pub trait FallbackResolver: Send + Sync {
357 fn resolve(&self, provider: &str) -> Option<String>;
359}
360
361pub struct FnFallbackResolver {
363 #[allow(clippy::type_complexity)]
364 f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>,
365}
366
367impl FnFallbackResolver {
368 #[allow(clippy::type_complexity)]
370 pub fn new(f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>) -> Self {
371 Self { f }
372 }
373}
374
375impl FallbackResolver for FnFallbackResolver {
376 fn resolve(&self, provider: &str) -> Option<String> {
377 (self.f)(provider)
378 }
379}
380
381pub struct EnvVarFallbackResolver;
386
387impl FallbackResolver for EnvVarFallbackResolver {
388 fn resolve(&self, provider: &str) -> Option<String> {
389 let builtin = oxi_ai::get_builtin_provider(provider)?;
391 let key = builtin.env_key;
392
393 if let Ok(val) = std::env::var(key) {
395 if !val.is_empty() {
396 return Some(val);
397 }
398 }
399
400 for extra in builtin.extra_env_keys {
402 if let Ok(val) = std::env::var(extra) {
403 if !val.is_empty() {
404 return Some(val);
405 }
406 }
407 }
408
409 None
410 }
411}
412
413pub struct AuthStorage {
427 file_storage: Option<Arc<dyn AuthStorageBackend>>,
429 credentials: RwLock<HashMap<String, AuthCredential>>,
431 runtime_overrides: RwLock<HashMap<String, String>>,
433 fallback_resolver: RwLock<Option<Arc<dyn FallbackResolver>>>,
435 errors: RwLock<Vec<AuthError>>,
437 load_error: RwLock<Option<AuthError>>,
439 plaintext_warned: OnceLock<()>,
441}
442
443impl AuthStorage {
444 pub fn new() -> Self {
446 let file_storage = FileAuthStorage::default_path()
447 .map(|p| Arc::new(FileAuthStorage::new(p)) as Arc<dyn AuthStorageBackend>);
448
449 let credentials = if let Some(ref storage) = file_storage {
450 match storage.read() {
451 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
452 _ => HashMap::new(),
453 }
454 } else {
455 HashMap::new()
456 };
457
458 Self {
459 file_storage,
460 credentials: RwLock::new(credentials),
461 runtime_overrides: RwLock::new(HashMap::new()),
462 fallback_resolver: RwLock::new(None),
463 errors: RwLock::new(Vec::new()),
464 load_error: RwLock::new(None),
465 plaintext_warned: OnceLock::new(),
466 }
467 }
468
469 pub fn with_backend(backend: impl AuthStorageBackend + 'static) -> Self {
471 let credentials = match backend.read() {
472 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
473 _ => HashMap::new(),
474 };
475
476 Self {
477 file_storage: Some(Arc::new(backend)),
478 credentials: RwLock::new(credentials),
479 runtime_overrides: RwLock::new(HashMap::new()),
480 fallback_resolver: RwLock::new(None),
481 errors: RwLock::new(Vec::new()),
482 load_error: RwLock::new(None),
483 plaintext_warned: OnceLock::new(),
484 }
485 }
486
487 pub fn in_memory() -> Self {
489 Self {
490 file_storage: None,
491 credentials: RwLock::new(HashMap::new()),
492 runtime_overrides: RwLock::new(HashMap::new()),
493 fallback_resolver: RwLock::new(None),
494 errors: RwLock::new(Vec::new()),
495 load_error: RwLock::new(None),
496 plaintext_warned: OnceLock::new(),
497 }
498 }
499
500 pub fn default_path() -> Option<PathBuf> {
502 FileAuthStorage::default_path()
503 }
504
505 pub fn set_runtime_key(&self, provider: &str, api_key: String) {
511 self.runtime_overrides
512 .write()
513 .insert(provider.to_string(), api_key);
514 }
515
516 pub fn remove_runtime_key(&self, provider: &str) {
518 self.runtime_overrides.write().remove(provider);
519 }
520
521 pub fn set_fallback_resolver(&self, resolver: Arc<dyn FallbackResolver>) {
528 *self.fallback_resolver.write() = Some(resolver);
529 }
530
531 pub fn clear_fallback_resolver(&self) {
533 *self.fallback_resolver.write() = None;
534 }
535
536 pub fn has_auth(&self, provider: &str) -> bool {
542 if self.runtime_overrides.read().contains_key(provider) {
543 return true;
544 }
545 if self.credentials.read().contains_key(provider) {
546 return true;
547 }
548 if let Some(ref resolver) = *self.fallback_resolver.read() {
549 if resolver.resolve(provider).is_some() {
550 return true;
551 }
552 }
553 false
554 }
555
556 pub fn get_status(&self, provider: &str) -> AuthStatus {
558 if self.runtime_overrides.read().contains_key(provider) {
559 return AuthStatus {
560 configured: false,
561 source: Some("runtime".to_string()),
562 label: Some("--api-key".to_string()),
563 };
564 }
565
566 if let Some(cred) = self.credentials.read().get(provider) {
567 return AuthStatus {
568 configured: true,
569 source: Some("stored".to_string()),
570 label: Some(cred.type_name().to_string()),
571 };
572 }
573
574 if let Some(ref resolver) = *self.fallback_resolver.read() {
575 if resolver.resolve(provider).is_some() {
576 return AuthStatus {
577 configured: false,
578 source: Some("fallback".to_string()),
579 label: Some("custom provider config".to_string()),
580 };
581 }
582 }
583
584 AuthStatus {
585 configured: false,
586 source: None,
587 label: None,
588 }
589 }
590
591 pub fn get_api_key(&self, provider: &str) -> Option<String> {
600 self.get_api_key_with_options(provider, true)
601 }
602
603 pub fn get_api_key_with_options(
605 &self,
606 provider: &str,
607 include_fallback: bool,
608 ) -> Option<String> {
609 if let Some(key) = self.runtime_overrides.read().get(provider) {
611 return Some(key.clone());
612 }
613
614 if let Some(cred) = self.credentials.read().get(provider) {
616 return match cred {
617 AuthCredential::ApiKey { key } => Some(key.clone()),
618 AuthCredential::OAuth {
619 access_token,
620 expires_at,
621 ..
622 } => {
623 if *expires_at > now_secs() {
624 Some(access_token.clone())
625 } else {
626 None
628 }
629 }
630 AuthCredential::Session {
631 token, expires_at, ..
632 } => {
633 if *expires_at == 0 || *expires_at > now_secs() {
634 Some(token.clone())
635 } else {
636 None
637 }
638 }
639 };
640 }
641
642 if include_fallback {
644 if let Some(ref resolver) = *self.fallback_resolver.read() {
645 return resolver.resolve(provider);
646 }
647 }
648
649 None
650 }
651
652 pub fn set_api_key(&self, provider: &str, key: String) {
658 self.credentials
659 .write()
660 .insert(provider.to_string(), AuthCredential::ApiKey { key });
661 if let Err(e) = self.persist() {
662 tracing::warn!("Failed to persist API key for '{}': {}", provider, e);
663 }
664 }
665
666 pub fn set_oauth(
668 &self,
669 provider: &str,
670 access_token: String,
671 refresh_token: Option<String>,
672 expires_at: u64,
673 ) {
674 self.set_oauth_full(
675 provider,
676 access_token,
677 refresh_token,
678 expires_at,
679 None,
680 None,
681 );
682 }
683
684 pub fn set_oauth_full(
686 &self,
687 provider: &str,
688 access_token: String,
689 refresh_token: Option<String>,
690 expires_at: u64,
691 scopes: Option<String>,
692 provider_data: Option<serde_json::Value>,
693 ) {
694 self.credentials.write().insert(
695 provider.to_string(),
696 AuthCredential::OAuth {
697 access_token,
698 refresh_token,
699 expires_at,
700 scopes,
701 provider_data,
702 },
703 );
704 if let Err(e) = self.persist() {
705 tracing::warn!("Failed to persist OAuth token for '{}': {}", provider, e);
706 }
707 }
708
709 pub fn set_session(
711 &self,
712 provider: &str,
713 token: String,
714 expires_at: u64,
715 metadata: Option<serde_json::Value>,
716 ) {
717 self.credentials.write().insert(
718 provider.to_string(),
719 AuthCredential::Session {
720 token,
721 expires_at,
722 metadata,
723 },
724 );
725 if let Err(e) = self.persist() {
726 tracing::warn!("Failed to persist session for '{}': {}", provider, e);
727 }
728 }
729
730 pub fn update_oauth_tokens(
732 &self,
733 provider: &str,
734 new_access_token: String,
735 new_refresh_token: Option<String>,
736 new_expires_at: u64,
737 ) -> AuthResult<()> {
738 let mut creds = self.credentials.write();
739 let cred = creds
740 .get_mut(provider)
741 .ok_or_else(|| AuthError::NotFound(provider.to_string()))?;
742
743 match cred {
744 AuthCredential::OAuth {
745 access_token,
746 refresh_token,
747 expires_at,
748 ..
749 } => {
750 *access_token = new_access_token;
751 *refresh_token = new_refresh_token;
752 *expires_at = new_expires_at;
753 }
754 _ => {
755 return Err(AuthError::InvalidFormat(format!(
756 "Provider '{}' does not have OAuth credentials",
757 provider
758 )));
759 }
760 }
761
762 drop(creds);
763 if let Err(e) = self.persist() {
764 tracing::warn!(
765 "Failed to persist OAuth token update for '{}': {}",
766 provider,
767 e
768 );
769 }
770 Ok(())
771 }
772
773 pub fn get(&self, provider: &str) -> Option<AuthCredential> {
779 self.credentials.read().get(provider).cloned()
780 }
781
782 pub fn get_oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
784 self.credentials.read().get(provider).cloned()
785 }
786
787 pub fn has_oauth_with_refresh(&self, provider: &str) -> bool {
789 if let Some(cred) = self.credentials.read().get(provider) {
790 matches!(
791 cred,
792 AuthCredential::OAuth {
793 refresh_token: Some(_),
794 ..
795 }
796 )
797 } else {
798 false
799 }
800 }
801
802 pub fn set(&self, provider: &str, credential: AuthCredential) {
808 self.credentials
809 .write()
810 .insert(provider.to_string(), credential);
811 if let Err(e) = self.persist() {
812 tracing::warn!("Failed to persist credential for '{}': {}", provider, e);
813 }
814 }
815
816 pub fn remove(&self, provider: &str) {
818 self.credentials.write().remove(provider);
819 if let Err(e) = self.persist() {
820 tracing::warn!("Failed to persist after removing '{}': {}", provider, e);
821 }
822 }
823
824 pub fn list_providers(&self) -> Vec<String> {
826 self.credentials.read().keys().cloned().collect()
827 }
828
829 pub fn has(&self, provider: &str) -> bool {
831 self.credentials.read().contains_key(provider)
832 }
833
834 pub fn get_all(&self) -> HashMap<String, AuthCredential> {
836 self.credentials.read().clone()
837 }
838
839 pub fn clear(&self) {
841 self.credentials.write().clear();
842 if let Err(e) = self.persist() {
843 tracing::warn!("Failed to persist after clearing credentials: {}", e);
844 }
845 }
846
847 pub fn reload(&self) {
853 if let Some(ref storage) = self.file_storage {
854 match storage.read() {
855 Ok(Some(content)) => {
856 if let Ok(creds) = serde_json::from_str(&content) {
857 *self.credentials.write() = creds;
858 }
859 *self.load_error.write() = None;
860 }
861 Ok(None) => {
862 self.credentials.write().clear();
863 *self.load_error.write() = None;
864 }
865 Err(e) => {
866 *self.load_error.write() = Some(e);
867 self.record_error(AuthError::ReadError(
868 "Failed to reload auth storage".to_string(),
869 ));
870 }
871 }
872 }
873 }
874
875 fn persist(&self) -> Result<(), String> {
877 if let Some(ref storage) = self.file_storage {
878 let creds = self.credentials.read();
879 if let Ok(json) = serde_json::to_string_pretty(&*creds) {
880 #[cfg(not(feature = "keyring"))]
882 {
883 self.plaintext_warned.get_or_init(|| {
884 tracing::warn!(
885 "Auth credentials are stored in plaintext. \
886 Enable the 'keyring' feature for secure OS-level storage."
887 );
888 });
889 }
890
891 if let Err(e) = storage.write(&json) {
892 tracing::error!("Failed to persist auth storage: {}", e);
893 self.record_error(e);
894 return Err("persist failed".to_string());
895 }
896 }
897 }
898 Ok(())
899 }
900
901 fn record_error(&self, error: AuthError) {
907 self.errors.write().push(error);
908 }
909
910 pub fn drain_errors(&self) -> Vec<AuthError> {
912 let mut errors = self.errors.write();
913 std::mem::take(&mut *errors)
914 }
915
916 pub fn load_error(&self) -> Option<AuthError> {
918 self.load_error.read().clone()
919 }
920
921 pub fn validate_all(&self) -> Vec<(String, CredentialValidationError)> {
927 let creds = self.credentials.read();
928 let mut results = Vec::new();
929 for (provider, cred) in creds.iter() {
930 if let Err(e) = cred.validate() {
931 results.push((provider.clone(), e));
932 }
933 }
934 results
935 }
936
937 pub fn validate(&self, provider: &str) -> Result<(), CredentialValidationError> {
939 let creds = self.credentials.read();
940 let cred = creds.get(provider).ok_or_else(|| {
941 CredentialValidationError::EmptyField(format!(
942 "no credential for provider '{}'",
943 provider
944 ))
945 })?;
946 cred.validate()
947 }
948
949 pub fn configured_providers(&self) -> Vec<String> {
955 let mut providers: Vec<String> = self.credentials.read().keys().cloned().collect();
956 providers.sort();
957 providers
958 }
959
960 pub fn has_multiple_providers(&self) -> bool {
962 self.credentials.read().len() > 1
963 }
964
965 pub fn primary_provider(&self) -> Option<String> {
967 let creds = self.credentials.read();
968 creds.keys().next().cloned()
969 }
970
971 pub fn migrate_provider(&self, from: &str, to: &str) -> AuthResult<()> {
973 let mut creds = self.credentials.write();
974 let cred = creds
975 .remove(from)
976 .ok_or_else(|| AuthError::NotFound(from.to_string()))?;
977 creds.insert(to.to_string(), cred);
978 drop(creds);
979 let _ = self.persist();
980 Ok(())
981 }
982}
983
984impl Default for AuthStorage {
985 fn default() -> Self {
986 Self::new()
987 }
988}
989
990fn now_secs() -> u64 {
995 std::time::SystemTime::now()
996 .duration_since(std::time::UNIX_EPOCH)
997 .map(|d| d.as_secs())
998 .unwrap_or(0)
999}
1000
1001#[allow(unexpected_cfgs)]
1007pub mod keyring_support {
1008 use super::*;
1009
1010 #[cfg(feature = "keyring")]
1012 pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
1013 use keyring::Entry;
1014 Entry::new(service, account)
1015 .ok()
1016 .and_then(|entry| entry.get_password().ok())
1017 }
1018
1019 #[cfg(feature = "keyring")]
1021 pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> AuthResult<()> {
1022 use keyring::Entry;
1023 Entry::new(service, account)
1024 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1025 .set_password(secret)
1026 .map_err(|e| AuthError::KeyringError(e.to_string()))
1027 }
1028
1029 #[cfg(feature = "keyring")]
1031 pub fn delete_keyring_secret(service: &str, account: &str) -> AuthResult<()> {
1032 use keyring::Entry;
1033 Entry::new(service, account)
1034 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1035 .delete_credential()
1036 .map_err(|e| AuthError::KeyringError(e.to_string()))
1037 }
1038
1039 #[cfg(not(feature = "keyring"))]
1041 pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
1045 None
1046 }
1047
1048 #[cfg(not(feature = "keyring"))]
1049 pub fn set_keyring_secret(_service: &str, _account: &str, _secret: &str) -> AuthResult<()> {
1053 Err(AuthError::KeyringError(
1054 "Keyring support not compiled".to_string(),
1055 ))
1056 }
1057
1058 #[cfg(not(feature = "keyring"))]
1059 pub fn delete_keyring_secret(_service: &str, _account: &str) -> AuthResult<()> {
1063 Err(AuthError::KeyringError(
1064 "Keyring support not compiled".to_string(),
1065 ))
1066 }
1067}
1068
1069pub fn shared_auth_storage() -> Arc<AuthStorage> {
1079 static STORAGE: OnceLock<Arc<AuthStorage>> = OnceLock::new();
1080 STORAGE
1081 .get_or_init(|| {
1082 let storage = Arc::new(AuthStorage::new());
1083 storage.set_fallback_resolver(Arc::new(EnvVarFallbackResolver));
1086 storage
1087 })
1088 .clone()
1089}
1090
1091#[cfg(test)]
1096mod tests {
1097 use super::*;
1098
1099 #[test]
1100 fn test_auth_storage_new() {
1101 let storage = AuthStorage::in_memory();
1102 assert!(!storage.has("anthropic"));
1103 }
1104
1105 #[test]
1106 fn test_set_and_get_api_key() {
1107 let storage = AuthStorage::in_memory();
1108 storage.set_api_key("anthropic", "sk-test123".to_string());
1109 assert!(storage.has("anthropic"));
1110 assert_eq!(
1111 storage.get_api_key("anthropic"),
1112 Some("sk-test123".to_string())
1113 );
1114 }
1115
1116 #[test]
1117 fn test_runtime_override() {
1118 let storage = AuthStorage::in_memory();
1119 storage.set_api_key("anthropic", "stored-key".to_string());
1120 storage.set_runtime_key("anthropic", "runtime-key".to_string());
1121
1122 assert_eq!(
1124 storage.get_api_key("anthropic"),
1125 Some("runtime-key".to_string())
1126 );
1127 }
1128
1129 #[test]
1130 fn test_remove_credential() {
1131 let storage = AuthStorage::in_memory();
1132 storage.set_api_key("anthropic", "sk-test123".to_string());
1133 assert!(storage.has("anthropic"));
1134
1135 storage.remove("anthropic");
1136 assert!(!storage.has("anthropic"));
1137 }
1138
1139 #[test]
1140 fn test_auth_status() {
1141 let storage = AuthStorage::in_memory();
1142 storage.set_api_key("anthropic", "sk-test123".to_string());
1143
1144 let status = storage.get_status("anthropic");
1145 assert!(status.configured);
1146 assert_eq!(status.source, Some("stored".to_string()));
1147 assert_eq!(status.label, Some("api_key".to_string()));
1148 }
1149
1150 #[test]
1151 fn test_auth_status_display() {
1152 let status = AuthStatus {
1153 configured: true,
1154 source: Some("stored".to_string()),
1155 label: Some("api_key".to_string()),
1156 };
1157 let display = format!("{}", status);
1158 assert_eq!(display, "stored (api_key)");
1159
1160 let no_config = AuthStatus {
1161 configured: false,
1162 source: None,
1163 label: None,
1164 };
1165 assert_eq!(format!("{}", no_config), "not configured");
1166 }
1167
1168 #[test]
1169 fn test_list_providers() {
1170 let storage = AuthStorage::in_memory();
1171 storage.set_api_key("anthropic", "key1".to_string());
1172 storage.set_api_key("openai", "key2".to_string());
1173
1174 let providers = storage.list_providers();
1175 assert!(providers.contains(&"anthropic".to_string()));
1176 assert!(providers.contains(&"openai".to_string()));
1177 }
1178
1179 #[test]
1180 fn test_oauth_credential() {
1181 let storage = AuthStorage::in_memory();
1182 storage.set_oauth(
1183 "provider",
1184 "access123".to_string(),
1185 Some("refresh456".to_string()),
1186 u64::MAX,
1187 );
1188
1189 assert!(storage.has("provider"));
1190 assert_eq!(
1191 storage.get_api_key("provider"),
1192 Some("access123".to_string())
1193 );
1194 }
1195
1196 #[test]
1197 fn test_expired_oauth_token() {
1198 let storage = AuthStorage::in_memory();
1199 storage.set_oauth("provider", "access123".to_string(), None, 0);
1201
1202 let key = storage.get_api_key("provider");
1204 assert!(key.is_none());
1205 }
1206
1207 #[test]
1208 fn test_get_all_credentials() {
1209 let storage = AuthStorage::in_memory();
1210 storage.set_api_key("anthropic", "key1".to_string());
1211 storage.set_api_key("openai", "key2".to_string());
1212
1213 let all = storage.get_all();
1214 assert_eq!(all.len(), 2);
1215 }
1216
1217 #[test]
1218 fn test_clear() {
1219 let storage = AuthStorage::in_memory();
1220 storage.set_api_key("anthropic", "key".to_string());
1221 assert!(storage.has("anthropic"));
1222
1223 storage.clear();
1224 assert!(!storage.has("anthropic"));
1225 }
1226
1227 #[test]
1228 fn test_remove_runtime_key() {
1229 let storage = AuthStorage::in_memory();
1230 storage.set_api_key("anthropic", "stored".to_string());
1231 storage.set_runtime_key("anthropic", "runtime".to_string());
1232
1233 assert_eq!(
1234 storage.get_api_key("anthropic"),
1235 Some("runtime".to_string())
1236 );
1237
1238 storage.remove_runtime_key("anthropic");
1239 assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
1240 }
1241
1242 #[test]
1243 fn test_auth_credential_is_expired() {
1244 let api_key_cred = AuthCredential::ApiKey {
1246 key: "test".to_string(),
1247 };
1248 assert!(!api_key_cred.is_expired());
1249
1250 let future_time = now_secs() + 3600;
1252 let oauth_cred = AuthCredential::OAuth {
1253 access_token: "token".to_string(),
1254 refresh_token: Some("refresh".to_string()),
1255 expires_at: future_time,
1256 scopes: None,
1257 provider_data: None,
1258 };
1259 assert!(!oauth_cred.is_expired());
1260
1261 let oauth_cred_expired = AuthCredential::OAuth {
1263 access_token: "token".to_string(),
1264 refresh_token: Some("refresh".to_string()),
1265 expires_at: 0,
1266 scopes: None,
1267 provider_data: None,
1268 };
1269 assert!(oauth_cred_expired.is_expired());
1270 }
1271
1272 #[test]
1273 fn test_auth_credential_needs_refresh() {
1274 let future_time = now_secs() + 120; let oauth_cred = AuthCredential::OAuth {
1278 access_token: "token".to_string(),
1279 refresh_token: Some("refresh".to_string()),
1280 expires_at: future_time,
1281 scopes: None,
1282 provider_data: None,
1283 };
1284 assert!(!oauth_cred.needs_refresh());
1285
1286 let soon = now_secs() + 30;
1288 let oauth_soon = AuthCredential::OAuth {
1289 access_token: "token".to_string(),
1290 refresh_token: Some("refresh".to_string()),
1291 expires_at: soon,
1292 scopes: None,
1293 provider_data: None,
1294 };
1295 assert!(oauth_soon.needs_refresh());
1296
1297 let no_refresh = AuthCredential::OAuth {
1299 access_token: "token".to_string(),
1300 refresh_token: None,
1301 expires_at: future_time,
1302 scopes: None,
1303 provider_data: None,
1304 };
1305 assert!(!no_refresh.needs_refresh());
1306
1307 let api_key_cred = AuthCredential::ApiKey {
1309 key: "test".to_string(),
1310 };
1311 assert!(!api_key_cred.needs_refresh());
1312 }
1313
1314 #[test]
1315 fn test_auth_credential_access_token() {
1316 let future_time = now_secs() + 3600;
1317
1318 let oauth_cred = AuthCredential::OAuth {
1319 access_token: "valid_token".to_string(),
1320 refresh_token: Some("refresh".to_string()),
1321 expires_at: future_time,
1322 scopes: None,
1323 provider_data: None,
1324 };
1325 assert_eq!(oauth_cred.access_token(), Some("valid_token"));
1326
1327 let expired_cred = AuthCredential::OAuth {
1329 access_token: "expired_token".to_string(),
1330 refresh_token: Some("refresh".to_string()),
1331 expires_at: 0,
1332 scopes: None,
1333 provider_data: None,
1334 };
1335 assert!(expired_cred.access_token().is_none());
1336
1337 let api_key_cred = AuthCredential::ApiKey {
1339 key: "api_key_token".to_string(),
1340 };
1341 assert!(api_key_cred.access_token().is_none());
1342 }
1343
1344 #[test]
1345 fn test_get_oauth_credential() {
1346 let storage = AuthStorage::in_memory();
1347 storage.set_oauth(
1348 "provider",
1349 "access".to_string(),
1350 Some("refresh".to_string()),
1351 u64::MAX,
1352 );
1353
1354 let cred = storage.get_oauth_credential("provider");
1355 assert!(cred.is_some());
1356 assert!(matches!(cred.unwrap(), AuthCredential::OAuth { .. }));
1357 }
1358
1359 #[test]
1360 fn test_has_oauth_with_refresh() {
1361 let storage = AuthStorage::in_memory();
1362
1363 storage.set_oauth(
1365 "with_refresh",
1366 "access".to_string(),
1367 Some("refresh".to_string()),
1368 u64::MAX,
1369 );
1370 assert!(storage.has_oauth_with_refresh("with_refresh"));
1371
1372 storage.set_oauth("without_refresh", "access".to_string(), None, u64::MAX);
1374 assert!(!storage.has_oauth_with_refresh("without_refresh"));
1375
1376 storage.set_api_key("apikey_provider", "key".to_string());
1378 assert!(!storage.has_oauth_with_refresh("apikey_provider"));
1379 }
1380
1381 #[test]
1382 fn test_set_oauth_full() {
1383 let storage = AuthStorage::in_memory();
1384 storage.set_oauth_full(
1385 "provider",
1386 "access_token".to_string(),
1387 Some("refresh_token".to_string()),
1388 3600,
1389 Some("read write".to_string()),
1390 Some(serde_json::json!({"extra": "data"})),
1391 );
1392
1393 let cred = storage.get_oauth_credential("provider");
1394 assert!(cred.is_some());
1395 if let AuthCredential::OAuth {
1396 scopes,
1397 provider_data,
1398 ..
1399 } = cred.unwrap()
1400 {
1401 assert_eq!(scopes, Some("read write".to_string()));
1402 assert!(provider_data.is_some());
1403 } else {
1404 panic!("Expected OAuth credential");
1405 }
1406 }
1407
1408 #[test]
1409 fn test_session_token() {
1410 let storage = AuthStorage::in_memory();
1411 storage.set_session(
1412 "browser",
1413 "session-token-123".to_string(),
1414 0, Some(serde_json::json!({"user": "test"})),
1416 );
1417
1418 assert!(storage.has("browser"));
1419 assert_eq!(
1420 storage.get_api_key("browser"),
1421 Some("session-token-123".to_string())
1422 );
1423
1424 let cred = storage.get("browser").unwrap();
1425 assert!(matches!(cred, AuthCredential::Session { .. }));
1426 assert!(cred.access_token().is_some());
1427 }
1428
1429 #[test]
1430 fn test_session_token_expired() {
1431 let storage = AuthStorage::in_memory();
1432 storage.set_session("browser", "session-token".to_string(), 1, None);
1433
1434 assert!(storage.get_api_key("browser").is_none());
1436 }
1437
1438 #[test]
1439 fn test_credential_validation() {
1440 let valid = AuthCredential::ApiKey {
1442 key: "sk-valid".to_string(),
1443 };
1444 assert!(valid.validate().is_ok());
1445
1446 let empty = AuthCredential::ApiKey {
1448 key: "".to_string(),
1449 };
1450 assert!(empty.validate().is_err());
1451
1452 let placeholder = AuthCredential::ApiKey {
1454 key: "your-api-key-here".to_string(),
1455 };
1456 assert!(placeholder.validate().is_err());
1457
1458 let valid_oauth = AuthCredential::OAuth {
1460 access_token: "token".to_string(),
1461 refresh_token: None,
1462 expires_at: now_secs() + 3600,
1463 scopes: None,
1464 provider_data: None,
1465 };
1466 assert!(valid_oauth.validate().is_ok());
1467
1468 let invalid_oauth = AuthCredential::OAuth {
1470 access_token: "".to_string(),
1471 refresh_token: None,
1472 expires_at: 1000,
1473 scopes: None,
1474 provider_data: None,
1475 };
1476 assert!(invalid_oauth.validate().is_err());
1477 }
1478
1479 #[test]
1480 fn test_validate_all() {
1481 let storage = AuthStorage::in_memory();
1482 storage.set_api_key("valid", "sk-good".to_string());
1483 storage.set_api_key("empty", "".to_string());
1484
1485 let errors = storage.validate_all();
1486 assert_eq!(errors.len(), 1);
1487 assert_eq!(errors[0].0, "empty");
1488 }
1489
1490 #[test]
1491 fn test_update_oauth_tokens() {
1492 let storage = AuthStorage::in_memory();
1493 storage.set_oauth(
1494 "provider",
1495 "old-access".to_string(),
1496 Some("old-refresh".to_string()),
1497 now_secs() + 3600,
1498 );
1499
1500 storage
1501 .update_oauth_tokens(
1502 "provider",
1503 "new-access".to_string(),
1504 Some("new-refresh".to_string()),
1505 now_secs() + 7200,
1506 )
1507 .unwrap();
1508
1509 let key = storage.get_api_key("provider");
1510 assert_eq!(key, Some("new-access".to_string()));
1511 }
1512
1513 #[test]
1514 fn test_update_oauth_tokens_wrong_type() {
1515 let storage = AuthStorage::in_memory();
1516 storage.set_api_key("provider", "key".to_string());
1517
1518 let result = storage.update_oauth_tokens(
1519 "provider",
1520 "new-access".to_string(),
1521 None,
1522 now_secs() + 3600,
1523 );
1524 assert!(result.is_err());
1525 }
1526
1527 #[test]
1528 fn test_migrate_provider() {
1529 let storage = AuthStorage::in_memory();
1530 storage.set_api_key("old-provider", "key123".to_string());
1531 storage
1532 .migrate_provider("old-provider", "new-provider")
1533 .unwrap();
1534
1535 assert!(!storage.has("old-provider"));
1536 assert!(storage.has("new-provider"));
1537 assert_eq!(
1538 storage.get_api_key("new-provider"),
1539 Some("key123".to_string())
1540 );
1541 }
1542
1543 #[test]
1544 fn test_migrate_provider_not_found() {
1545 let storage = AuthStorage::in_memory();
1546 let result = storage.migrate_provider("nonexistent", "target");
1547 assert!(result.is_err());
1548 }
1549
1550 #[test]
1551 fn test_error_draining() {
1552 let storage = AuthStorage::in_memory();
1553 let errors = storage.drain_errors();
1554 assert!(errors.is_empty());
1555 }
1556
1557 #[test]
1558 fn test_fallback_resolver() {
1559 let storage = AuthStorage::in_memory();
1560 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|provider| {
1561 if provider == "custom" {
1562 Some("custom-key-from-config".to_string())
1563 } else {
1564 None
1565 }
1566 }))));
1567
1568 assert_eq!(
1569 storage.get_api_key("custom"),
1570 Some("custom-key-from-config".to_string())
1571 );
1572 assert!(storage.get_api_key("unknown").is_none());
1573
1574 storage.clear_fallback_resolver();
1576 assert!(storage.get_api_key("custom").is_none());
1577 }
1578
1579 #[test]
1580 fn test_get_api_key_with_options() {
1581 let storage = AuthStorage::in_memory();
1582 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|_| {
1583 Some("fallback-key".to_string())
1584 }))));
1585
1586 assert_eq!(
1588 storage.get_api_key_with_options("test", true),
1589 Some("fallback-key".to_string())
1590 );
1591
1592 assert!(storage.get_api_key_with_options("test", false).is_none());
1594 }
1595
1596 #[test]
1597 fn test_configured_providers() {
1598 let storage = AuthStorage::in_memory();
1599 storage.set_api_key("openai", "key".to_string());
1600 storage.set_api_key("anthropic", "key".to_string());
1601
1602 let providers = storage.configured_providers();
1603 assert!(providers.len() >= 2);
1604 let mut sorted = providers.clone();
1606 sorted.sort();
1607 assert_eq!(providers, sorted);
1608 }
1609
1610 #[test]
1611 fn test_has_multiple_providers() {
1612 let storage = AuthStorage::in_memory();
1613 assert!(!storage.has_multiple_providers());
1614
1615 storage.set_api_key("openai", "key1".to_string());
1616 assert!(!storage.has_multiple_providers());
1617
1618 storage.set_api_key("anthropic", "key2".to_string());
1619 assert!(storage.has_multiple_providers());
1620 }
1621
1622 #[test]
1623 fn test_set_and_get_credential() {
1624 let storage = AuthStorage::in_memory();
1625 let cred = AuthCredential::Session {
1626 token: "abc".to_string(),
1627 expires_at: 0,
1628 metadata: None,
1629 };
1630 storage.set("custom", cred);
1631 let retrieved = storage.get("custom");
1632 assert!(retrieved.is_some());
1633 assert!(matches!(retrieved.unwrap(), AuthCredential::Session { .. }));
1634 }
1635
1636 #[test]
1637 fn test_credential_type_name() {
1638 assert_eq!(
1639 AuthCredential::ApiKey {
1640 key: "k".to_string()
1641 }
1642 .type_name(),
1643 "api_key"
1644 );
1645 assert_eq!(
1646 AuthCredential::OAuth {
1647 access_token: "t".to_string(),
1648 refresh_token: None,
1649 expires_at: 0,
1650 scopes: None,
1651 provider_data: None,
1652 }
1653 .type_name(),
1654 "oauth"
1655 );
1656 assert_eq!(
1657 AuthCredential::Session {
1658 token: "t".to_string(),
1659 expires_at: 0,
1660 metadata: None,
1661 }
1662 .type_name(),
1663 "session"
1664 );
1665 }
1666}