use std::collections::HashMap;
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::error::{Result, SaorsaAgentError};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthEntry {
ApiKey {
key: String,
},
EnvVar {
name: String,
},
Command {
command: String,
},
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct AuthConfig {
#[serde(flatten)]
pub providers: HashMap<String, AuthEntry>,
}
pub fn load(path: &Path) -> Result<AuthConfig> {
if !path.exists() {
return Ok(AuthConfig::default());
}
let data = std::fs::read_to_string(path).map_err(SaorsaAgentError::ConfigIo)?;
let config: AuthConfig = serde_json::from_str(&data).map_err(SaorsaAgentError::ConfigParse)?;
Ok(config)
}
pub fn save(config: &AuthConfig, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(SaorsaAgentError::ConfigIo)?;
}
let data = serde_json::to_string_pretty(config).map_err(SaorsaAgentError::ConfigParse)?;
std::fs::write(path, data).map_err(SaorsaAgentError::ConfigIo)?;
Ok(())
}
pub fn resolve(entry: &AuthEntry) -> Result<String> {
match entry {
AuthEntry::ApiKey { key } => Ok(key.clone()),
AuthEntry::EnvVar { name } => {
std::env::var(name).map_err(|_| SaorsaAgentError::EnvVarNotFound { name: name.clone() })
}
AuthEntry::Command { command } => {
let output = std::process::Command::new("sh")
.arg("-c")
.arg(command)
.output()
.map_err(|e| SaorsaAgentError::CommandFailed(e.to_string()))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(SaorsaAgentError::CommandFailed(format!(
"command exited with {}: {}",
output.status,
stderr.trim()
)));
}
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
}
}
}
pub fn get_key(config: &AuthConfig, provider: &str) -> Result<String> {
let entry = config
.providers
.get(provider)
.ok_or_else(|| SaorsaAgentError::EnvVarNotFound {
name: provider.to_string(),
})?;
resolve(entry)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn roundtrip_auth_config() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("auth.json");
let mut config = AuthConfig::default();
config.providers.insert(
"anthropic".into(),
AuthEntry::ApiKey {
key: "sk-test-123".into(),
},
);
config.providers.insert(
"openai".into(),
AuthEntry::EnvVar {
name: "OPENAI_API_KEY".into(),
},
);
save(&config, &path).unwrap();
let loaded = load(&path).unwrap();
assert_eq!(loaded.providers.len(), 2);
assert!(loaded.providers.contains_key("anthropic"));
assert!(loaded.providers.contains_key("openai"));
}
#[test]
fn load_missing_file_returns_default() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("nonexistent.json");
let config = load(&path).unwrap();
assert!(config.providers.is_empty());
}
#[test]
fn resolve_api_key() {
let entry = AuthEntry::ApiKey {
key: "sk-direct".into(),
};
let resolved = resolve(&entry).unwrap();
assert_eq!(resolved, "sk-direct");
}
#[test]
fn resolve_env_var() {
unsafe {
std::env::set_var("SAORSA_TEST_AUTH_KEY", "sk-from-env");
}
let entry = AuthEntry::EnvVar {
name: "SAORSA_TEST_AUTH_KEY".into(),
};
let resolved = resolve(&entry).unwrap();
assert_eq!(resolved, "sk-from-env");
unsafe {
std::env::remove_var("SAORSA_TEST_AUTH_KEY");
}
}
#[test]
fn resolve_env_var_missing() {
let entry = AuthEntry::EnvVar {
name: "SAORSA_NONEXISTENT_VAR_12345".into(),
};
let err = resolve(&entry).unwrap_err();
assert!(matches!(err, SaorsaAgentError::EnvVarNotFound { .. }));
}
#[test]
fn resolve_command() {
let entry = AuthEntry::Command {
command: "echo sk-from-cmd".into(),
};
let resolved = resolve(&entry).unwrap();
assert_eq!(resolved, "sk-from-cmd");
}
#[test]
fn resolve_command_failure() {
let entry = AuthEntry::Command {
command: "exit 1".into(),
};
let err = resolve(&entry).unwrap_err();
assert!(matches!(err, SaorsaAgentError::CommandFailed(_)));
}
#[test]
fn get_key_found() {
let mut config = AuthConfig::default();
config.providers.insert(
"test".into(),
AuthEntry::ApiKey {
key: "sk-test".into(),
},
);
let key = get_key(&config, "test").unwrap();
assert_eq!(key, "sk-test");
}
#[test]
fn get_key_missing_provider() {
let config = AuthConfig::default();
let err = get_key(&config, "missing").unwrap_err();
assert!(matches!(err, SaorsaAgentError::EnvVarNotFound { .. }));
}
#[test]
fn save_creates_parent_dirs() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("nested").join("deep").join("auth.json");
let config = AuthConfig::default();
save(&config, &path).unwrap();
assert!(path.exists());
}
}