swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! DependencyGraphLearnModel - DependencyGraph 推論の学習モデル
//!
//! ## 概要
//!
//! DependencyGraph 推論(アクション依存関係の推論)を学習するための LearnModel 実装。
//!
//! ## フロー
//!
//! ```text
//! Eval 実行
//!//! LLM 呼び出し → DependencyGraphRecord 記録(prompt, response 含む)
//!//! Eval 終了 → Outcome 確定
//!//! Episode 保存(context に DependencyGraphRecord を含む)
//!//! learn → DpoLearnModel<F> で group_id ベースの DPO ペア生成
//!//! LoRA 生成 → 次回 Eval で適用
//! ```
//!
//! ## 責務
//!
//! - Episode の context から DependencyGraphRecord を取得
//! - Record を評価(成功/失敗判定)
//! - TrainingData への変換(SFT 用)
//!
//! DPO ペア生成は汎用の `DpoLearnModel<F>` を使用する。

use super::{LearnError, LearnModel};
use crate::learn::episode::{EpisodeContext, Outcome};
use crate::learn::record::DependencyGraphRecord;
use crate::learn::training::TrainingData;

/// DependencyGraph 推論の学習モデル
///
/// DependencyGraph 推論の prompt/response を評価し、TrainingData に変換する。
pub struct DependencyGraphLearnModel {
    /// システムプロンプト
    pub system_prompt: String,
}

impl Default for DependencyGraphLearnModel {
    fn default() -> Self {
        Self {
            system_prompt: Self::default_system_prompt().to_string(),
        }
    }
}

impl DependencyGraphLearnModel {
    /// 新しい DependencyGraphLearnModel を作成
    pub fn new() -> Self {
        Self::default()
    }

    /// システムプロンプトを設定
    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.system_prompt = prompt.into();
        self
    }

    /// デフォルトのシステムプロンプト
    pub fn default_system_prompt() -> &'static str {
        r#"You are an expert at analyzing action dependencies for task execution.

Given a list of available actions with their descriptions, determine the correct execution order.

Actions are categorized as:
- NodeExpand (discover): Actions that discover new targets (e.g., Search, List)
- NodeStateChange (not_discover): Actions that modify state (e.g., Read, Submit)

Output the sorted order for each category based on:
1. Dependency relationships (e.g., "requires X first" in descriptions)
2. Logical execution flow (start actions first, terminal actions last)

Format your response as:
discover_order: [action1, action2, ...]
not_discover_order: [action1, action2, ...]"#
    }

    /// Episode から (prompt, response) を抽出する extractor
    ///
    /// DpoLearnModel<F> と組み合わせて使用する。
    pub fn extractor() -> impl Fn(&crate::learn::episode::Episode) -> Option<(String, String)> {
        |episode| {
            episode
                .context
                .first::<DependencyGraphRecord>()
                .map(|record| (record.prompt.clone(), record.response.clone()))
        }
    }
}

impl LearnModel for DependencyGraphLearnModel {
    fn name(&self) -> &str {
        "dependency_graph"
    }

    fn objective(&self) -> &str {
        "Learn correct DependencyGraph inference from action descriptions"
    }

    /// DependencyGraphRecord から Episode を生成
    ///
    /// ## 旧 runner.rs のロジック(参考)
    ///
    /// ```ignore
    /// // runner.rs:680-738 で行っていた処理:
    ///
    /// let scenario_key = self.scenario.meta.id.learning_key();
    /// let ticks = result.total_ticks as u32;
    /// let max_ticks = self.scenario.app_config.max_ticks as u32;
    ///
    /// // Outcome 計算
    /// let outcome = if environment_done {
    ///     let score = 1.0 - (ticks as f64 / max_ticks as f64).min(1.0);
    ///     Outcome::success(score.max(0.01))
    /// } else if timed_out {
    ///     let partial_score = (ticks as f64 / max_ticks as f64).min(0.99);
    ///     Outcome::timeout(Some(partial_score))
    /// } else {
    ///     Outcome::failure(format!("Task not completed at tick {}", ticks))
    /// };
    ///
    /// // Episode 構築
    /// let context = EpisodeContext::new().with_record(record);
    /// let episode = Episode::builder()
    ///     .learn_model("dependency_graph")
    ///     .context(context)
    ///     .outcome(outcome)
    ///     .scenario(&scenario_key)
    ///     .build();
    ///
    /// // JSONL 保存
    /// let path = store.storage().base_dir().join(format!(
    ///     "scenarios/{}/dep_graph_episodes.jsonl", scenario_key
    /// ));
    /// ```
    fn build_episodes(
        &self,
        records: &[crate::learn::record::Record],
    ) -> Vec<crate::learn::episode::Episode> {
        use crate::learn::episode::Episode;
        use crate::learn::record::RecordStream;

        let stream = RecordStream::new(records);
        let mut episodes = Vec::new();

        // 各 DependencyGraphRecord から Episode を生成
        for record in stream.dependency_graphs() {
            // EpisodeContext を構築
            let context = EpisodeContext::new().with_record(record.clone());

            // evaluate() で Outcome を計算
            let outcome = self.evaluate(&context);

            // Episode を構築
            let episode = Episode::builder()
                .learn_model("dependency_graph")
                .context(context)
                .outcome(outcome)
                .build();

            episodes.push(episode);
        }

        episodes
    }

