swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Worker タスク完了ベースの LearnModel

use std::collections::HashMap;

use super::super::episode::{Episode, EpisodeContext, Outcome};
use super::super::record::{ActionRecord, Record, RecordStream};
use super::super::training::TrainingData;
use super::{LearnError, LearnModel};
use crate::types::TaskId;

/// Worker タスク完了ベースの LearnModel
///
/// TaskId ごとにアクション列をグループ化し、done で終わるシーケンスを Episode として構築。
pub struct WorkerTaskLearn {
    /// システムプロンプト
    system_prompt: String,
    /// 最小アクション数
    min_actions: usize,
}

impl WorkerTaskLearn {
    pub fn new() -> Self {
        Self {
            system_prompt:
                "You are an intelligent agent that diagnoses and resolves system issues."
                    .to_string(),
            min_actions: 2,
        }
    }

    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.system_prompt = prompt.into();
        self
    }

    pub fn with_min_actions(mut self, min: usize) -> Self {
        self.min_actions = min;
        self
    }
}

impl Default for WorkerTaskLearn {
    fn default() -> Self {
        Self::new()
    }
}

impl LearnModel for WorkerTaskLearn {
    fn name(&self) -> &str {
        "worker_task"
    }

    fn objective(&self) -> &str {
        "Learn complete worker task sequences from start to done"
    }

    fn evaluate(&self, context: &EpisodeContext) -> Outcome {
        // 空のコンテキストは評価不能
        if context.is_empty() {
            return Outcome::failure("Empty context: no actions to evaluate");
        }

        // 最後のアクションが done かつ success なら成功
        let last_action = context.iter::<ActionRecord>().last();

        match last_action {
            Some(action) if action.is_terminal() => {
                if action.success {
                    Outcome::success_binary()
                } else {
                    Outcome::failure("Task failed")
                }
            }
            _ => Outcome::Unknown,
        }
    }

    fn build_episodes(&self, records: &[Record]) -> Vec<Episode> {
        let stream = RecordStream::new(records);

        // TaskId ごとにグルーピング(Action Records のみ)
        let mut task_actions: HashMap<TaskId, Vec<&ActionRecord>> = HashMap::new();
        for record in stream.actions() {
            task_actions.entry(record.task_id).or_default().push(record);
        }

        let mut episodes = Vec::new();

        for (task_id, task_records) in task_actions {
            // done で終わるシーケンスを探す
            // (1 Task 内で複数の done がある場合は複数の Episode を生成)
            let mut current_sequence: Vec<&ActionRecord> = Vec::new();
            let group_id = task_records.first().and_then(|r| r.group_id);

            for record in task_records {
                current_sequence.push(record);

                if record.is_terminal() {
                    // シーケンス完了
                    if current_sequence.len() >= self.min_actions {
                        // context を構築
                        let mut context = EpisodeContext::new();
                        for r in &current_sequence {
                            context.push((*r).clone());
                        }

                        // evaluate で outcome を判定
                        let outcome = self.evaluate(&context);

                        let mut builder = Episode::builder()
                            .learn_model("worker_task")
                            .task_id(task_id)
                            .context(context)
                            .outcome(outcome);

                        if let Some(gid) = group_id {
                            builder = builder.group_id(gid);
                        }

                        episodes.push(builder.build());
                    }

                    current_sequence.clear();
                }
            }
        }

        episodes
    }

    fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError> {
        if !episode.outcome.is_success() {
            return Err(LearnError::InvalidEpisode(
                "Episode is not successful".into(),
            ));
        }

        let action_count = episode.context.iter::<ActionRecord>().count();
        if action_count < self.min_actions {
            return Err(LearnError::InvalidEpisode(format!(
                "Too few actions: {} < {}",
                action_count, self.min_actions
            )));
        }

        let actions: Vec<&str> = episode
            .context
            .iter::<ActionRecord>()
            .filter(|a| !a.is_terminal())
            .map(|a| a.action.as_str())
            .collect();

        let prompt = format!(
            "Diagnose and resolve the issue.\nAvailable actions: {}",
            actions.join(", ")
        );

        let response = format!("Execute the following sequence: {}", actions.join(" -> "));

        Ok(TrainingData::sft(&self.system_prompt, &prompt, &response)
            .with_episode_id(episode.id.to_string())
            .with_outcome_score(episode.outcome.score()))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::events::{ActionContext, ActionEvent, ActionEventBuilder, ActionEventResult};
    use crate::types::WorkerId;
    use std::time::Duration;

    fn make_action_with_task(
        tick: u64,
        worker_id: usize,
        action: &str,
        success: bool,
        task_id: TaskId,
    ) -> ActionEvent {
        let result = if success {
            ActionEventResult::success()
        } else {
            ActionEventResult::failure("error")
        };

        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
            .task_id(task_id)
            .result(result)
            .duration(Duration::from_millis(10))
            .context(ActionContext::new())
            .build()
    }

    fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
        actions.iter().map(Record::from).collect()
    }

    #[test]
    fn test_worker_task_learn_build_episodes() {
        let learn = WorkerTaskLearn::new().with_min_actions(2);

        // Task 1: success (3 actions including done)
        let task1_id = TaskId::new();
        // Task 2: failure (2 actions including done)
        let task2_id = TaskId::new();

        let actions = vec![
            make_action_with_task(1, 0, "CheckStatus", true, task1_id),
            make_action_with_task(2, 0, "ReadLogs", true, task1_id),
            make_action_with_task(3, 0, "done", true, task1_id),
            make_action_with_task(4, 1, "Grep", true, task2_id),
            make_action_with_task(5, 1, "done", false, task2_id),
        ];
        let records = make_records(&actions);

        let episodes = learn.build_episodes(&records);

        // Task 1: success, Task 2: failure
        assert_eq!(episodes.len(), 2);

        let task1_ep = episodes.iter().find(|ep| ep.task_id == Some(task1_id));
        assert!(task1_ep.is_some());
        assert!(task1_ep.unwrap().outcome.is_success());

        let task2_ep = episodes.iter().find(|ep| ep.task_id == Some(task2_id));
        assert!(task2_ep.is_some());
        assert!(task2_ep.unwrap().outcome.is_failure());
    }

    #[test]
    fn test_worker_task_learn_convert_success_only() {
        let learn = WorkerTaskLearn::new();

        // 失敗 Episode は変換エラー
        let failed_ep = Episode::builder()
            .learn_model("worker_task")
            .outcome(Outcome::failure("test"))
            .build();

        assert!(learn.convert(&failed_ep).is_err());

        // 成功 Episode は変換可能
        let success_ep = Episode::builder()
            .learn_model("worker_task")
            .record(ActionRecord::new(1, 0, "Check").success(true))
            .record(ActionRecord::new(2, 0, "Fix").success(true))
            .record(ActionRecord::new(3, 0, "done").success(true))
            .outcome(Outcome::success_binary())
            .build();

        assert!(learn.convert(&success_ep).is_ok());
    }
}