1use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9use crate::controller::{CompactionConfig, LLMSessionConfig, ToolCompaction};
10use serde::Deserialize;
11
12pub trait AgentConfig {
17 fn config_path(&self) -> &str;
22
23 fn default_system_prompt(&self) -> &str;
25
26 fn log_prefix(&self) -> &str;
28
29 fn name(&self) -> &str;
31
32 fn channel_buffer_size(&self) -> Option<usize> {
41 None
42 }
43}
44
45pub struct SimpleConfig {
60 name: String,
61 config_path: String,
62 system_prompt: String,
63 log_prefix: String,
64}
65
66impl SimpleConfig {
67 pub fn new(
74 name: impl Into<String>,
75 config_path: impl Into<String>,
76 system_prompt: impl Into<String>,
77 ) -> Self {
78 let name = name.into();
79 let log_prefix = name
81 .chars()
82 .map(|c| {
83 if c.is_alphanumeric() {
84 c.to_ascii_lowercase()
85 } else {
86 '_'
87 }
88 })
89 .collect();
90
91 Self {
92 name,
93 config_path: config_path.into(),
94 system_prompt: system_prompt.into(),
95 log_prefix,
96 }
97 }
98}
99
100impl AgentConfig for SimpleConfig {
101 fn config_path(&self) -> &str {
102 &self.config_path
103 }
104
105 fn default_system_prompt(&self) -> &str {
106 &self.system_prompt
107 }
108
109 fn log_prefix(&self) -> &str {
110 &self.log_prefix
111 }
112
113 fn name(&self) -> &str {
114 &self.name
115 }
116}
117
118#[derive(Debug, Deserialize)]
138pub struct ProviderConfig {
139 pub provider: String,
141 pub api_key: String,
143 #[serde(default)]
145 pub model: String,
146}
147
148#[derive(Debug, Deserialize)]
150pub struct ConfigFile {
151 #[serde(default)]
153 pub providers: Vec<ProviderConfig>,
154
155 pub default_provider: Option<String>,
157}
158
159pub struct LLMRegistry {
161 configs: HashMap<String, LLMSessionConfig>,
162 default_provider: Option<String>,
163}
164
165impl LLMRegistry {
166 pub fn new() -> Self {
168 Self {
169 configs: HashMap::new(),
170 default_provider: None,
171 }
172 }
173
174 pub fn load_from_file(
176 path: &PathBuf,
177 default_system_prompt: &str,
178 ) -> Result<Self, ConfigError> {
179 let content = fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
180 path: path.display().to_string(),
181 source: e.to_string(),
182 })?;
183
184 let config_file: ConfigFile =
185 serde_yaml::from_str(&content).map_err(|e| ConfigError::ParseError {
186 path: path.display().to_string(),
187 source: e.to_string(),
188 })?;
189
190 let mut registry = Self::new();
191 registry.default_provider = config_file.default_provider;
192
193 for provider_config in config_file.providers {
194 let session_config =
195 Self::create_session_config(&provider_config, default_system_prompt)?;
196 registry
197 .configs
198 .insert(provider_config.provider.clone(), session_config);
199
200 if registry.default_provider.is_none() {
202 registry.default_provider = Some(provider_config.provider);
203 }
204 }
205
206 Ok(registry)
207 }
208
209 fn create_session_config(
211 config: &ProviderConfig,
212 default_system_prompt: &str,
213 ) -> Result<LLMSessionConfig, ConfigError> {
214 use super::providers::get_provider_info;
215
216 let provider_name = config.provider.to_lowercase();
217
218 let mut session_config = if let Some(info) = get_provider_info(&provider_name) {
220 let model = if config.model.is_empty() {
222 info.default_model.to_string()
223 } else {
224 config.model.clone()
225 };
226
227 LLMSessionConfig::openai_compatible(
228 &config.api_key,
229 &model,
230 info.base_url,
231 info.context_limit,
232 )
233 } else {
234 match provider_name.as_str() {
236 "anthropic" => {
237 let model = if config.model.is_empty() {
238 "claude-sonnet-4-20250514".to_string()
239 } else {
240 config.model.clone()
241 };
242 LLMSessionConfig::anthropic(&config.api_key, &model)
243 }
244 "openai" => {
245 let model = if config.model.is_empty() {
246 "gpt-4-turbo-preview".to_string()
247 } else {
248 config.model.clone()
249 };
250 LLMSessionConfig::openai(&config.api_key, &model)
251 }
252 "google" => {
253 let model = if config.model.is_empty() {
254 "gemini-2.5-flash".to_string()
255 } else {
256 config.model.clone()
257 };
258 LLMSessionConfig::google(&config.api_key, &model)
259 }
260 other => {
261 return Err(ConfigError::UnknownProvider {
262 provider: other.to_string(),
263 });
264 }
265 }
266 };
267
268 session_config = session_config.with_system_prompt(default_system_prompt);
270
271 session_config = session_config.with_threshold_compaction(CompactionConfig {
276 threshold: 0.05,
277 keep_recent_turns: 1,
278 tool_compaction: ToolCompaction::Summarize,
279 });
280
281 Ok(session_config)
282 }
283
284 pub fn get_default(&self) -> Option<&LLMSessionConfig> {
286 self.default_provider
287 .as_ref()
288 .and_then(|p| self.configs.get(p))
289 .or_else(|| self.configs.values().next())
290 }
291
292 pub fn get(&self, provider: &str) -> Option<&LLMSessionConfig> {
294 self.configs.get(provider)
295 }
296
297 pub fn default_provider_name(&self) -> Option<&str> {
299 self.default_provider.as_deref()
300 }
301
302 pub fn is_empty(&self) -> bool {
304 self.configs.is_empty()
305 }
306
307 pub fn providers(&self) -> Vec<&str> {
309 self.configs.keys().map(|s| s.as_str()).collect()
310 }
311
312 pub fn with_environment_context(mut self) -> Self {
324 use super::environment::EnvironmentContext;
325
326 let context = EnvironmentContext::gather();
327 let context_section = context.to_prompt_section();
328
329 for config in self.configs.values_mut() {
330 if let Some(ref prompt) = config.system_prompt {
331 config.system_prompt = Some(format!("{}\n\n{}", prompt, context_section));
332 } else {
333 config.system_prompt = Some(context_section.clone());
334 }
335 }
336
337 self
338 }
339}
340
341impl Default for LLMRegistry {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347#[derive(Debug)]
349pub enum ConfigError {
350 NoHomeDirectory,
352 ReadError { path: String, source: String },
354 ParseError { path: String, source: String },
356 UnknownProvider { provider: String },
358}
359
360impl std::fmt::Display for ConfigError {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 match self {
363 ConfigError::NoHomeDirectory => write!(f, "Could not determine home directory"),
364 ConfigError::ReadError { path, source } => {
365 write!(f, "Failed to read config file '{}': {}", path, source)
366 }
367 ConfigError::ParseError { path, source } => {
368 write!(f, "Failed to parse config file '{}': {}", path, source)
369 }
370 ConfigError::UnknownProvider { provider } => {
371 write!(f, "Unknown provider: {}", provider)
372 }
373 }
374 }
375}
376
377impl std::error::Error for ConfigError {}
378
379pub fn load_config<A: AgentConfig>(agent_config: &A) -> LLMRegistry {
384 let config_path = agent_config.config_path();
385 let default_prompt = agent_config.default_system_prompt();
386
387 let path = if let Some(rest) = config_path.strip_prefix("~/") {
389 match dirs::home_dir() {
390 Some(home) => home.join(rest),
391 None => {
392 tracing::debug!("Could not determine home directory");
393 PathBuf::from(config_path)
394 }
395 }
396 } else {
397 PathBuf::from(config_path)
398 };
399
400 match LLMRegistry::load_from_file(&path, default_prompt) {
402 Ok(registry) if !registry.is_empty() => {
403 tracing::info!("Loaded configuration from {}", path.display());
404 return registry;
405 }
406 Ok(_) => {
407 tracing::debug!("Config file empty, trying environment variables");
408 }
409 Err(e) => {
410 tracing::debug!("Could not load config file: {}", e);
411 }
412 }
413
414 let mut registry = LLMRegistry::new();
416
417 let compaction = CompactionConfig {
419 threshold: 0.05,
420 keep_recent_turns: 1,
421 tool_compaction: ToolCompaction::Summarize,
422 };
423
424 if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
425 let model = std::env::var("ANTHROPIC_MODEL")
426 .unwrap_or_else(|_| "claude-sonnet-4-20250514".to_string());
427
428 let config = LLMSessionConfig::anthropic(&api_key, &model)
429 .with_system_prompt(default_prompt)
430 .with_threshold_compaction(compaction.clone());
431
432 registry.configs.insert("anthropic".to_string(), config);
433 registry.default_provider = Some("anthropic".to_string());
434
435 tracing::info!("Loaded Anthropic configuration from environment");
436 }
437
438 if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
439 let model =
440 std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4-turbo-preview".to_string());
441
442 let config = LLMSessionConfig::openai(&api_key, &model)
443 .with_system_prompt(default_prompt)
444 .with_threshold_compaction(compaction.clone());
445
446 registry.configs.insert("openai".to_string(), config);
447 if registry.default_provider.is_none() {
448 registry.default_provider = Some("openai".to_string());
449 }
450
451 tracing::info!("Loaded OpenAI configuration from environment");
452 }
453
454 if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
455 let model =
456 std::env::var("GOOGLE_MODEL").unwrap_or_else(|_| "gemini-2.5-flash".to_string());
457
458 let config = LLMSessionConfig::google(&api_key, &model)
459 .with_system_prompt(default_prompt)
460 .with_threshold_compaction(compaction.clone());
461
462 registry.configs.insert("google".to_string(), config);
463 if registry.default_provider.is_none() {
464 registry.default_provider = Some("google".to_string());
465 }
466
467 tracing::info!("Loaded Google (Gemini) configuration from environment");
468 }
469
470 for (name, info) in super::providers::KNOWN_PROVIDERS {
472 let api_key = if info.requires_api_key {
475 match std::env::var(info.env_var) {
476 Ok(key) if !key.is_empty() => key,
477 _ => continue, }
479 } else {
480 if std::env::var(info.env_var).is_err() {
482 continue;
483 }
484 String::new() };
486
487 let model =
488 std::env::var(info.model_env_var).unwrap_or_else(|_| info.default_model.to_string());
489
490 let config = LLMSessionConfig::openai_compatible(
491 &api_key,
492 &model,
493 info.base_url,
494 info.context_limit,
495 )
496 .with_system_prompt(default_prompt)
497 .with_threshold_compaction(compaction.clone());
498
499 registry.configs.insert(name.to_string(), config);
500 if registry.default_provider.is_none() {
501 registry.default_provider = Some(name.to_string());
502 }
503
504 tracing::info!("Loaded {} configuration from environment", info.name);
505 }
506
507 registry
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_parse_config() {
516 let yaml = r#"
517providers:
518 - provider: anthropic
519 api_key: test-key
520 model: claude-sonnet-4-20250514
521default_provider: anthropic
522"#;
523 let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
524 assert_eq!(config.providers.len(), 1);
525 assert_eq!(config.providers[0].provider, "anthropic");
526 assert_eq!(config.default_provider, Some("anthropic".to_string()));
527 }
528
529 #[test]
530 fn test_parse_known_provider() {
531 let yaml = r#"
532providers:
533 - provider: groq
534 api_key: gsk_test_key
535 model: llama-3.3-70b-versatile
536"#;
537 let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
538 assert_eq!(config.providers.len(), 1);
539 assert_eq!(config.providers[0].provider, "groq");
540 }
541
542 #[test]
543 fn test_known_provider_default_model() {
544 let provider_config = ProviderConfig {
546 provider: "groq".to_string(),
547 api_key: "test-key".to_string(),
548 model: String::new(), };
550
551 let session_config =
552 LLMRegistry::create_session_config(&provider_config, "test prompt").unwrap();
553 assert_eq!(session_config.model, "llama-3.3-70b-versatile");
555 assert!(session_config.base_url.is_some());
557 assert!(
558 session_config
559 .base_url
560 .as_ref()
561 .unwrap()
562 .contains("groq.com")
563 );
564 }
565
566 #[test]
567 fn test_empty_registry() {
568 let registry = LLMRegistry::new();
569 assert!(registry.is_empty());
570 assert!(registry.get_default().is_none());
571 }
572}