swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Processor - 学習実行
//!
//! Trigger 発火時に呼び出され、以下の処理を実行:
//! - Offline 分析 → OfflineModel 生成
//! - LoRA 学習 → TrainedModel 生成

use std::sync::Arc;

use crate::learn::learn_model::LearnModel;
use crate::learn::lora::{LoraTrainer, LoraTrainerError, TrainedModel};
use crate::learn::offline::OfflineModel;
use crate::learn::snapshot::LearningStore;
use crate::learn::store::{EpisodeStore, StoreError};

// ============================================================================
// ProcessorMode
// ============================================================================

/// 処理モード
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ProcessorMode {
    /// Offline 分析のみ(OfflineModel 生成)
    #[default]
    OfflineOnly,
    /// LoRA 学習のみ(TrainedModel 生成)
    LoraOnly,
    /// 両方実行
    Full,
}

impl std::str::FromStr for ProcessorMode {
    type Err = String;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s.to_lowercase().as_str() {
            "offline" | "offline_only" => Ok(Self::OfflineOnly),
            "lora" | "lora_only" => Ok(Self::LoraOnly),
            "full" | "both" => Ok(Self::Full),
            _ => Err(format!("Unknown processor mode: {}", s)),
        }
    }
}

impl std::fmt::Display for ProcessorMode {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::OfflineOnly => write!(f, "offline"),
            Self::LoraOnly => write!(f, "lora"),
            Self::Full => write!(f, "full"),
        }
    }
}

// ============================================================================
// ProcessResult
// ============================================================================

/// 処理結果
#[derive(Debug)]
pub enum ProcessResult {
    /// Offline 分析結果
    Offline(OfflineModel),
    /// LoRA 学習結果
    Lora(TrainedModel),
    /// 両方の結果
    Full {
        offline: OfflineModel,
        lora: TrainedModel,
    },
}

impl ProcessResult {
    /// LoRA モデルを取得(あれば)
    pub fn lora_model(&self) -> Option<&TrainedModel> {
        match self {
            Self::Lora(m) => Some(m),
            Self::Full { lora, .. } => Some(lora),
            Self::Offline(_) => None,
        }
    }

    /// Offline モデルを取得(あれば)
    pub fn offline_model(&self) -> Option<&OfflineModel> {
        match self {
            Self::Offline(m) => Some(m),
            Self::Full { offline, .. } => Some(offline),
            Self::Lora(_) => None,
        }
    }
}

// ============================================================================
// ProcessorError
// ============================================================================

/// Processor のエラー型
#[derive(Debug)]
pub enum ProcessorError {
    /// Store エラー
    Store(StoreError),
    /// LoRA Trainer エラー
    LoraTrainer(LoraTrainerError),
    /// IO エラー
    Io(std::io::Error),
    /// データ不足
    InsufficientData(String),
    /// その他
    Other(String),
}

impl std::fmt::Display for ProcessorError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Store(e) => write!(f, "Store error: {}", e),
            Self::LoraTrainer(e) => write!(f, "LoRA trainer error: {}", e),
            Self::Io(e) => write!(f, "IO error: {}", e),
            Self::InsufficientData(msg) => write!(f, "Insufficient data: {}", msg),
            Self::Other(msg) => write!(f, "{}", msg),
        }
    }
}

impl std::error::Error for ProcessorError {}

impl From<StoreError> for ProcessorError {
    fn from(e: StoreError) -> Self {
        Self::Store(e)
    }
}

impl From<LoraTrainerError> for ProcessorError {
    fn from(e: LoraTrainerError) -> Self {
        Self::LoraTrainer(e)
    }
}

impl From<std::io::Error> for ProcessorError {
    fn from(e: std::io::Error) -> Self {
        Self::Io(e)
    }
}

// ============================================================================
// ProcessorConfig
// ============================================================================

/// Processor の設定
#[derive(Debug, Clone)]
pub struct ProcessorConfig {
    /// 処理モード
    pub mode: ProcessorMode,
    /// シナリオ名(Offline 分析用)
    pub scenario: String,
    /// Offline 分析に使用するセッション数
    pub max_sessions: usize,
}

impl Default for ProcessorConfig {
    fn default() -> Self {
        Self {
            mode: ProcessorMode::OfflineOnly,
            scenario: "default".to_string(),
            max_sessions: 20,
        }
    }
}

impl ProcessorConfig {
    /// 新しい設定を作成
    pub fn new(scenario: impl Into<String>) -> Self {
        Self {
            scenario: scenario.into(),
            ..Default::default()
        }
    }

    /// 処理モードを設定
    pub fn mode(mut self, mode: ProcessorMode) -> Self {
        self.mode = mode;
        self
    }

    /// 最大セッション数を設定
    pub fn max_sessions(mut self, n: usize) -> Self {
        self.max_sessions = n;
        self
    }
}

// ============================================================================
// Processor
// ============================================================================

