swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Validator - 学習結果の性能検証
//!
//! Learn Pipeline の外で使用。
//! データ分割 → 検証 のプロセスを実行。

use super::result::ValidationResult;
use super::strategy::ValidationStrategy;

/// Validator - 学習結果の性能検証
///
/// データを Train/Test に分割し、Test データで性能を検証する。
/// Learn Pipeline の外で使用される。
///
/// # Type Parameter
///
/// * `T` - 検証対象のデータ型(Episode, ActionRecord 等)
///
/// # Example
///
/// ```ignore
/// use swarm_engine_core::validation::{Validator, NoRegression};
///
/// let validator = Validator::new(0.8, Box::new(NoRegression::new()));
/// let result = validator.validate(&episodes, |test| {
///     // test データで成功率を計算
///     compute_success_rate(test)
/// });
/// ```
pub struct Validator<T> {
    /// Train/Test 分割比率 (0.8 = 80% train, 20% test)
    split_ratio: f64,
    /// 検証戦略
    strategy: Box<dyn ValidationStrategy>,
    /// データ型マーカー
    _marker: std::marker::PhantomData<T>,
}

impl<T> Validator<T> {
    /// 新しい Validator を作成
    ///
    /// # Arguments
    ///
    /// * `split_ratio` - Train/Test 分割比率 (0.0-1.0)
    /// * `strategy` - 検証戦略
    pub fn new(split_ratio: f64, strategy: Box<dyn ValidationStrategy>) -> Self {
        assert!(
            (0.0..=1.0).contains(&split_ratio),
            "split_ratio must be between 0.0 and 1.0"
        );
        Self {
            split_ratio,
            strategy,
            _marker: std::marker::PhantomData,
        }
    }

    /// 8:2 分割で作成
    pub fn with_80_20_split(strategy: Box<dyn ValidationStrategy>) -> Self {
        Self::new(0.8, strategy)
    }

    /// 7:3 分割で作成
    pub fn with_70_30_split(strategy: Box<dyn ValidationStrategy>) -> Self {
        Self::new(0.7, strategy)
    }

    /// 検証を実行
    ///
    /// # Arguments
    ///
    /// * `data` - 全データ
    /// * `baseline_fn` - Train データからベースライン成績を計算する関数
    /// * `evaluate_fn` - Test データから検証成績を計算する関数
    ///
    /// # Returns
    ///
    /// 検証結果
    pub fn validate<F, G>(&self, data: &[T], baseline_fn: F, evaluate_fn: G) -> ValidationResult
    where
        F: FnOnce(&[T]) -> f64,
        G: FnOnce(&[T]) -> f64,
    {
        let (train, test) = self.split(data);

        let baseline = baseline_fn(train);
        let current = evaluate_fn(test);

        self.strategy.evaluate(baseline, current, test.len())
    }

    /// 検証を実行(ベースラインを外部から指定)
    ///
    /// Bootstrap で既に計算済みのベースラインを使用する場合。
    pub fn validate_with_baseline<F>(
        &self,
        data: &[T],
        baseline: f64,
        evaluate_fn: F,
    ) -> ValidationResult
    where
        F: FnOnce(&[T]) -> f64,
    {
        let (_, test) = self.split(data);
        let current = evaluate_fn(test);
        self.strategy.evaluate(baseline, current, test.len())
    }

    /// データを Train/Test に分割
    fn split<'a>(&self, data: &'a [T]) -> (&'a [T], &'a [T]) {
        let split_idx = (data.len() as f64 * self.split_ratio) as usize;
        let split_idx = split_idx.min(data.len());
        (&data[..split_idx], &data[split_idx..])
    }

    /// 戦略名を取得
    pub fn strategy_name(&self) -> &str {
        self.strategy.name()
    }

    /// 分割比率を取得
    pub fn split_ratio(&self) -> f64 {
        self.split_ratio
    }
}

#[cfg(test)]
mod tests {
    use super::super::strategy::{Absolute, Improvement, NoRegression};
    use super::*;

    #[test]
    fn test_validator_split() {
        let data: Vec<i32> = (0..100).collect();
        let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));

        let (train, test) = validator.split(&data);
        assert_eq!(train.len(), 80);
        assert_eq!(test.len(), 20);
    }

    #[test]
    fn test_validator_validate() {
        // 100 samples: Train 80, Test 20
        let mut data: Vec<f64> = Vec::with_capacity(100);
        // Train: 80 samples, 64 success (80%)
        for i in 0..80 {
            data.push(if i < 64 { 1.0 } else { 0.0 });
        }
        // Test: 20 samples, 18 success (90%)
        for i in 0..20 {
            data.push(if i < 18 { 1.0 } else { 0.0 });
        }

        let validator = Validator::with_80_20_split(Box::new(NoRegression::new()));

        let result = validator.validate(
            &data,
            |train| train.iter().sum::<f64>() / train.len() as f64,
            |test| test.iter().sum::<f64>() / test.len() as f64,
        );

        assert!(result.passed);
        assert!((result.baseline - 0.8).abs() < 0.01);
        assert!((result.current - 0.9).abs() < 0.01);
    }

    #[test]
    fn test_validator_with_baseline() {
        let data: Vec<f64> = (0..100).map(|i| if i < 85 { 1.0 } else { 0.0 }).collect();

        let validator = Validator::with_80_20_split(Box::new(Improvement::ten_percent()));

        // 外部ベースライン 0.7 を使用
        let result = validator.validate_with_baseline(&data, 0.7, |test| {
            test.iter().sum::<f64>() / test.len() as f64
        });

        // Test データの 20% は 85/100 のうちの後半 20 件 = 15 success / 20 = 0.75
        // 0.75 >= 0.7 * 1.1 = 0.77? → No, fail
        assert!(!result.passed);
    }

    #[test]
    fn test_validator_absolute_strategy() {
        let data: Vec<f64> = (0..100).map(|i| if i < 90 { 1.0 } else { 0.0 }).collect();

        let validator = Validator::with_80_20_split(Box::new(Absolute::eighty_percent()));

        let result = validator.validate_with_baseline(&data, 0.5, |test| {
            test.iter().sum::<f64>() / test.len() as f64
        });

        // Test: 後半 20 件 = 10 success / 20 = 0.5 < 0.8 → fail
        assert!(!result.passed);
    }

    #[test]
    fn test_validator_empty_data() {
        let data: Vec<i32> = vec![];
        let validator: Validator<i32> = Validator::new(0.8, Box::new(NoRegression::new()));

        let (train, test) = validator.split(&data);
        assert!(train.is_empty());
        assert!(test.is_empty());
    }

    #[test]
    #[should_panic(expected = "split_ratio must be between")]
    fn test_validator_invalid_ratio() {
        let _: Validator<i32> = Validator::new(1.5, Box::new(NoRegression::new()));
    }
}