swarm-engine-eval 0.1.6

Evaluation framework for SwarmEngine
Documentation
//! マイルストーン定義と KPI スコア計算
//!
//! マイルストーンは評価の中間目標を定義し、kpi_score の計算に使用される。

use serde::{Deserialize, Serialize};

use super::conditions::{Condition, ConditionValue};

/// マイルストーン定義
///
/// kpi_score = Σ(milestone.weight * milestone.achieved)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Milestone {
    /// マイルストーン名
    pub name: String,

    /// 説明
    #[serde(default)]
    pub description: Option<String>,

    /// 達成条件
    pub condition: Condition,

    /// スコア加重 (0.0 - 1.0, 合計は 1.0 を推奨)
    pub weight: f64,

    /// 部分達成を許可するか
    #[serde(default)]
    pub partial: bool,

    /// 部分達成の計算方法 (partial = true の場合のみ有効)
    #[serde(default)]
    pub partial_config: Option<PartialConfig>,
}

impl Milestone {
    /// マイルストーンを作成
    pub fn new(name: impl Into<String>, condition: Condition, weight: f64) -> Self {
        Self {
            name: name.into(),
            description: None,
            condition,
            weight,
            partial: false,
            partial_config: None,
        }
    }

    /// 説明を設定
    pub fn with_description(mut self, description: impl Into<String>) -> Self {
        self.description = Some(description.into());
        self
    }

    /// 部分達成を有効化
    pub fn with_partial(mut self, config: PartialConfig) -> Self {
        self.partial = true;
        self.partial_config = Some(config);
        self
    }

    /// マイルストーンを評価
    ///
    /// Returns: 達成度 (0.0 - 1.0)
    pub fn evaluate(&self, actual: &ConditionValue) -> f64 {
        if self.condition.evaluate(actual) {
            return 1.0;
        }

        if !self.partial {
            return 0.0;
        }

        // 部分達成の計算
        self.calculate_partial_score(actual)
    }

    /// 部分達成スコアを計算
    fn calculate_partial_score(&self, actual: &ConditionValue) -> f64 {
        let config = match &self.partial_config {
            Some(c) => c,
            None => return 0.0,
        };

        let actual_f64 = match actual {
            ConditionValue::Integer(v) => *v as f64,
            ConditionValue::Float(v) => *v,
            _ => return 0.0,
        };

        let target_f64 = match &self.condition.value {
            ConditionValue::Integer(v) => *v as f64,
            ConditionValue::Float(v) => *v,
            _ => return 0.0,
        };

        match config {
            PartialConfig::Linear {
                min,
                max,
                descending,
            } => {
                let min_val = min.unwrap_or(0.0);
                let max_val = max.unwrap_or(target_f64);

                if *descending {
                    // 値が小さいほどスコアが高い (Lte/Lt 条件用)
                    // min で 1.0, max で 0.0
                    if actual_f64 <= min_val {
                        1.0
                    } else if actual_f64 >= max_val {
                        0.0
                    } else {
                        (max_val - actual_f64) / (max_val - min_val)
                    }
                } else {
                    // 値が大きいほどスコアが高い (Gte/Gt 条件用)
                    // min で 0.0, max で 1.0
                    if actual_f64 <= min_val {
                        0.0
                    } else if actual_f64 >= max_val {
                        1.0
                    } else {
                        (actual_f64 - min_val) / (max_val - min_val)
                    }
                }
            }
            PartialConfig::Threshold { thresholds } => {
                // 閾値ベース
                let mut score = 0.0;
                for (threshold, threshold_score) in thresholds {
                    if actual_f64 >= *threshold {
                        score = *threshold_score;
                    }
                }
                score
            }
        }
    }
}

/// 部分達成の計算方法
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PartialConfig {
    /// 線形補間
    Linear {
        /// 最小値 (境界値)
        min: Option<f64>,
        /// 最大値 (境界値)
        max: Option<f64>,
        /// 値が小さいほどスコアが高い (Lte/Lt 条件用)
        /// true: min で 1.0, max で 0.0 (値が小さいほど良い)
        /// false: min で 0.0, max で 1.0 (値が大きいほど良い)
        #[serde(default)]
        descending: bool,
    },
    /// 閾値ベース
    Threshold {
        /// (閾値, スコア) のペア
        thresholds: Vec<(f64, f64)>,
    },
}

/// マイルストーン評価結果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MilestoneResult {
    /// マイルストーン名
    pub name: String,
    /// 達成度 (0.0 - 1.0)
    pub achievement: f64,
    /// 加重
    pub weight: f64,
    /// 加重スコア (achievement * weight)
    pub weighted_score: f64,
    /// 完全達成かどうか
    pub completed: bool,
}

impl MilestoneResult {
    pub fn new(milestone: &Milestone, achievement: f64) -> Self {
        Self {
            name: milestone.name.clone(),
            achievement,
            weight: milestone.weight,
            weighted_score: achievement * milestone.weight,
            completed: achievement >= 1.0,
        }
    }
}

/// KPI スコア計算機
#[derive(Debug, Clone)]
pub struct KpiCalculator {
    milestones: Vec<Milestone>,
}

impl KpiCalculator {
    /// 計算機を作成
    pub fn new(milestones: Vec<Milestone>) -> Self {
        Self { milestones }
    }