    fn evaluate(&self, context: &EpisodeContext) -> Outcome {
        // EpisodeContext 内の DependencyGraphRecord を評価
        // 設計方針: 1 Episode = 1 Record(単一生成単位)
        let records: Vec<_> = context.iter::<DependencyGraphRecord>().collect();

        if records.is_empty() {
            return Outcome::Unknown;
        }

        // 全 Record が成功なら Success、1つでも失敗なら Failure
        let all_success = records.iter().all(|r| r.is_success());

        if all_success {
            Outcome::success_binary()
        } else {
            Outcome::failure("DependencyGraph inference failed")
        }
    }

    fn convert(
        &self,
        episode: &crate::learn::episode::Episode,
    ) -> Result<TrainingData, LearnError> {
        let record = episode
            .context
            .first::<DependencyGraphRecord>()
            .ok_or_else(|| LearnError::MissingData("DependencyGraphRecord".into()))?;

        let scenario = episode
            .metadata
            .scenario_name
            .as_deref()
            .unwrap_or("unknown");

        Ok(
            TrainingData::sft(&self.system_prompt, &record.prompt, &record.response)
                .with_scenario(scenario)
                .with_episode_id(episode.id.to_string()),
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_system_prompt() {
        let model = DependencyGraphLearnModel::new();
        assert!(model.system_prompt.contains("action dependencies"));
    }

    #[test]
    fn test_extractor() {
        use crate::learn::episode::{Episode, EpisodeContext};

        let record = DependencyGraphRecord::new("test-model")
            .prompt("test prompt")
            .response("test response")
            .discover_order(vec!["A".into()]);

        let context = EpisodeContext::new().with_record(record);
        let episode = Episode::builder()
            .learn_model("dependency_graph")
            .context(context)
            .build();

        let extractor = DependencyGraphLearnModel::extractor();
        let result = extractor(&episode);

        assert!(result.is_some());
        let (prompt, response) = result.unwrap();
        assert_eq!(prompt, "test prompt");
        assert_eq!(response, "test response");
    }

    #[test]
    fn test_build_episodes_success() {
        use crate::learn::record::Record;

        let model = DependencyGraphLearnModel::new();

        // 成功 Record(discover_order が空でない、error が None)
        let success_record = DependencyGraphRecord::new("test-model")
            .prompt("What is the dependency order?")
            .response("discover_order: [A, B]")
            .discover_order(vec!["A".into(), "B".into()]);

        let records = vec![Record::from(success_record)];
        let episodes = model.build_episodes(&records);

        assert_eq!(episodes.len(), 1);
        assert!(episodes[0].outcome.is_success());
    }

    #[test]
    fn test_build_episodes_failure() {
        use crate::learn::record::Record;

        let model = DependencyGraphLearnModel::new();

        // 失敗 Record(error がある)
        let failure_record = DependencyGraphRecord::new("test-model")
            .prompt("What is the dependency order?")
            .response("parse error")
            .error("Failed to parse response");

        let records = vec![Record::from(failure_record)];
        let episodes = model.build_episodes(&records);

        assert_eq!(episodes.len(), 1);
        assert!(episodes[0].outcome.is_failure());
    }

    #[test]
    fn test_build_episodes_mixed() {
        use crate::learn::record::Record;

        let model = DependencyGraphLearnModel::new();

        // 成功と失敗の混合
        let success = DependencyGraphRecord::new("model").discover_order(vec!["A".into()]);
        let failure = DependencyGraphRecord::new("model").error("error");

        let records = vec![Record::from(success), Record::from(failure)];
        let episodes = model.build_episodes(&records);

        assert_eq!(episodes.len(), 2);

        let success_count = episodes.iter().filter(|e| e.outcome.is_success()).count();
        let failure_count = episodes.iter().filter(|e| e.outcome.is_failure()).count();

        assert_eq!(success_count, 1);
        assert_eq!(failure_count, 1);
    }

    #[test]
    fn test_build_episodes_ignores_other_records() {
        use crate::learn::record::{ActionRecord, Record};

        let model = DependencyGraphLearnModel::new();

        // ActionRecord は無視される
        let action = ActionRecord::new(1, 0, "CheckStatus").success(true);
        let dep_graph = DependencyGraphRecord::new("model").discover_order(vec!["A".into()]);

        let records = vec![Record::from(action), Record::from(dep_graph)];
        let episodes = model.build_episodes(&records);

        // DependencyGraphRecord のみが Episode になる
        assert_eq!(episodes.len(), 1);
    }
}