1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::{Arc, OnceLock};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum AuthCredential {
21 ApiKey {
23 key: String,
25 },
26 OAuth {
28 access_token: String,
30 refresh_token: Option<String>,
32 expires_at: u64,
34 #[serde(default)]
36 scopes: Option<String>,
37 #[serde(default)]
39 provider_data: Option<serde_json::Value>,
40 },
41 Session {
43 token: String,
45 #[serde(default)]
47 expires_at: u64,
48 #[serde(default)]
50 metadata: Option<serde_json::Value>,
51 },
52}
53
54impl AuthCredential {
55 pub fn is_expired(&self) -> bool {
57 match self {
58 AuthCredential::OAuth { expires_at, .. } => {
59 let now = now_secs();
60 *expires_at < now
61 }
62 AuthCredential::Session { expires_at, .. } => {
63 if *expires_at == 0 {
64 return false; }
66 *expires_at <= now_secs()
67 }
68 AuthCredential::ApiKey { .. } => false,
69 }
70 }
71
72 pub fn needs_refresh(&self) -> bool {
74 match self {
75 AuthCredential::OAuth {
76 expires_at,
77 refresh_token,
78 ..
79 } => {
80 let now = now_secs();
81 refresh_token.is_some() && *expires_at <= now + 60
82 }
83 AuthCredential::Session { .. } => false,
84 AuthCredential::ApiKey { .. } => false,
85 }
86 }
87
88 pub fn access_token(&self) -> Option<&str> {
90 match self {
91 AuthCredential::OAuth { access_token, .. } if !self.is_expired() => Some(access_token),
92 AuthCredential::Session { token, .. } if !self.is_expired() => Some(token),
93 _ => None,
94 }
95 }
96
97 pub fn type_name(&self) -> &'static str {
99 match self {
100 AuthCredential::ApiKey { .. } => "api_key",
101 AuthCredential::OAuth { .. } => "oauth",
102 AuthCredential::Session { .. } => "session",
103 }
104 }
105
106 pub fn validate(&self) -> Result<(), CredentialValidationError> {
108 match self {
109 AuthCredential::ApiKey { key } => {
110 if key.is_empty() {
111 return Err(CredentialValidationError::EmptyField("key".to_string()));
112 }
113 if key == "your-api-key-here" || key == "xxx" {
115 return Err(CredentialValidationError::PlaceholderValue(key.clone()));
116 }
117 Ok(())
118 }
119 AuthCredential::OAuth {
120 access_token,
121 expires_at,
122 ..
123 } => {
124 if access_token.is_empty() {
125 return Err(CredentialValidationError::EmptyField(
126 "access_token".to_string(),
127 ));
128 }
129 if *expires_at == 0 {
130 return Err(CredentialValidationError::InvalidExpiry);
131 }
132 Ok(())
133 }
134 AuthCredential::Session { token, .. } => {
135 if token.is_empty() {
136 return Err(CredentialValidationError::EmptyField("token".to_string()));
137 }
138 Ok(())
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, thiserror::Error)]
146pub enum CredentialValidationError {
147 #[error("Field '{0}' must not be empty")]
148 EmptyField(String),
150 #[error("Placeholder value detected: '{0}'")]
151 PlaceholderValue(String),
153 #[error("Invalid expiry timestamp")]
154 InvalidExpiry,
156}
157
158#[derive(Debug, Clone)]
164pub struct AuthStatus {
165 pub configured: bool,
167 pub source: Option<String>,
169 pub label: Option<String>,
171}
172
173impl std::fmt::Display for AuthStatus {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match (&self.source, &self.label) {
176 (Some(source), Some(label)) => write!(f, "{} ({})", source, label),
177 (Some(source), None) => write!(f, "{}", source),
178 (None, Some(label)) => write!(f, "{}", label),
179 (None, None) => write!(f, "not configured"),
180 }
181 }
182}
183
184pub type AuthResult<T> = Result<T, AuthError>;
190
191#[derive(Debug, Clone, thiserror::Error)]
193pub enum AuthError {
194 #[error("Failed to read auth storage: {0}")]
195 ReadError(String),
197 #[error("Failed to write auth storage: {0}")]
198 WriteError(String),
200 #[error("Credential not found: {0}")]
201 NotFound(String),
203 #[error("Invalid credential format: {0}")]
204 InvalidFormat(String),
206 #[error("Keyring error: {0}")]
207 KeyringError(String),
209 #[error("Credential validation failed: {0}")]
210 ValidationFailed(String),
212}
213
214pub trait AuthStorageBackend: Send + Sync {
220 fn read(&self) -> AuthResult<Option<String>>;
222 fn write(&self, data: &str) -> AuthResult<()>;
224 fn delete(&self) -> AuthResult<()>;
226}
227
228pub struct FileAuthStorage {
234 path: PathBuf,
235 cache: RwLock<Option<String>>,
236}
237
238impl FileAuthStorage {
239 pub fn new(path: PathBuf) -> Self {
241 Self {
242 path,
243 cache: RwLock::new(None),
244 }
245 }
246
247 pub fn default_path() -> Option<PathBuf> {
249 dirs::home_dir().map(|p| p.join(".oxi").join("auth.json"))
250 }
251
252 pub fn path(&self) -> &PathBuf {
254 &self.path
255 }
256}
257
258impl AuthStorageBackend for FileAuthStorage {
259 fn read(&self) -> AuthResult<Option<String>> {
260 if !self.path.exists() {
261 return Ok(None);
262 }
263
264 match std::fs::read_to_string(&self.path) {
265 Ok(content) => {
266 *self.cache.write() = Some(content.clone());
267 Ok(Some(content))
268 }
269 Err(e) => Err(AuthError::ReadError(e.to_string())),
270 }
271 }
272
273 fn write(&self, data: &str) -> AuthResult<()> {
274 if let Some(parent) = self.path.parent() {
276 std::fs::create_dir_all(parent).map_err(|e| AuthError::WriteError(e.to_string()))?;
277
278 #[cfg(unix)]
279 {
280 use std::os::unix::fs::PermissionsExt;
281 let perms = std::fs::Permissions::from_mode(0o700);
282 let _ = std::fs::set_permissions(parent, perms);
283 }
284 }
285
286 std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
288
289 #[cfg(unix)]
291 {
292 use std::os::unix::fs::PermissionsExt;
293 let perms = std::fs::Permissions::from_mode(0o600);
294 std::fs::set_permissions(&self.path, perms)
295 .map_err(|e| AuthError::WriteError(e.to_string()))?;
296 }
297
298 *self.cache.write() = Some(data.to_string());
299 Ok(())
300 }
301
302 fn delete(&self) -> AuthResult<()> {
303 if self.path.exists() {
304 std::fs::remove_file(&self.path).map_err(|e| AuthError::WriteError(e.to_string()))?;
305 }
306 *self.cache.write() = None;
307 Ok(())
308 }
309}
310
311pub struct MemoryAuthStorage {
317 data: RwLock<HashMap<String, AuthCredential>>,
318}
319
320impl MemoryAuthStorage {
321 pub fn new() -> Self {
323 Self {
324 data: RwLock::new(HashMap::new()),
325 }
326 }
327}
328
329impl Default for MemoryAuthStorage {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl AuthStorageBackend for MemoryAuthStorage {
336 fn read(&self) -> AuthResult<Option<String>> {
337 Ok(None)
339 }
340
341 fn write(&self, _data: &str) -> AuthResult<()> {
342 Ok(())
343 }
344
345 fn delete(&self) -> AuthResult<()> {
346 self.data.write().clear();
347 Ok(())
348 }
349}
350
351pub trait FallbackResolver: Send + Sync {
357 fn resolve(&self, provider: &str) -> Option<String>;
359}
360
361pub struct FnFallbackResolver {
363 #[allow(clippy::type_complexity)]
364 f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>,
365}
366
367impl FnFallbackResolver {
368 #[allow(clippy::type_complexity)]
370 pub fn new(f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>) -> Self {
371 Self { f }
372 }
373}
374
375impl FallbackResolver for FnFallbackResolver {
376 fn resolve(&self, provider: &str) -> Option<String> {
377 (self.f)(provider)
378 }
379}
380
381pub struct EnvVarFallbackResolver;
386
387impl FallbackResolver for EnvVarFallbackResolver {
388 fn resolve(&self, provider: &str) -> Option<String> {
389 let builtin = oxi_ai::get_builtin_provider(provider)?;
391 let key = builtin.env_key;
392
393 if let Ok(val) = std::env::var(key) {
395 if !val.is_empty() {
396 return Some(val);
397 }
398 }
399
400 for extra in builtin.extra_env_keys {
402 if let Ok(val) = std::env::var(extra) {
403 if !val.is_empty() {
404 return Some(val);
405 }
406 }
407 }
408
409 None
410 }
411}
412
413pub struct AuthStorage {
427 file_storage: Option<Arc<dyn AuthStorageBackend>>,
429 credentials: RwLock<HashMap<String, AuthCredential>>,
431 runtime_overrides: RwLock<HashMap<String, String>>,
433 fallback_resolver: RwLock<Option<Arc<dyn FallbackResolver>>>,
435 errors: RwLock<Vec<AuthError>>,
437 load_error: RwLock<Option<AuthError>>,
439 plaintext_warned: OnceLock<()>,
441}
442
443impl AuthStorage {
444 pub fn new() -> Self {
446 let file_storage = FileAuthStorage::default_path()
447 .map(|p| Arc::new(FileAuthStorage::new(p)) as Arc<dyn AuthStorageBackend>);
448
449 let credentials = if let Some(ref storage) = file_storage {
450 match storage.read() {
451 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
452 _ => HashMap::new(),
453 }
454 } else {
455 HashMap::new()
456 };
457
458 Self {
459 file_storage,
460 credentials: RwLock::new(credentials),
461 runtime_overrides: RwLock::new(HashMap::new()),
462 fallback_resolver: RwLock::new(None),
463 errors: RwLock::new(Vec::new()),
464 load_error: RwLock::new(None),
465 plaintext_warned: OnceLock::new(),
466 }
467 }
468
469 pub fn with_backend(backend: impl AuthStorageBackend + 'static) -> Self {
471 let credentials = match backend.read() {
472 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
473 _ => HashMap::new(),
474 };
475
476 Self {
477 file_storage: Some(Arc::new(backend)),
478 credentials: RwLock::new(credentials),
479 runtime_overrides: RwLock::new(HashMap::new()),
480 fallback_resolver: RwLock::new(None),
481 errors: RwLock::new(Vec::new()),
482 load_error: RwLock::new(None),
483 plaintext_warned: OnceLock::new(),
484 }
485 }
486
487 pub fn in_memory() -> Self {
489 Self {
490 file_storage: None,
491 credentials: RwLock::new(HashMap::new()),
492 runtime_overrides: RwLock::new(HashMap::new()),
493 fallback_resolver: RwLock::new(None),
494 errors: RwLock::new(Vec::new()),
495 load_error: RwLock::new(None),
496 plaintext_warned: OnceLock::new(),
497 }
498 }
499
500 pub fn default_path() -> Option<PathBuf> {
502 FileAuthStorage::default_path()
503 }
504
505 pub fn set_runtime_key(&self, provider: &str, api_key: String) {
511 self.runtime_overrides
512 .write()
513 .insert(provider.to_string(), api_key);
514 }
515
516 pub fn remove_runtime_key(&self, provider: &str) {
518 self.runtime_overrides.write().remove(provider);
519 }
520
521 pub fn set_fallback_resolver(&self, resolver: Arc<dyn FallbackResolver>) {
528 *self.fallback_resolver.write() = Some(resolver);
529 }
530
531 pub fn clear_fallback_resolver(&self) {
533 *self.fallback_resolver.write() = None;
534 }
535
536 pub fn has_auth(&self, provider: &str) -> bool {
542 if self.runtime_overrides.read().contains_key(provider) {
543 return true;
544 }
545 if self.credentials.read().contains_key(provider) {
546 return true;
547 }
548 if let Some(ref resolver) = *self.fallback_resolver.read() {
549 if resolver.resolve(provider).is_some() {
550 return true;
551 }
552 }
553 false
554 }
555
556 pub fn get_status(&self, provider: &str) -> AuthStatus {
558 if self.runtime_overrides.read().contains_key(provider) {
559 return AuthStatus {
560 configured: false,
561 source: Some("runtime".to_string()),
562 label: Some("--api-key".to_string()),
563 };
564 }
565
566 if let Some(cred) = self.credentials.read().get(provider) {
567 return AuthStatus {
568 configured: true,
569 source: Some("stored".to_string()),
570 label: Some(cred.type_name().to_string()),
571 };
572 }
573
574 if let Some(ref resolver) = *self.fallback_resolver.read() {
575 if resolver.resolve(provider).is_some() {
576 return AuthStatus {
577 configured: false,
578 source: Some("fallback".to_string()),
579 label: Some("custom provider config".to_string()),
580 };
581 }
582 }
583
584 AuthStatus {
585 configured: false,
586 source: None,
587 label: None,
588 }
589 }
590
591 pub fn get_api_key(&self, provider: &str) -> Option<String> {
600 self.get_api_key_with_options(provider, true)
601 }
602
603 pub fn get_api_key_with_options(
605 &self,
606 provider: &str,
607 include_fallback: bool,
608 ) -> Option<String> {
609 if let Some(key) = self.runtime_overrides.read().get(provider) {
611 return Some(key.clone());
612 }
613
614 if let Some(cred) = self.credentials.read().get(provider) {
616 return match cred {
617 AuthCredential::ApiKey { key } => Some(key.clone()),
618 AuthCredential::OAuth {
619 access_token,
620 expires_at,
621 ..
622 } => {
623 if *expires_at > now_secs() {
624 Some(access_token.clone())
625 } else {
626 None
628 }
629 }
630 AuthCredential::Session {
631 token, expires_at, ..
632 } => {
633 if *expires_at == 0 || *expires_at > now_secs() {
634 Some(token.clone())
635 } else {
636 None
637 }
638 }
639 };
640 }
641
642 if let Some(builtin) = oxi_ai::register_builtins::get_builtin_provider(provider) {
647 let env_key = builtin.env_key;
648 let credentials = self.credentials.read();
649 for other in oxi_ai::register_builtins::get_builtin_providers() {
650 if other.name == provider {
651 continue; }
653 if other.env_key == env_key {
654 if let Some(cred) = credentials.get(other.name) {
655 return match cred {
656 AuthCredential::ApiKey { key } => Some(key.clone()),
657 AuthCredential::OAuth {
658 access_token,
659 expires_at,
660 ..
661 } => {
662 if *expires_at > now_secs() {
663 Some(access_token.clone())
664 } else {
665 None
666 }
667 }
668 AuthCredential::Session {
669 token, expires_at, ..
670 } => {
671 if *expires_at == 0 || *expires_at > now_secs() {
672 Some(token.clone())
673 } else {
674 None
675 }
676 }
677 };
678 }
679 }
680 }
681 }
682
683 if include_fallback {
685 if let Some(ref resolver) = *self.fallback_resolver.read() {
686 return resolver.resolve(provider);
687 }
688 }
689
690 None
691 }
692
693 pub fn set_api_key(&self, provider: &str, key: String) {
699 self.credentials
700 .write()
701 .insert(provider.to_string(), AuthCredential::ApiKey { key });
702 if let Err(e) = self.persist() {
703 tracing::warn!("Failed to persist API key for '{}': {}", provider, e);
704 }
705 }
706
707 pub fn set_oauth(
709 &self,
710 provider: &str,
711 access_token: String,
712 refresh_token: Option<String>,
713 expires_at: u64,
714 ) {
715 self.set_oauth_full(
716 provider,
717 access_token,
718 refresh_token,
719 expires_at,
720 None,
721 None,
722 );
723 }
724
725 pub fn set_oauth_full(
727 &self,
728 provider: &str,
729 access_token: String,
730 refresh_token: Option<String>,
731 expires_at: u64,
732 scopes: Option<String>,
733 provider_data: Option<serde_json::Value>,
734 ) {
735 self.credentials.write().insert(
736 provider.to_string(),
737 AuthCredential::OAuth {
738 access_token,
739 refresh_token,
740 expires_at,
741 scopes,
742 provider_data,
743 },
744 );
745 if let Err(e) = self.persist() {
746 tracing::warn!("Failed to persist OAuth token for '{}': {}", provider, e);
747 }
748 }
749
750 pub fn set_session(
752 &self,
753 provider: &str,
754 token: String,
755 expires_at: u64,
756 metadata: Option<serde_json::Value>,
757 ) {
758 self.credentials.write().insert(
759 provider.to_string(),
760 AuthCredential::Session {
761 token,
762 expires_at,
763 metadata,
764 },
765 );
766 if let Err(e) = self.persist() {
767 tracing::warn!("Failed to persist session for '{}': {}", provider, e);
768 }
769 }
770
771 pub fn update_oauth_tokens(
773 &self,
774 provider: &str,
775 new_access_token: String,
776 new_refresh_token: Option<String>,
777 new_expires_at: u64,
778 ) -> AuthResult<()> {
779 let mut creds = self.credentials.write();
780 let cred = creds
781 .get_mut(provider)
782 .ok_or_else(|| AuthError::NotFound(provider.to_string()))?;
783
784 match cred {
785 AuthCredential::OAuth {
786 access_token,
787 refresh_token,
788 expires_at,
789 ..
790 } => {
791 *access_token = new_access_token;
792 *refresh_token = new_refresh_token;
793 *expires_at = new_expires_at;
794 }
795 _ => {
796 return Err(AuthError::InvalidFormat(format!(
797 "Provider '{}' does not have OAuth credentials",
798 provider
799 )));
800 }
801 }
802
803 drop(creds);
804 if let Err(e) = self.persist() {
805 tracing::warn!(
806 "Failed to persist OAuth token update for '{}': {}",
807 provider,
808 e
809 );
810 }
811 Ok(())
812 }
813
814 pub fn get(&self, provider: &str) -> Option<AuthCredential> {
820 self.credentials.read().get(provider).cloned()
821 }
822
823 pub fn get_oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
825 self.credentials.read().get(provider).cloned()
826 }
827
828 pub fn has_oauth_with_refresh(&self, provider: &str) -> bool {
830 if let Some(cred) = self.credentials.read().get(provider) {
831 matches!(
832 cred,
833 AuthCredential::OAuth {
834 refresh_token: Some(_),
835 ..
836 }
837 )
838 } else {
839 false
840 }
841 }
842
843 pub fn set(&self, provider: &str, credential: AuthCredential) {
849 self.credentials
850 .write()
851 .insert(provider.to_string(), credential);
852 if let Err(e) = self.persist() {
853 tracing::warn!("Failed to persist credential for '{}': {}", provider, e);
854 }
855 }
856
857 pub fn remove(&self, provider: &str) {
859 self.credentials.write().remove(provider);
860 if let Err(e) = self.persist() {
861 tracing::warn!("Failed to persist after removing '{}': {}", provider, e);
862 }
863 }
864
865 pub fn list_providers(&self) -> Vec<String> {
867 self.credentials.read().keys().cloned().collect()
868 }
869
870 pub fn has(&self, provider: &str) -> bool {
872 self.credentials.read().contains_key(provider)
873 }
874
875 pub fn get_all(&self) -> HashMap<String, AuthCredential> {
877 self.credentials.read().clone()
878 }
879
880 pub fn clear(&self) {
882 self.credentials.write().clear();
883 if let Err(e) = self.persist() {
884 tracing::warn!("Failed to persist after clearing credentials: {}", e);
885 }
886 }
887
888 pub fn reload(&self) {
894 if let Some(ref storage) = self.file_storage {
895 match storage.read() {
896 Ok(Some(content)) => {
897 if let Ok(creds) = serde_json::from_str(&content) {
898 *self.credentials.write() = creds;
899 }
900 *self.load_error.write() = None;
901 }
902 Ok(None) => {
903 self.credentials.write().clear();
904 *self.load_error.write() = None;
905 }
906 Err(e) => {
907 *self.load_error.write() = Some(e);
908 self.record_error(AuthError::ReadError(
909 "Failed to reload auth storage".to_string(),
910 ));
911 }
912 }
913 }
914 }
915
916 fn persist(&self) -> Result<(), String> {
918 if let Some(ref storage) = self.file_storage {
919 let creds = self.credentials.read();
920 if let Ok(json) = serde_json::to_string_pretty(&*creds) {
921 #[cfg(not(feature = "keyring"))]
923 {
924 self.plaintext_warned.get_or_init(|| {
925 tracing::warn!(
926 "Auth credentials are stored in plaintext. \
927 Enable the 'keyring' feature for secure OS-level storage."
928 );
929 });
930 }
931
932 if let Err(e) = storage.write(&json) {
933 tracing::error!("Failed to persist auth storage: {}", e);
934 self.record_error(e);
935 return Err("persist failed".to_string());
936 }
937 }
938 }
939 Ok(())
940 }
941
942 fn record_error(&self, error: AuthError) {
948 self.errors.write().push(error);
949 }
950
951 pub fn drain_errors(&self) -> Vec<AuthError> {
953 let mut errors = self.errors.write();
954 std::mem::take(&mut *errors)
955 }
956
957 pub fn load_error(&self) -> Option<AuthError> {
959 self.load_error.read().clone()
960 }
961
962 pub fn validate_all(&self) -> Vec<(String, CredentialValidationError)> {
968 let creds = self.credentials.read();
969 let mut results = Vec::new();
970 for (provider, cred) in creds.iter() {
971 if let Err(e) = cred.validate() {
972 results.push((provider.clone(), e));
973 }
974 }
975 results
976 }
977
978 pub fn validate(&self, provider: &str) -> Result<(), CredentialValidationError> {
980 let creds = self.credentials.read();
981 let cred = creds.get(provider).ok_or_else(|| {
982 CredentialValidationError::EmptyField(format!(
983 "no credential for provider '{}'",
984 provider
985 ))
986 })?;
987 cred.validate()
988 }
989
990 pub fn configured_providers(&self) -> Vec<String> {
996 let mut providers: Vec<String> = self.credentials.read().keys().cloned().collect();
997 providers.sort();
998 providers
999 }
1000
1001 pub fn has_multiple_providers(&self) -> bool {
1003 self.credentials.read().len() > 1
1004 }
1005
1006 pub fn primary_provider(&self) -> Option<String> {
1008 let creds = self.credentials.read();
1009 creds.keys().next().cloned()
1010 }
1011
1012 pub fn migrate_provider(&self, from: &str, to: &str) -> AuthResult<()> {
1014 let mut creds = self.credentials.write();
1015 let cred = creds
1016 .remove(from)
1017 .ok_or_else(|| AuthError::NotFound(from.to_string()))?;
1018 creds.insert(to.to_string(), cred);
1019 drop(creds);
1020 let _ = self.persist();
1021 Ok(())
1022 }
1023}
1024
1025impl Default for AuthStorage {
1026 fn default() -> Self {
1027 Self::new()
1028 }
1029}
1030
1031fn now_secs() -> u64 {
1036 std::time::SystemTime::now()
1037 .duration_since(std::time::UNIX_EPOCH)
1038 .map(|d| d.as_secs())
1039 .unwrap_or(0)
1040}
1041
1042#[allow(unexpected_cfgs)]
1048pub mod keyring_support {
1049 use super::*;
1050
1051 #[cfg(feature = "keyring")]
1053 pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
1054 use keyring::Entry;
1055 Entry::new(service, account)
1056 .ok()
1057 .and_then(|entry| entry.get_password().ok())
1058 }
1059
1060 #[cfg(feature = "keyring")]
1062 pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> AuthResult<()> {
1063 use keyring::Entry;
1064 Entry::new(service, account)
1065 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1066 .set_password(secret)
1067 .map_err(|e| AuthError::KeyringError(e.to_string()))
1068 }
1069
1070 #[cfg(feature = "keyring")]
1072 pub fn delete_keyring_secret(service: &str, account: &str) -> AuthResult<()> {
1073 use keyring::Entry;
1074 Entry::new(service, account)
1075 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1076 .delete_credential()
1077 .map_err(|e| AuthError::KeyringError(e.to_string()))
1078 }
1079
1080 #[cfg(not(feature = "keyring"))]
1082 pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
1086 None
1087 }
1088
1089 #[cfg(not(feature = "keyring"))]
1090 pub fn set_keyring_secret(_service: &str, _account: &str, _secret: &str) -> AuthResult<()> {
1094 Err(AuthError::KeyringError(
1095 "Keyring support not compiled".to_string(),
1096 ))
1097 }
1098
1099 #[cfg(not(feature = "keyring"))]
1100 pub fn delete_keyring_secret(_service: &str, _account: &str) -> AuthResult<()> {
1104 Err(AuthError::KeyringError(
1105 "Keyring support not compiled".to_string(),
1106 ))
1107 }
1108}
1109
1110pub fn shared_auth_storage() -> Arc<AuthStorage> {
1120 static STORAGE: OnceLock<Arc<AuthStorage>> = OnceLock::new();
1121 STORAGE
1122 .get_or_init(|| {
1123 let storage = Arc::new(AuthStorage::new());
1124 storage.set_fallback_resolver(Arc::new(EnvVarFallbackResolver));
1127 storage
1128 })
1129 .clone()
1130}
1131
1132#[cfg(test)]
1137mod tests {
1138 use super::*;
1139
1140 #[test]
1141 fn test_auth_storage_new() {
1142 let storage = AuthStorage::in_memory();
1143 assert!(!storage.has("anthropic"));
1144 }
1145
1146 #[test]
1147 fn test_set_and_get_api_key() {
1148 let storage = AuthStorage::in_memory();
1149 storage.set_api_key("anthropic", "sk-test123".to_string());
1150 assert!(storage.has("anthropic"));
1151 assert_eq!(
1152 storage.get_api_key("anthropic"),
1153 Some("sk-test123".to_string())
1154 );
1155 }
1156
1157 #[test]
1158 fn test_runtime_override() {
1159 let storage = AuthStorage::in_memory();
1160 storage.set_api_key("anthropic", "stored-key".to_string());
1161 storage.set_runtime_key("anthropic", "runtime-key".to_string());
1162
1163 assert_eq!(
1165 storage.get_api_key("anthropic"),
1166 Some("runtime-key".to_string())
1167 );
1168 }
1169
1170 #[test]
1171 fn test_remove_credential() {
1172 let storage = AuthStorage::in_memory();
1173 storage.set_api_key("anthropic", "sk-test123".to_string());
1174 assert!(storage.has("anthropic"));
1175
1176 storage.remove("anthropic");
1177 assert!(!storage.has("anthropic"));
1178 }
1179
1180 #[test]
1181 fn test_auth_status() {
1182 let storage = AuthStorage::in_memory();
1183 storage.set_api_key("anthropic", "sk-test123".to_string());
1184
1185 let status = storage.get_status("anthropic");
1186 assert!(status.configured);
1187 assert_eq!(status.source, Some("stored".to_string()));
1188 assert_eq!(status.label, Some("api_key".to_string()));
1189 }
1190
1191 #[test]
1192 fn test_auth_status_display() {
1193 let status = AuthStatus {
1194 configured: true,
1195 source: Some("stored".to_string()),
1196 label: Some("api_key".to_string()),
1197 };
1198 let display = format!("{}", status);
1199 assert_eq!(display, "stored (api_key)");
1200
1201 let no_config = AuthStatus {
1202 configured: false,
1203 source: None,
1204 label: None,
1205 };
1206 assert_eq!(format!("{}", no_config), "not configured");
1207 }
1208
1209 #[test]
1210 fn test_list_providers() {
1211 let storage = AuthStorage::in_memory();
1212 storage.set_api_key("anthropic", "key1".to_string());
1213 storage.set_api_key("openai", "key2".to_string());
1214
1215 let providers = storage.list_providers();
1216 assert!(providers.contains(&"anthropic".to_string()));
1217 assert!(providers.contains(&"openai".to_string()));
1218 }
1219
1220 #[test]
1221 fn test_oauth_credential() {
1222 let storage = AuthStorage::in_memory();
1223 storage.set_oauth(
1224 "provider",
1225 "access123".to_string(),
1226 Some("refresh456".to_string()),
1227 u64::MAX,
1228 );
1229
1230 assert!(storage.has("provider"));
1231 assert_eq!(
1232 storage.get_api_key("provider"),
1233 Some("access123".to_string())
1234 );
1235 }
1236
1237 #[test]
1238 fn test_expired_oauth_token() {
1239 let storage = AuthStorage::in_memory();
1240 storage.set_oauth("provider", "access123".to_string(), None, 0);
1242
1243 let key = storage.get_api_key("provider");
1245 assert!(key.is_none());
1246 }
1247
1248 #[test]
1249 fn test_get_all_credentials() {
1250 let storage = AuthStorage::in_memory();
1251 storage.set_api_key("anthropic", "key1".to_string());
1252 storage.set_api_key("openai", "key2".to_string());
1253
1254 let all = storage.get_all();
1255 assert_eq!(all.len(), 2);
1256 }
1257
1258 #[test]
1259 fn test_clear() {
1260 let storage = AuthStorage::in_memory();
1261 storage.set_api_key("anthropic", "key".to_string());
1262 assert!(storage.has("anthropic"));
1263
1264 storage.clear();
1265 assert!(!storage.has("anthropic"));
1266 }
1267
1268 #[test]
1269 fn test_remove_runtime_key() {
1270 let storage = AuthStorage::in_memory();
1271 storage.set_api_key("anthropic", "stored".to_string());
1272 storage.set_runtime_key("anthropic", "runtime".to_string());
1273
1274 assert_eq!(
1275 storage.get_api_key("anthropic"),
1276 Some("runtime".to_string())
1277 );
1278
1279 storage.remove_runtime_key("anthropic");
1280 assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
1281 }
1282
1283 #[test]
1284 fn test_auth_credential_is_expired() {
1285 let api_key_cred = AuthCredential::ApiKey {
1287 key: "test".to_string(),
1288 };
1289 assert!(!api_key_cred.is_expired());
1290
1291 let future_time = now_secs() + 3600;
1293 let oauth_cred = AuthCredential::OAuth {
1294 access_token: "token".to_string(),
1295 refresh_token: Some("refresh".to_string()),
1296 expires_at: future_time,
1297 scopes: None,
1298 provider_data: None,
1299 };
1300 assert!(!oauth_cred.is_expired());
1301
1302 let oauth_cred_expired = AuthCredential::OAuth {
1304 access_token: "token".to_string(),
1305 refresh_token: Some("refresh".to_string()),
1306 expires_at: 0,
1307 scopes: None,
1308 provider_data: None,
1309 };
1310 assert!(oauth_cred_expired.is_expired());
1311 }
1312
1313 #[test]
1314 fn test_auth_credential_needs_refresh() {
1315 let future_time = now_secs() + 120; let oauth_cred = AuthCredential::OAuth {
1319 access_token: "token".to_string(),
1320 refresh_token: Some("refresh".to_string()),
1321 expires_at: future_time,
1322 scopes: None,
1323 provider_data: None,
1324 };
1325 assert!(!oauth_cred.needs_refresh());
1326
1327 let soon = now_secs() + 30;
1329 let oauth_soon = AuthCredential::OAuth {
1330 access_token: "token".to_string(),
1331 refresh_token: Some("refresh".to_string()),
1332 expires_at: soon,
1333 scopes: None,
1334 provider_data: None,
1335 };
1336 assert!(oauth_soon.needs_refresh());
1337
1338 let no_refresh = AuthCredential::OAuth {
1340 access_token: "token".to_string(),
1341 refresh_token: None,
1342 expires_at: future_time,
1343 scopes: None,
1344 provider_data: None,
1345 };
1346 assert!(!no_refresh.needs_refresh());
1347
1348 let api_key_cred = AuthCredential::ApiKey {
1350 key: "test".to_string(),
1351 };
1352 assert!(!api_key_cred.needs_refresh());
1353 }
1354
1355 #[test]
1356 fn test_auth_credential_access_token() {
1357 let future_time = now_secs() + 3600;
1358
1359 let oauth_cred = AuthCredential::OAuth {
1360 access_token: "valid_token".to_string(),
1361 refresh_token: Some("refresh".to_string()),
1362 expires_at: future_time,
1363 scopes: None,
1364 provider_data: None,
1365 };
1366 assert_eq!(oauth_cred.access_token(), Some("valid_token"));
1367
1368 let expired_cred = AuthCredential::OAuth {
1370 access_token: "expired_token".to_string(),
1371 refresh_token: Some("refresh".to_string()),
1372 expires_at: 0,
1373 scopes: None,
1374 provider_data: None,
1375 };
1376 assert!(expired_cred.access_token().is_none());
1377
1378 let api_key_cred = AuthCredential::ApiKey {
1380 key: "api_key_token".to_string(),
1381 };
1382 assert!(api_key_cred.access_token().is_none());
1383 }
1384
1385 #[test]
1386 fn test_get_oauth_credential() {
1387 let storage = AuthStorage::in_memory();
1388 storage.set_oauth(
1389 "provider",
1390 "access".to_string(),
1391 Some("refresh".to_string()),
1392 u64::MAX,
1393 );
1394
1395 let cred = storage.get_oauth_credential("provider");
1396 assert!(cred.is_some());
1397 assert!(matches!(cred.unwrap(), AuthCredential::OAuth { .. }));
1398 }
1399
1400 #[test]
1401 fn test_has_oauth_with_refresh() {
1402 let storage = AuthStorage::in_memory();
1403
1404 storage.set_oauth(
1406 "with_refresh",
1407 "access".to_string(),
1408 Some("refresh".to_string()),
1409 u64::MAX,
1410 );
1411 assert!(storage.has_oauth_with_refresh("with_refresh"));
1412
1413 storage.set_oauth("without_refresh", "access".to_string(), None, u64::MAX);
1415 assert!(!storage.has_oauth_with_refresh("without_refresh"));
1416
1417 storage.set_api_key("apikey_provider", "key".to_string());
1419 assert!(!storage.has_oauth_with_refresh("apikey_provider"));
1420 }
1421
1422 #[test]
1423 fn test_set_oauth_full() {
1424 let storage = AuthStorage::in_memory();
1425 storage.set_oauth_full(
1426 "provider",
1427 "access_token".to_string(),
1428 Some("refresh_token".to_string()),
1429 3600,
1430 Some("read write".to_string()),
1431 Some(serde_json::json!({"extra": "data"})),
1432 );
1433
1434 let cred = storage.get_oauth_credential("provider");
1435 assert!(cred.is_some());
1436 if let AuthCredential::OAuth {
1437 scopes,
1438 provider_data,
1439 ..
1440 } = cred.unwrap()
1441 {
1442 assert_eq!(scopes, Some("read write".to_string()));
1443 assert!(provider_data.is_some());
1444 } else {
1445 panic!("Expected OAuth credential");
1446 }
1447 }
1448
1449 #[test]
1450 fn test_session_token() {
1451 let storage = AuthStorage::in_memory();
1452 storage.set_session(
1453 "browser",
1454 "session-token-123".to_string(),
1455 0, Some(serde_json::json!({"user": "test"})),
1457 );
1458
1459 assert!(storage.has("browser"));
1460 assert_eq!(
1461 storage.get_api_key("browser"),
1462 Some("session-token-123".to_string())
1463 );
1464
1465 let cred = storage.get("browser").unwrap();
1466 assert!(matches!(cred, AuthCredential::Session { .. }));
1467 assert!(cred.access_token().is_some());
1468 }
1469
1470 #[test]
1471 fn test_session_token_expired() {
1472 let storage = AuthStorage::in_memory();
1473 storage.set_session("browser", "session-token".to_string(), 1, None);
1474
1475 assert!(storage.get_api_key("browser").is_none());
1477 }
1478
1479 #[test]
1480 fn test_credential_validation() {
1481 let valid = AuthCredential::ApiKey {
1483 key: "sk-valid".to_string(),
1484 };
1485 assert!(valid.validate().is_ok());
1486
1487 let empty = AuthCredential::ApiKey {
1489 key: "".to_string(),
1490 };
1491 assert!(empty.validate().is_err());
1492
1493 let placeholder = AuthCredential::ApiKey {
1495 key: "your-api-key-here".to_string(),
1496 };
1497 assert!(placeholder.validate().is_err());
1498
1499 let valid_oauth = AuthCredential::OAuth {
1501 access_token: "token".to_string(),
1502 refresh_token: None,
1503 expires_at: now_secs() + 3600,
1504 scopes: None,
1505 provider_data: None,
1506 };
1507 assert!(valid_oauth.validate().is_ok());
1508
1509 let invalid_oauth = AuthCredential::OAuth {
1511 access_token: "".to_string(),
1512 refresh_token: None,
1513 expires_at: 1000,
1514 scopes: None,
1515 provider_data: None,
1516 };
1517 assert!(invalid_oauth.validate().is_err());
1518 }
1519
1520 #[test]
1521 fn test_validate_all() {
1522 let storage = AuthStorage::in_memory();
1523 storage.set_api_key("valid", "sk-good".to_string());
1524 storage.set_api_key("empty", "".to_string());
1525
1526 let errors = storage.validate_all();
1527 assert_eq!(errors.len(), 1);
1528 assert_eq!(errors[0].0, "empty");
1529 }
1530
1531 #[test]
1532 fn test_update_oauth_tokens() {
1533 let storage = AuthStorage::in_memory();
1534 storage.set_oauth(
1535 "provider",
1536 "old-access".to_string(),
1537 Some("old-refresh".to_string()),
1538 now_secs() + 3600,
1539 );
1540
1541 storage
1542 .update_oauth_tokens(
1543 "provider",
1544 "new-access".to_string(),
1545 Some("new-refresh".to_string()),
1546 now_secs() + 7200,
1547 )
1548 .unwrap();
1549
1550 let key = storage.get_api_key("provider");
1551 assert_eq!(key, Some("new-access".to_string()));
1552 }
1553
1554 #[test]
1555 fn test_update_oauth_tokens_wrong_type() {
1556 let storage = AuthStorage::in_memory();
1557 storage.set_api_key("provider", "key".to_string());
1558
1559 let result = storage.update_oauth_tokens(
1560 "provider",
1561 "new-access".to_string(),
1562 None,
1563 now_secs() + 3600,
1564 );
1565 assert!(result.is_err());
1566 }
1567
1568 #[test]
1569 fn test_migrate_provider() {
1570 let storage = AuthStorage::in_memory();
1571 storage.set_api_key("old-provider", "key123".to_string());
1572 storage
1573 .migrate_provider("old-provider", "new-provider")
1574 .unwrap();
1575
1576 assert!(!storage.has("old-provider"));
1577 assert!(storage.has("new-provider"));
1578 assert_eq!(
1579 storage.get_api_key("new-provider"),
1580 Some("key123".to_string())
1581 );
1582 }
1583
1584 #[test]
1585 fn test_migrate_provider_not_found() {
1586 let storage = AuthStorage::in_memory();
1587 let result = storage.migrate_provider("nonexistent", "target");
1588 assert!(result.is_err());
1589 }
1590
1591 #[test]
1592 fn test_error_draining() {
1593 let storage = AuthStorage::in_memory();
1594 let errors = storage.drain_errors();
1595 assert!(errors.is_empty());
1596 }
1597
1598 #[test]
1599 fn test_fallback_resolver() {
1600 let storage = AuthStorage::in_memory();
1601 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|provider| {
1602 if provider == "custom" {
1603 Some("custom-key-from-config".to_string())
1604 } else {
1605 None
1606 }
1607 }))));
1608
1609 assert_eq!(
1610 storage.get_api_key("custom"),
1611 Some("custom-key-from-config".to_string())
1612 );
1613 assert!(storage.get_api_key("unknown").is_none());
1614
1615 storage.clear_fallback_resolver();
1617 assert!(storage.get_api_key("custom").is_none());
1618 }
1619
1620 #[test]
1621 fn test_get_api_key_with_options() {
1622 let storage = AuthStorage::in_memory();
1623 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|_| {
1624 Some("fallback-key".to_string())
1625 }))));
1626
1627 assert_eq!(
1629 storage.get_api_key_with_options("test", true),
1630 Some("fallback-key".to_string())
1631 );
1632
1633 assert!(storage.get_api_key_with_options("test", false).is_none());
1635 }
1636
1637 #[test]
1638 fn test_configured_providers() {
1639 let storage = AuthStorage::in_memory();
1640 storage.set_api_key("openai", "key".to_string());
1641 storage.set_api_key("anthropic", "key".to_string());
1642
1643 let providers = storage.configured_providers();
1644 assert!(providers.len() >= 2);
1645 let mut sorted = providers.clone();
1647 sorted.sort();
1648 assert_eq!(providers, sorted);
1649 }
1650
1651 #[test]
1652 fn test_has_multiple_providers() {
1653 let storage = AuthStorage::in_memory();
1654 assert!(!storage.has_multiple_providers());
1655
1656 storage.set_api_key("openai", "key1".to_string());
1657 assert!(!storage.has_multiple_providers());
1658
1659 storage.set_api_key("anthropic", "key2".to_string());
1660 assert!(storage.has_multiple_providers());
1661 }
1662
1663 #[test]
1664 fn test_set_and_get_credential() {
1665 let storage = AuthStorage::in_memory();
1666 let cred = AuthCredential::Session {
1667 token: "abc".to_string(),
1668 expires_at: 0,
1669 metadata: None,
1670 };
1671 storage.set("custom", cred);
1672 let retrieved = storage.get("custom");
1673 assert!(retrieved.is_some());
1674 assert!(matches!(retrieved.unwrap(), AuthCredential::Session { .. }));
1675 }
1676
1677 #[test]
1678 fn test_credential_type_name() {
1679 assert_eq!(
1680 AuthCredential::ApiKey {
1681 key: "k".to_string()
1682 }
1683 .type_name(),
1684 "api_key"
1685 );
1686 assert_eq!(
1687 AuthCredential::OAuth {
1688 access_token: "t".to_string(),
1689 refresh_token: None,
1690 expires_at: 0,
1691 scopes: None,
1692 provider_data: None,
1693 }
1694 .type_name(),
1695 "oauth"
1696 );
1697 assert_eq!(
1698 AuthCredential::Session {
1699 token: "t".to_string(),
1700 expires_at: 0,
1701 metadata: None,
1702 }
1703 .type_name(),
1704 "session"
1705 );
1706 }
1707}