1use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9use crate::controller::{CompactionConfig, LLMSessionConfig, ToolCompaction};
10use serde::Deserialize;
11
12pub trait AgentConfig {
17 fn config_path(&self) -> &str;
22
23 fn default_system_prompt(&self) -> &str;
25
26 fn log_prefix(&self) -> &str;
28
29 fn name(&self) -> &str;
31
32 fn channel_buffer_size(&self) -> Option<usize> {
41 None
42 }
43}
44
45#[derive(Debug, Deserialize)]
65pub struct ProviderConfig {
66 pub provider: String,
68 pub api_key: String,
70 #[serde(default)]
72 pub model: String,
73 pub system_prompt: Option<String>,
75}
76
77#[derive(Debug, Deserialize)]
79pub struct ConfigFile {
80 #[serde(default)]
82 pub providers: Vec<ProviderConfig>,
83
84 pub default_provider: Option<String>,
86}
87
88pub struct LLMRegistry {
90 configs: HashMap<String, LLMSessionConfig>,
91 default_provider: Option<String>,
92}
93
94impl LLMRegistry {
95 pub fn new() -> Self {
97 Self {
98 configs: HashMap::new(),
99 default_provider: None,
100 }
101 }
102
103 pub fn load_from_file(path: &PathBuf, default_system_prompt: &str) -> Result<Self, ConfigError> {
105 let content = fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
106 path: path.display().to_string(),
107 source: e.to_string(),
108 })?;
109
110 let config_file: ConfigFile =
111 serde_yaml::from_str(&content).map_err(|e| ConfigError::ParseError {
112 path: path.display().to_string(),
113 source: e.to_string(),
114 })?;
115
116 let mut registry = Self::new();
117 registry.default_provider = config_file.default_provider;
118
119 for provider_config in config_file.providers {
120 let session_config = Self::create_session_config(&provider_config, default_system_prompt)?;
121 registry
122 .configs
123 .insert(provider_config.provider.clone(), session_config);
124
125 if registry.default_provider.is_none() {
127 registry.default_provider = Some(provider_config.provider);
128 }
129 }
130
131 Ok(registry)
132 }
133
134 fn create_session_config(config: &ProviderConfig, default_system_prompt: &str) -> Result<LLMSessionConfig, ConfigError> {
136 use super::providers::get_provider_info;
137
138 let provider_name = config.provider.to_lowercase();
139
140 let mut session_config = if let Some(info) = get_provider_info(&provider_name) {
142 let model = if config.model.is_empty() {
144 info.default_model.to_string()
145 } else {
146 config.model.clone()
147 };
148
149 LLMSessionConfig::openai_compatible(
150 &config.api_key,
151 &model,
152 info.base_url,
153 info.context_limit,
154 )
155 } else {
156 match provider_name.as_str() {
158 "anthropic" => {
159 let model = if config.model.is_empty() {
160 "claude-sonnet-4-20250514".to_string()
161 } else {
162 config.model.clone()
163 };
164 LLMSessionConfig::anthropic(&config.api_key, &model)
165 }
166 "openai" => {
167 let model = if config.model.is_empty() {
168 "gpt-4-turbo-preview".to_string()
169 } else {
170 config.model.clone()
171 };
172 LLMSessionConfig::openai(&config.api_key, &model)
173 }
174 "google" => {
175 let model = if config.model.is_empty() {
176 "gemini-2.5-flash".to_string()
177 } else {
178 config.model.clone()
179 };
180 LLMSessionConfig::google(&config.api_key, &model)
181 }
182 other => {
183 return Err(ConfigError::UnknownProvider {
184 provider: other.to_string(),
185 })
186 }
187 }
188 };
189
190 let system_prompt = config
192 .system_prompt
193 .clone()
194 .unwrap_or_else(|| default_system_prompt.to_string());
195 session_config = session_config.with_system_prompt(system_prompt);
196
197 session_config = session_config.with_threshold_compaction(CompactionConfig {
202 threshold: 0.05,
203 keep_recent_turns: 1,
204 tool_compaction: ToolCompaction::Summarize,
205 });
206
207 Ok(session_config)
208 }
209
210 pub fn get_default(&self) -> Option<&LLMSessionConfig> {
212 self.default_provider
213 .as_ref()
214 .and_then(|p| self.configs.get(p))
215 .or_else(|| self.configs.values().next())
216 }
217
218 pub fn get(&self, provider: &str) -> Option<&LLMSessionConfig> {
220 self.configs.get(provider)
221 }
222
223 pub fn default_provider_name(&self) -> Option<&str> {
225 self.default_provider.as_deref()
226 }
227
228 pub fn is_empty(&self) -> bool {
230 self.configs.is_empty()
231 }
232
233 pub fn providers(&self) -> Vec<&str> {
235 self.configs.keys().map(|s| s.as_str()).collect()
236 }
237
238 pub fn with_environment_context(mut self) -> Self {
250 use super::environment::EnvironmentContext;
251
252 let context = EnvironmentContext::gather();
253 let context_section = context.to_prompt_section();
254
255 for config in self.configs.values_mut() {
256 if let Some(ref prompt) = config.system_prompt {
257 config.system_prompt = Some(format!("{}\n\n{}", prompt, context_section));
258 } else {
259 config.system_prompt = Some(context_section.clone());
260 }
261 }
262
263 self
264 }
265}
266
267impl Default for LLMRegistry {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[derive(Debug)]
275pub enum ConfigError {
276 NoHomeDirectory,
278 ReadError { path: String, source: String },
280 ParseError { path: String, source: String },
282 UnknownProvider { provider: String },
284}
285
286impl std::fmt::Display for ConfigError {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 match self {
289 ConfigError::NoHomeDirectory => write!(f, "Could not determine home directory"),
290 ConfigError::ReadError { path, source } => {
291 write!(f, "Failed to read config file '{}': {}", path, source)
292 }
293 ConfigError::ParseError { path, source } => {
294 write!(f, "Failed to parse config file '{}': {}", path, source)
295 }
296 ConfigError::UnknownProvider { provider } => {
297 write!(f, "Unknown provider: {}", provider)
298 }
299 }
300 }
301}
302
303impl std::error::Error for ConfigError {}
304
305pub fn load_config<A: AgentConfig>(agent_config: &A) -> LLMRegistry {
310 let config_path = agent_config.config_path();
311 let default_prompt = agent_config.default_system_prompt();
312
313 let path = if let Some(rest) = config_path.strip_prefix("~/") {
315 match dirs::home_dir() {
316 Some(home) => home.join(rest),
317 None => {
318 tracing::debug!("Could not determine home directory");
319 PathBuf::from(config_path)
320 }
321 }
322 } else {
323 PathBuf::from(config_path)
324 };
325
326 match LLMRegistry::load_from_file(&path, default_prompt) {
328 Ok(registry) if !registry.is_empty() => {
329 tracing::info!("Loaded configuration from {}", path.display());
330 return registry;
331 }
332 Ok(_) => {
333 tracing::debug!("Config file empty, trying environment variables");
334 }
335 Err(e) => {
336 tracing::debug!("Could not load config file: {}", e);
337 }
338 }
339
340 let mut registry = LLMRegistry::new();
342
343 let compaction = CompactionConfig {
345 threshold: 0.05,
346 keep_recent_turns: 1,
347 tool_compaction: ToolCompaction::Summarize,
348 };
349
350 if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
351 let model = std::env::var("ANTHROPIC_MODEL")
352 .unwrap_or_else(|_| "claude-sonnet-4-20250514".to_string());
353
354 let config = LLMSessionConfig::anthropic(&api_key, &model)
355 .with_system_prompt(default_prompt)
356 .with_threshold_compaction(compaction.clone());
357
358 registry.configs.insert("anthropic".to_string(), config);
359 registry.default_provider = Some("anthropic".to_string());
360
361 tracing::info!("Loaded Anthropic configuration from environment");
362 }
363
364 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
365 let model =
366 std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4-turbo-preview".to_string());
367
368 let config = LLMSessionConfig::openai(&api_key, &model)
369 .with_system_prompt(default_prompt)
370 .with_threshold_compaction(compaction.clone());
371
372 registry.configs.insert("openai".to_string(), config);
373 if registry.default_provider.is_none() {
374 registry.default_provider = Some("openai".to_string());
375 }
376
377 tracing::info!("Loaded OpenAI configuration from environment");
378 }
379
380 if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
381 let model =
382 std::env::var("GOOGLE_MODEL").unwrap_or_else(|_| "gemini-2.5-flash".to_string());
383
384 let config = LLMSessionConfig::google(&api_key, &model)
385 .with_system_prompt(default_prompt)
386 .with_threshold_compaction(compaction.clone());
387
388 registry.configs.insert("google".to_string(), config);
389 if registry.default_provider.is_none() {
390 registry.default_provider = Some("google".to_string());
391 }
392
393 tracing::info!("Loaded Google (Gemini) configuration from environment");
394 }
395
396 for (name, info) in super::providers::KNOWN_PROVIDERS {
398 let api_key = if info.requires_api_key {
401 match std::env::var(info.env_var) {
402 Ok(key) if !key.is_empty() => key,
403 _ => continue, }
405 } else {
406 if std::env::var(info.env_var).is_err() {
408 continue;
409 }
410 String::new() };
412
413 let model =
414 std::env::var(info.model_env_var).unwrap_or_else(|_| info.default_model.to_string());
415
416 let config = LLMSessionConfig::openai_compatible(&api_key, &model, info.base_url, info.context_limit)
417 .with_system_prompt(default_prompt)
418 .with_threshold_compaction(compaction.clone());
419
420 registry.configs.insert(name.to_string(), config);
421 if registry.default_provider.is_none() {
422 registry.default_provider = Some(name.to_string());
423 }
424
425 tracing::info!("Loaded {} configuration from environment", info.name);
426 }
427
428 registry
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_parse_config() {
437 let yaml = r#"
438providers:
439 - provider: anthropic
440 api_key: test-key
441 model: claude-sonnet-4-20250514
442default_provider: anthropic
443"#;
444 let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
445 assert_eq!(config.providers.len(), 1);
446 assert_eq!(config.providers[0].provider, "anthropic");
447 assert_eq!(config.default_provider, Some("anthropic".to_string()));
448 }
449
450 #[test]
451 fn test_parse_known_provider() {
452 let yaml = r#"
453providers:
454 - provider: groq
455 api_key: gsk_test_key
456 model: llama-3.3-70b-versatile
457"#;
458 let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
459 assert_eq!(config.providers.len(), 1);
460 assert_eq!(config.providers[0].provider, "groq");
461 }
462
463 #[test]
464 fn test_known_provider_default_model() {
465 let provider_config = ProviderConfig {
467 provider: "groq".to_string(),
468 api_key: "test-key".to_string(),
469 model: String::new(), system_prompt: None,
471 };
472
473 let session_config = LLMRegistry::create_session_config(&provider_config, "test prompt").unwrap();
474 assert_eq!(session_config.model, "llama-3.3-70b-versatile");
476 assert!(session_config.base_url.is_some());
478 assert!(session_config.base_url.as_ref().unwrap().contains("groq.com"));
479 }
480
481 #[test]
482 fn test_empty_registry() {
483 let registry = LLMRegistry::new();
484 assert!(registry.is_empty());
485 assert!(registry.get_default().is_none());
486 }
487}