swarm-engine-eval 0.1.6

Evaluation framework for SwarmEngine
Documentation
//! Evaluation configuration
//!
//! TOML 設定ファイルから評価設定を読み込みます。

use std::path::Path;
use std::time::Duration;

use serde::{Deserialize, Serialize};

use crate::error::{EvalError, Result};

/// Evaluation configuration
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EvalConfig {
    /// Orchestrator settings
    #[serde(default)]
    pub orchestrator: OrchestratorSettings,

    /// Evaluation-specific settings
    #[serde(default)]
    pub eval: EvalSettings,

    /// Assertions to verify
    #[serde(default)]
    pub assertions: Vec<AssertionConfig>,

    /// Fault injection configurations
    #[serde(default)]
    pub faults: Vec<FaultConfig>,
}

/// Orchestrator settings
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrchestratorSettings {
    /// Tick duration in milliseconds
    #[serde(default = "default_tick_duration_ms")]
    pub tick_duration_ms: u64,

    /// Maximum ticks
    #[serde(default = "default_max_ticks")]
    pub max_ticks: u64,

    /// DependencyGraph プロバイダーの種類
    #[serde(default)]
    pub dependency_provider: DependencyProviderKind,
}

/// DependencyGraph プロバイダーの種類
///
/// 学習済みアクション順序からグラフを提供する方式を選択する。
///
/// Note: `Smart` と `Learned` は統合され、どちらも `LearnedDependencyProvider` を使用。
/// 後方互換性のため両方の値を受け付けるが、動作は同一。
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum DependencyProviderKind {
    /// LearnedDependencyProvider
    ///
    /// 100% 一致時は学習済みグラフを使用。
    /// 部分一致時は `select()` で投票戦略を決定。
    Learned,

    /// LearnedDependencyProvider(Smart は Learned に統合)
    ///
    /// 後方互換性のためのエイリアス。動作は `Learned` と同一。
    #[default]
    Smart,
}

fn default_tick_duration_ms() -> u64 {
    10
}

fn default_max_ticks() -> u64 {
    1000
}

impl Default for OrchestratorSettings {
    fn default() -> Self {
        Self {
            tick_duration_ms: default_tick_duration_ms(),
            max_ticks: default_max_ticks(),
            dependency_provider: DependencyProviderKind::default(),
        }
    }
}

impl EvalConfig {
    /// Load from TOML file
    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
        let content = std::fs::read_to_string(path)?;
        Self::from_toml_str(&content)
    }

    /// Parse from TOML string
    pub fn from_toml_str(content: &str) -> Result<Self> {
        let config: EvalConfig = toml::from_str(content)?;
        config.validate()?;
        Ok(config)
    }

    /// Validate configuration
    fn validate(&self) -> Result<()> {
        if self.eval.runs == 0 {
            return Err(EvalError::Config("runs must be > 0".to_string()));
        }
        Ok(())
    }
}

/// Evaluation-specific settings
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalSettings {
    /// Number of runs for statistical analysis
    #[serde(default = "default_runs")]
    pub runs: usize,

    /// Base seed for reproducibility (None = use current time)
    pub base_seed: Option<u64>,

    /// Record seeds in report
    #[serde(default = "default_true")]
    pub record_seeds: bool,

    /// Parallel execution (number of concurrent runs)
    #[serde(default = "default_parallel")]
    pub parallel: usize,

    /// Target tick duration for miss rate calculation
    #[serde(default)]
    pub target_tick_duration_ms: Option<u64>,
}

fn default_runs() -> usize {
    30
}

fn default_true() -> bool {
    true
}

fn default_parallel() -> usize {
    1
}

impl Default for EvalSettings {
    fn default() -> Self {
        Self {
            runs: default_runs(),
            base_seed: None,
            record_seeds: true,
            parallel: default_parallel(),
            target_tick_duration_ms: None,
        }
    }
}

impl EvalSettings {
    /// Get target tick duration
    pub fn target_tick_duration(&self) -> Option<Duration> {
        self.target_tick_duration_ms.map(Duration::from_millis)
    }
}

/// Assertion configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssertionConfig {
    /// Assertion name
    pub name: String,

    /// Metric to check
    pub metric: String,

    /// Comparison operator
    pub op: ComparisonOp,

    /// Expected value
    pub expected: f64,
}

