swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Base Model Trait - 全モデル共通のインターフェース

use std::any::Any;
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use crate::util::epoch_millis_for_ordering;

/// 統計モデル識別子
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StatsModelId(pub String);

impl StatsModelId {
    pub fn new(id: impl Into<String>) -> Self {
        Self(id.into())
    }

    pub fn generate() -> Self {
        let ts = epoch_millis_for_ordering();
        Self(format!("stats-{:x}", ts))
    }

    pub fn as_str(&self) -> &str {
        &self.0
    }
}

impl std::fmt::Display for StatsModelId {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

/// 全ての学習済みモデルが実装する基本 trait
pub trait Model: Send + Sync {
    /// モデルの種類
    fn model_type(&self) -> ModelType;

    /// バージョン(Lineage追跡用)
    fn version(&self) -> &ModelVersion;

    /// 作成日時(Unix timestamp ms)
    fn created_at(&self) -> u64;

    /// メタデータ
    fn metadata(&self) -> &ModelMetadata;

    /// ダウンキャスト用
    fn as_any(&self) -> &dyn Any;
}

/// モデルの種類
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelType {
    /// 行動選択スコア
    ActionScore,
    /// パラメータ最適化
    OptimalParams,
    /// 将来の拡張用
    Custom(String),
}

impl ModelType {
    /// ディレクトリ名を取得
    pub fn dir_name(&self) -> &str {
        match self {
            Self::ActionScore => "action_scores",
            Self::OptimalParams => "optimal_params",
            Self::Custom(name) => name,
        }
    }
}

impl std::fmt::Display for ModelType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::ActionScore => write!(f, "ActionScore"),
            Self::OptimalParams => write!(f, "OptimalParams"),
            Self::Custom(name) => write!(f, "Custom({})", name),
        }
    }
}

/// モデルバージョン
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelVersion {
    pub major: u32,
    pub minor: u32,
    /// 元データの識別子(Episode IDs, Snapshot IDs 等)
    pub source_ids: Vec<String>,
}

impl ModelVersion {
    pub fn new(major: u32, minor: u32) -> Self {
        Self {
            major,
            minor,
            source_ids: Vec::new(),
        }
    }

    pub fn with_sources(major: u32, minor: u32, source_ids: Vec<String>) -> Self {
        Self {
            major,
            minor,
            source_ids,
        }
    }
}

impl std::fmt::Display for ModelVersion {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}.{}", self.major, self.minor)
    }
}

/// モデルメタデータ
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModelMetadata {
    pub name: Option<String>,
    pub description: Option<String>,
    pub tags: HashMap<String, String>,
}

impl ModelMetadata {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn with_name(mut self, name: impl Into<String>) -> Self {
        self.name = Some(name.into());
        self
    }

    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
        self.description = Some(desc.into());
        self
    }

    pub fn with_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
        self.tags.insert(key.into(), value.into());
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_stats_model_id_generate() {
        let id1 = StatsModelId::generate();
        let id2 = StatsModelId::generate();
        assert!(!id1.0.is_empty());
        assert!(!id2.0.is_empty());
        assert!(id1.as_str().starts_with("stats-"));
    }

    #[test]
    fn test_model_type_dir_name() {
        assert_eq!(ModelType::ActionScore.dir_name(), "action_scores");
        assert_eq!(ModelType::OptimalParams.dir_name(), "optimal_params");
        assert_eq!(
            ModelType::Custom("my_model".to_string()).dir_name(),
            "my_model"
        );
    }

    #[test]
    fn test_model_version() {
        let v = ModelVersion::new(1, 2);
        assert_eq!(format!("{}", v), "1.2");
    }

    #[test]
    fn test_model_metadata_builder() {
        let meta = ModelMetadata::new()
            .with_name("test")
            .with_description("desc")
            .with_tag("env", "prod");

        assert_eq!(meta.name.as_deref(), Some("test"));
        assert_eq!(meta.description.as_deref(), Some("desc"));
        assert_eq!(meta.tags.get("env").map(|s| s.as_str()), Some("prod"));
    }
}