swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Worker Decision 学習モデル(シーケンスベース)

use std::collections::HashMap;

use super::super::episode::{Episode, EpisodeContext, Outcome};
use super::super::record::{ActionRecord, Record};
use super::super::training::TrainingData;
use super::{system_events, LearnError, LearnModel};
use crate::types::TaskId;

/// Worker Decision 学習モデル(シーケンスベース)
///
/// **ターゲット**: Worker の Decision(アクション選択)を改善
/// **手法**: 成功シーケンスを使った学習
///
/// ## 学習目的
///
/// Worker の Decider(アクション決定LLM)を fine-tuning するためのデータ生成。
/// 成功したアクションシーケンスから「どの順序でアクションを実行すべきか」を学習。
///
/// ## プロンプト形式
///
/// - **入力**: コンテキスト + 利用可能なアクション一覧
/// - **出力**: 最適なアクションシーケンス
///
/// ## 用途
///
/// - LoRA fine-tuning 用のデータ生成(Decider向け)
/// - 成功パターンの抽出と再利用
pub struct WorkerDecisionSequenceLearn {
    /// システムプロンプト
    system_prompt: String,
    /// 最小アクション数(これ以下は無視)
    min_actions: usize,
    /// 利用可能なアクション一覧(プロンプト生成用)
    available_actions: Vec<String>,
    /// システムイベント(フィルタ対象)
    system_events: Vec<String>,
}

impl WorkerDecisionSequenceLearn {
    pub fn new() -> Self {
        Self {
            system_prompt: "You are an intelligent agent that diagnoses and resolves system issues. \
                           Given a context and available actions, determine the optimal action sequence.".to_string(),
            min_actions: 3,
            available_actions: vec![
                "CheckStatus".to_string(),
                "ReadLogs".to_string(),
                "AnalyzeMetrics".to_string(),
                "Diagnose".to_string(),
                "Restart".to_string(),
            ],
            system_events: system_events::DEFAULT_SYSTEM_EVENTS
                .iter()
                .map(|s| s.to_string())
                .collect(),
        }
    }

    /// システムプロンプトを設定
    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.system_prompt = prompt.into();
        self
    }

    /// 最小アクション数を設定
    pub fn with_min_actions(mut self, min: usize) -> Self {
        self.min_actions = min;
        self
    }

    /// 利用可能なアクション一覧を設定
    pub fn with_available_actions(mut self, actions: Vec<String>) -> Self {
        self.available_actions = actions;
        self
    }

    /// システムイベント(フィルタ対象)を追加
    pub fn with_system_event(mut self, event: impl Into<String>) -> Self {
        self.system_events.push(event.into());
        self
    }

    /// アクションがシステムイベントかどうか判定
    fn is_system_event(&self, action: &str) -> bool {
        self.system_events.iter().any(|e| e == action)
    }
}

impl Default for WorkerDecisionSequenceLearn {
    fn default() -> Self {
        Self::new()
    }
}

impl LearnModel for WorkerDecisionSequenceLearn {
    fn name(&self) -> &str {
        "worker_decision_sequence"
    }

    fn objective(&self) -> &str {
        "Learn successful action sequences for problem resolution"
    }

    fn evaluate(&self, context: &EpisodeContext) -> Outcome {
        // 空のコンテキストは評価不能
        if context.is_empty() {
            return Outcome::failure("Empty context: no actions to evaluate");
        }

        // 成功したアクション(システムイベント除く)をカウント
        let successful_actions: Vec<_> = context
            .iter::<ActionRecord>()
            .filter(|a| a.success && !self.is_system_event(&a.action))
            .collect();

        if successful_actions.len() >= self.min_actions {
            Outcome::success(1.0)
        } else {
            Outcome::failure(format!(
                "Insufficient successful actions: {} < {}",
                successful_actions.len(),
                self.min_actions
            ))
        }
    }

    fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
        // TaskId ごとにグルーピング(成功アクション、システムイベント除く)
        let mut task_actions: HashMap<TaskId, Vec<&ActionRecord>> = HashMap::new();
        for record in records.iter().filter_map(Record::as_action) {
            if record.success && !self.is_system_event(&record.action) {
                task_actions.entry(record.task_id).or_default().push(record);
            }
        }

        let mut episodes = Vec::new();