/// Comparison operator
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ComparisonOp {
    /// Greater than
    Gt,
    /// Greater than or equal
    Gte,
    /// Less than
    Lt,
    /// Less than or equal
    Lte,
    /// Equal (within epsilon)
    Eq,
}

impl ComparisonOp {
    /// Check if actual value satisfies the comparison
    pub fn check(&self, actual: f64, expected: f64) -> bool {
        const EPSILON: f64 = 1e-9;
        match self {
            ComparisonOp::Gt => actual > expected,
            ComparisonOp::Gte => actual >= expected - EPSILON,
            ComparisonOp::Lt => actual < expected,
            ComparisonOp::Lte => actual <= expected + EPSILON,
            ComparisonOp::Eq => (actual - expected).abs() < EPSILON,
        }
    }
}

/// Fault injection configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FaultConfig {
    /// Fault type
    pub fault_type: FaultType,

    /// Tick range for fault injection (start, end)
    #[serde(default)]
    pub tick_range: Option<(u64, u64)>,

    /// Probability of fault occurrence (0.0 - 1.0)
    #[serde(default = "default_probability")]
    pub probability: f64,

    /// Duration in ticks (for delay injection)
    #[serde(default)]
    pub duration_ticks: Option<u64>,

    /// Target workers (None = all workers)
    #[serde(default)]
    pub target_workers: Option<Vec<usize>>,
}

fn default_probability() -> f64 {
    1.0
}

/// Fault type
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum FaultType {
    /// Inject delay into tick processing
    DelayInjection {
        /// Delay in milliseconds
        delay_ms: u64,
    },

    /// Skip worker execution
    WorkerSkip,

    /// Override worker guidance
    GuidanceOverride {
        /// Goal to inject
        goal: String,
    },

    /// Cause action to fail
    ActionFailure,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = EvalConfig::default();
        assert_eq!(config.eval.runs, 30);
        assert!(config.eval.record_seeds);
        assert_eq!(config.eval.parallel, 1);
    }

    #[test]
    fn test_parse_minimal_toml() {
        let toml = r#"
[eval]
runs = 10
"#;
        let config = EvalConfig::from_toml_str(toml).unwrap();
        assert_eq!(config.eval.runs, 10);
    }

    #[test]
    fn test_parse_with_assertions() {
        let toml = r#"
[eval]
runs = 30

[[assertions]]
name = "success_rate_threshold"
metric = "success_rate"
op = "gte"
expected = 0.8
"#;
        let config = EvalConfig::from_toml_str(toml).unwrap();
        assert_eq!(config.assertions.len(), 1);
        assert_eq!(config.assertions[0].name, "success_rate_threshold");
        assert_eq!(config.assertions[0].op, ComparisonOp::Gte);
    }

    #[test]
    fn test_parse_with_faults() {
        let toml = r#"
[eval]
runs = 10

[[faults]]
fault_type = { type = "delay_injection", delay_ms = 100 }
tick_range = [10, 50]
probability = 0.1
"#;
        let config = EvalConfig::from_toml_str(toml).unwrap();
        assert_eq!(config.faults.len(), 1);
        assert_eq!(config.faults[0].tick_range, Some((10, 50)));
    }

    #[test]
    fn test_comparison_op() {
        assert!(ComparisonOp::Gt.check(0.9, 0.8));
        assert!(!ComparisonOp::Gt.check(0.8, 0.8));

        assert!(ComparisonOp::Gte.check(0.8, 0.8));
        assert!(ComparisonOp::Gte.check(0.9, 0.8));

        assert!(ComparisonOp::Lt.check(0.7, 0.8));
        assert!(!ComparisonOp::Lt.check(0.8, 0.8));

        assert!(ComparisonOp::Lte.check(0.8, 0.8));
        assert!(ComparisonOp::Lte.check(0.7, 0.8));

        assert!(ComparisonOp::Eq.check(0.8, 0.8));
        assert!(!ComparisonOp::Eq.check(0.81, 0.8));
    }

    #[test]
    fn test_invalid_runs() {
        let toml = r#"
[eval]
runs = 0
"#;
        let result = EvalConfig::from_toml_str(toml);
        assert!(result.is_err());
    }
}