use super::{ApiProvider, ApiProviderConfig, ConfigCache};
use crate::error::{ConfigError, Result};
use serde_json;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use tokio::fs;
use tracing::{debug, info, warn};
pub struct ConfigLoader {
config_dir: PathBuf,
cache: ConfigCache,
cache_path: PathBuf,
}
impl ConfigLoader {
pub fn new<P: AsRef<Path>>(config_dir: P) -> Self {
let config_dir = config_dir.as_ref().to_path_buf();
let cache_path = ConfigCache::default_cache_path();
Self {
config_dir,
cache: ConfigCache::new(),
cache_path,
}
}
pub async fn init(&mut self) -> Result<()> {
self.cache = ConfigCache::load(&self.cache_path).await?;
debug!("Loaded configuration cache from: {}", self.cache_path.display());
Ok(())
}
pub async fn discover_configs(&self) -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
let mut configs = HashMap::new();
let json_configs = self.load_json_configs().await?;
configs.extend(json_configs);
let env_configs = self.load_env_configs(&configs).await?;
configs.extend(env_configs);
debug!("Discovered {} provider configurations", configs.len());
Ok(configs)
}
async fn load_json_configs(&self) -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
let mut configs = HashMap::new();
let providers = [
ApiProvider::OpenAI,
ApiProvider::Anthropic,
ApiProvider::Google,
];
for provider in providers {
let config_path = self.config_dir.join(provider.config_filename());
if config_path.exists() {
match self.load_json_config(&config_path).await {
Ok(config) => {
info!("Loaded {} configuration from: {}", provider, config_path.display());
configs.insert(provider, config);
}
Err(e) => {
warn!("Failed to load {} configuration from {}: {}",
provider, config_path.display(), e);
}
}
}
}
Ok(configs)
}
async fn load_json_config(&self, path: &Path) -> Result<ApiProviderConfig> {
let content = fs::read_to_string(path).await?;
let config: ApiProviderConfig = serde_json::from_str(&content)
.map_err(|e| ConfigError::InvalidFormat)?;
Ok(config)
}
async fn load_env_configs(&self, existing_configs: &HashMap<ApiProvider, ApiProviderConfig>)
-> Result<HashMap<ApiProvider, ApiProviderConfig>> {
let mut configs = HashMap::new();
let providers = [
ApiProvider::OpenAI,
ApiProvider::Anthropic,
ApiProvider::Google,
];
for provider in providers {
if existing_configs.contains_key(&provider) {
continue;
}
if let Some(config) = self.load_env_config(&provider) {
info!("Loaded {} configuration from environment variables", provider);
configs.insert(provider, config);
}
}
Ok(configs)
}
fn load_env_config(&self, provider: &ApiProvider) -> Option<ApiProviderConfig> {
let prefix = provider.env_prefix();
let base_url = std::env::var(format!("{}_BASE_URL", prefix)).ok();
let api_key = std::env::var(format!("{}_API_KEY", prefix)).ok();
let model = std::env::var(format!("{}_MODEL", prefix)).ok();
if base_url.is_some() || api_key.is_some() || model.is_some() {
Some(ApiProviderConfig {
base_url,
api_key,
model,
extra: HashMap::new(),
})
} else {
None
}
}
pub async fn select_config(&mut self, configs: HashMap<ApiProvider, ApiProviderConfig>)
-> Result<(ApiProvider, ApiProviderConfig)> {
if configs.is_empty() {
return Err(ConfigError::NoConfigFound.into());
}
if configs.len() == 1 {
let (provider, config) = configs.into_iter().next().unwrap();
info!("Using single available configuration: {}", provider);
return Ok((provider, config));
}
if let Some(cached_provider) = self.cache.get_selected_provider() {
if !self.cache.is_expired() {
if let Ok(provider) = cached_provider.parse::<ApiProvider>() {
if let Some(config) = configs.get(&provider) {
info!("Using cached provider selection: {}", provider);
return Ok((provider, config.clone()));
}
}
}
}
self.prompt_user_selection(configs).await
}
async fn prompt_user_selection(&mut self, configs: HashMap<ApiProvider, ApiProviderConfig>)
-> Result<(ApiProvider, ApiProviderConfig)> {
println!("Multiple API provider configurations found:");
let providers: Vec<_> = configs.keys().collect();
for (i, provider) in providers.iter().enumerate() {
let config = configs.get(provider).unwrap();
let source = if config.api_key.is_some() { "configured" } else { "env vars" };
println!(" {}. {} ({})", i + 1, provider, source);
}
println!("Please select a provider (1-{}): ", providers.len());
let mut input = String::new();
std::io::stdin().read_line(&mut input)
.map_err(|_| ConfigError::InvalidValue {
field: "user_selection".to_string(),
value: "failed to read input".to_string(),
})?;
let selection: usize = input.trim().parse()
.map_err(|_| ConfigError::InvalidValue {
field: "user_selection".to_string(),
value: input.trim().to_string(),
})?;
if selection == 0 || selection > providers.len() {
return Err(ConfigError::InvalidValue {
field: "user_selection".to_string(),
value: selection.to_string(),
}.into());
}
let selected_provider = providers[selection - 1].clone();
let selected_config = configs.get(&selected_provider).unwrap().clone();
self.cache.set_selected_provider(selected_provider.to_string());
if let Err(e) = self.cache.save(&self.cache_path).await {
warn!("Failed to save configuration cache: {}", e);
}
info!("Selected provider: {}", selected_provider);
Ok((selected_provider, selected_config))
}
pub async fn load_config(&mut self) -> Result<(ApiProvider, ApiProviderConfig)> {
self.init().await?;
let configs = self.discover_configs().await?;
self.select_config(configs).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[tokio::test]
async fn test_load_json_config() {
let temp_dir = tempdir().unwrap();
let config_path = temp_dir.path().join("openai.json");
let config_content = r#"{
"base_url": "https://api.openai.com/v1",
"api_key": "test-key",
"model": "gpt-4"
}"#;
fs::write(&config_path, config_content).await.unwrap();
let loader = ConfigLoader::new(temp_dir.path());
let config = loader.load_json_config(&config_path).await.unwrap();
assert_eq!(config.base_url, Some("https://api.openai.com/v1".to_string()));
assert_eq!(config.api_key, Some("test-key".to_string()));
assert_eq!(config.model, Some("gpt-4".to_string()));
}
#[test]
fn test_load_env_config() {
std::env::set_var("OPENAI_API_KEY", "test-env-key");
std::env::set_var("OPENAI_MODEL", "gpt-3.5-turbo");
let loader = ConfigLoader::new(".");
let config = loader.load_env_config(&ApiProvider::OpenAI).unwrap();
assert_eq!(config.api_key, Some("test-env-key".to_string()));
assert_eq!(config.model, Some("gpt-3.5-turbo".to_string()));
std::env::remove_var("OPENAI_API_KEY");
std::env::remove_var("OPENAI_MODEL");
}
}
#[tokio::test]
async fn test_discover_configs_prefers_json_over_env() {
std::env::set_var("OPENAI_API_KEY", "env-key");
std::env::set_var("OPENAI_MODEL", "env-model");
let temp_dir = tempfile::tempdir().unwrap();
let openai_json = temp_dir.path().join("openai.json");
let content = r#"{
"base_url": "https://api.openai.com/v1",
"api_key": "json-key",
"model": "gpt-4"
}"#;
fs::write(&openai_json, content).await.unwrap();
let loader = ConfigLoader::new(temp_dir.path());
let configs = loader.discover_configs().await.unwrap();
let openai = configs.get(&ApiProvider::OpenAI).expect("openai config missing");
assert_eq!(openai.api_key.as_deref(), Some("json-key"));
assert_eq!(openai.model.as_deref(), Some("gpt-4"));
std::env::remove_var("OPENAI_API_KEY");
std::env::remove_var("OPENAI_MODEL");
}
#[tokio::test]
async fn test_select_config_single_and_cache() {
let temp_dir = tempfile::tempdir().unwrap();
let mut loader = ConfigLoader::new(temp_dir.path());
let mut single = HashMap::new();
single.insert(
ApiProvider::OpenAI,
ApiProviderConfig { base_url: Some("https://api.openai.com/v1".into()), api_key: Some("k".into()), model: Some("gpt-4".into()), extra: HashMap::new() }
);
let (prov, _cfg) = loader.select_config(single).await.unwrap();
assert!(matches!(prov, ApiProvider::OpenAI));
let mut many = HashMap::new();
many.insert(
ApiProvider::OpenAI,
ApiProviderConfig { base_url: None, api_key: Some("k1".into()), model: Some("m1".into()), extra: HashMap::new() }
);
many.insert(
ApiProvider::Anthropic,
ApiProviderConfig { base_url: None, api_key: Some("k2".into()), model: Some("m2".into()), extra: HashMap::new() }
);
loader.cache.set_selected_provider("openai".into());
let (prov2, _cfg2) = loader.select_config(many).await.unwrap();
assert!(matches!(prov2, ApiProvider::OpenAI));
}
#[tokio::test]
async fn test_load_config_uses_cache_file() {
let temp_dir = tempfile::tempdir().unwrap();
let openai_json = temp_dir.path().join("openai.json");
let content = r#"{
"base_url": "https://api.openai.com/v1",
"api_key": "json-key",
"model": "gpt-4"
}"#;
fs::write(&openai_json, content).await.unwrap();
let cache_path = temp_dir.path().join("cache.json");
let mut cache = ConfigCache::new();
cache.set_selected_provider("openai".into());
cache.save(&cache_path).await.unwrap();
let mut loader = ConfigLoader::new(temp_dir.path());
loader.cache_path = cache_path;
let (prov, cfg) = loader.load_config().await.unwrap();
assert!(matches!(prov, ApiProvider::OpenAI));
assert_eq!(cfg.api_key.as_deref(), Some("json-key"));
}
#[tokio::test]
async fn test_load_config_no_config_returns_error() {
for k in [
"OPENAI_API_KEY","OPENAI_BASE_URL","OPENAI_MODEL",
"ANTHROPIC_API_KEY","ANTHROPIC_BASE_URL","ANTHROPIC_MODEL",
"GOOGLE_API_KEY","GOOGLE_BASE_URL","GOOGLE_MODEL"
] { let _ = std::env::remove_var(k); }
let temp_dir = tempfile::tempdir().unwrap();
let mut loader = ConfigLoader::new(temp_dir.path());
loader.cache_path = temp_dir.path().join("cache.json");
let err = loader.load_config().await.err().expect("expected error");
let msg = format!("{}", err);
assert!(msg.contains("No configuration"), "unexpected error: {}", msg);
}