Skip to main content

agent_air_runtime/agent/
config.rs

1// Configuration management for LLM agents
2//
3// Provides trait-based customization for config paths and system prompts.
4
5use std::collections::HashMap;
6use std::fs;
7use std::path::PathBuf;
8
9use crate::controller::{CompactionConfig, LLMSessionConfig, ToolCompaction};
10use serde::Deserialize;
11
12/// Trait for agent-specific configuration.
13///
14/// Implement this trait to provide custom config paths and system prompts
15/// for your agent.
16pub trait AgentConfig {
17    /// The config file path.
18    ///
19    /// Paths starting with `~/` are expanded to the home directory.
20    /// All other paths (absolute or relative) are used as-is.
21    fn config_path(&self) -> &str;
22
23    /// The default system prompt for this agent
24    fn default_system_prompt(&self) -> &str;
25
26    /// The log file prefix for this agent (e.g., "multi_code", "europa")
27    fn log_prefix(&self) -> &str;
28
29    /// Agent name for display and logging
30    fn name(&self) -> &str;
31
32    /// Channel buffer size for internal communication channels.
33    ///
34    /// Returns None to use the default (500). Override to customize
35    /// the buffer size for all async channels (LLM responses, tool results,
36    /// UI events, etc.).
37    ///
38    /// Larger values reduce backpressure but use more memory.
39    /// Smaller values provide tighter flow control.
40    fn channel_buffer_size(&self) -> Option<usize> {
41        None
42    }
43}
44
45/// A simple configuration for quick agent setup.
46///
47/// Use this when you don't need a custom config struct. Created via
48/// `AgentAir::with_config()`.
49///
50/// # Example
51///
52/// ```ignore
53/// let agent = AgentAir::with_config(
54///     "my-agent",
55///     "~/.config/my-agent/config.yaml",
56///     "You are a helpful assistant."
57/// )?;
58/// ```
59pub struct SimpleConfig {
60    name: String,
61    config_path: String,
62    system_prompt: String,
63    log_prefix: String,
64}
65
66impl SimpleConfig {
67    /// Create a new simple configuration.
68    ///
69    /// # Arguments
70    /// * `name` - Agent name for display (e.g., "my-agent")
71    /// * `config_path` - Path to config file (e.g., "~/.config/my-agent/config.yaml")
72    /// * `system_prompt` - Default system prompt for the agent
73    pub fn new(
74        name: impl Into<String>,
75        config_path: impl Into<String>,
76        system_prompt: impl Into<String>,
77    ) -> Self {
78        let name = name.into();
79        // Derive log prefix from name: lowercase, replace non-alphanumeric with underscores
80        let log_prefix = name
81            .chars()
82            .map(|c| {
83                if c.is_alphanumeric() {
84                    c.to_ascii_lowercase()
85                } else {
86                    '_'
87                }
88            })
89            .collect();
90
91        Self {
92            name,
93            config_path: config_path.into(),
94            system_prompt: system_prompt.into(),
95            log_prefix,
96        }
97    }
98}
99
100impl AgentConfig for SimpleConfig {
101    fn config_path(&self) -> &str {
102        &self.config_path
103    }
104
105    fn default_system_prompt(&self) -> &str {
106        &self.system_prompt
107    }
108
109    fn log_prefix(&self) -> &str {
110        &self.log_prefix
111    }
112
113    fn name(&self) -> &str {
114        &self.name
115    }
116}
117
118/// Provider configuration from YAML
119///
120/// Supported providers:
121/// - `anthropic` - Anthropic Claude models
122/// - `openai` - OpenAI GPT models
123/// - `google` - Google Gemini models
124/// - `groq` - Groq (Llama, Mixtral)
125/// - `together` - Together AI
126/// - `fireworks` - Fireworks AI
127/// - `mistral` - Mistral AI
128/// - `perplexity` - Perplexity
129/// - `deepseek` - DeepSeek
130/// - `openrouter` - OpenRouter (access to multiple providers)
131/// - `ollama` - Local Ollama server
132/// - `lmstudio` - Local LM Studio server
133/// - `anyscale` - Anyscale Endpoints
134/// - `cerebras` - Cerebras
135/// - `sambanova` - SambaNova
136/// - `xai` - xAI (Grok)
137#[derive(Debug, Deserialize)]
138pub struct ProviderConfig {
139    /// Provider name (see above for supported values)
140    pub provider: String,
141    /// API token/key
142    pub api_key: String,
143    /// Model identifier (optional - uses provider default if not specified)
144    #[serde(default)]
145    pub model: String,
146}
147
148/// Root configuration structure from YAML
149#[derive(Debug, Deserialize)]
150pub struct ConfigFile {
151    /// List of LLM provider configurations
152    #[serde(default)]
153    pub providers: Vec<ProviderConfig>,
154
155    /// Default provider to use (optional, defaults to first provider)
156    pub default_provider: Option<String>,
157}
158
159/// LLM Registry - stores loaded provider configurations
160pub struct LLMRegistry {
161    configs: HashMap<String, LLMSessionConfig>,
162    default_provider: Option<String>,
163}
164
165impl LLMRegistry {
166    /// Creates an empty registry
167    pub fn new() -> Self {
168        Self {
169            configs: HashMap::new(),
170            default_provider: None,
171        }
172    }
173
174    /// Load configuration from the specified config file path
175    pub fn load_from_file(
176        path: &PathBuf,
177        default_system_prompt: &str,
178    ) -> Result<Self, ConfigError> {
179        let content = fs::read_to_string(path).map_err(|e| ConfigError::ReadError {
180            path: path.display().to_string(),
181            source: e.to_string(),
182        })?;
183
184        let config_file: ConfigFile =
185            serde_yaml::from_str(&content).map_err(|e| ConfigError::ParseError {
186                path: path.display().to_string(),
187                source: e.to_string(),
188            })?;
189
190        let mut registry = Self::new();
191        registry.default_provider = config_file.default_provider;
192
193        for provider_config in config_file.providers {
194            let session_config =
195                Self::create_session_config(&provider_config, default_system_prompt)?;
196            registry
197                .configs
198                .insert(provider_config.provider.clone(), session_config);
199
200            // Set first provider as default if not specified
201            if registry.default_provider.is_none() {
202                registry.default_provider = Some(provider_config.provider);
203            }
204        }
205
206        Ok(registry)
207    }
208
209    /// Create session config from provider config
210    fn create_session_config(
211        config: &ProviderConfig,
212        default_system_prompt: &str,
213    ) -> Result<LLMSessionConfig, ConfigError> {
214        use super::providers::get_provider_info;
215
216        let provider_name = config.provider.to_lowercase();
217
218        // Check if it's a known OpenAI-compatible provider
219        let mut session_config = if let Some(info) = get_provider_info(&provider_name) {
220            // Use model from config, or fall back to provider default
221            let model = if config.model.is_empty() {
222                info.default_model.to_string()
223            } else {
224                config.model.clone()
225            };
226
227            LLMSessionConfig::openai_compatible(
228                &config.api_key,
229                &model,
230                info.base_url,
231                info.context_limit,
232            )
233        } else {
234            // Handle built-in providers
235            match provider_name.as_str() {
236                "anthropic" => {
237                    let model = if config.model.is_empty() {
238                        "claude-sonnet-4-20250514".to_string()
239                    } else {
240                        config.model.clone()
241                    };
242                    LLMSessionConfig::anthropic(&config.api_key, &model)
243                }
244                "openai" => {
245                    let model = if config.model.is_empty() {
246                        "gpt-4-turbo-preview".to_string()
247                    } else {
248                        config.model.clone()
249                    };
250                    LLMSessionConfig::openai(&config.api_key, &model)
251                }
252                "google" => {
253                    let model = if config.model.is_empty() {
254                        "gemini-2.5-flash".to_string()
255                    } else {
256                        config.model.clone()
257                    };
258                    LLMSessionConfig::google(&config.api_key, &model)
259                }
260                other => {
261                    return Err(ConfigError::UnknownProvider {
262                        provider: other.to_string(),
263                    });
264                }
265            }
266        };
267
268        // Set system prompt from AgentConfig default
269        session_config = session_config.with_system_prompt(default_system_prompt);
270
271        // Configure aggressive compaction to avoid rate limits
272        // With 0.05 threshold on 200K context = 10K tokens triggers compaction
273        // keep_recent_turns=1 means only current turn keeps full tool results
274        // All previous tool results are summarized to compact strings
275        session_config = session_config.with_threshold_compaction(CompactionConfig {
276            threshold: 0.05,
277            keep_recent_turns: 1,
278            tool_compaction: ToolCompaction::Summarize,
279        });
280
281        Ok(session_config)
282    }
283
284    /// Get the default session config
285    pub fn get_default(&self) -> Option<&LLMSessionConfig> {
286        self.default_provider
287            .as_ref()
288            .and_then(|p| self.configs.get(p))
289            .or_else(|| self.configs.values().next())
290    }
291
292    /// Get session config by provider name
293    pub fn get(&self, provider: &str) -> Option<&LLMSessionConfig> {
294        self.configs.get(provider)
295    }
296
297    /// Get the default provider name
298    pub fn default_provider_name(&self) -> Option<&str> {
299        self.default_provider.as_deref()
300    }
301
302    /// Check if registry is empty
303    pub fn is_empty(&self) -> bool {
304        self.configs.is_empty()
305    }
306
307    /// Get list of available providers
308    pub fn providers(&self) -> Vec<&str> {
309        self.configs.keys().map(|s| s.as_str()).collect()
310    }
311
312    /// Inject environment context into all session prompts.
313    ///
314    /// This appends environment information (working directory, platform, date)
315    /// to all configured system prompts, giving the LLM awareness of its
316    /// execution context.
317    ///
318    /// # Example
319    ///
320    /// ```ignore
321    /// let registry = load_config(&config).with_environment_context();
322    /// ```
323    pub fn with_environment_context(mut self) -> Self {
324        use super::environment::EnvironmentContext;
325
326        let context = EnvironmentContext::gather();
327        let context_section = context.to_prompt_section();
328
329        for config in self.configs.values_mut() {
330            if let Some(ref prompt) = config.system_prompt {
331                config.system_prompt = Some(format!("{}\n\n{}", prompt, context_section));
332            } else {
333                config.system_prompt = Some(context_section.clone());
334            }
335        }
336
337        self
338    }
339}
340
341impl Default for LLMRegistry {
342    fn default() -> Self {
343        Self::new()
344    }
345}
346
347/// Configuration errors
348#[derive(Debug)]
349pub enum ConfigError {
350    /// Home directory not found
351    NoHomeDirectory,
352    /// Failed to read config file
353    ReadError { path: String, source: String },
354    /// Failed to parse config file
355    ParseError { path: String, source: String },
356    /// Unknown provider
357    UnknownProvider { provider: String },
358}
359
360impl std::fmt::Display for ConfigError {
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        match self {
363            ConfigError::NoHomeDirectory => write!(f, "Could not determine home directory"),
364            ConfigError::ReadError { path, source } => {
365                write!(f, "Failed to read config file '{}': {}", path, source)
366            }
367            ConfigError::ParseError { path, source } => {
368                write!(f, "Failed to parse config file '{}': {}", path, source)
369            }
370            ConfigError::UnknownProvider { provider } => {
371                write!(f, "Unknown provider: {}", provider)
372            }
373        }
374    }
375}
376
377impl std::error::Error for ConfigError {}
378
379/// Load config for an agent using its AgentConfig trait implementation.
380///
381/// Tries to load from the config file first, then falls back to environment variables.
382/// Supports both absolute paths and paths relative to home directory.
383pub fn load_config<A: AgentConfig>(agent_config: &A) -> LLMRegistry {
384    let config_path = agent_config.config_path();
385    let default_prompt = agent_config.default_system_prompt();
386
387    // Resolve config path - expand ~/ to home directory, otherwise use as-is
388    let path = if let Some(rest) = config_path.strip_prefix("~/") {
389        match dirs::home_dir() {
390            Some(home) => home.join(rest),
391            None => {
392                tracing::debug!("Could not determine home directory");
393                PathBuf::from(config_path)
394            }
395        }
396    } else {
397        PathBuf::from(config_path)
398    };
399
400    // Try loading from config file first
401    match LLMRegistry::load_from_file(&path, default_prompt) {
402        Ok(registry) if !registry.is_empty() => {
403            tracing::info!("Loaded configuration from {}", path.display());
404            return registry;
405        }
406        Ok(_) => {
407            tracing::debug!("Config file empty, trying environment variables");
408        }
409        Err(e) => {
410            tracing::debug!("Could not load config file: {}", e);
411        }
412    }
413
414    // Fall back to environment variables
415    let mut registry = LLMRegistry::new();
416
417    // Default compaction config for environment-based configuration
418    let compaction = CompactionConfig {
419        threshold: 0.05,
420        keep_recent_turns: 1,
421        tool_compaction: ToolCompaction::Summarize,
422    };
423
424    if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
425        let model = std::env::var("ANTHROPIC_MODEL")
426            .unwrap_or_else(|_| "claude-sonnet-4-20250514".to_string());
427
428        let config = LLMSessionConfig::anthropic(&api_key, &model)
429            .with_system_prompt(default_prompt)
430            .with_threshold_compaction(compaction.clone());
431
432        registry.configs.insert("anthropic".to_string(), config);
433        registry.default_provider = Some("anthropic".to_string());
434
435        tracing::info!("Loaded Anthropic configuration from environment");
436    }
437
438    if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
439        let model =
440            std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4-turbo-preview".to_string());
441
442        let config = LLMSessionConfig::openai(&api_key, &model)
443            .with_system_prompt(default_prompt)
444            .with_threshold_compaction(compaction.clone());
445
446        registry.configs.insert("openai".to_string(), config);
447        if registry.default_provider.is_none() {
448            registry.default_provider = Some("openai".to_string());
449        }
450
451        tracing::info!("Loaded OpenAI configuration from environment");
452    }
453
454    if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
455        let model =
456            std::env::var("GOOGLE_MODEL").unwrap_or_else(|_| "gemini-2.5-flash".to_string());
457
458        let config = LLMSessionConfig::google(&api_key, &model)
459            .with_system_prompt(default_prompt)
460            .with_threshold_compaction(compaction.clone());
461
462        registry.configs.insert("google".to_string(), config);
463        if registry.default_provider.is_none() {
464            registry.default_provider = Some("google".to_string());
465        }
466
467        tracing::info!("Loaded Google (Gemini) configuration from environment");
468    }
469
470    // Check for known OpenAI-compatible providers via environment variables
471    for (name, info) in super::providers::KNOWN_PROVIDERS {
472        // For providers that require API keys, the env var must contain the key
473        // For local providers (Ollama, LM Studio), the env var just signals enablement
474        let api_key = if info.requires_api_key {
475            match std::env::var(info.env_var) {
476                Ok(key) if !key.is_empty() => key,
477                _ => continue, // Skip if no API key provided
478            }
479        } else {
480            // Local provider - check if env var is set (any value enables it)
481            if std::env::var(info.env_var).is_err() {
482                continue;
483            }
484            String::new() // Empty API key for local providers
485        };
486
487        let model =
488            std::env::var(info.model_env_var).unwrap_or_else(|_| info.default_model.to_string());
489
490        let config = LLMSessionConfig::openai_compatible(
491            &api_key,
492            &model,
493            info.base_url,
494            info.context_limit,
495        )
496        .with_system_prompt(default_prompt)
497        .with_threshold_compaction(compaction.clone());
498
499        registry.configs.insert(name.to_string(), config);
500        if registry.default_provider.is_none() {
501            registry.default_provider = Some(name.to_string());
502        }
503
504        tracing::info!("Loaded {} configuration from environment", info.name);
505    }
506
507    registry
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_parse_config() {
516        let yaml = r#"
517providers:
518  - provider: anthropic
519    api_key: test-key
520    model: claude-sonnet-4-20250514
521default_provider: anthropic
522"#;
523        let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
524        assert_eq!(config.providers.len(), 1);
525        assert_eq!(config.providers[0].provider, "anthropic");
526        assert_eq!(config.default_provider, Some("anthropic".to_string()));
527    }
528
529    #[test]
530    fn test_parse_known_provider() {
531        let yaml = r#"
532providers:
533  - provider: groq
534    api_key: gsk_test_key
535    model: llama-3.3-70b-versatile
536"#;
537        let config: ConfigFile = serde_yaml::from_str(yaml).unwrap();
538        assert_eq!(config.providers.len(), 1);
539        assert_eq!(config.providers[0].provider, "groq");
540    }
541
542    #[test]
543    fn test_known_provider_default_model() {
544        // When model is not specified, it should use the provider's default
545        let provider_config = ProviderConfig {
546            provider: "groq".to_string(),
547            api_key: "test-key".to_string(),
548            model: String::new(), // Empty model
549        };
550
551        let session_config =
552            LLMRegistry::create_session_config(&provider_config, "test prompt").unwrap();
553        // Should use groq's default model
554        assert_eq!(session_config.model, "llama-3.3-70b-versatile");
555        // Should have groq's base_url set
556        assert!(session_config.base_url.is_some());
557        assert!(
558            session_config
559                .base_url
560                .as_ref()
561                .unwrap()
562                .contains("groq.com")
563        );
564    }
565
566    #[test]
567    fn test_empty_registry() {
568        let registry = LLMRegistry::new();
569        assert!(registry.is_empty());
570        assert!(registry.get_default().is_none());
571    }
572}