    /// KPI スコアを計算
    ///
    /// # Arguments
    /// * `metric_getter` - メトリクス取得関数 (metric_path -> value)
    pub fn calculate<F>(&self, metric_getter: F) -> KpiScore
    where
        F: Fn(&str) -> Option<ConditionValue>,
    {
        let mut results = Vec::new();
        let mut total_score = 0.0;
        let mut total_weight = 0.0;

        for milestone in &self.milestones {
            let achievement = match metric_getter(&milestone.condition.metric) {
                Some(value) => milestone.evaluate(&value),
                None => 0.0,
            };

            let result = MilestoneResult::new(milestone, achievement);
            total_score += result.weighted_score;
            total_weight += milestone.weight;
            results.push(result);
        }

        // 正規化 (weight の合計が 1.0 でない場合に対応)
        let normalized_score = if total_weight > 0.0 {
            total_score / total_weight
        } else {
            0.0
        };

        KpiScore {
            score: normalized_score,
            raw_score: total_score,
            total_weight,
            results,
        }
    }
}

/// KPI スコア計算結果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KpiScore {
    /// 正規化されたスコア (0.0 - 1.0)
    pub score: f64,
    /// 生スコア (weighted_score の合計)
    pub raw_score: f64,
    /// 総加重
    pub total_weight: f64,
    /// 各マイルストーンの結果
    pub results: Vec<MilestoneResult>,
}

impl KpiScore {
    /// 完全達成したマイルストーン数
    pub fn completed_count(&self) -> usize {
        self.results.iter().filter(|r| r.completed).count()
    }

    /// マイルストーン総数
    pub fn total_count(&self) -> usize {
        self.results.len()
    }
}

#[cfg(test)]
mod tests {
    use super::super::conditions::CompareOp;
    use super::*;

    fn create_test_milestone(
        name: &str,
        metric: &str,
        op: CompareOp,
        value: i64,
        weight: f64,
    ) -> Milestone {
        Milestone::new(name, Condition::new(name, metric, op, value), weight)
    }

    #[test]
    fn test_milestone_evaluate_complete() {
        let milestone = create_test_milestone(
            "first_collection",
            "resources_collected",
            CompareOp::Gte,
            1,
            0.2,
        );

        assert_eq!(milestone.evaluate(&ConditionValue::Integer(1)), 1.0);
        assert_eq!(milestone.evaluate(&ConditionValue::Integer(5)), 1.0);
        assert_eq!(milestone.evaluate(&ConditionValue::Integer(0)), 0.0);
    }

    #[test]
    fn test_milestone_evaluate_partial_linear() {
        let mut milestone = create_test_milestone("efficiency", "tick", CompareOp::Lte, 300, 0.3);
        milestone = milestone.with_partial(PartialConfig::Linear {
            min: Some(300.0),
            max: Some(400.0),
            descending: true, // 値が小さいほどスコアが高い (tick <= 300 が良い)
        });

        // tick <= 300 の場合は完全達成
        assert_eq!(milestone.evaluate(&ConditionValue::Integer(250)), 1.0);
        assert_eq!(milestone.evaluate(&ConditionValue::Integer(300)), 1.0);

        // tick > 300 の場合は部分達成
        // 350 tick = (400 - 350) / (400 - 300) = 0.5
        assert!((milestone.evaluate(&ConditionValue::Integer(350)) - 0.5).abs() < 0.01);

        // tick >= 400 の場合は 0
        assert_eq!(milestone.evaluate(&ConditionValue::Integer(400)), 0.0);
        assert_eq!(milestone.evaluate(&ConditionValue::Integer(500)), 0.0);
    }

    #[test]
    fn test_kpi_calculator() {
        let milestones = vec![
            create_test_milestone("first", "collected", CompareOp::Gte, 1, 0.2),
            create_test_milestone("half", "collected", CompareOp::Gte, 3, 0.3),
            create_test_milestone("all", "collected", CompareOp::Gte, 5, 0.5),
        ];

        let calculator = KpiCalculator::new(milestones);

        // 全て達成
        let score = calculator.calculate(|_| Some(ConditionValue::Integer(5)));
        assert_eq!(score.score, 1.0);
        assert_eq!(score.completed_count(), 3);

        // 半分達成
        let score = calculator.calculate(|_| Some(ConditionValue::Integer(3)));
        // first (0.2) + half (0.3) = 0.5
        assert!((score.score - 0.5).abs() < 0.01);
        assert_eq!(score.completed_count(), 2);

        // 最初だけ達成
        let score = calculator.calculate(|_| Some(ConditionValue::Integer(1)));
        // first (0.2) = 0.2
        assert!((score.score - 0.2).abs() < 0.01);
        assert_eq!(score.completed_count(), 1);
    }

    #[test]
    fn test_milestone_deserialize() {
        let json = r#"{
            "name": "efficiency_bonus",
            "description": "Complete within 300 ticks",
            "condition": {
                "name": "efficiency",
                "metric": "tick",
                "op": "lte",
                "value": 300
            },
            "weight": 0.2,
            "partial": true,
            "partial_config": {
                "type": "linear",
                "min": 300.0,
                "max": 400.0
            }
        }"#;

        let milestone: Milestone = serde_json::from_str(json).unwrap();
        assert_eq!(milestone.name, "efficiency_bonus");
        assert!(milestone.partial);
    }
}