use std::path::Path;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::{EvalError, Result};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EvalConfig {
#[serde(default)]
pub orchestrator: OrchestratorSettings,
#[serde(default)]
pub eval: EvalSettings,
#[serde(default)]
pub assertions: Vec<AssertionConfig>,
#[serde(default)]
pub faults: Vec<FaultConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestratorSettings {
#[serde(default = "default_tick_duration_ms")]
pub tick_duration_ms: u64,
#[serde(default = "default_max_ticks")]
pub max_ticks: u64,
#[serde(default)]
pub dependency_provider: DependencyProviderKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum DependencyProviderKind {
Learned,
#[default]
Smart,
}
fn default_tick_duration_ms() -> u64 {
10
}
fn default_max_ticks() -> u64 {
1000
}
impl Default for OrchestratorSettings {
fn default() -> Self {
Self {
tick_duration_ms: default_tick_duration_ms(),
max_ticks: default_max_ticks(),
dependency_provider: DependencyProviderKind::default(),
}
}
}
impl EvalConfig {
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
Self::from_toml_str(&content)
}
pub fn from_toml_str(content: &str) -> Result<Self> {
let config: EvalConfig = toml::from_str(content)?;
config.validate()?;
Ok(config)
}
fn validate(&self) -> Result<()> {
if self.eval.runs == 0 {
return Err(EvalError::Config("runs must be > 0".to_string()));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalSettings {
#[serde(default = "default_runs")]
pub runs: usize,
pub base_seed: Option<u64>,
#[serde(default = "default_true")]
pub record_seeds: bool,
#[serde(default = "default_parallel")]
pub parallel: usize,
#[serde(default)]
pub target_tick_duration_ms: Option<u64>,
}
fn default_runs() -> usize {
30
}
fn default_true() -> bool {
true
}
fn default_parallel() -> usize {
1
}
impl Default for EvalSettings {
fn default() -> Self {
Self {
runs: default_runs(),
base_seed: None,
record_seeds: true,
parallel: default_parallel(),
target_tick_duration_ms: None,
}
}
}
impl EvalSettings {
pub fn target_tick_duration(&self) -> Option<Duration> {
self.target_tick_duration_ms.map(Duration::from_millis)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssertionConfig {
pub name: String,
pub metric: String,
pub op: ComparisonOp,
pub expected: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ComparisonOp {
Gt,
Gte,
Lt,
Lte,
Eq,
}
impl ComparisonOp {
pub fn check(&self, actual: f64, expected: f64) -> bool {
const EPSILON: f64 = 1e-9;
match self {
ComparisonOp::Gt => actual > expected,
ComparisonOp::Gte => actual >= expected - EPSILON,
ComparisonOp::Lt => actual < expected,
ComparisonOp::Lte => actual <= expected + EPSILON,
ComparisonOp::Eq => (actual - expected).abs() < EPSILON,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaultConfig {
pub fault_type: FaultType,
#[serde(default)]
pub tick_range: Option<(u64, u64)>,
#[serde(default = "default_probability")]
pub probability: f64,
#[serde(default)]
pub duration_ticks: Option<u64>,
#[serde(default)]
pub target_workers: Option<Vec<usize>>,
}
fn default_probability() -> f64 {
1.0
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum FaultType {
DelayInjection {
delay_ms: u64,
},
WorkerSkip,
GuidanceOverride {
goal: String,
},
ActionFailure,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = EvalConfig::default();
assert_eq!(config.eval.runs, 30);
assert!(config.eval.record_seeds);
assert_eq!(config.eval.parallel, 1);
}
#[test]
fn test_parse_minimal_toml() {
let toml = r#"
[eval]
runs = 10
"#;
let config = EvalConfig::from_toml_str(toml).unwrap();
assert_eq!(config.eval.runs, 10);
}
#[test]
fn test_parse_with_assertions() {
let toml = r#"
[eval]
runs = 30
[[assertions]]
name = "success_rate_threshold"
metric = "success_rate"
op = "gte"
expected = 0.8
"#;
let config = EvalConfig::from_toml_str(toml).unwrap();
assert_eq!(config.assertions.len(), 1);
assert_eq!(config.assertions[0].name, "success_rate_threshold");
assert_eq!(config.assertions[0].op, ComparisonOp::Gte);
}
#[test]
fn test_parse_with_faults() {
let toml = r#"
[eval]
runs = 10
[[faults]]
fault_type = { type = "delay_injection", delay_ms = 100 }
tick_range = [10, 50]
probability = 0.1
"#;
let config = EvalConfig::from_toml_str(toml).unwrap();
assert_eq!(config.faults.len(), 1);
assert_eq!(config.faults[0].tick_range, Some((10, 50)));
}
#[test]
fn test_comparison_op() {
assert!(ComparisonOp::Gt.check(0.9, 0.8));
assert!(!ComparisonOp::Gt.check(0.8, 0.8));
assert!(ComparisonOp::Gte.check(0.8, 0.8));
assert!(ComparisonOp::Gte.check(0.9, 0.8));
assert!(ComparisonOp::Lt.check(0.7, 0.8));
assert!(!ComparisonOp::Lt.check(0.8, 0.8));
assert!(ComparisonOp::Lte.check(0.8, 0.8));
assert!(ComparisonOp::Lte.check(0.7, 0.8));
assert!(ComparisonOp::Eq.check(0.8, 0.8));
assert!(!ComparisonOp::Eq.check(0.81, 0.8));
}
#[test]
fn test_invalid_runs() {
let toml = r#"
[eval]
runs = 0
"#;
let result = EvalConfig::from_toml_str(toml);
assert!(result.is_err());
}
}