use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use crate::controller::{CompactionConfig, LLMSessionConfig, ToolCompaction};
use serde::Deserialize;
pub trait AgentConfig {
fn config_path(&self) -> &str;
fn default_system_prompt(&self) -> &str;
fn log_prefix(&self) -> &str;
fn name(&self) -> &str;
fn channel_buffer_size(&self) -> Option<usize> {
None
}
}
pub struct SimpleConfig {
name: String,
config_path: String,
system_prompt: String,
log_prefix: String,
}
impl SimpleConfig {
pub fn new(
name: impl Into<String>,
config_path: impl Into<String>,
system_prompt: impl Into<String>,
) -> Self {
let name = name.into();
let log_prefix = name
.chars()
.map(|c| {
if c.is_alphanumeric() {
c.to_ascii_lowercase()
} else {
'_'
}
})
.collect();
Self {
name,
config_path: config_path.into(),
system_prompt: system_prompt.into(),
log_prefix,
}
}
}
impl AgentConfig for SimpleConfig {
fn config_path(&self) -> &str {
&self.config_path
}
fn default_system_prompt(&self) -> &str {
&self.system_prompt
}
fn log_prefix(&self) -> &str {
&self.log_prefix
}
fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug, Deserialize)]
pub struct ProviderConfig {
pub provider: String,
pub api_key: String,
#[serde(default)]
pub model: String,
}
#[derive(Debug, Deserialize)]
pub struct ConfigFile {
#[serde(default)]
pub providers: Vec<ProviderConfig>,
pub default_provider: Option<String>,
}
pub struct LLMRegistry {
configs: HashMap<String, LLMSessionConfig>,
default_provider: Option<String>,
}
impl LLMRegistry {
pub fn new() -> Self {
Self {
configs: HashMap::new(),
default_provider: None,
}
}
pub fn load_from_file(
path: &PathBuf,
default_system_prompt: &str,
) -> Result<Self, ConfigError> {
let content = fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
path: path.display().to_string(),
source: e.to_string(),
})?;
let config_file: ConfigFile =
serde_yaml::from_str(&content).map_err(|e| ConfigError::ParseError {
path: path.display().to_string(),
source: e.to_string(),
})?;
let mut registry = Self::new();
registry.default_provider = config_file.default_provider;
for provider_config in config_file.providers {
let session_config =
Self::create_session_config(&provider_config, default_system_prompt)?;
registry
.configs
.insert(provider_config.provider.clone(), session_config);
if registry.default_provider.is_none() {
registry.default_provider = Some(provider_config.provider);
}
}
Ok(registry)
}
fn create_session_config(
config: &ProviderConfig,
default_system_prompt: &str,
) -> Result<LLMSessionConfig, ConfigError> {
use super::providers::get_provider_info;
let provider_name = config.provider.to_lowercase();
let mut session_config = if let Some(info) = get_provider_info(&provider_name) {
let model = if config.model.is_empty() {
info.default_model.to_string()
} else {
config.model.clone()
};
LLMSessionConfig::openai_compatible(
&config.api_key,
&model,
info.base_url,
info.context_limit,
)
} else {
match provider_name.as_str() {
"anthropic" => {
let model = if config.model.is_empty() {
"claude-sonnet-4-20250514".to_string()
} else {
config.model.clone()
};
LLMSessionConfig::anthropic(&config.api_key, &model)
}
"openai" => {
let model = if config.model.is_empty() {
"gpt-4-turbo-preview".to_string()
} else {
config.model.clone()
};
LLMSessionConfig::openai(&config.api_key, &model)
}
"google" => {
let model = if config.model.is_empty() {
"gemini-2.5-flash".to_string()
} else {
config.model.clone()
};
LLMSessionConfig::google(&config.api_key, &model)
}
other => {
return Err(ConfigError::UnknownProvider {
provider: other.to_string(),
});
}
}
};
session_config = session_config.with_system_prompt(default_system_prompt);
session_config = session_config.with_threshold_compaction(CompactionConfig {
threshold: 0.05,
keep_recent_turns: 1,
tool_compaction: ToolCompaction::Summarize,
});
Ok(session_config)
}
pub fn get_default(&self) -> Option<&LLMSessionConfig> {
self.default_provider
.as_ref()
.and_then(|p| self.configs.get(p))
.or_else(|| self.configs.values().next())
}
pub fn get(&self, provider: &str) -> Option<&LLMSessionConfig> {
self.configs.get(provider)
}
pub fn default_provider_name(&self) -> Option<&str> {
self.default_provider.as_deref()
}
pub fn is_empty(&self) -> bool {
self.configs.is_empty()
}
pub fn providers(&self) -> Vec<&str> {
self.configs.keys().map(|s| s.as_str()).collect()
}
pub fn with_environment_context(mut self) -> Self {
use super::environment::EnvironmentContext;
let context = EnvironmentContext::gather();
let context_section = context.to_prompt_section();
for config in self.configs.values_mut() {
if let Some(ref prompt) = config.system_prompt {
config.system_prompt = Some(format!("{}\n\n{}", prompt, context_section));
} else {
config.system_prompt = Some(context_section.clone());
}
}
self
}
}
impl Default for LLMRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum ConfigError {
NoHomeDirectory,
ReadError { path: String, source: String },
ParseError { path: String, source: String },
UnknownProvider { provider: String },
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigError::NoHomeDirectory => write!(f, "Could not determine home directory"),
ConfigError::ReadError { path, source } => {
write!(f, "Failed to read config file '{}': {}", path, source)
}
ConfigError::ParseError { path, source } => {
write!(f, "Failed to parse config file '{}': {}", path, source)
}
ConfigError::UnknownProvider { provider } => {
write!(f, "Unknown provider: {}", provider)
}
}
}
}
impl std::error::Error for ConfigError {}
pub fn load_config<A: AgentConfig>(agent_config: &A) -> LLMRegistry {
let config_path = agent_config.config_path();
let default_prompt = agent_config.default_system_prompt();
let path = if let Some(rest) = config_path.strip_prefix("~/") {
match dirs::home_dir() {
Some(home) => home.join(rest),
None => {
tracing::debug!("Could not determine home directory");
PathBuf::from(config_path)
}
}
} else {
PathBuf::from(config_path)
};
match LLMRegistry::load_from_file(&path, default_prompt) {
Ok(registry) if !registry.is_empty() => {
tracing::info!("Loaded configuration from {}", path.display());
return registry;
}
Ok(_) => {
tracing::debug!("Config file empty, trying environment variables");
}
Err(e) => {
tracing::debug!("Could not load config file: {}", e);
}
}
let mut registry = LLMRegistry::new();
let compaction = CompactionConfig {
threshold: 0.05,
keep_recent_turns: 1,
tool_compaction: ToolCompaction::Summarize,
};
if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
let model = std::env::var("ANTHROPIC_MODEL")
.unwrap_or_else(|_| "claude-sonnet-4-20250514".to_string());
let config = LLMSessionConfig::anthropic(&api_key, &model)
.with_system_prompt(default_prompt)
.with_threshold_compaction(compaction.clone());
registry.configs.insert("anthropic".to_string(), config);
registry.default_provider = Some("anthropic".to_string());
tracing::info!("Loaded Anthropic configuration from environment");
}
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
let model =
std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4-turbo-preview".to_string());
let config = LLMSessionConfig::openai(&api_key, &model)
.with_system_prompt(default_prompt)
.with_threshold_compaction(compaction.clone());
registry.configs.insert("openai".to_string(), config);
if registry.default_provider.is_none() {
registry.default_provider = Some("openai".to_string());
}
tracing::info!("Loaded OpenAI configuration from environment");
}
if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
let model =
std::env::var("GOOGLE_MODEL").unwrap_or_else(|_| "gemini-2.5-flash".to_string());
let config = LLMSessionConfig::google(&api_key, &model)
.with_system_prompt(default_prompt)
.with_threshold_compaction(compaction.clone());
registry.configs.insert("google".to_string(), config);
if registry.default_provider.is_none() {
registry.default_provider = Some("google".to_string());
}
tracing::info!("Loaded Google (Gemini) configuration from environment");
}
for (name, info) in super::providers::KNOWN_PROVIDERS {
let api_key = if info.requires_api_key {
match std::env::var(info.env_var) {
Ok(key) if !key.is_empty() => key,
_ => continue, }
} else {
if std::env::var(info.env_var).is_err() {
continue;
}
String::new() };
let model =
std::env::var(info.model_env_var).unwrap_or_else(|_| info.default_model.to_string());
let config = LLMSessionConfig::openai_compatible(
&api_key,
&model,
info.base_url,
info.context_limit,
)
.with_system_prompt(default_prompt)
.with_threshold_compaction(compaction.clone());
registry.configs.insert(name.to_string(), config);
if registry.default_provider.is_none() {
registry.default_provider = Some(name.to_string());
}
tracing::info!("Loaded {} configuration from environment", info.name);
}
registry
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_config() {
let yaml = r#"
providers:
- provider: anthropic
api_key: test-key
model: claude-sonnet-4-20250514
default_provider: anthropic
"#;
let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.providers.len(), 1);
assert_eq!(config.providers[0].provider, "anthropic");
assert_eq!(config.default_provider, Some("anthropic".to_string()));
}
#[test]
fn test_parse_known_provider() {
let yaml = r#"
providers:
- provider: groq
api_key: gsk_test_key
model: llama-3.3-70b-versatile
"#;
let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.providers.len(), 1);
assert_eq!(config.providers[0].provider, "groq");
}
#[test]
fn test_known_provider_default_model() {
let provider_config = ProviderConfig {
provider: "groq".to_string(),
api_key: "test-key".to_string(),
model: String::new(), };
let session_config =
LLMRegistry::create_session_config(&provider_config, "test prompt").unwrap();
assert_eq!(session_config.model, "llama-3.3-70b-versatile");
assert!(session_config.base_url.is_some());
assert!(
session_config
.base_url
.as_ref()
.unwrap()
.contains("groq.com")
);
}
#[test]
fn test_empty_registry() {
let registry = LLMRegistry::new();
assert!(registry.is_empty());
assert!(registry.get_default().is_none());
}
}