swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Applier - 学習済みモデルの適用
//!
//! 学習完了後に TrainedModel を llama-server に適用する。
//! Auto-apply モードでは自動的に適用、それ以外では通知のみ。

use std::sync::Arc;

use crate::learn::lora::{ApplicatorError, ModelApplicator, TrainedModel};

// ============================================================================
// ApplyMode
// ============================================================================

/// 適用モード
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ApplyMode {
    /// 手動適用(通知のみ)
    #[default]
    Manual,
    /// 自動適用
    Auto,
}

// ============================================================================
// ApplierConfig
// ============================================================================

/// Applier の設定
#[derive(Debug, Clone)]
pub struct ApplierConfig {
    /// 適用モード
    pub mode: ApplyMode,
    /// ロールバック履歴の最大数
    pub max_history: usize,
}

impl Default for ApplierConfig {
    fn default() -> Self {
        Self {
            mode: ApplyMode::Manual,
            max_history: 5,
        }
    }
}

impl ApplierConfig {
    /// Auto-apply を有効化
    pub fn auto_apply(mut self) -> Self {
        self.mode = ApplyMode::Auto;
        self
    }

    /// 履歴の最大数を設定
    pub fn max_history(mut self, n: usize) -> Self {
        self.max_history = n;
        self
    }
}

// ============================================================================
// ApplierError
// ============================================================================

/// Applier のエラー型
#[derive(Debug)]
pub enum ApplierError {
    /// Applicator エラー
    Applicator(ApplicatorError),
    /// 適用がスキップされた(手動モード)
    Skipped(String),
    /// その他
    Other(String),
}

impl std::fmt::Display for ApplierError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Applicator(e) => write!(f, "Applicator error: {}", e),
            Self::Skipped(msg) => write!(f, "Apply skipped: {}", msg),
            Self::Other(msg) => write!(f, "{}", msg),
        }
    }
}

impl std::error::Error for ApplierError {}

impl From<ApplicatorError> for ApplierError {
    fn from(e: ApplicatorError) -> Self {
        Self::Applicator(e)
    }
}

// ============================================================================
// ApplyResult
// ============================================================================

/// 適用結果
#[derive(Debug)]
pub enum ApplyResult {
    /// 適用成功
    Applied {
        model_id: String,
        previous_model_id: Option<String>,
    },
    /// スキップ(手動モード)
    Skipped { model_id: String, reason: String },
}

impl ApplyResult {
    /// 適用が成功したかどうか
    pub fn is_applied(&self) -> bool {
        matches!(self, Self::Applied { .. })
    }
}

// ============================================================================
// Applier
// ============================================================================

/// 学習済みモデルの適用を担当
pub struct Applier {
    /// 設定
    config: ApplierConfig,
    /// ModelApplicator
    applicator: Arc<dyn ModelApplicator>,
    /// 適用履歴(model_id のリスト)
    history: Vec<String>,
}

impl Applier {
    /// 新しい Applier を作成
    pub fn new(config: ApplierConfig, applicator: Arc<dyn ModelApplicator>) -> Self {
        Self {
            config,
            applicator,
            history: Vec::new(),
        }
    }

    /// 設定を取得
    pub fn config(&self) -> &ApplierConfig {
        &self.config
    }

    /// 適用履歴を取得
    pub fn history(&self) -> &[String] {
        &self.history
    }

    /// モデルを適用(設定に応じて、非同期)
    pub async fn apply(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
        match self.config.mode {
            ApplyMode::Manual => {
                tracing::info!(
                    model_id = %model.id,
                    "Model ready for manual apply (auto-apply disabled)"
                );
                Ok(ApplyResult::Skipped {
                    model_id: model.id.to_string(),
                    reason: "Auto-apply disabled".into(),
                })
            }
            ApplyMode::Auto => self.apply_now(model).await,
        }
    }

    /// モデルを即座に適用(モード関係なく、非同期)
    pub async fn apply_now(&mut self, model: &TrainedModel) -> Result<ApplyResult, ApplierError> {
        let previous_model_id = self.applicator.previous_model_id().map(|id| id.to_string());

        tracing::info!(
            model_id = %model.id,
            previous = ?previous_model_id,
            "Applying trained model"
        );

        self.applicator.apply(model).await?;

        // 履歴に追加
        self.history.push(model.id.to_string());
        if self.history.len() > self.config.max_history {
            self.history.remove(0);
        }

        tracing::info!(
            model_id = %model.id,
            "Model applied successfully"
        );

        Ok(ApplyResult::Applied {
            model_id: model.id.to_string(),
            previous_model_id,
        })
    }

    /// 前のモデルにロールバック(非同期)
    pub async fn rollback(&self) -> Result<(), ApplierError> {
        let previous_id = self
            .applicator
            .previous_model_id()
            .ok_or_else(|| ApplierError::Other("No previous model to rollback to".into()))?;

        tracing::info!(target_id = %previous_id, "Rolling back to previous model");
        self.applicator.rollback(&previous_id).await?;

        Ok(())
    }

    /// 現在適用中のモデルを取得
    pub fn current_model(&self) -> Option<TrainedModel> {
        self.applicator.current()
    }
}

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
mod tests {
    use super::*;
    use crate::learn::lora::{LoraModelId, NoOpApplicator};
    use std::path::PathBuf;

    fn create_test_model(id: &str) -> TrainedModel {
        TrainedModel {
            id: LoraModelId::parse(id),
            base_model: "test-base".to_string(),
            adapter_path: PathBuf::from("/tmp/test"),
            learn_model_name: "test".to_string(),
            episode_ids: vec![],
            sample_count: 10,
            created_at: 0,
            metrics: None,
        }
    }

    #[tokio::test]
    async fn test_applier_manual_mode() {
        let config = ApplierConfig::default(); // Manual mode
        let applicator = Arc::new(NoOpApplicator::new());
        let mut applier = Applier::new(config, applicator);

        let model = create_test_model("test-model-1");
        let result = applier.apply(&model).await.unwrap();

        assert!(!result.is_applied());
        match result {
            ApplyResult::Skipped { model_id, .. } => {
                assert_eq!(model_id, "test-model-1");
            }
            _ => panic!("Expected Skipped"),
        }
    }

    #[tokio::test]
    async fn test_applier_auto_mode() {
        let config = ApplierConfig::default().auto_apply();
        let applicator = Arc::new(NoOpApplicator::new());
        let mut applier = Applier::new(config, applicator);

        let model = create_test_model("test-model-1");
        let result = applier.apply(&model).await.unwrap();

        assert!(result.is_applied());
        assert_eq!(applier.history().len(), 1);
    }

    #[tokio::test]
    async fn test_applier_history_limit() {
        let config = ApplierConfig::default().auto_apply().max_history(2);
        let applicator = Arc::new(NoOpApplicator::new());
        let mut applier = Applier::new(config, applicator);

        // Apply 3 models
        for i in 0..3 {
            let model = create_test_model(&format!("model-{}", i));
            applier.apply(&model).await.unwrap();
        }

        // History should only have last 2
        assert_eq!(applier.history().len(), 2);
        assert_eq!(applier.history()[0], "model-1");
        assert_eq!(applier.history()[1], "model-2");
    }
}