1use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::{Arc, OnceLock};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "snake_case")]
20pub enum AuthCredential {
21 ApiKey {
23 key: String,
25 },
26 OAuth {
28 access_token: String,
30 refresh_token: Option<String>,
32 expires_at: u64,
34 #[serde(default)]
36 scopes: Option<String>,
37 #[serde(default)]
39 provider_data: Option<serde_json::Value>,
40 },
41 Session {
43 token: String,
45 #[serde(default)]
47 expires_at: u64,
48 #[serde(default)]
50 metadata: Option<serde_json::Value>,
51 },
52}
53
54impl AuthCredential {
55 pub fn is_expired(&self) -> bool {
57 match self {
58 AuthCredential::OAuth { expires_at, .. } => {
59 let now = now_secs();
60 *expires_at < now
61 }
62 AuthCredential::Session { expires_at, .. } => {
63 if *expires_at == 0 {
64 return false; }
66 *expires_at <= now_secs()
67 }
68 AuthCredential::ApiKey { .. } => false,
69 }
70 }
71
72 pub fn needs_refresh(&self) -> bool {
74 match self {
75 AuthCredential::OAuth {
76 expires_at,
77 refresh_token,
78 ..
79 } => {
80 let now = now_secs();
81 refresh_token.is_some() && *expires_at <= now + 60
82 }
83 AuthCredential::Session { .. } => false,
84 AuthCredential::ApiKey { .. } => false,
85 }
86 }
87
88 pub fn access_token(&self) -> Option<&str> {
90 match self {
91 AuthCredential::OAuth { access_token, .. } if !self.is_expired() => Some(access_token),
92 AuthCredential::Session { token, .. } if !self.is_expired() => Some(token),
93 _ => None,
94 }
95 }
96
97 pub fn type_name(&self) -> &'static str {
99 match self {
100 AuthCredential::ApiKey { .. } => "api_key",
101 AuthCredential::OAuth { .. } => "oauth",
102 AuthCredential::Session { .. } => "session",
103 }
104 }
105
106 pub fn validate(&self) -> Result<(), CredentialValidationError> {
108 match self {
109 AuthCredential::ApiKey { key } => {
110 if key.is_empty() {
111 return Err(CredentialValidationError::EmptyField("key".to_string()));
112 }
113 if key == "your-api-key-here" || key == "xxx" {
115 return Err(CredentialValidationError::PlaceholderValue(key.clone()));
116 }
117 Ok(())
118 }
119 AuthCredential::OAuth {
120 access_token,
121 expires_at,
122 ..
123 } => {
124 if access_token.is_empty() {
125 return Err(CredentialValidationError::EmptyField(
126 "access_token".to_string(),
127 ));
128 }
129 if *expires_at == 0 {
130 return Err(CredentialValidationError::InvalidExpiry);
131 }
132 Ok(())
133 }
134 AuthCredential::Session { token, .. } => {
135 if token.is_empty() {
136 return Err(CredentialValidationError::EmptyField("token".to_string()));
137 }
138 Ok(())
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, thiserror::Error)]
146pub enum CredentialValidationError {
147 #[error("Field '{0}' must not be empty")]
148 EmptyField(String),
150 #[error("Placeholder value detected: '{0}'")]
151 PlaceholderValue(String),
153 #[error("Invalid expiry timestamp")]
154 InvalidExpiry,
156}
157
158#[derive(Debug, Clone)]
164pub struct AuthStatus {
165 pub configured: bool,
167 pub source: Option<String>,
169 pub label: Option<String>,
171}
172
173impl std::fmt::Display for AuthStatus {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match (&self.source, &self.label) {
176 (Some(source), Some(label)) => write!(f, "{} ({})", source, label),
177 (Some(source), None) => write!(f, "{}", source),
178 (None, Some(label)) => write!(f, "{}", label),
179 (None, None) => write!(f, "not configured"),
180 }
181 }
182}
183
184pub type AuthResult<T> = Result<T, AuthError>;
190
191#[derive(Debug, Clone, thiserror::Error)]
193pub enum AuthError {
194 #[error("Failed to read auth storage: {0}")]
195 ReadError(String),
197 #[error("Failed to write auth storage: {0}")]
198 WriteError(String),
200 #[error("Credential not found: {0}")]
201 NotFound(String),
203 #[error("Invalid credential format: {0}")]
204 InvalidFormat(String),
206 #[error("Keyring error: {0}")]
207 KeyringError(String),
209 #[error("Credential validation failed: {0}")]
210 ValidationFailed(String),
212}
213
214pub trait AuthStorageBackend: Send + Sync {
220 fn read(&self) -> AuthResult<Option<String>>;
222 fn write(&self, data: &str) -> AuthResult<()>;
224 fn delete(&self) -> AuthResult<()>;
226}
227
228pub struct FileAuthStorage {
234 path: PathBuf,
235 cache: RwLock<Option<String>>,
236}
237
238impl FileAuthStorage {
239 pub fn new(path: PathBuf) -> Self {
241 Self {
242 path,
243 cache: RwLock::new(None),
244 }
245 }
246
247 pub fn default_path() -> Option<PathBuf> {
249 dirs::config_dir().map(|p| p.join("oxi").join("auth.json"))
250 }
251
252 pub fn path(&self) -> &PathBuf {
254 &self.path
255 }
256}
257
258impl AuthStorageBackend for FileAuthStorage {
259 fn read(&self) -> AuthResult<Option<String>> {
260 if !self.path.exists() {
261 return Ok(None);
262 }
263
264 match std::fs::read_to_string(&self.path) {
265 Ok(content) => {
266 *self.cache.write() = Some(content.clone());
267 Ok(Some(content))
268 }
269 Err(e) => Err(AuthError::ReadError(e.to_string())),
270 }
271 }
272
273 fn write(&self, data: &str) -> AuthResult<()> {
274 if let Some(parent) = self.path.parent() {
276 std::fs::create_dir_all(parent).map_err(|e| AuthError::WriteError(e.to_string()))?;
277
278 #[cfg(unix)]
279 {
280 use std::os::unix::fs::PermissionsExt;
281 let perms = std::fs::Permissions::from_mode(0o700);
282 let _ = std::fs::set_permissions(parent, perms);
283 }
284 }
285
286 std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
288
289 #[cfg(unix)]
291 {
292 use std::os::unix::fs::PermissionsExt;
293 let perms = std::fs::Permissions::from_mode(0o600);
294 std::fs::set_permissions(&self.path, perms)
295 .map_err(|e| AuthError::WriteError(e.to_string()))?;
296 }
297
298 *self.cache.write() = Some(data.to_string());
299 Ok(())
300 }
301
302 fn delete(&self) -> AuthResult<()> {
303 if self.path.exists() {
304 std::fs::remove_file(&self.path).map_err(|e| AuthError::WriteError(e.to_string()))?;
305 }
306 *self.cache.write() = None;
307 Ok(())
308 }
309}
310
311pub struct MemoryAuthStorage {
317 data: RwLock<HashMap<String, AuthCredential>>,
318}
319
320impl MemoryAuthStorage {
321 pub fn new() -> Self {
323 Self {
324 data: RwLock::new(HashMap::new()),
325 }
326 }
327}
328
329impl Default for MemoryAuthStorage {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl AuthStorageBackend for MemoryAuthStorage {
336 fn read(&self) -> AuthResult<Option<String>> {
337 Ok(None)
339 }
340
341 fn write(&self, _data: &str) -> AuthResult<()> {
342 Ok(())
343 }
344
345 fn delete(&self) -> AuthResult<()> {
346 self.data.write().clear();
347 Ok(())
348 }
349}
350
351pub trait FallbackResolver: Send + Sync {
357 fn resolve(&self, provider: &str) -> Option<String>;
359}
360
361pub struct FnFallbackResolver {
363 #[allow(clippy::type_complexity)]
364 f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>,
365}
366
367impl FnFallbackResolver {
368 #[allow(clippy::type_complexity)]
370 pub fn new(f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>) -> Self {
371 Self { f }
372 }
373}
374
375impl FallbackResolver for FnFallbackResolver {
376 fn resolve(&self, provider: &str) -> Option<String> {
377 (self.f)(provider)
378 }
379}
380
381pub struct AuthStorage {
395 file_storage: Option<Arc<dyn AuthStorageBackend>>,
397 credentials: RwLock<HashMap<String, AuthCredential>>,
399 runtime_overrides: RwLock<HashMap<String, String>>,
401 fallback_resolver: RwLock<Option<Arc<dyn FallbackResolver>>>,
403 errors: RwLock<Vec<AuthError>>,
405 load_error: RwLock<Option<AuthError>>,
407 plaintext_warned: OnceLock<()>,
409}
410
411impl AuthStorage {
412 pub fn new() -> Self {
414 let file_storage = FileAuthStorage::default_path()
415 .map(|p| Arc::new(FileAuthStorage::new(p)) as Arc<dyn AuthStorageBackend>);
416
417 let credentials = if let Some(ref storage) = file_storage {
418 match storage.read() {
419 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
420 _ => HashMap::new(),
421 }
422 } else {
423 HashMap::new()
424 };
425
426 Self {
427 file_storage,
428 credentials: RwLock::new(credentials),
429 runtime_overrides: RwLock::new(HashMap::new()),
430 fallback_resolver: RwLock::new(None),
431 errors: RwLock::new(Vec::new()),
432 load_error: RwLock::new(None),
433 plaintext_warned: OnceLock::new(),
434 }
435 }
436
437 pub fn with_backend(backend: impl AuthStorageBackend + 'static) -> Self {
439 let credentials = match backend.read() {
440 Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
441 _ => HashMap::new(),
442 };
443
444 Self {
445 file_storage: Some(Arc::new(backend)),
446 credentials: RwLock::new(credentials),
447 runtime_overrides: RwLock::new(HashMap::new()),
448 fallback_resolver: RwLock::new(None),
449 errors: RwLock::new(Vec::new()),
450 load_error: RwLock::new(None),
451 plaintext_warned: OnceLock::new(),
452 }
453 }
454
455 pub fn in_memory() -> Self {
457 Self {
458 file_storage: None,
459 credentials: RwLock::new(HashMap::new()),
460 runtime_overrides: RwLock::new(HashMap::new()),
461 fallback_resolver: RwLock::new(None),
462 errors: RwLock::new(Vec::new()),
463 load_error: RwLock::new(None),
464 plaintext_warned: OnceLock::new(),
465 }
466 }
467
468 pub fn default_path() -> Option<PathBuf> {
470 FileAuthStorage::default_path()
471 }
472
473 pub fn set_runtime_key(&self, provider: &str, api_key: String) {
479 self.runtime_overrides
480 .write()
481 .insert(provider.to_string(), api_key);
482 }
483
484 pub fn remove_runtime_key(&self, provider: &str) {
486 self.runtime_overrides.write().remove(provider);
487 }
488
489 pub fn set_fallback_resolver(&self, resolver: Arc<dyn FallbackResolver>) {
496 *self.fallback_resolver.write() = Some(resolver);
497 }
498
499 pub fn clear_fallback_resolver(&self) {
501 *self.fallback_resolver.write() = None;
502 }
503
504 pub fn has_auth(&self, provider: &str) -> bool {
510 if self.runtime_overrides.read().contains_key(provider) {
511 return true;
512 }
513 if self.credentials.read().contains_key(provider) {
514 return true;
515 }
516 if let Some(ref resolver) = *self.fallback_resolver.read() {
517 if resolver.resolve(provider).is_some() {
518 return true;
519 }
520 }
521 false
522 }
523
524 pub fn get_status(&self, provider: &str) -> AuthStatus {
526 if self.runtime_overrides.read().contains_key(provider) {
527 return AuthStatus {
528 configured: false,
529 source: Some("runtime".to_string()),
530 label: Some("--api-key".to_string()),
531 };
532 }
533
534 if let Some(cred) = self.credentials.read().get(provider) {
535 return AuthStatus {
536 configured: true,
537 source: Some("stored".to_string()),
538 label: Some(cred.type_name().to_string()),
539 };
540 }
541
542 if let Some(ref resolver) = *self.fallback_resolver.read() {
543 if resolver.resolve(provider).is_some() {
544 return AuthStatus {
545 configured: false,
546 source: Some("fallback".to_string()),
547 label: Some("custom provider config".to_string()),
548 };
549 }
550 }
551
552 AuthStatus {
553 configured: false,
554 source: None,
555 label: None,
556 }
557 }
558
559 pub fn get_api_key(&self, provider: &str) -> Option<String> {
568 self.get_api_key_with_options(provider, true)
569 }
570
571 pub fn get_api_key_with_options(
573 &self,
574 provider: &str,
575 include_fallback: bool,
576 ) -> Option<String> {
577 if let Some(key) = self.runtime_overrides.read().get(provider) {
579 return Some(key.clone());
580 }
581
582 if let Some(cred) = self.credentials.read().get(provider) {
584 return match cred {
585 AuthCredential::ApiKey { key } => Some(key.clone()),
586 AuthCredential::OAuth {
587 access_token,
588 expires_at,
589 ..
590 } => {
591 if *expires_at > now_secs() {
592 Some(access_token.clone())
593 } else {
594 None
596 }
597 }
598 AuthCredential::Session {
599 token, expires_at, ..
600 } => {
601 if *expires_at == 0 || *expires_at > now_secs() {
602 Some(token.clone())
603 } else {
604 None
605 }
606 }
607 };
608 }
609
610 if include_fallback {
612 if let Some(ref resolver) = *self.fallback_resolver.read() {
613 return resolver.resolve(provider);
614 }
615 }
616
617 None
618 }
619
620 pub fn set_api_key(&self, provider: &str, key: String) {
626 self.credentials
627 .write()
628 .insert(provider.to_string(), AuthCredential::ApiKey { key });
629 if let Err(e) = self.persist() {
630 tracing::warn!("Failed to persist API key for '{}': {}", provider, e);
631 }
632 }
633
634 pub fn set_oauth(
636 &self,
637 provider: &str,
638 access_token: String,
639 refresh_token: Option<String>,
640 expires_at: u64,
641 ) {
642 self.set_oauth_full(
643 provider,
644 access_token,
645 refresh_token,
646 expires_at,
647 None,
648 None,
649 );
650 }
651
652 pub fn set_oauth_full(
654 &self,
655 provider: &str,
656 access_token: String,
657 refresh_token: Option<String>,
658 expires_at: u64,
659 scopes: Option<String>,
660 provider_data: Option<serde_json::Value>,
661 ) {
662 self.credentials.write().insert(
663 provider.to_string(),
664 AuthCredential::OAuth {
665 access_token,
666 refresh_token,
667 expires_at,
668 scopes,
669 provider_data,
670 },
671 );
672 if let Err(e) = self.persist() {
673 tracing::warn!("Failed to persist OAuth token for '{}': {}", provider, e);
674 }
675 }
676
677 pub fn set_session(
679 &self,
680 provider: &str,
681 token: String,
682 expires_at: u64,
683 metadata: Option<serde_json::Value>,
684 ) {
685 self.credentials.write().insert(
686 provider.to_string(),
687 AuthCredential::Session {
688 token,
689 expires_at,
690 metadata,
691 },
692 );
693 if let Err(e) = self.persist() {
694 tracing::warn!("Failed to persist session for '{}': {}", provider, e);
695 }
696 }
697
698 pub fn update_oauth_tokens(
700 &self,
701 provider: &str,
702 new_access_token: String,
703 new_refresh_token: Option<String>,
704 new_expires_at: u64,
705 ) -> AuthResult<()> {
706 let mut creds = self.credentials.write();
707 let cred = creds
708 .get_mut(provider)
709 .ok_or_else(|| AuthError::NotFound(provider.to_string()))?;
710
711 match cred {
712 AuthCredential::OAuth {
713 access_token,
714 refresh_token,
715 expires_at,
716 ..
717 } => {
718 *access_token = new_access_token;
719 *refresh_token = new_refresh_token;
720 *expires_at = new_expires_at;
721 }
722 _ => {
723 return Err(AuthError::InvalidFormat(format!(
724 "Provider '{}' does not have OAuth credentials",
725 provider
726 )));
727 }
728 }
729
730 drop(creds);
731 if let Err(e) = self.persist() {
732 tracing::warn!(
733 "Failed to persist OAuth token update for '{}': {}",
734 provider,
735 e
736 );
737 }
738 Ok(())
739 }
740
741 pub fn get(&self, provider: &str) -> Option<AuthCredential> {
747 self.credentials.read().get(provider).cloned()
748 }
749
750 pub fn get_oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
752 self.credentials.read().get(provider).cloned()
753 }
754
755 pub fn has_oauth_with_refresh(&self, provider: &str) -> bool {
757 if let Some(cred) = self.credentials.read().get(provider) {
758 matches!(
759 cred,
760 AuthCredential::OAuth {
761 refresh_token: Some(_),
762 ..
763 }
764 )
765 } else {
766 false
767 }
768 }
769
770 pub fn set(&self, provider: &str, credential: AuthCredential) {
776 self.credentials
777 .write()
778 .insert(provider.to_string(), credential);
779 if let Err(e) = self.persist() {
780 tracing::warn!("Failed to persist credential for '{}': {}", provider, e);
781 }
782 }
783
784 pub fn remove(&self, provider: &str) {
786 self.credentials.write().remove(provider);
787 if let Err(e) = self.persist() {
788 tracing::warn!("Failed to persist after removing '{}': {}", provider, e);
789 }
790 }
791
792 pub fn list_providers(&self) -> Vec<String> {
794 self.credentials.read().keys().cloned().collect()
795 }
796
797 pub fn has(&self, provider: &str) -> bool {
799 self.credentials.read().contains_key(provider)
800 }
801
802 pub fn get_all(&self) -> HashMap<String, AuthCredential> {
804 self.credentials.read().clone()
805 }
806
807 pub fn clear(&self) {
809 self.credentials.write().clear();
810 if let Err(e) = self.persist() {
811 tracing::warn!("Failed to persist after clearing credentials: {}", e);
812 }
813 }
814
815 pub fn reload(&self) {
821 if let Some(ref storage) = self.file_storage {
822 match storage.read() {
823 Ok(Some(content)) => {
824 if let Ok(creds) = serde_json::from_str(&content) {
825 *self.credentials.write() = creds;
826 }
827 *self.load_error.write() = None;
828 }
829 Ok(None) => {
830 self.credentials.write().clear();
831 *self.load_error.write() = None;
832 }
833 Err(e) => {
834 *self.load_error.write() = Some(e);
835 self.record_error(AuthError::ReadError(
836 "Failed to reload auth storage".to_string(),
837 ));
838 }
839 }
840 }
841 }
842
843 fn persist(&self) -> Result<(), String> {
845 if let Some(ref storage) = self.file_storage {
846 let creds = self.credentials.read();
847 if let Ok(json) = serde_json::to_string_pretty(&*creds) {
848 #[cfg(not(feature = "keyring"))]
850 {
851 self.plaintext_warned.get_or_init(|| {
852 tracing::warn!(
853 "Auth credentials are stored in plaintext. \
854 Enable the 'keyring' feature for secure OS-level storage."
855 );
856 });
857 }
858
859 if let Err(e) = storage.write(&json) {
860 tracing::error!("Failed to persist auth storage: {}", e);
861 self.record_error(e);
862 return Err("persist failed".to_string());
863 }
864 }
865 }
866 Ok(())
867 }
868
869 fn record_error(&self, error: AuthError) {
875 self.errors.write().push(error);
876 }
877
878 pub fn drain_errors(&self) -> Vec<AuthError> {
880 let mut errors = self.errors.write();
881 std::mem::take(&mut *errors)
882 }
883
884 pub fn load_error(&self) -> Option<AuthError> {
886 self.load_error.read().clone()
887 }
888
889 pub fn validate_all(&self) -> Vec<(String, CredentialValidationError)> {
895 let creds = self.credentials.read();
896 let mut results = Vec::new();
897 for (provider, cred) in creds.iter() {
898 if let Err(e) = cred.validate() {
899 results.push((provider.clone(), e));
900 }
901 }
902 results
903 }
904
905 pub fn validate(&self, provider: &str) -> Result<(), CredentialValidationError> {
907 let creds = self.credentials.read();
908 let cred = creds.get(provider).ok_or_else(|| {
909 CredentialValidationError::EmptyField(format!(
910 "no credential for provider '{}'",
911 provider
912 ))
913 })?;
914 cred.validate()
915 }
916
917 pub fn configured_providers(&self) -> Vec<String> {
923 let mut providers: Vec<String> = self.credentials.read().keys().cloned().collect();
924 providers.sort();
925 providers
926 }
927
928 pub fn has_multiple_providers(&self) -> bool {
930 self.credentials.read().len() > 1
931 }
932
933 pub fn primary_provider(&self) -> Option<String> {
935 let creds = self.credentials.read();
936 creds.keys().next().cloned()
937 }
938
939 pub fn migrate_provider(&self, from: &str, to: &str) -> AuthResult<()> {
941 let mut creds = self.credentials.write();
942 let cred = creds
943 .remove(from)
944 .ok_or_else(|| AuthError::NotFound(from.to_string()))?;
945 creds.insert(to.to_string(), cred);
946 drop(creds);
947 let _ = self.persist();
948 Ok(())
949 }
950}
951
952impl Default for AuthStorage {
953 fn default() -> Self {
954 Self::new()
955 }
956}
957
958fn now_secs() -> u64 {
963 std::time::SystemTime::now()
964 .duration_since(std::time::UNIX_EPOCH)
965 .map(|d| d.as_secs())
966 .unwrap_or(0)
967}
968
969#[allow(unexpected_cfgs)]
975pub mod keyring_support {
976 use super::*;
977
978 #[cfg(feature = "keyring")]
980 pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
981 use keyring::Entry;
982 Entry::new(service, account)
983 .ok()
984 .and_then(|entry| entry.get_password().ok())
985 }
986
987 #[cfg(feature = "keyring")]
989 pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> AuthResult<()> {
990 use keyring::Entry;
991 Entry::new(service, account)
992 .map_err(|e| AuthError::KeyringError(e.to_string()))?
993 .set_password(secret)
994 .map_err(|e| AuthError::KeyringError(e.to_string()))
995 }
996
997 #[cfg(feature = "keyring")]
999 pub fn delete_keyring_secret(service: &str, account: &str) -> AuthResult<()> {
1000 use keyring::Entry;
1001 Entry::new(service, account)
1002 .map_err(|e| AuthError::KeyringError(e.to_string()))?
1003 .delete_credential()
1004 .map_err(|e| AuthError::KeyringError(e.to_string()))
1005 }
1006
1007 #[cfg(not(feature = "keyring"))]
1009 pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
1013 None
1014 }
1015
1016 #[cfg(not(feature = "keyring"))]
1017 pub fn set_keyring_secret(_service: &str, _account: &str, _secret: &str) -> AuthResult<()> {
1021 Err(AuthError::KeyringError(
1022 "Keyring support not compiled".to_string(),
1023 ))
1024 }
1025
1026 #[cfg(not(feature = "keyring"))]
1027 pub fn delete_keyring_secret(_service: &str, _account: &str) -> AuthResult<()> {
1031 Err(AuthError::KeyringError(
1032 "Keyring support not compiled".to_string(),
1033 ))
1034 }
1035}
1036
1037pub fn shared_auth_storage() -> Arc<AuthStorage> {
1047 static STORAGE: OnceLock<Arc<AuthStorage>> = OnceLock::new();
1048 STORAGE.get_or_init(|| Arc::new(AuthStorage::new())).clone()
1049}
1050
1051#[cfg(test)]
1056mod tests {
1057 use super::*;
1058
1059 #[test]
1060 fn test_auth_storage_new() {
1061 let storage = AuthStorage::in_memory();
1062 assert!(!storage.has("anthropic"));
1063 }
1064
1065 #[test]
1066 fn test_set_and_get_api_key() {
1067 let storage = AuthStorage::in_memory();
1068 storage.set_api_key("anthropic", "sk-test123".to_string());
1069 assert!(storage.has("anthropic"));
1070 assert_eq!(
1071 storage.get_api_key("anthropic"),
1072 Some("sk-test123".to_string())
1073 );
1074 }
1075
1076 #[test]
1077 fn test_runtime_override() {
1078 let storage = AuthStorage::in_memory();
1079 storage.set_api_key("anthropic", "stored-key".to_string());
1080 storage.set_runtime_key("anthropic", "runtime-key".to_string());
1081
1082 assert_eq!(
1084 storage.get_api_key("anthropic"),
1085 Some("runtime-key".to_string())
1086 );
1087 }
1088
1089 #[test]
1090 fn test_remove_credential() {
1091 let storage = AuthStorage::in_memory();
1092 storage.set_api_key("anthropic", "sk-test123".to_string());
1093 assert!(storage.has("anthropic"));
1094
1095 storage.remove("anthropic");
1096 assert!(!storage.has("anthropic"));
1097 }
1098
1099 #[test]
1100 fn test_auth_status() {
1101 let storage = AuthStorage::in_memory();
1102 storage.set_api_key("anthropic", "sk-test123".to_string());
1103
1104 let status = storage.get_status("anthropic");
1105 assert!(status.configured);
1106 assert_eq!(status.source, Some("stored".to_string()));
1107 assert_eq!(status.label, Some("api_key".to_string()));
1108 }
1109
1110 #[test]
1111 fn test_auth_status_display() {
1112 let status = AuthStatus {
1113 configured: true,
1114 source: Some("stored".to_string()),
1115 label: Some("api_key".to_string()),
1116 };
1117 let display = format!("{}", status);
1118 assert_eq!(display, "stored (api_key)");
1119
1120 let no_config = AuthStatus {
1121 configured: false,
1122 source: None,
1123 label: None,
1124 };
1125 assert_eq!(format!("{}", no_config), "not configured");
1126 }
1127
1128 #[test]
1129 fn test_list_providers() {
1130 let storage = AuthStorage::in_memory();
1131 storage.set_api_key("anthropic", "key1".to_string());
1132 storage.set_api_key("openai", "key2".to_string());
1133
1134 let providers = storage.list_providers();
1135 assert!(providers.contains(&"anthropic".to_string()));
1136 assert!(providers.contains(&"openai".to_string()));
1137 }
1138
1139 #[test]
1140 fn test_oauth_credential() {
1141 let storage = AuthStorage::in_memory();
1142 storage.set_oauth(
1143 "provider",
1144 "access123".to_string(),
1145 Some("refresh456".to_string()),
1146 u64::MAX,
1147 );
1148
1149 assert!(storage.has("provider"));
1150 assert_eq!(
1151 storage.get_api_key("provider"),
1152 Some("access123".to_string())
1153 );
1154 }
1155
1156 #[test]
1157 fn test_expired_oauth_token() {
1158 let storage = AuthStorage::in_memory();
1159 storage.set_oauth("provider", "access123".to_string(), None, 0);
1161
1162 let key = storage.get_api_key("provider");
1164 assert!(key.is_none());
1165 }
1166
1167 #[test]
1168 fn test_get_all_credentials() {
1169 let storage = AuthStorage::in_memory();
1170 storage.set_api_key("anthropic", "key1".to_string());
1171 storage.set_api_key("openai", "key2".to_string());
1172
1173 let all = storage.get_all();
1174 assert_eq!(all.len(), 2);
1175 }
1176
1177 #[test]
1178 fn test_clear() {
1179 let storage = AuthStorage::in_memory();
1180 storage.set_api_key("anthropic", "key".to_string());
1181 assert!(storage.has("anthropic"));
1182
1183 storage.clear();
1184 assert!(!storage.has("anthropic"));
1185 }
1186
1187 #[test]
1188 fn test_remove_runtime_key() {
1189 let storage = AuthStorage::in_memory();
1190 storage.set_api_key("anthropic", "stored".to_string());
1191 storage.set_runtime_key("anthropic", "runtime".to_string());
1192
1193 assert_eq!(
1194 storage.get_api_key("anthropic"),
1195 Some("runtime".to_string())
1196 );
1197
1198 storage.remove_runtime_key("anthropic");
1199 assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
1200 }
1201
1202 #[test]
1203 fn test_auth_credential_is_expired() {
1204 let api_key_cred = AuthCredential::ApiKey {
1206 key: "test".to_string(),
1207 };
1208 assert!(!api_key_cred.is_expired());
1209
1210 let future_time = now_secs() + 3600;
1212 let oauth_cred = AuthCredential::OAuth {
1213 access_token: "token".to_string(),
1214 refresh_token: Some("refresh".to_string()),
1215 expires_at: future_time,
1216 scopes: None,
1217 provider_data: None,
1218 };
1219 assert!(!oauth_cred.is_expired());
1220
1221 let oauth_cred_expired = AuthCredential::OAuth {
1223 access_token: "token".to_string(),
1224 refresh_token: Some("refresh".to_string()),
1225 expires_at: 0,
1226 scopes: None,
1227 provider_data: None,
1228 };
1229 assert!(oauth_cred_expired.is_expired());
1230 }
1231
1232 #[test]
1233 fn test_auth_credential_needs_refresh() {
1234 let future_time = now_secs() + 120; let oauth_cred = AuthCredential::OAuth {
1238 access_token: "token".to_string(),
1239 refresh_token: Some("refresh".to_string()),
1240 expires_at: future_time,
1241 scopes: None,
1242 provider_data: None,
1243 };
1244 assert!(!oauth_cred.needs_refresh());
1245
1246 let soon = now_secs() + 30;
1248 let oauth_soon = AuthCredential::OAuth {
1249 access_token: "token".to_string(),
1250 refresh_token: Some("refresh".to_string()),
1251 expires_at: soon,
1252 scopes: None,
1253 provider_data: None,
1254 };
1255 assert!(oauth_soon.needs_refresh());
1256
1257 let no_refresh = AuthCredential::OAuth {
1259 access_token: "token".to_string(),
1260 refresh_token: None,
1261 expires_at: future_time,
1262 scopes: None,
1263 provider_data: None,
1264 };
1265 assert!(!no_refresh.needs_refresh());
1266
1267 let api_key_cred = AuthCredential::ApiKey {
1269 key: "test".to_string(),
1270 };
1271 assert!(!api_key_cred.needs_refresh());
1272 }
1273
1274 #[test]
1275 fn test_auth_credential_access_token() {
1276 let future_time = now_secs() + 3600;
1277
1278 let oauth_cred = AuthCredential::OAuth {
1279 access_token: "valid_token".to_string(),
1280 refresh_token: Some("refresh".to_string()),
1281 expires_at: future_time,
1282 scopes: None,
1283 provider_data: None,
1284 };
1285 assert_eq!(oauth_cred.access_token(), Some("valid_token"));
1286
1287 let expired_cred = AuthCredential::OAuth {
1289 access_token: "expired_token".to_string(),
1290 refresh_token: Some("refresh".to_string()),
1291 expires_at: 0,
1292 scopes: None,
1293 provider_data: None,
1294 };
1295 assert!(expired_cred.access_token().is_none());
1296
1297 let api_key_cred = AuthCredential::ApiKey {
1299 key: "api_key_token".to_string(),
1300 };
1301 assert!(api_key_cred.access_token().is_none());
1302 }
1303
1304 #[test]
1305 fn test_get_oauth_credential() {
1306 let storage = AuthStorage::in_memory();
1307 storage.set_oauth(
1308 "provider",
1309 "access".to_string(),
1310 Some("refresh".to_string()),
1311 u64::MAX,
1312 );
1313
1314 let cred = storage.get_oauth_credential("provider");
1315 assert!(cred.is_some());
1316 assert!(matches!(cred.unwrap(), AuthCredential::OAuth { .. }));
1317 }
1318
1319 #[test]
1320 fn test_has_oauth_with_refresh() {
1321 let storage = AuthStorage::in_memory();
1322
1323 storage.set_oauth(
1325 "with_refresh",
1326 "access".to_string(),
1327 Some("refresh".to_string()),
1328 u64::MAX,
1329 );
1330 assert!(storage.has_oauth_with_refresh("with_refresh"));
1331
1332 storage.set_oauth("without_refresh", "access".to_string(), None, u64::MAX);
1334 assert!(!storage.has_oauth_with_refresh("without_refresh"));
1335
1336 storage.set_api_key("apikey_provider", "key".to_string());
1338 assert!(!storage.has_oauth_with_refresh("apikey_provider"));
1339 }
1340
1341 #[test]
1342 fn test_set_oauth_full() {
1343 let storage = AuthStorage::in_memory();
1344 storage.set_oauth_full(
1345 "provider",
1346 "access_token".to_string(),
1347 Some("refresh_token".to_string()),
1348 3600,
1349 Some("read write".to_string()),
1350 Some(serde_json::json!({"extra": "data"})),
1351 );
1352
1353 let cred = storage.get_oauth_credential("provider");
1354 assert!(cred.is_some());
1355 if let AuthCredential::OAuth {
1356 scopes,
1357 provider_data,
1358 ..
1359 } = cred.unwrap()
1360 {
1361 assert_eq!(scopes, Some("read write".to_string()));
1362 assert!(provider_data.is_some());
1363 } else {
1364 panic!("Expected OAuth credential");
1365 }
1366 }
1367
1368 #[test]
1369 fn test_session_token() {
1370 let storage = AuthStorage::in_memory();
1371 storage.set_session(
1372 "browser",
1373 "session-token-123".to_string(),
1374 0, Some(serde_json::json!({"user": "test"})),
1376 );
1377
1378 assert!(storage.has("browser"));
1379 assert_eq!(
1380 storage.get_api_key("browser"),
1381 Some("session-token-123".to_string())
1382 );
1383
1384 let cred = storage.get("browser").unwrap();
1385 assert!(matches!(cred, AuthCredential::Session { .. }));
1386 assert!(cred.access_token().is_some());
1387 }
1388
1389 #[test]
1390 fn test_session_token_expired() {
1391 let storage = AuthStorage::in_memory();
1392 storage.set_session("browser", "session-token".to_string(), 1, None);
1393
1394 assert!(storage.get_api_key("browser").is_none());
1396 }
1397
1398 #[test]
1399 fn test_credential_validation() {
1400 let valid = AuthCredential::ApiKey {
1402 key: "sk-valid".to_string(),
1403 };
1404 assert!(valid.validate().is_ok());
1405
1406 let empty = AuthCredential::ApiKey {
1408 key: "".to_string(),
1409 };
1410 assert!(empty.validate().is_err());
1411
1412 let placeholder = AuthCredential::ApiKey {
1414 key: "your-api-key-here".to_string(),
1415 };
1416 assert!(placeholder.validate().is_err());
1417
1418 let valid_oauth = AuthCredential::OAuth {
1420 access_token: "token".to_string(),
1421 refresh_token: None,
1422 expires_at: now_secs() + 3600,
1423 scopes: None,
1424 provider_data: None,
1425 };
1426 assert!(valid_oauth.validate().is_ok());
1427
1428 let invalid_oauth = AuthCredential::OAuth {
1430 access_token: "".to_string(),
1431 refresh_token: None,
1432 expires_at: 1000,
1433 scopes: None,
1434 provider_data: None,
1435 };
1436 assert!(invalid_oauth.validate().is_err());
1437 }
1438
1439 #[test]
1440 fn test_validate_all() {
1441 let storage = AuthStorage::in_memory();
1442 storage.set_api_key("valid", "sk-good".to_string());
1443 storage.set_api_key("empty", "".to_string());
1444
1445 let errors = storage.validate_all();
1446 assert_eq!(errors.len(), 1);
1447 assert_eq!(errors[0].0, "empty");
1448 }
1449
1450 #[test]
1451 fn test_update_oauth_tokens() {
1452 let storage = AuthStorage::in_memory();
1453 storage.set_oauth(
1454 "provider",
1455 "old-access".to_string(),
1456 Some("old-refresh".to_string()),
1457 now_secs() + 3600,
1458 );
1459
1460 storage
1461 .update_oauth_tokens(
1462 "provider",
1463 "new-access".to_string(),
1464 Some("new-refresh".to_string()),
1465 now_secs() + 7200,
1466 )
1467 .unwrap();
1468
1469 let key = storage.get_api_key("provider");
1470 assert_eq!(key, Some("new-access".to_string()));
1471 }
1472
1473 #[test]
1474 fn test_update_oauth_tokens_wrong_type() {
1475 let storage = AuthStorage::in_memory();
1476 storage.set_api_key("provider", "key".to_string());
1477
1478 let result = storage.update_oauth_tokens(
1479 "provider",
1480 "new-access".to_string(),
1481 None,
1482 now_secs() + 3600,
1483 );
1484 assert!(result.is_err());
1485 }
1486
1487 #[test]
1488 fn test_migrate_provider() {
1489 let storage = AuthStorage::in_memory();
1490 storage.set_api_key("old-provider", "key123".to_string());
1491 storage
1492 .migrate_provider("old-provider", "new-provider")
1493 .unwrap();
1494
1495 assert!(!storage.has("old-provider"));
1496 assert!(storage.has("new-provider"));
1497 assert_eq!(
1498 storage.get_api_key("new-provider"),
1499 Some("key123".to_string())
1500 );
1501 }
1502
1503 #[test]
1504 fn test_migrate_provider_not_found() {
1505 let storage = AuthStorage::in_memory();
1506 let result = storage.migrate_provider("nonexistent", "target");
1507 assert!(result.is_err());
1508 }
1509
1510 #[test]
1511 fn test_error_draining() {
1512 let storage = AuthStorage::in_memory();
1513 let errors = storage.drain_errors();
1514 assert!(errors.is_empty());
1515 }
1516
1517 #[test]
1518 fn test_fallback_resolver() {
1519 let storage = AuthStorage::in_memory();
1520 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|provider| {
1521 if provider == "custom" {
1522 Some("custom-key-from-config".to_string())
1523 } else {
1524 None
1525 }
1526 }))));
1527
1528 assert_eq!(
1529 storage.get_api_key("custom"),
1530 Some("custom-key-from-config".to_string())
1531 );
1532 assert!(storage.get_api_key("unknown").is_none());
1533
1534 storage.clear_fallback_resolver();
1536 assert!(storage.get_api_key("custom").is_none());
1537 }
1538
1539 #[test]
1540 fn test_get_api_key_with_options() {
1541 let storage = AuthStorage::in_memory();
1542 storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(Box::new(|_| {
1543 Some("fallback-key".to_string())
1544 }))));
1545
1546 assert_eq!(
1548 storage.get_api_key_with_options("test", true),
1549 Some("fallback-key".to_string())
1550 );
1551
1552 assert!(storage.get_api_key_with_options("test", false).is_none());
1554 }
1555
1556 #[test]
1557 fn test_configured_providers() {
1558 let storage = AuthStorage::in_memory();
1559 storage.set_api_key("openai", "key".to_string());
1560 storage.set_api_key("anthropic", "key".to_string());
1561
1562 let providers = storage.configured_providers();
1563 assert!(providers.len() >= 2);
1564 let mut sorted = providers.clone();
1566 sorted.sort();
1567 assert_eq!(providers, sorted);
1568 }
1569
1570 #[test]
1571 fn test_has_multiple_providers() {
1572 let storage = AuthStorage::in_memory();
1573 assert!(!storage.has_multiple_providers());
1574
1575 storage.set_api_key("openai", "key1".to_string());
1576 assert!(!storage.has_multiple_providers());
1577
1578 storage.set_api_key("anthropic", "key2".to_string());
1579 assert!(storage.has_multiple_providers());
1580 }
1581
1582 #[test]
1583 fn test_set_and_get_credential() {
1584 let storage = AuthStorage::in_memory();
1585 let cred = AuthCredential::Session {
1586 token: "abc".to_string(),
1587 expires_at: 0,
1588 metadata: None,
1589 };
1590 storage.set("custom", cred);
1591 let retrieved = storage.get("custom");
1592 assert!(retrieved.is_some());
1593 assert!(matches!(retrieved.unwrap(), AuthCredential::Session { .. }));
1594 }
1595
1596 #[test]
1597 fn test_credential_type_name() {
1598 assert_eq!(
1599 AuthCredential::ApiKey {
1600 key: "k".to_string()
1601 }
1602 .type_name(),
1603 "api_key"
1604 );
1605 assert_eq!(
1606 AuthCredential::OAuth {
1607 access_token: "t".to_string(),
1608 refresh_token: None,
1609 expires_at: 0,
1610 scopes: None,
1611 provider_data: None,
1612 }
1613 .type_name(),
1614 "oauth"
1615 );
1616 assert_eq!(
1617 AuthCredential::Session {
1618 token: "t".to_string(),
1619 expires_at: 0,
1620 metadata: None,
1621 }
1622 .type_name(),
1623 "session"
1624 );
1625 }
1626}