use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use parking_lot::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthCredential {
ApiKey {
key: String,
},
OAuth {
access_token: String,
refresh_token: Option<String>,
expires_at: u64,
#[serde(default)]
scopes: Option<String>,
#[serde(default)]
provider_data: Option<serde_json::Value>,
},
Session {
token: String,
#[serde(default)]
expires_at: u64,
#[serde(default)]
metadata: Option<serde_json::Value>,
},
}
impl AuthCredential {
pub fn is_expired(&self) -> bool {
match self {
AuthCredential::OAuth { expires_at, .. } => {
let now = now_secs();
*expires_at <= now
}
AuthCredential::Session { expires_at, .. } => {
if *expires_at == 0 {
return false; }
*expires_at <= now_secs()
}
AuthCredential::ApiKey { .. } => false,
}
}
pub fn needs_refresh(&self) -> bool {
match self {
AuthCredential::OAuth {
expires_at,
refresh_token,
..
} => {
let now = now_secs();
refresh_token.is_some() && *expires_at <= now + 60
}
AuthCredential::Session { .. } => false,
AuthCredential::ApiKey { .. } => false,
}
}
pub fn access_token(&self) -> Option<&str> {
match self {
AuthCredential::OAuth { access_token, .. } if !self.is_expired() => Some(access_token),
AuthCredential::Session { token, .. } if !self.is_expired() => Some(token),
_ => None,
}
}
pub fn type_name(&self) -> &'static str {
match self {
AuthCredential::ApiKey { .. } => "api_key",
AuthCredential::OAuth { .. } => "oauth",
AuthCredential::Session { .. } => "session",
}
}
pub fn validate(&self) -> Result<(), CredentialValidationError> {
match self {
AuthCredential::ApiKey { key } => {
if key.is_empty() {
return Err(CredentialValidationError::EmptyField("key".to_string()));
}
if key == "your-api-key-here" || key == "xxx" {
return Err(CredentialValidationError::PlaceholderValue(key.clone()));
}
Ok(())
}
AuthCredential::OAuth {
access_token,
expires_at,
..
} => {
if access_token.is_empty() {
return Err(CredentialValidationError::EmptyField("access_token".to_string()));
}
if *expires_at == 0 {
return Err(CredentialValidationError::InvalidExpiry);
}
Ok(())
}
AuthCredential::Session { token, .. } => {
if token.is_empty() {
return Err(CredentialValidationError::EmptyField("token".to_string()));
}
Ok(())
}
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum CredentialValidationError {
#[error("Field '{0}' must not be empty")]
EmptyField(String),
#[error("Placeholder value detected: '{0}'")]
PlaceholderValue(String),
#[error("Invalid expiry timestamp")]
InvalidExpiry,
}
#[derive(Debug, Clone)]
pub struct AuthStatus {
pub configured: bool,
pub source: Option<String>,
pub label: Option<String>,
}
impl std::fmt::Display for AuthStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (&self.source, &self.label) {
(Some(source), Some(label)) => write!(f, "{} ({})", source, label),
(Some(source), None) => write!(f, "{}", source),
(None, Some(label)) => write!(f, "{}", label),
(None, None) => write!(f, "not configured"),
}
}
}
pub type AuthResult<T> = Result<T, AuthError>;
#[derive(Debug, Clone, thiserror::Error)]
pub enum AuthError {
#[error("Failed to read auth storage: {0}")]
ReadError(String),
#[error("Failed to write auth storage: {0}")]
WriteError(String),
#[error("Credential not found: {0}")]
NotFound(String),
#[error("Invalid credential format: {0}")]
InvalidFormat(String),
#[error("Keyring error: {0}")]
KeyringError(String),
#[error("Credential validation failed: {0}")]
ValidationFailed(String),
}
static PROVIDER_ENV_KEYS: &[(&str, &[&str])] = &[
("anthropic", &["ANTHROPIC_API_KEY"]),
("openai", &["OPENAI_API_KEY"]),
("google", &["GOOGLE_API_KEY", "GEMINI_API_KEY"]),
("groq", &["GROQ_API_KEY"]),
("mistral", &["MISTRAL_API_KEY"]),
("deepseek", &["DEEPSEEK_API_KEY"]),
("xai", &["XAI_API_KEY"]),
("cohere", &["COHERE_API_KEY", "CO_API_KEY"]),
("perplexity", &["PERPLEXITY_API_KEY"]),
];
pub fn find_env_keys(provider: &str) -> Vec<String> {
let normalized = provider.to_lowercase().replace('-', "_");
for (name, keys) in PROVIDER_ENV_KEYS {
if *name == normalized {
let found: Vec<String> = keys
.iter()
.filter(|k| std::env::var(k).is_ok())
.map(|k| k.to_string())
.collect();
if !found.is_empty() {
return found;
}
}
}
let generic_key = format!("{}_API_KEY", normalized.to_uppercase());
if std::env::var(&generic_key).is_ok() {
return vec![generic_key];
}
Vec::new()
}
pub fn get_env_api_key(provider: &str) -> Option<String> {
let keys = find_env_keys(provider);
keys.first().and_then(|k| std::env::var(k).ok())
}
pub trait AuthStorageBackend: Send + Sync {
fn read(&self) -> AuthResult<Option<String>>;
fn write(&self, data: &str) -> AuthResult<()>;
fn delete(&self) -> AuthResult<()>;
}
pub struct FileAuthStorage {
path: PathBuf,
cache: RwLock<Option<String>>,
}
impl FileAuthStorage {
pub fn new(path: PathBuf) -> Self {
Self {
path,
cache: RwLock::new(None),
}
}
pub fn default_path() -> Option<PathBuf> {
dirs::config_dir().map(|p| p.join("oxi").join("auth.json"))
}
pub fn path(&self) -> &PathBuf {
&self.path
}
}
impl AuthStorageBackend for FileAuthStorage {
fn read(&self) -> AuthResult<Option<String>> {
if !self.path.exists() {
return Ok(None);
}
match std::fs::read_to_string(&self.path) {
Ok(content) => {
*self.cache.write() = Some(content.clone());
Ok(Some(content))
}
Err(e) => Err(AuthError::ReadError(e.to_string())),
}
}
fn write(&self, data: &str) -> AuthResult<()> {
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent).map_err(|e| AuthError::WriteError(e.to_string()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o700);
let _ = std::fs::set_permissions(parent, perms);
}
}
std::fs::write(&self.path, data).map_err(|e| AuthError::WriteError(e.to_string()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(&self.path, perms)
.map_err(|e| AuthError::WriteError(e.to_string()))?;
}
*self.cache.write() = Some(data.to_string());
Ok(())
}
fn delete(&self) -> AuthResult<()> {
if self.path.exists() {
std::fs::remove_file(&self.path).map_err(|e| AuthError::WriteError(e.to_string()))?;
}
*self.cache.write() = None;
Ok(())
}
}
pub struct EnvAuthStorage {
provider_prefix: String,
}
impl EnvAuthStorage {
pub fn new(provider: &str) -> Self {
Self {
provider_prefix: format!("{}_API_KEY", provider.to_uppercase().replace('-', "_")),
}
}
}
impl AuthStorageBackend for EnvAuthStorage {
fn read(&self) -> AuthResult<Option<String>> {
Ok(std::env::var(&self.provider_prefix).ok())
}
fn write(&self, _data: &str) -> AuthResult<()> {
Err(AuthError::WriteError(
"Cannot write to environment variables".to_string(),
))
}
fn delete(&self) -> AuthResult<()> {
std::env::remove_var(&self.provider_prefix);
Ok(())
}
}
pub struct MemoryAuthStorage {
data: RwLock<HashMap<String, AuthCredential>>,
}
impl MemoryAuthStorage {
pub fn new() -> Self {
Self {
data: RwLock::new(HashMap::new()),
}
}
}
impl Default for MemoryAuthStorage {
fn default() -> Self {
Self::new()
}
}
impl AuthStorageBackend for MemoryAuthStorage {
fn read(&self) -> AuthResult<Option<String>> {
Ok(None)
}
fn write(&self, _data: &str) -> AuthResult<()> {
Ok(())
}
fn delete(&self) -> AuthResult<()> {
self.data.write().clear();
Ok(())
}
}
pub trait FallbackResolver: Send + Sync {
fn resolve(&self, provider: &str) -> Option<String>;
}
pub struct FnFallbackResolver {
f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>,
}
impl FnFallbackResolver {
pub fn new(f: Box<dyn Fn(&str) -> Option<String> + Send + Sync>) -> Self {
Self { f }
}
}
impl FallbackResolver for FnFallbackResolver {
fn resolve(&self, provider: &str) -> Option<String> {
(self.f)(provider)
}
}
pub struct AuthStorage {
file_storage: Option<Arc<dyn AuthStorageBackend>>,
credentials: RwLock<HashMap<String, AuthCredential>>,
runtime_overrides: RwLock<HashMap<String, String>>,
fallback_resolver: RwLock<Option<Arc<dyn FallbackResolver>>>,
errors: RwLock<Vec<AuthError>>,
load_error: RwLock<Option<AuthError>>,
}
impl AuthStorage {
pub fn new() -> Self {
let file_storage = FileAuthStorage::default_path()
.map(|p| Arc::new(FileAuthStorage::new(p)) as Arc<dyn AuthStorageBackend>);
let credentials = if let Some(ref storage) = file_storage {
match storage.read() {
Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
_ => HashMap::new(),
}
} else {
HashMap::new()
};
Self {
file_storage,
credentials: RwLock::new(credentials),
runtime_overrides: RwLock::new(HashMap::new()),
fallback_resolver: RwLock::new(None),
errors: RwLock::new(Vec::new()),
load_error: RwLock::new(None),
}
}
pub fn with_backend(backend: impl AuthStorageBackend + 'static) -> Self {
let credentials = match backend.read() {
Ok(Some(content)) => serde_json::from_str(&content).unwrap_or_default(),
_ => HashMap::new(),
};
Self {
file_storage: Some(Arc::new(backend)),
credentials: RwLock::new(credentials),
runtime_overrides: RwLock::new(HashMap::new()),
fallback_resolver: RwLock::new(None),
errors: RwLock::new(Vec::new()),
load_error: RwLock::new(None),
}
}
pub fn in_memory() -> Self {
Self {
file_storage: None,
credentials: RwLock::new(HashMap::new()),
runtime_overrides: RwLock::new(HashMap::new()),
fallback_resolver: RwLock::new(None),
errors: RwLock::new(Vec::new()),
load_error: RwLock::new(None),
}
}
pub fn default_path() -> Option<PathBuf> {
FileAuthStorage::default_path()
}
pub fn set_runtime_key(&self, provider: &str, api_key: String) {
self.runtime_overrides
.write()
.insert(provider.to_string(), api_key);
}
pub fn remove_runtime_key(&self, provider: &str) {
self.runtime_overrides.write().remove(provider);
}
pub fn set_fallback_resolver(&self, resolver: Arc<dyn FallbackResolver>) {
*self.fallback_resolver.write() = Some(resolver);
}
pub fn clear_fallback_resolver(&self) {
*self.fallback_resolver.write() = None;
}
pub fn has_auth(&self, provider: &str) -> bool {
if self.runtime_overrides.read().contains_key(provider) {
return true;
}
if self.credentials.read().contains_key(provider) {
return true;
}
if get_env_api_key(provider).is_some() {
return true;
}
if let Some(ref resolver) = *self.fallback_resolver.read() {
if resolver.resolve(provider).is_some() {
return true;
}
}
false
}
pub fn get_status(&self, provider: &str) -> AuthStatus {
if self.runtime_overrides.read().contains_key(provider) {
return AuthStatus {
configured: false,
source: Some("runtime".to_string()),
label: Some("--api-key".to_string()),
};
}
if let Some(cred) = self.credentials.read().get(provider) {
return AuthStatus {
configured: true,
source: Some("stored".to_string()),
label: Some(cred.type_name().to_string()),
};
}
let env_keys = find_env_keys(provider);
if let Some(first_key) = env_keys.first() {
return AuthStatus {
configured: false,
source: Some("environment".to_string()),
label: Some(first_key.clone()),
};
}
if let Some(ref resolver) = *self.fallback_resolver.read() {
if resolver.resolve(provider).is_some() {
return AuthStatus {
configured: false,
source: Some("fallback".to_string()),
label: Some("custom provider config".to_string()),
};
}
}
AuthStatus {
configured: false,
source: None,
label: None,
}
}
pub fn get_api_key(&self, provider: &str) -> Option<String> {
self.get_api_key_with_options(provider, true)
}
pub fn get_api_key_with_options(&self, provider: &str, include_fallback: bool) -> Option<String> {
if let Some(key) = self.runtime_overrides.read().get(provider) {
return Some(key.clone());
}
if let Some(cred) = self.credentials.read().get(provider) {
return match cred {
AuthCredential::ApiKey { key } => Some(key.clone()),
AuthCredential::OAuth { access_token, expires_at, .. } => {
if *expires_at > now_secs() {
Some(access_token.clone())
} else {
None
}
}
AuthCredential::Session { token, expires_at, .. } => {
if *expires_at == 0 || *expires_at > now_secs() {
Some(token.clone())
} else {
None
}
}
};
}
if include_fallback {
if let Some(ref resolver) = *self.fallback_resolver.read() {
return resolver.resolve(provider);
}
}
None
}
pub fn set_api_key(&self, provider: &str, key: String) {
self.credentials
.write()
.insert(provider.to_string(), AuthCredential::ApiKey { key });
self.persist();
}
pub fn set_oauth(
&self,
provider: &str,
access_token: String,
refresh_token: Option<String>,
expires_at: u64,
) {
self.set_oauth_full(
provider,
access_token,
refresh_token,
expires_at,
None,
None,
);
}
pub fn set_oauth_full(
&self,
provider: &str,
access_token: String,
refresh_token: Option<String>,
expires_at: u64,
scopes: Option<String>,
provider_data: Option<serde_json::Value>,
) {
self.credentials.write().insert(
provider.to_string(),
AuthCredential::OAuth {
access_token,
refresh_token,
expires_at,
scopes,
provider_data,
},
);
self.persist();
}
pub fn set_session(
&self,
provider: &str,
token: String,
expires_at: u64,
metadata: Option<serde_json::Value>,
) {
self.credentials.write().insert(
provider.to_string(),
AuthCredential::Session {
token,
expires_at,
metadata,
},
);
self.persist();
}
pub fn update_oauth_tokens(
&self,
provider: &str,
new_access_token: String,
new_refresh_token: Option<String>,
new_expires_at: u64,
) -> AuthResult<()> {
let mut creds = self.credentials.write();
let cred = creds.get_mut(provider).ok_or_else(|| {
AuthError::NotFound(provider.to_string())
})?;
match cred {
AuthCredential::OAuth {
access_token,
refresh_token,
expires_at,
..
} => {
*access_token = new_access_token;
*refresh_token = new_refresh_token;
*expires_at = new_expires_at;
}
_ => {
return Err(AuthError::InvalidFormat(
format!("Provider '{}' does not have OAuth credentials", provider),
));
}
}
drop(creds);
self.persist();
Ok(())
}
pub fn get(&self, provider: &str) -> Option<AuthCredential> {
self.credentials.read().get(provider).cloned()
}
pub fn get_oauth_credential(&self, provider: &str) -> Option<AuthCredential> {
self.credentials.read().get(provider).cloned()
}
pub fn has_oauth_with_refresh(&self, provider: &str) -> bool {
if let Some(cred) = self.credentials.read().get(provider) {
matches!(cred, AuthCredential::OAuth { refresh_token: Some(_), .. })
} else {
false
}
}
pub fn set(&self, provider: &str, credential: AuthCredential) {
self.credentials.write().insert(provider.to_string(), credential);
self.persist();
}
pub fn remove(&self, provider: &str) {
self.credentials.write().remove(provider);
self.persist();
}
pub fn list_providers(&self) -> Vec<String> {
self.credentials.read().keys().cloned().collect()
}
pub fn has(&self, provider: &str) -> bool {
self.credentials.read().contains_key(provider)
}
pub fn get_all(&self) -> HashMap<String, AuthCredential> {
self.credentials.read().clone()
}
pub fn clear(&self) {
self.credentials.write().clear();
self.persist();
}
pub fn reload(&self) {
if let Some(ref storage) = self.file_storage {
match storage.read() {
Ok(Some(content)) => {
if let Ok(creds) = serde_json::from_str(&content) {
*self.credentials.write() = creds;
}
*self.load_error.write() = None;
}
Ok(None) => {
self.credentials.write().clear();
*self.load_error.write() = None;
}
Err(e) => {
*self.load_error.write() = Some(e);
self.record_error(AuthError::ReadError("Failed to reload auth storage".to_string()));
}
}
}
}
fn persist(&self) {
if let Some(ref storage) = self.file_storage {
let creds = self.credentials.read();
if let Ok(json) = serde_json::to_string_pretty(&*creds) {
if let Err(e) = storage.write(&json) {
self.record_error(e);
}
}
}
}
fn record_error(&self, error: AuthError) {
self.errors.write().push(error);
}
pub fn drain_errors(&self) -> Vec<AuthError> {
let mut errors = self.errors.write();
std::mem::take(&mut *errors)
}
pub fn load_error(&self) -> Option<AuthError> {
self.load_error.read().clone()
}
pub fn validate_all(&self) -> Vec<(String, CredentialValidationError)> {
let creds = self.credentials.read();
let mut results = Vec::new();
for (provider, cred) in creds.iter() {
if let Err(e) = cred.validate() {
results.push((provider.clone(), e));
}
}
results
}
pub fn validate(&self, provider: &str) -> Result<(), CredentialValidationError> {
let creds = self.credentials.read();
let cred = creds.get(provider).ok_or_else(|| {
CredentialValidationError::EmptyField(format!("no credential for provider '{}'", provider))
})?;
cred.validate()
}
pub fn configured_providers(&self) -> Vec<String> {
let mut providers: Vec<String> = self.credentials.read().keys().cloned().collect();
providers.sort();
providers
}
pub fn has_multiple_providers(&self) -> bool {
self.credentials.read().len() > 1
}
pub fn primary_provider(&self) -> Option<String> {
let creds = self.credentials.read();
creds.keys().next().cloned()
}
pub fn migrate_provider(&self, from: &str, to: &str) -> AuthResult<()> {
let mut creds = self.credentials.write();
let cred = creds.remove(from).ok_or_else(|| {
AuthError::NotFound(from.to_string())
})?;
creds.insert(to.to_string(), cred);
drop(creds);
self.persist();
Ok(())
}
}
impl Default for AuthStorage {
fn default() -> Self {
Self::new()
}
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[allow(unexpected_cfgs)]
pub mod keyring_support {
use super::*;
#[cfg(feature = "keyring")]
pub fn get_keyring_secret(service: &str, account: &str) -> Option<String> {
use keyring::Entry;
Entry::new(service, account)
.ok()
.and_then(|entry| entry.get_password().ok())
}
#[cfg(feature = "keyring")]
pub fn set_keyring_secret(service: &str, account: &str, secret: &str) -> AuthResult<()> {
use keyring::Entry;
Entry::new(service, account)
.map_err(|e| AuthError::KeyringError(e.to_string()))?
.set_password(secret)
.map_err(|e| AuthError::KeyringError(e.to_string()))
}
#[cfg(feature = "keyring")]
pub fn delete_keyring_secret(service: &str, account: &str) -> AuthResult<()> {
use keyring::Entry;
Entry::new(service, account)
.map_err(|e| AuthError::KeyringError(e.to_string()))?
.delete_credential()
.map_err(|e| AuthError::KeyringError(e.to_string()))
}
#[cfg(not(feature = "keyring"))]
pub fn get_keyring_secret(_service: &str, _account: &str) -> Option<String> {
None
}
#[cfg(not(feature = "keyring"))]
pub fn set_keyring_secret(
_service: &str,
_account: &str,
_secret: &str,
) -> AuthResult<()> {
Err(AuthError::KeyringError(
"Keyring support not compiled".to_string(),
))
}
#[cfg(not(feature = "keyring"))]
pub fn delete_keyring_secret(_service: &str, _account: &str) -> AuthResult<()> {
Err(AuthError::KeyringError(
"Keyring support not compiled".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_storage_new() {
let storage = AuthStorage::in_memory();
assert!(!storage.has("anthropic"));
}
#[test]
fn test_set_and_get_api_key() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "sk-test123".to_string());
assert!(storage.has("anthropic"));
assert_eq!(
storage.get_api_key("anthropic"),
Some("sk-test123".to_string())
);
}
#[test]
fn test_runtime_override() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "stored-key".to_string());
storage.set_runtime_key("anthropic", "runtime-key".to_string());
assert_eq!(
storage.get_api_key("anthropic"),
Some("runtime-key".to_string())
);
}
#[test]
fn test_remove_credential() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "sk-test123".to_string());
assert!(storage.has("anthropic"));
storage.remove("anthropic");
assert!(!storage.has("anthropic"));
}
#[test]
fn test_auth_status() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "sk-test123".to_string());
let status = storage.get_status("anthropic");
assert!(status.configured);
assert_eq!(status.source, Some("stored".to_string()));
assert_eq!(status.label, Some("api_key".to_string()));
}
#[test]
fn test_auth_status_display() {
let status = AuthStatus {
configured: true,
source: Some("stored".to_string()),
label: Some("api_key".to_string()),
};
let display = format!("{}", status);
assert_eq!(display, "stored (api_key)");
let no_config = AuthStatus {
configured: false,
source: None,
label: None,
};
assert_eq!(format!("{}", no_config), "not configured");
}
#[test]
fn test_list_providers() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "key1".to_string());
storage.set_api_key("openai", "key2".to_string());
let providers = storage.list_providers();
assert!(providers.contains(&"anthropic".to_string()));
assert!(providers.contains(&"openai".to_string()));
}
#[test]
fn test_oauth_credential() {
let storage = AuthStorage::in_memory();
storage.set_oauth(
"provider",
"access123".to_string(),
Some("refresh456".to_string()),
u64::MAX,
);
assert!(storage.has("provider"));
assert_eq!(
storage.get_api_key("provider"),
Some("access123".to_string())
);
}
#[test]
fn test_expired_oauth_token() {
let storage = AuthStorage::in_memory();
storage.set_oauth("provider", "access123".to_string(), None, 0);
let key = storage.get_api_key("provider");
assert!(key.is_none());
}
#[test]
fn test_get_all_credentials() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "key1".to_string());
storage.set_api_key("openai", "key2".to_string());
let all = storage.get_all();
assert_eq!(all.len(), 2);
}
#[test]
fn test_clear() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "key".to_string());
assert!(storage.has("anthropic"));
storage.clear();
assert!(!storage.has("anthropic"));
}
#[test]
fn test_remove_runtime_key() {
let storage = AuthStorage::in_memory();
storage.set_api_key("anthropic", "stored".to_string());
storage.set_runtime_key("anthropic", "runtime".to_string());
assert_eq!(
storage.get_api_key("anthropic"),
Some("runtime".to_string())
);
storage.remove_runtime_key("anthropic");
assert_eq!(storage.get_api_key("anthropic"), Some("stored".to_string()));
}
#[test]
fn test_auth_credential_is_expired() {
let api_key_cred = AuthCredential::ApiKey {
key: "test".to_string(),
};
assert!(!api_key_cred.is_expired());
let future_time = now_secs() + 3600;
let oauth_cred = AuthCredential::OAuth {
access_token: "token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: future_time,
scopes: None,
provider_data: None,
};
assert!(!oauth_cred.is_expired());
let oauth_cred_expired = AuthCredential::OAuth {
access_token: "token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: 0,
scopes: None,
provider_data: None,
};
assert!(oauth_cred_expired.is_expired());
}
#[test]
fn test_auth_credential_needs_refresh() {
let future_time = now_secs() + 120;
let oauth_cred = AuthCredential::OAuth {
access_token: "token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: future_time,
scopes: None,
provider_data: None,
};
assert!(!oauth_cred.needs_refresh());
let soon = now_secs() + 30;
let oauth_soon = AuthCredential::OAuth {
access_token: "token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: soon,
scopes: None,
provider_data: None,
};
assert!(oauth_soon.needs_refresh());
let no_refresh = AuthCredential::OAuth {
access_token: "token".to_string(),
refresh_token: None,
expires_at: future_time,
scopes: None,
provider_data: None,
};
assert!(!no_refresh.needs_refresh());
let api_key_cred = AuthCredential::ApiKey {
key: "test".to_string(),
};
assert!(!api_key_cred.needs_refresh());
}
#[test]
fn test_auth_credential_access_token() {
let future_time = now_secs() + 3600;
let oauth_cred = AuthCredential::OAuth {
access_token: "valid_token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: future_time,
scopes: None,
provider_data: None,
};
assert_eq!(oauth_cred.access_token(), Some("valid_token"));
let expired_cred = AuthCredential::OAuth {
access_token: "expired_token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: 0,
scopes: None,
provider_data: None,
};
assert!(expired_cred.access_token().is_none());
let api_key_cred = AuthCredential::ApiKey {
key: "api_key_token".to_string(),
};
assert!(api_key_cred.access_token().is_none());
}
#[test]
fn test_get_oauth_credential() {
let storage = AuthStorage::in_memory();
storage.set_oauth(
"provider",
"access".to_string(),
Some("refresh".to_string()),
u64::MAX,
);
let cred = storage.get_oauth_credential("provider");
assert!(cred.is_some());
assert!(matches!(cred.unwrap(), AuthCredential::OAuth { .. }));
}
#[test]
fn test_has_oauth_with_refresh() {
let storage = AuthStorage::in_memory();
storage.set_oauth(
"with_refresh",
"access".to_string(),
Some("refresh".to_string()),
u64::MAX,
);
assert!(storage.has_oauth_with_refresh("with_refresh"));
storage.set_oauth("without_refresh", "access".to_string(), None, u64::MAX);
assert!(!storage.has_oauth_with_refresh("without_refresh"));
storage.set_api_key("apikey_provider", "key".to_string());
assert!(!storage.has_oauth_with_refresh("apikey_provider"));
}
#[test]
fn test_set_oauth_full() {
let storage = AuthStorage::in_memory();
storage.set_oauth_full(
"provider",
"access_token".to_string(),
Some("refresh_token".to_string()),
3600,
Some("read write".to_string()),
Some(serde_json::json!({"extra": "data"})),
);
let cred = storage.get_oauth_credential("provider");
assert!(cred.is_some());
if let AuthCredential::OAuth {
scopes,
provider_data,
..
} = cred.unwrap()
{
assert_eq!(scopes, Some("read write".to_string()));
assert!(provider_data.is_some());
} else {
panic!("Expected OAuth credential");
}
}
#[test]
fn test_session_token() {
let storage = AuthStorage::in_memory();
storage.set_session(
"browser",
"session-token-123".to_string(),
0, Some(serde_json::json!({"user": "test"})),
);
assert!(storage.has("browser"));
assert_eq!(
storage.get_api_key("browser"),
Some("session-token-123".to_string())
);
let cred = storage.get("browser").unwrap();
assert!(matches!(cred, AuthCredential::Session { .. }));
assert!(cred.access_token().is_some());
}
#[test]
fn test_session_token_expired() {
let storage = AuthStorage::in_memory();
storage.set_session("browser", "session-token".to_string(), 1, None);
assert!(storage.get_api_key("browser").is_none());
}
#[test]
fn test_credential_validation() {
let valid = AuthCredential::ApiKey { key: "sk-valid".to_string() };
assert!(valid.validate().is_ok());
let empty = AuthCredential::ApiKey { key: "".to_string() };
assert!(empty.validate().is_err());
let placeholder = AuthCredential::ApiKey { key: "your-api-key-here".to_string() };
assert!(placeholder.validate().is_err());
let valid_oauth = AuthCredential::OAuth {
access_token: "token".to_string(),
refresh_token: None,
expires_at: now_secs() + 3600,
scopes: None,
provider_data: None,
};
assert!(valid_oauth.validate().is_ok());
let invalid_oauth = AuthCredential::OAuth {
access_token: "".to_string(),
refresh_token: None,
expires_at: 1000,
scopes: None,
provider_data: None,
};
assert!(invalid_oauth.validate().is_err());
}
#[test]
fn test_validate_all() {
let storage = AuthStorage::in_memory();
storage.set_api_key("valid", "sk-good".to_string());
storage.set_api_key("empty", "".to_string());
let errors = storage.validate_all();
assert_eq!(errors.len(), 1);
assert_eq!(errors[0].0, "empty");
}
#[test]
fn test_update_oauth_tokens() {
let storage = AuthStorage::in_memory();
storage.set_oauth(
"provider",
"old-access".to_string(),
Some("old-refresh".to_string()),
now_secs() + 3600,
);
storage
.update_oauth_tokens(
"provider",
"new-access".to_string(),
Some("new-refresh".to_string()),
now_secs() + 7200,
)
.unwrap();
let key = storage.get_api_key("provider");
assert_eq!(key, Some("new-access".to_string()));
}
#[test]
fn test_update_oauth_tokens_wrong_type() {
let storage = AuthStorage::in_memory();
storage.set_api_key("provider", "key".to_string());
let result = storage.update_oauth_tokens(
"provider",
"new-access".to_string(),
None,
now_secs() + 3600,
);
assert!(result.is_err());
}
#[test]
fn test_migrate_provider() {
let storage = AuthStorage::in_memory();
storage.set_api_key("old-provider", "key123".to_string());
storage.migrate_provider("old-provider", "new-provider").unwrap();
assert!(!storage.has("old-provider"));
assert!(storage.has("new-provider"));
assert_eq!(storage.get_api_key("new-provider"), Some("key123".to_string()));
}
#[test]
fn test_migrate_provider_not_found() {
let storage = AuthStorage::in_memory();
let result = storage.migrate_provider("nonexistent", "target");
assert!(result.is_err());
}
#[test]
fn test_error_draining() {
let storage = AuthStorage::in_memory();
let errors = storage.drain_errors();
assert!(errors.is_empty());
}
#[test]
fn test_fallback_resolver() {
let storage = AuthStorage::in_memory();
storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(
Box::new(|provider| {
if provider == "custom" {
Some("custom-key-from-config".to_string())
} else {
None
}
}),
)));
assert_eq!(
storage.get_api_key("custom"),
Some("custom-key-from-config".to_string())
);
assert!(storage.get_api_key("unknown").is_none());
storage.clear_fallback_resolver();
assert!(storage.get_api_key("custom").is_none());
}
#[test]
fn test_get_api_key_with_options() {
let storage = AuthStorage::in_memory();
storage.set_fallback_resolver(Arc::new(FnFallbackResolver::new(
Box::new(|_| Some("fallback-key".to_string())),
)));
assert_eq!(
storage.get_api_key_with_options("test", true),
Some("fallback-key".to_string())
);
assert!(storage.get_api_key_with_options("test", false).is_none());
}
#[test]
fn test_configured_providers() {
let storage = AuthStorage::in_memory();
storage.set_api_key("openai", "key".to_string());
storage.set_api_key("anthropic", "key".to_string());
let providers = storage.configured_providers();
assert!(providers.len() >= 2);
let mut sorted = providers.clone();
sorted.sort();
assert_eq!(providers, sorted);
}
#[test]
fn test_has_multiple_providers() {
let storage = AuthStorage::in_memory();
assert!(!storage.has_multiple_providers());
storage.set_api_key("openai", "key1".to_string());
assert!(!storage.has_multiple_providers());
storage.set_api_key("anthropic", "key2".to_string());
assert!(storage.has_multiple_providers());
}
#[test]
fn test_set_and_get_credential() {
let storage = AuthStorage::in_memory();
let cred = AuthCredential::Session {
token: "abc".to_string(),
expires_at: 0,
metadata: None,
};
storage.set("custom", cred);
let retrieved = storage.get("custom");
assert!(retrieved.is_some());
assert!(matches!(retrieved.unwrap(), AuthCredential::Session { .. }));
}
#[test]
fn test_credential_type_name() {
assert_eq!(
AuthCredential::ApiKey { key: "k".to_string() }.type_name(),
"api_key"
);
assert_eq!(
AuthCredential::OAuth {
access_token: "t".to_string(),
refresh_token: None,
expires_at: 0,
scopes: None,
provider_data: None,
}
.type_name(),
"oauth"
);
assert_eq!(
AuthCredential::Session {
token: "t".to_string(),
expires_at: 0,
metadata: None,
}
.type_name(),
"session"
);
}
}