1use std::path::Path;
6use std::time::Duration;
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::{EvalError, Result};
11
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct EvalConfig {
15 #[serde(default)]
17 pub orchestrator: OrchestratorSettings,
18
19 #[serde(default)]
21 pub eval: EvalSettings,
22
23 #[serde(default)]
25 pub assertions: Vec<AssertionConfig>,
26
27 #[serde(default)]
29 pub faults: Vec<FaultConfig>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct OrchestratorSettings {
35 #[serde(default = "default_tick_duration_ms")]
37 pub tick_duration_ms: u64,
38
39 #[serde(default = "default_max_ticks")]
41 pub max_ticks: u64,
42
43 #[serde(default)]
45 pub dependency_provider: DependencyProviderKind,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
55#[serde(rename_all = "snake_case")]
56pub enum DependencyProviderKind {
57 Learned,
62
63 #[default]
67 Smart,
68}
69
70fn default_tick_duration_ms() -> u64 {
71 10
72}
73
74fn default_max_ticks() -> u64 {
75 1000
76}
77
78impl Default for OrchestratorSettings {
79 fn default() -> Self {
80 Self {
81 tick_duration_ms: default_tick_duration_ms(),
82 max_ticks: default_max_ticks(),
83 dependency_provider: DependencyProviderKind::default(),
84 }
85 }
86}
87
88impl EvalConfig {
89 pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
91 let content = std::fs::read_to_string(path)?;
92 Self::from_toml_str(&content)
93 }
94
95 pub fn from_toml_str(content: &str) -> Result<Self> {
97 let config: EvalConfig = toml::from_str(content)?;
98 config.validate()?;
99 Ok(config)
100 }
101
102 fn validate(&self) -> Result<()> {
104 if self.eval.runs == 0 {
105 return Err(EvalError::Config("runs must be > 0".to_string()));
106 }
107 Ok(())
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct EvalSettings {
114 #[serde(default = "default_runs")]
116 pub runs: usize,
117
118 pub base_seed: Option<u64>,
120
121 #[serde(default = "default_true")]
123 pub record_seeds: bool,
124
125 #[serde(default = "default_parallel")]
127 pub parallel: usize,
128
129 #[serde(default)]
131 pub target_tick_duration_ms: Option<u64>,
132}
133
134fn default_runs() -> usize {
135 30
136}
137
138fn default_true() -> bool {
139 true
140}
141
142fn default_parallel() -> usize {
143 1
144}
145
146impl Default for EvalSettings {
147 fn default() -> Self {
148 Self {
149 runs: default_runs(),
150 base_seed: None,
151 record_seeds: true,
152 parallel: default_parallel(),
153 target_tick_duration_ms: None,
154 }
155 }
156}
157
158impl EvalSettings {
159 pub fn target_tick_duration(&self) -> Option<Duration> {
161 self.target_tick_duration_ms.map(Duration::from_millis)
162 }
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct AssertionConfig {
168 pub name: String,
170
171 pub metric: String,
173
174 pub op: ComparisonOp,
176
177 pub expected: f64,
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case")]
184pub enum ComparisonOp {
185 Gt,
187 Gte,
189 Lt,
191 Lte,
193 Eq,
195}
196
197impl ComparisonOp {
198 pub fn check(&self, actual: f64, expected: f64) -> bool {
200 const EPSILON: f64 = 1e-9;
201 match self {
202 ComparisonOp::Gt => actual > expected,
203 ComparisonOp::Gte => actual >= expected - EPSILON,
204 ComparisonOp::Lt => actual < expected,
205 ComparisonOp::Lte => actual <= expected + EPSILON,
206 ComparisonOp::Eq => (actual - expected).abs() < EPSILON,
207 }
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct FaultConfig {
214 pub fault_type: FaultType,
216
217 #[serde(default)]
219 pub tick_range: Option<(u64, u64)>,
220
221 #[serde(default = "default_probability")]
223 pub probability: f64,
224
225 #[serde(default)]
227 pub duration_ticks: Option<u64>,
228
229 #[serde(default)]
231 pub target_workers: Option<Vec<usize>>,
232}
233
234fn default_probability() -> f64 {
235 1.0
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240#[serde(tag = "type", rename_all = "snake_case")]
241pub enum FaultType {
242 DelayInjection {
244 delay_ms: u64,
246 },
247
248 WorkerSkip,
250
251 GuidanceOverride {
253 goal: String,
255 },
256
257 ActionFailure,
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_default_config() {
267 let config = EvalConfig::default();
268 assert_eq!(config.eval.runs, 30);
269 assert!(config.eval.record_seeds);
270 assert_eq!(config.eval.parallel, 1);
271 }
272
273 #[test]
274 fn test_parse_minimal_toml() {
275 let toml = r#"
276[eval]
277runs = 10
278"#;
279 let config = EvalConfig::from_toml_str(toml).unwrap();
280 assert_eq!(config.eval.runs, 10);
281 }
282
283 #[test]
284 fn test_parse_with_assertions() {
285 let toml = r#"
286[eval]
287runs = 30
288
289[[assertions]]
290name = "success_rate_threshold"
291metric = "success_rate"
292op = "gte"
293expected = 0.8
294"#;
295 let config = EvalConfig::from_toml_str(toml).unwrap();
296 assert_eq!(config.assertions.len(), 1);
297 assert_eq!(config.assertions[0].name, "success_rate_threshold");
298 assert_eq!(config.assertions[0].op, ComparisonOp::Gte);
299 }
300
301 #[test]
302 fn test_parse_with_faults() {
303 let toml = r#"
304[eval]
305runs = 10
306
307[[faults]]
308fault_type = { type = "delay_injection", delay_ms = 100 }
309tick_range = [10, 50]
310probability = 0.1
311"#;
312 let config = EvalConfig::from_toml_str(toml).unwrap();
313 assert_eq!(config.faults.len(), 1);
314 assert_eq!(config.faults[0].tick_range, Some((10, 50)));
315 }
316
317 #[test]
318 fn test_comparison_op() {
319 assert!(ComparisonOp::Gt.check(0.9, 0.8));
320 assert!(!ComparisonOp::Gt.check(0.8, 0.8));
321
322 assert!(ComparisonOp::Gte.check(0.8, 0.8));
323 assert!(ComparisonOp::Gte.check(0.9, 0.8));
324
325 assert!(ComparisonOp::Lt.check(0.7, 0.8));
326 assert!(!ComparisonOp::Lt.check(0.8, 0.8));
327
328 assert!(ComparisonOp::Lte.check(0.8, 0.8));
329 assert!(ComparisonOp::Lte.check(0.7, 0.8));
330
331 assert!(ComparisonOp::Eq.check(0.8, 0.8));
332 assert!(!ComparisonOp::Eq.check(0.81, 0.8));
333 }
334
335 #[test]
336 fn test_invalid_runs() {
337 let toml = r#"
338[eval]
339runs = 0
340"#;
341 let result = EvalConfig::from_toml_str(toml);
342 assert!(result.is_err());
343 }
344}