swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Parametric Trait - パラメータ提供機能
//!
//! 戦略設定に使用するパラメータを提供するモデル。

use std::any::Any;
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use super::base::{Model, ModelMetadata, ModelType, ModelVersion};
use crate::learn::offline::{OfflineModel, RecommendedPath, StrategyConfig};
use crate::util::epoch_millis;

/// パラメータを提供できるモデル(戦略設定に使用)
pub trait Parametric: Model {
    /// パラメータ取得
    fn get_param(&self, key: &str) -> Option<ParamValue>;

    /// 全パラメータ取得
    fn all_params(&self) -> HashMap<String, ParamValue>;
}

/// パラメータ値
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ParamValue {
    Float(f64),
    Int(i64),
    Bool(bool),
    String(String),
    Array(Vec<ParamValue>),
}

impl ParamValue {
    pub fn as_f64(&self) -> Option<f64> {
        match self {
            Self::Float(v) => Some(*v),
            Self::Int(v) => Some(*v as f64),
            _ => None,
        }
    }

    pub fn as_i64(&self) -> Option<i64> {
        match self {
            Self::Int(v) => Some(*v),
            Self::Float(v) => Some(*v as i64),
            _ => None,
        }
    }

    pub fn as_bool(&self) -> Option<bool> {
        match self {
            Self::Bool(v) => Some(*v),
            _ => None,
        }
    }

    pub fn as_str(&self) -> Option<&str> {
        match self {
            Self::String(v) => Some(v),
            _ => None,
        }
    }
}

impl From<f64> for ParamValue {
    fn from(v: f64) -> Self {
        Self::Float(v)
    }
}

impl From<i64> for ParamValue {
    fn from(v: i64) -> Self {
        Self::Int(v)
    }
}

impl From<bool> for ParamValue {
    fn from(v: bool) -> Self {
        Self::Bool(v)
    }
}

impl From<String> for ParamValue {
    fn from(v: String) -> Self {
        Self::String(v)
    }
}

impl From<&str> for ParamValue {
    fn from(v: &str) -> Self {
        Self::String(v.to_string())
    }
}

/// パラメータキー定数
pub mod param_keys {
    pub const UCB1_C: &str = "ucb1_c";
    pub const LEARNING_WEIGHT: &str = "learning_weight";
    pub const NGRAM_WEIGHT: &str = "ngram_weight";
    pub const MATURITY_THRESHOLD: &str = "maturity_threshold";
    pub const ERROR_RATE_THRESHOLD: &str = "error_rate_threshold";
    pub const INITIAL_STRATEGY: &str = "initial_strategy";
}

// ============================================================================
// OptimalParamsModel - パラメータ最適化モデル
// ============================================================================

/// パラメータ最適化モデル
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimalParamsModel {
    version: ModelVersion,
    metadata: ModelMetadata,
    created_at: u64,

    /// パラメータデータ
    params: HashMap<String, ParamValue>,

    /// 推奨設定
    pub strategy_config: StrategyConfig,
    pub recommended_paths: Vec<RecommendedPath>,

    /// 分析に使用したセッション数
    pub analyzed_sessions: usize,
}

impl Default for OptimalParamsModel {
    fn default() -> Self {
        let mut params = HashMap::new();
        params.insert(
            param_keys::UCB1_C.to_string(),
            ParamValue::Float(std::f64::consts::SQRT_2),
        );
        params.insert(
            param_keys::LEARNING_WEIGHT.to_string(),
            ParamValue::Float(0.3),
        );
        params.insert(param_keys::NGRAM_WEIGHT.to_string(), ParamValue::Float(1.0));

        Self {
            version: ModelVersion::new(1, 0),
            metadata: ModelMetadata::default(),
            created_at: epoch_millis(),
            params,
            strategy_config: StrategyConfig::default(),
            recommended_paths: Vec::new(),
            analyzed_sessions: 0,
        }
    }
}

impl OptimalParamsModel {
    /// 新しいモデルを作成
    pub fn new() -> Self {
        Self::default()
    }

    /// パラメータを設定
    pub fn set_param(&mut self, key: &str, value: impl Into<ParamValue>) {
        self.params.insert(key.to_string(), value.into());
    }

    /// UCB1 の探索係数を取得
    pub fn ucb1_c(&self) -> f64 {
        self.get_param(param_keys::UCB1_C)
            .and_then(|v| v.as_f64())
            .unwrap_or(std::f64::consts::SQRT_2)
    }

    /// 学習重みを取得
    pub fn learning_weight(&self) -> f64 {
        self.get_param(param_keys::LEARNING_WEIGHT)
            .and_then(|v| v.as_f64())
            .unwrap_or(0.3)
    }