/// 学習処理を実行
pub struct Processor {
    /// 設定
    config: ProcessorConfig,
    /// LearningStore(Offline 分析用)
    learning_store: Option<LearningStore>,
    /// LoRA Trainer(LoRA 学習用)
    lora_trainer: Option<LoraTrainer>,
    /// LearnModel(LoRA 学習用)
    learn_model: Option<Arc<dyn LearnModel>>,
}

impl Processor {
    /// 新しい Processor を作成
    pub fn new(config: ProcessorConfig) -> Self {
        Self {
            config,
            learning_store: None,
            lora_trainer: None,
            learn_model: None,
        }
    }

    /// LearningStore を設定(Offline 分析用)
    pub fn with_learning_store(mut self, store: LearningStore) -> Self {
        self.learning_store = Some(store);
        self
    }

    /// LoRA Trainer を設定
    pub fn with_lora_trainer(mut self, trainer: LoraTrainer) -> Self {
        self.lora_trainer = Some(trainer);
        self
    }

    /// LearnModel を設定(LoRA 学習用)
    pub fn with_learn_model(mut self, model: Arc<dyn LearnModel>) -> Self {
        self.learn_model = Some(model);
        self
    }

    /// 設定を取得
    pub fn config(&self) -> &ProcessorConfig {
        &self.config
    }

    /// 学習処理を実行
    pub async fn run(
        &self,
        episode_store: &dyn EpisodeStore,
    ) -> Result<ProcessResult, ProcessorError> {
        tracing::info!(
            mode = %self.config.mode,
            scenario = %self.config.scenario,
            "Starting learning process"
        );

        match self.config.mode {
            ProcessorMode::OfflineOnly => {
                let model = self.run_offline()?;
                Ok(ProcessResult::Offline(model))
            }
            ProcessorMode::LoraOnly => {
                let model = self.run_lora(episode_store).await?;
                Ok(ProcessResult::Lora(model))
            }
            ProcessorMode::Full => {
                let offline = self.run_offline()?;
                let lora = self.run_lora(episode_store).await?;
                Ok(ProcessResult::Full { offline, lora })
            }
        }
    }

    /// Offline 分析を実行
    fn run_offline(&self) -> Result<OfflineModel, ProcessorError> {
        let store = self.learning_store.as_ref().ok_or_else(|| {
            ProcessorError::Other("LearningStore not configured for offline analysis".into())
        })?;

        tracing::info!(
            scenario = %self.config.scenario,
            max_sessions = self.config.max_sessions,
            "Running offline analysis"
        );

        let model = store.run_offline_learning(&self.config.scenario, self.config.max_sessions)?;

        tracing::info!(
            analyzed_sessions = model.analyzed_sessions,
            ucb1_c = model.parameters.ucb1_c,
            "Offline analysis completed"
        );

        Ok(model)
    }

    /// LoRA 学習を実行
    async fn run_lora(
        &self,
        episode_store: &dyn EpisodeStore,
    ) -> Result<TrainedModel, ProcessorError> {
        let trainer = self
            .lora_trainer
            .as_ref()
            .ok_or_else(|| ProcessorError::Other("LoraTrainer not configured".into()))?;

        let learn_model = self.learn_model.as_ref().ok_or_else(|| {
            ProcessorError::Other("LearnModel not configured for LoRA training".into())
        })?;

        // Episode 数を確認
        let episode_count = episode_store.count(None)?;
        if episode_count == 0 {
            return Err(ProcessorError::InsufficientData(
                "No episodes available for LoRA training".into(),
            ));
        }

        tracing::info!(
            episode_count,
            learn_model = learn_model.name(),
            "Running LoRA training"
        );

        let model = trainer.train(learn_model.as_ref(), None).await?;

        tracing::info!(
            model_id = %model.id,
            sample_count = model.sample_count,
            "LoRA training completed"
        );

        Ok(model)
    }
}

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_processor_mode_from_str() {
        assert_eq!(
            "offline".parse::<ProcessorMode>().unwrap(),
            ProcessorMode::OfflineOnly
        );
        assert_eq!(
            "lora".parse::<ProcessorMode>().unwrap(),
            ProcessorMode::LoraOnly
        );
        assert_eq!(
            "full".parse::<ProcessorMode>().unwrap(),
            ProcessorMode::Full
        );
        assert!("invalid".parse::<ProcessorMode>().is_err());
    }

    #[test]
    fn test_processor_config_builder() {
        let config = ProcessorConfig::new("test-scenario")
            .mode(ProcessorMode::Full)
            .max_sessions(50);

        assert_eq!(config.scenario, "test-scenario");
        assert_eq!(config.mode, ProcessorMode::Full);
        assert_eq!(config.max_sessions, 50);
    }

    #[test]
    fn test_process_result_accessors() {
        // Offline only
        let offline_model = OfflineModel::default();
        let result = ProcessResult::Offline(offline_model);
        assert!(result.offline_model().is_some());
        assert!(result.lora_model().is_none());
    }
}