swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Session Group - 複数セッションをまとめる単位
//!
//! Bootstrap / Release / Validate の各フェーズで実行された
//! 複数のセッションをグループとして管理する。

use std::time::{SystemTime, UNIX_EPOCH};

use serde::{Deserialize, Serialize};

use super::snapshot::SessionId;

/// セッショングループ ID
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SessionGroupId(pub String);

impl SessionGroupId {
    /// 新しいグループ ID を生成(タイムスタンプベース)
    pub fn new() -> Self {
        let timestamp = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .map(|d| d.as_millis())
            .unwrap_or(0);
        Self(format!("g{}", timestamp))
    }

    /// 文字列から生成
    pub fn from_raw(s: impl Into<String>) -> Self {
        Self(s.into())
    }
}

impl Default for SessionGroupId {
    fn default() -> Self {
        Self::new()
    }
}

impl std::fmt::Display for SessionGroupId {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

/// 学習フェーズ
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LearningPhase {
    /// Bootstrap: 正解グラフで強制的に成功させ、学習データを蓄積
    Bootstrap,
    /// Release: 学習済みモデルで自律実行
    Release,
    /// Validate: 検証・修正(将来用)
    Validate,
}

impl std::fmt::Display for LearningPhase {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Bootstrap => write!(f, "bootstrap"),
            Self::Release => write!(f, "release"),
            Self::Validate => write!(f, "validate"),
        }
    }
}

impl std::str::FromStr for LearningPhase {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s.to_lowercase().as_str() {
            "bootstrap" => Ok(Self::Bootstrap),
            "release" => Ok(Self::Release),
            "validate" => Ok(Self::Validate),
            _ => Err(format!("Unknown phase: {}", s)),
        }
    }
}

/// セッショングループのメタデータ
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionGroupMetadata {
    /// シナリオ名
    pub scenario: String,
    /// 作成日時(Unix timestamp)
    pub created_at: u64,
    /// 完了日時(Unix timestamp)
    pub completed_at: Option<u64>,
    /// 目標実行回数
    pub target_runs: usize,
    /// 成功回数
    pub success_count: usize,
    /// 失敗回数
    pub failure_count: usize,
    /// 使用した variant(with_graph 等)
    pub variant: Option<String>,
}

impl SessionGroupMetadata {
    /// 新しいメタデータを作成
    pub fn new(scenario: impl Into<String>, target_runs: usize) -> Self {
        Self {
            scenario: scenario.into(),
            created_at: SystemTime::now()
                .duration_since(UNIX_EPOCH)
                .map(|d| d.as_secs())
                .unwrap_or(0),
            completed_at: None,
            target_runs,
            success_count: 0,
            failure_count: 0,
            variant: None,
        }
    }

    /// variant を設定
    pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
        self.variant = Some(variant.into());
        self
    }

    /// 成功を記録
    pub fn record_success(&mut self) {
        self.success_count += 1;
    }

    /// 失敗を記録
    pub fn record_failure(&mut self) {
        self.failure_count += 1;
    }

    /// 完了をマーク
    pub fn mark_completed(&mut self) {
        self.completed_at = Some(
            SystemTime::now()
                .duration_since(UNIX_EPOCH)
                .map(|d| d.as_secs())
                .unwrap_or(0),
        );
    }

    /// 成功率を計算
    pub fn success_rate(&self) -> f64 {
        let total = self.success_count + self.failure_count;
        if total == 0 {
            0.0
        } else {
            self.success_count as f64 / total as f64
        }
    }

    /// 完了した実行回数
    pub fn completed_runs(&self) -> usize {
        self.success_count + self.failure_count
    }
}

/// セッショングループ
///
/// 複数の Eval セッションをまとめて管理する単位。
/// Bootstrap / Release / Validate の各フェーズで使用。
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionGroup {
    /// グループ ID
    pub id: SessionGroupId,
    /// フェーズ
    pub phase: LearningPhase,
    /// 含まれるセッション ID
    pub session_ids: Vec<SessionId>,
    /// メタデータ
    pub metadata: SessionGroupMetadata,
}

impl SessionGroup {
    /// 新しいセッショングループを作成
    pub fn new(phase: LearningPhase, scenario: impl Into<String>, target_runs: usize) -> Self {
        let scenario = scenario.into();
        Self {
            id: SessionGroupId::new(),
            phase,
            session_ids: Vec::new(),
            metadata: SessionGroupMetadata::new(&scenario, target_runs),
        }
    }

    /// variant を設定
    pub fn with_variant(mut self, variant: impl Into<String>) -> Self {
        self.metadata = self.metadata.with_variant(variant);
        self
    }

    /// セッションを追加
    pub fn add_session(&mut self, session_id: SessionId, success: bool) {
        self.session_ids.push(session_id);
        if success {
            self.metadata.record_success();
        } else {
            self.metadata.record_failure();
        }
    }

    /// 完了をマーク
    pub fn mark_completed(&mut self) {
        self.metadata.mark_completed();
    }

    /// 成功率を取得
    pub fn success_rate(&self) -> f64 {
        self.metadata.success_rate()
    }

    /// 目標回数に達したか
    pub fn is_target_reached(&self) -> bool {
        self.metadata.completed_runs() >= self.metadata.target_runs
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_session_group_id_generation() {
        let id1 = SessionGroupId::new();
        let id2 = SessionGroupId::new();
        // 同じミリ秒内でも異なる可能性があるが、フォーマットは一貫
        assert!(id1.0.starts_with('g'));
        assert!(id2.0.starts_with('g'));
    }

    #[test]
    fn test_learning_phase_display() {
        assert_eq!(LearningPhase::Bootstrap.to_string(), "bootstrap");
        assert_eq!(LearningPhase::Release.to_string(), "release");
        assert_eq!(LearningPhase::Validate.to_string(), "validate");
    }

    #[test]
    fn test_learning_phase_parse() {
        assert_eq!(
            "bootstrap".parse::<LearningPhase>().unwrap(),
            LearningPhase::Bootstrap
        );
        assert_eq!(
            "RELEASE".parse::<LearningPhase>().unwrap(),
            LearningPhase::Release
        );
        assert!("unknown".parse::<LearningPhase>().is_err());
    }

    #[test]
    fn test_session_group_success_rate() {
        let mut group = SessionGroup::new(LearningPhase::Bootstrap, "test", 10);

        // 初期状態
        assert_eq!(group.success_rate(), 0.0);

        // 3 成功、2 失敗
        group.add_session(SessionId("1".to_string()), true);
        group.add_session(SessionId("2".to_string()), true);
        group.add_session(SessionId("3".to_string()), true);
        group.add_session(SessionId("4".to_string()), false);
        group.add_session(SessionId("5".to_string()), false);

        assert_eq!(group.success_rate(), 0.6);
        assert!(!group.is_target_reached());

        // 残り 5 回追加
        for i in 6..=10 {
            group.add_session(SessionId(i.to_string()), true);
        }
        assert!(group.is_target_reached());
    }
}