    /// N-gram 重みを取得
    pub fn ngram_weight(&self) -> f64 {
        self.get_param(param_keys::NGRAM_WEIGHT)
            .and_then(|v| v.as_f64())
            .unwrap_or(1.0)
    }
}

impl Model for OptimalParamsModel {
    fn model_type(&self) -> ModelType {
        ModelType::OptimalParams
    }

    fn version(&self) -> &ModelVersion {
        &self.version
    }

    fn created_at(&self) -> u64 {
        self.created_at
    }

    fn metadata(&self) -> &ModelMetadata {
        &self.metadata
    }

    fn as_any(&self) -> &dyn Any {
        self
    }
}

impl Parametric for OptimalParamsModel {
    fn get_param(&self, key: &str) -> Option<ParamValue> {
        self.params.get(key).cloned()
    }

    fn all_params(&self) -> HashMap<String, ParamValue> {
        self.params.clone()
    }
}

// ============================================================================
// 旧 OfflineModel からの変換
// ============================================================================

impl From<OfflineModel> for OptimalParamsModel {
    fn from(old: OfflineModel) -> Self {
        let mut params = HashMap::new();
        params.insert(
            param_keys::UCB1_C.to_string(),
            ParamValue::Float(old.parameters.ucb1_c),
        );
        params.insert(
            param_keys::LEARNING_WEIGHT.to_string(),
            ParamValue::Float(old.parameters.learning_weight),
        );
        params.insert(
            param_keys::NGRAM_WEIGHT.to_string(),
            ParamValue::Float(old.parameters.ngram_weight),
        );
        params.insert(
            param_keys::MATURITY_THRESHOLD.to_string(),
            ParamValue::Int(old.strategy_config.maturity_threshold as i64),
        );
        params.insert(
            param_keys::ERROR_RATE_THRESHOLD.to_string(),
            ParamValue::Float(old.strategy_config.error_rate_threshold),
        );
        params.insert(
            param_keys::INITIAL_STRATEGY.to_string(),
            ParamValue::String(old.strategy_config.initial_strategy.clone()),
        );

        Self {
            version: ModelVersion::new(old.version, 0),
            metadata: ModelMetadata::default(),
            created_at: old.updated_at,
            params,
            strategy_config: old.strategy_config,
            recommended_paths: old.recommended_paths,
            analyzed_sessions: old.analyzed_sessions,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_param_value_conversions() {
        let f = ParamValue::Float(1.5);
        assert_eq!(f.as_f64(), Some(1.5));
        assert_eq!(f.as_i64(), Some(1));

        let i = ParamValue::Int(42);
        assert_eq!(i.as_i64(), Some(42));
        assert_eq!(i.as_f64(), Some(42.0));

        let b = ParamValue::Bool(true);
        assert_eq!(b.as_bool(), Some(true));

        let s = ParamValue::String("test".to_string());
        assert_eq!(s.as_str(), Some("test"));
    }

    #[test]
    fn test_optimal_params_model_default() {
        let model = OptimalParamsModel::new();

        assert!((model.ucb1_c() - std::f64::consts::SQRT_2).abs() < 1e-10);
        assert!((model.learning_weight() - 0.3).abs() < 1e-10);
        assert!((model.ngram_weight() - 1.0).abs() < 1e-10);
    }

    #[test]
    fn test_optimal_params_model_set_param() {
        let mut model = OptimalParamsModel::new();
        model.set_param(param_keys::UCB1_C, 2.0);

        assert!((model.ucb1_c() - 2.0).abs() < 1e-10);
    }

    #[test]
    fn test_parametric_trait() {
        let model = OptimalParamsModel::new();

        let value = model.get_param(param_keys::UCB1_C);
        assert!(value.is_some());

        let all = model.all_params();
        assert!(all.contains_key(param_keys::UCB1_C));
        assert!(all.contains_key(param_keys::LEARNING_WEIGHT));
    }

    #[test]
    fn test_from_offline_model() {
        use crate::learn::offline::{OfflineModel, OptimalParameters, StrategyConfig};

        let old = OfflineModel {
            version: 2,
            parameters: OptimalParameters {
                ucb1_c: 1.5,
                learning_weight: 0.4,
                ngram_weight: 1.2,
            },
            recommended_paths: vec![],
            strategy_config: StrategyConfig::default(),
            analyzed_sessions: 5,
            updated_at: 12345,
            action_order: None,
        };

        let model: OptimalParamsModel = old.into();

        assert!((model.ucb1_c() - 1.5).abs() < 1e-10);
        assert!((model.learning_weight() - 0.4).abs() < 1e-10);
        assert!((model.ngram_weight() - 1.2).abs() < 1e-10);
        assert_eq!(model.analyzed_sessions, 5);
    }
}