use super::types::{AgentConfig, AppConfig, ProviderConfig};
use crate::error::{Error, Result};
fn validate_provider(name: &str, provider: &ProviderConfig) -> Result<()> {
if provider.api_base.is_empty() {
return Err(Error::config(format!(
"Provider '{}' has an empty api_base",
name
)));
}
if !provider.api_base.starts_with("http://") && !provider.api_base.starts_with("https://") {
return Err(Error::config(format!(
"Provider '{}' api_base '{}' must start with http:// or https://",
name, provider.api_base
)));
}
if !provider.models.is_empty() && !provider.models.contains(&provider.default_model) {
return Err(Error::config(format!(
"Provider '{}' default_model '{}' is not listed in models: [{}]",
name,
provider.default_model,
provider.models.join(", ")
)));
}
Ok(())
}
impl AppConfig {
pub fn resolve(&self, model_name: Option<&str>) -> Result<AgentConfig> {
match model_name {
Some(name) => self.resolve_model(name),
None => self.resolve_default(),
}
}
fn resolve_default(&self) -> Result<AgentConfig> {
let provider = self.providers.get(&self.default_provider).ok_or_else(|| {
Error::config(format!(
"Default provider '{}' not found in config. Available providers: {}",
self.default_provider,
self.provider_names()
))
})?;
validate_provider(&self.default_provider, provider)?;
Ok(AgentConfig {
provider_name: self.default_provider.clone(),
api_base: provider.api_base.clone(),
api_key: provider.resolve_api_key(),
model: provider.default_model.clone(),
max_iterations: self.max_iterations,
timeout_secs: provider.timeout_secs.unwrap_or(120),
max_tokens: provider.max_tokens,
})
}
fn resolve_model(&self, model_name: &str) -> Result<AgentConfig> {
for (name, provider) in &self.providers {
if provider.models.contains(&model_name.to_string()) {
validate_provider(name, provider)?;
return Ok(AgentConfig {
provider_name: name.clone(),
api_base: provider.api_base.clone(),
api_key: provider.resolve_api_key(),
model: model_name.to_string(),
max_iterations: self.max_iterations,
timeout_secs: provider.timeout_secs.unwrap_or(120),
max_tokens: provider.max_tokens,
});
}
}
Err(Error::config(format!(
"Model '{}' not found in any provider. Check the [providers] section in your config file.",
model_name
)))
}
fn provider_names(&self) -> String {
self.providers
.keys()
.cloned()
.collect::<Vec<_>>()
.join(", ")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ProviderConfig;
use std::collections::BTreeMap;
fn sample_config() -> AppConfig {
let mut providers = BTreeMap::new();
providers.insert(
"openai".into(),
ProviderConfig {
api_base: "https://api.openai.com/v1".into(),
default_model: "gpt-4".into(),
models: vec!["gpt-4".into(), "gpt-3.5-turbo".into()],
env_var: None,
api_key: Some("test-key".into()),
timeout_secs: Some(60),
max_tokens: Some(4096),
},
);
providers.insert(
"anthropic".into(),
ProviderConfig {
api_base: "https://api.anthropic.com/v1".into(),
default_model: "claude-3".into(),
models: vec!["claude-3".into()],
env_var: None,
api_key: Some("anthropic-key".into()),
timeout_secs: None,
max_tokens: None,
},
);
AppConfig {
default_provider: "openai".into(),
max_iterations: 5,
providers,
mcp_servers: BTreeMap::new(),
}
}
#[test]
fn resolve_none_returns_default_provider() {
let config = sample_config();
let agent = config.resolve(None).unwrap();
assert_eq!(agent.provider_name, "openai");
assert_eq!(agent.model, "gpt-4");
assert_eq!(agent.api_key, "test-key");
assert_eq!(agent.max_iterations, 5);
assert_eq!(agent.timeout_secs, 60);
assert_eq!(agent.max_tokens, Some(4096));
}
#[test]
fn resolve_specific_model_finds_correct_provider() {
let config = sample_config();
let agent = config.resolve(Some("claude-3")).unwrap();
assert_eq!(agent.provider_name, "anthropic");
assert_eq!(agent.model, "claude-3");
assert_eq!(agent.api_key, "anthropic-key");
assert_eq!(agent.timeout_secs, 120); }
#[test]
fn resolve_unknown_model_returns_error() {
let config = sample_config();
let err = config.resolve(Some("unknown-model")).unwrap_err();
assert!(err.to_string().contains("unknown-model"));
}
#[test]
fn resolve_default_errors_when_provider_missing() {
let config = AppConfig {
default_provider: "nonexistent".into(),
max_iterations: 10,
providers: BTreeMap::new(),
mcp_servers: BTreeMap::new(),
};
let err = config.resolve(None).unwrap_err();
assert!(err.to_string().contains("nonexistent"));
}
fn provider_with_base(api_base: &str) -> ProviderConfig {
ProviderConfig {
api_base: api_base.into(),
default_model: "gpt-4".into(),
models: vec!["gpt-4".into()],
env_var: None,
api_key: Some("key".into()),
timeout_secs: None,
max_tokens: None,
}
}
#[test]
fn validate_rejects_empty_api_base() {
let p = provider_with_base("");
let err = validate_provider("test", &p).unwrap_err();
assert!(err.to_string().contains("empty api_base"));
}
#[test]
fn validate_rejects_non_http_api_base() {
let p = provider_with_base("ftp://example.com");
let err = validate_provider("test", &p).unwrap_err();
assert!(err.to_string().contains("http://") || err.to_string().contains("https://"));
}
#[test]
fn validate_rejects_default_model_not_in_models() {
let mut p = provider_with_base("https://api.example.com");
p.default_model = "gpt-5".into();
let err = validate_provider("test", &p).unwrap_err();
assert!(err.to_string().contains("gpt-5"));
}
#[test]
fn validate_accepts_valid_provider() {
let p = provider_with_base("https://api.example.com");
assert!(validate_provider("test", &p).is_ok());
}
}