Skip to main content

agent_runtime/
config.rs

1use crate::error::{ConfigError, ConfigErrorCode};
2use crate::retry::RetryPolicy;
3use crate::timeout::TimeoutConfig;
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6use std::time::Duration;
7
8/// Main runtime configuration
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
10pub struct RuntimeConfig {
11    /// LLM provider configurations
12    #[serde(default)]
13    pub llm: LlmConfig,
14
15    /// Retry policy configuration
16    #[serde(default)]
17    pub retry: RetryConfig,
18
19    /// Timeout configuration
20    #[serde(default)]
21    pub timeout: TimeoutConfigSettings,
22
23    /// Logging configuration
24    #[serde(default)]
25    pub logging: LoggingConfig,
26
27    /// Workflow configuration
28    #[serde(default)]
29    pub workflow: WorkflowConfig,
30}
31
32impl RuntimeConfig {
33    /// Load configuration from a TOML file
34    pub fn from_toml_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
35        let path = path.as_ref();
36        let content = std::fs::read_to_string(path).map_err(|e| ConfigError {
37            code: ConfigErrorCode::FileNotFound,
38            message: format!("Failed to read config file: {}", e),
39            field: Some(path.display().to_string()),
40        })?;
41
42        toml::from_str(&content).map_err(|e| ConfigError {
43            code: ConfigErrorCode::ParseError,
44            message: format!("Failed to parse TOML: {}", e),
45            field: None,
46        })
47    }
48
49    /// Load configuration from a YAML file
50    pub fn from_yaml_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
51        let path = path.as_ref();
52        let content = std::fs::read_to_string(path).map_err(|e| ConfigError {
53            code: ConfigErrorCode::FileNotFound,
54            message: format!("Failed to read config file: {}", e),
55            field: Some(path.display().to_string()),
56        })?;
57
58        serde_yaml::from_str(&content).map_err(|e| ConfigError {
59            code: ConfigErrorCode::ParseError,
60            message: format!("Failed to parse YAML: {}", e),
61            field: None,
62        })
63    }
64
65    /// Load configuration from a file (auto-detects format from extension)
66    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
67        let path = path.as_ref();
68        let extension = path.extension().and_then(|s| s.to_str()).unwrap_or("");
69
70        match extension {
71            "toml" => Self::from_toml_file(path),
72            "yaml" | "yml" => Self::from_yaml_file(path),
73            _ => Err(ConfigError {
74                code: ConfigErrorCode::ParseError,
75                message: format!(
76                    "Unsupported file extension '{}'. Use .toml, .yaml, or .yml",
77                    extension
78                ),
79                field: Some(path.display().to_string()),
80            }),
81        }
82    }
83
84    /// Load configuration from environment variables
85    /// Prefix: AGENT_RUNTIME_
86    pub fn from_env() -> Result<Self, ConfigError> {
87        let mut settings = config::Config::builder();
88
89        // Add environment variables with prefix
90        settings = settings.add_source(
91            config::Environment::with_prefix("AGENT_RUNTIME")
92                .separator("__")
93                .try_parsing(true),
94        );
95
96        settings
97            .build()
98            .and_then(|c| c.try_deserialize())
99            .map_err(|e| ConfigError {
100                code: ConfigErrorCode::ParseError,
101                message: format!("Failed to parse environment config: {}", e),
102                field: None,
103            })
104    }
105
106    /// Load configuration from multiple sources (file, then env overrides)
107    pub fn from_sources<P: AsRef<Path>>(file_path: Option<P>) -> Result<Self, ConfigError> {
108        let mut settings = config::Config::builder();
109
110        // Start with defaults
111        settings = settings.add_source(config::Config::try_from(&Self::default()).unwrap());
112
113        // Add file if provided
114        if let Some(path) = file_path {
115            let path_str = path.as_ref().display().to_string();
116            settings = settings.add_source(config::File::with_name(&path_str).required(false));
117        }
118
119        // Add environment variables (highest priority)
120        settings = settings.add_source(
121            config::Environment::with_prefix("AGENT_RUNTIME")
122                .separator("__")
123                .try_parsing(true),
124        );
125
126        settings
127            .build()
128            .and_then(|c| c.try_deserialize())
129            .map_err(|e| ConfigError {
130                code: ConfigErrorCode::ParseError,
131                message: format!("Failed to build config: {}", e),
132                field: None,
133            })
134    }
135
136    /// Validate the configuration
137    pub fn validate(&self) -> Result<(), ConfigError> {
138        // Validate LLM config
139        self.llm.validate()?;
140
141        // Validate retry config
142        self.retry.validate()?;
143
144        // Validate timeout config
145        self.timeout.validate()?;
146
147        Ok(())
148    }
149}
150
151/// LLM provider configuration
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct LlmConfig {
154    /// Default provider to use
155    pub default_provider: Option<String>,
156
157    /// OpenAI configuration
158    pub openai: Option<OpenAIConfig>,
159
160    /// Llama.cpp configuration
161    pub llama: Option<LlamaConfig>,
162
163    /// Default model name
164    pub default_model: Option<String>,
165
166    /// Default temperature
167    #[serde(default = "default_temperature")]
168    pub default_temperature: f32,
169
170    /// Default max tokens
171    pub default_max_tokens: Option<u32>,
172}
173
174fn default_temperature() -> f32 {
175    0.7
176}
177
178impl Default for LlmConfig {
179    fn default() -> Self {
180        Self {
181            default_provider: None,
182            openai: None,
183            llama: None,
184            default_model: None,
185            default_temperature: 0.7,
186            default_max_tokens: None,
187        }
188    }
189}
190
191impl LlmConfig {
192    fn validate(&self) -> Result<(), ConfigError> {
193        if let Some(temp) = Some(self.default_temperature) {
194            if !(0.0..=2.0).contains(&temp) {
195                return Err(ConfigError {
196                    code: ConfigErrorCode::InvalidValue,
197                    message: "Temperature must be between 0.0 and 2.0".to_string(),
198                    field: Some("llm.default_temperature".to_string()),
199                });
200            }
201        }
202        Ok(())
203    }
204}
205
206/// OpenAI-specific configuration
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct OpenAIConfig {
209    pub api_key: Option<String>,
210    pub api_base: Option<String>,
211    pub organization: Option<String>,
212}
213
214/// Llama.cpp-specific configuration
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct LlamaConfig {
217    pub base_url: String,
218    pub insecure: bool,
219}
220
221/// Retry policy configuration
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct RetryConfig {
224    /// Maximum retry attempts
225    #[serde(default = "default_max_attempts")]
226    pub max_attempts: u32,
227
228    /// Initial delay in milliseconds
229    #[serde(default = "default_initial_delay_ms")]
230    pub initial_delay_ms: u64,
231
232    /// Maximum delay in milliseconds
233    #[serde(default = "default_max_delay_ms")]
234    pub max_delay_ms: u64,
235
236    /// Backoff multiplier
237    #[serde(default = "default_backoff_multiplier")]
238    pub backoff_multiplier: f64,
239
240    /// Jitter factor (0.0 - 1.0)
241    #[serde(default = "default_jitter_factor")]
242    pub jitter_factor: f64,
243}
244
245fn default_max_attempts() -> u32 {
246    3
247}
248fn default_initial_delay_ms() -> u64 {
249    100
250}
251fn default_max_delay_ms() -> u64 {
252    30000
253}
254fn default_backoff_multiplier() -> f64 {
255    2.0
256}
257fn default_jitter_factor() -> f64 {
258    0.1
259}
260
261impl Default for RetryConfig {
262    fn default() -> Self {
263        Self {
264            max_attempts: 3,
265            initial_delay_ms: 100,
266            max_delay_ms: 30000,
267            backoff_multiplier: 2.0,
268            jitter_factor: 0.1,
269        }
270    }
271}
272
273impl RetryConfig {
274    fn validate(&self) -> Result<(), ConfigError> {
275        if self.backoff_multiplier < 1.0 {
276            return Err(ConfigError {
277                code: ConfigErrorCode::InvalidValue,
278                message: "Backoff multiplier must be >= 1.0".to_string(),
279                field: Some("retry.backoff_multiplier".to_string()),
280            });
281        }
282
283        if self.jitter_factor < 0.0 || self.jitter_factor > 1.0 {
284            return Err(ConfigError {
285                code: ConfigErrorCode::InvalidValue,
286                message: "Jitter factor must be between 0.0 and 1.0".to_string(),
287                field: Some("retry.jitter_factor".to_string()),
288            });
289        }
290
291        Ok(())
292    }
293
294    /// Convert to RetryPolicy
295    pub fn to_policy(&self) -> RetryPolicy {
296        RetryPolicy {
297            max_attempts: self.max_attempts,
298            initial_delay: Duration::from_millis(self.initial_delay_ms),
299            max_delay: Duration::from_millis(self.max_delay_ms),
300            backoff_multiplier: self.backoff_multiplier,
301            jitter_factor: self.jitter_factor,
302            max_total_duration: None, // Can be added if needed
303        }
304    }
305}
306
307/// Timeout configuration settings
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct TimeoutConfigSettings {
310    /// Total timeout in milliseconds
311    pub total_ms: Option<u64>,
312
313    /// First response timeout in milliseconds
314    pub first_response_ms: Option<u64>,
315}
316
317impl Default for TimeoutConfigSettings {
318    fn default() -> Self {
319        Self {
320            total_ms: Some(300000),         // 5 minutes
321            first_response_ms: Some(30000), // 30 seconds
322        }
323    }
324}
325
326impl TimeoutConfigSettings {
327    fn validate(&self) -> Result<(), ConfigError> {
328        // Validation passes for now
329        Ok(())
330    }
331
332    /// Convert to TimeoutConfig
333    pub fn to_config(&self) -> TimeoutConfig {
334        TimeoutConfig {
335            total: self.total_ms.map(Duration::from_millis),
336            first_response: self.first_response_ms.map(Duration::from_millis),
337        }
338    }
339}
340
341/// Logging configuration
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct LoggingConfig {
344    /// Log level (trace, debug, info, warn, error)
345    #[serde(default = "default_log_level")]
346    pub level: String,
347
348    /// Log output directory
349    #[serde(default = "default_log_dir")]
350    pub directory: String,
351
352    /// Enable JSON format
353    #[serde(default)]
354    pub json_format: bool,
355}
356
357fn default_log_level() -> String {
358    "info".to_string()
359}
360
361fn default_log_dir() -> String {
362    "output".to_string()
363}
364
365impl Default for LoggingConfig {
366    fn default() -> Self {
367        Self {
368            level: "info".to_string(),
369            directory: "output".to_string(),
370            json_format: false,
371        }
372    }
373}
374
375/// Workflow execution configuration
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct WorkflowConfig {
378    /// Maximum concurrent workflows
379    pub max_concurrent: Option<usize>,
380
381    /// Maximum tool iterations per agent
382    #[serde(default = "default_max_tool_iterations")]
383    pub max_tool_iterations: u32,
384}
385
386fn default_max_tool_iterations() -> u32 {
387    5
388}
389
390impl Default for WorkflowConfig {
391    fn default() -> Self {
392        Self {
393            max_concurrent: None,
394            max_tool_iterations: 5,
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_default_config() {
405        let config = RuntimeConfig::default();
406        assert_eq!(config.retry.max_attempts, 3);
407        assert_eq!(config.logging.level, "info");
408        assert_eq!(config.workflow.max_tool_iterations, 5);
409    }
410
411    #[test]
412    fn test_toml_serialization() {
413        let config = RuntimeConfig::default();
414        let toml_str = toml::to_string(&config).unwrap();
415        assert!(toml_str.contains("max_attempts"));
416    }
417
418    #[test]
419    fn test_toml_deserialization() {
420        let toml_str = r#"
421            [retry]
422            max_attempts = 5
423            initial_delay_ms = 200
424
425            [logging]
426            level = "debug"
427            directory = "logs"
428        "#;
429
430        let config: RuntimeConfig = toml::from_str(toml_str).unwrap();
431        assert_eq!(config.retry.max_attempts, 5);
432        assert_eq!(config.logging.level, "debug");
433    }
434
435    #[test]
436    fn test_yaml_serialization() {
437        let config = RuntimeConfig::default();
438        let yaml_str = serde_yaml::to_string(&config).unwrap();
439        assert!(yaml_str.contains("max_attempts"));
440    }
441
442    #[test]
443    fn test_yaml_deserialization() {
444        let yaml_str = r#"
445retry:
446  max_attempts: 5
447  initial_delay_ms: 200
448
449logging:
450  level: debug
451  directory: logs
452        "#;
453
454        let config: RuntimeConfig = serde_yaml::from_str(yaml_str).unwrap();
455        assert_eq!(config.retry.max_attempts, 5);
456        assert_eq!(config.logging.level, "debug");
457    }
458
459    #[test]
460    fn test_retry_config_to_policy() {
461        let config = RetryConfig {
462            max_attempts: 5,
463            initial_delay_ms: 200,
464            max_delay_ms: 60000,
465            backoff_multiplier: 2.0,
466            jitter_factor: 0.2,
467        };
468
469        let policy = config.to_policy();
470        assert_eq!(policy.max_attempts, 5);
471        assert_eq!(policy.initial_delay, Duration::from_millis(200));
472        assert_eq!(policy.backoff_multiplier, 2.0);
473    }
474
475    #[test]
476    fn test_validation_invalid_temperature() {
477        let config = LlmConfig {
478            default_temperature: 3.0,
479            ..Default::default()
480        };
481
482        let result = config.validate();
483        assert!(result.is_err());
484    }
485
486    #[test]
487    fn test_validation_invalid_jitter() {
488        let config = RetryConfig {
489            jitter_factor: 1.5,
490            ..Default::default()
491        };
492
493        let result = config.validate();
494        assert!(result.is_err());
495    }
496
497    #[test]
498    fn test_timeout_config_conversion() {
499        let settings = TimeoutConfigSettings {
500            total_ms: Some(5000),
501            first_response_ms: Some(1000),
502        };
503
504        let timeout = settings.to_config();
505        assert_eq!(timeout.total, Some(Duration::from_millis(5000)));
506        assert_eq!(timeout.first_response, Some(Duration::from_millis(1000)));
507    }
508}