        for (task_id, successful_actions) in task_actions {
            if successful_actions.len() < self.min_actions {
                continue;
            }

            let group_id = successful_actions.first().and_then(|r| r.group_id);

            // TaskId ごとに 1 Episode を構築
            let mut context = EpisodeContext::new();
            for action in &successful_actions {
                context.push((*action).clone());
            }

            let outcome = self.evaluate(&context);

            let mut builder = Episode::builder()
                .learn_model("worker_decision_sequence")
                .task_id(task_id)
                .context(context)
                .outcome(outcome);

            if let Some(gid) = group_id {
                builder = builder.group_id(gid);
            }

            episodes.push(builder.build());
        }

        episodes
    }

    fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
        if !episode.outcome.is_success() {
            return Err(LearnError::InvalidEpisode(
                "Episode is not successful".into(),
            ));
        }

        let actions: Vec<&str> = episode
            .context
            .iter::<ActionRecord>()
            .map(|a| a.action.as_str())
            .collect();

        if actions.len() < self.min_actions {
            return Err(LearnError::InvalidEpisode(format!(
                "Too few actions: {} < {}",
                actions.len(),
                self.min_actions
            )));
        }

        // プロンプト生成
        let available = self.available_actions.join(", ");
        let prompt = format!(
            "Current context: default\n\
             Available actions: {}\n\n\
             What is the best sequence of actions to resolve this issue?",
            available
        );

        // レスポンス生成
        let action_sequence = actions.join(" -> ");
        let response = format!(
            "Based on the context, the optimal action sequence is: {}",
            action_sequence
        );

        Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
            .with_episode_id(episode.id.to_string())
            .with_outcome_score(episode.outcome.score()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
    use crate::types::WorkerId;
    use std::time::Duration;

    fn make_action_with_task(
        tick: u64,
        worker_id: usize,
        action: &str,
        success: bool,
        task_id: TaskId,
    ) -> ActionEvent {
        let result = if success {
            ActionEventResult::success()
        } else {
            ActionEventResult::failure("error")
        };

        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
            .task_id(task_id)
            .result(result)
            .duration(Duration::from_millis(10))
            .context(ActionContext::new())
            .build()
    }

    fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
        actions.iter().map(Record::from).collect()
    }

    #[test]
    fn test_worker_decision_sequence_learn_build_episodes() {
        let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);

        // All actions belong to the same task
        let task_id = TaskId::new();

        let actions = vec![
            make_action_with_task(1, 0, "CheckStatus", true, task_id),
            make_action_with_task(2, 0, "ReadLogs", true, task_id),
            make_action_with_task(3, 0, "Restart", true, task_id),
            make_action_with_task(4, 0, "tick_end", true, task_id), // system event, should be filtered
            make_action_with_task(5, 0, "done", true, task_id), // system event, should be filtered
        ];
        let records = make_records(&actions);

        let episodes = learn.build_episodes(&records);

        // 1 episode(成功アクション3つ: CheckStatus, ReadLogs, Restart)
        assert_eq!(episodes.len(), 1);
        assert!(episodes[0].outcome.is_success());
        assert_eq!(episodes[0].task_id, Some(task_id));

        // システムイベントは除外されている
        let action_count = episodes[0].context.iter::<ActionRecord>().count();
        assert_eq!(action_count, 3);
    }

    #[test]
    fn test_worker_decision_sequence_learn_filters_failed_actions() {
        let learn = WorkerDecisionSequenceLearn::new().with_min_actions(3);

        let task_id = TaskId::new();

        let actions = vec![
            make_action_with_task(1, 0, "CheckStatus", true, task_id),
            make_action_with_task(2, 0, "ReadLogs", false, task_id), // failed, should be filtered
            make_action_with_task(3, 0, "Restart", true, task_id),
        ];
        let records = make_records(&actions);

        let episodes = learn.build_episodes(&records);

        // 成功アクションが2つしかないので Episode は生成されない
        assert_eq!(episodes.len(), 0);
    }

    #[test]
    fn test_worker_decision_sequence_learn_convert() {
        let learn = WorkerDecisionSequenceLearn::new().with_available_actions(vec![
            "A".to_string(),
            "B".to_string(),
            "C".to_string(),
        ]);

        let episode = Episode::builder()
            .learn_model("worker_decision_sequence")
            .record(ActionRecord::new(1, 0, "A").success(true))
            .record(ActionRecord::new(2, 0, "B").success(true))
            .record(ActionRecord::new(3, 0, "C").success(true))
            .outcome(Outcome::success(1.0))
            .build();

        let result = learn.convert(&episode);
        assert!(result.is_ok());

        let training_data = result.unwrap();
        assert!(training_data.prompt.contains("Available actions: A, B, C"));
        assert!(training_data.chosen.contains("A -> B -> C"));
    }
}