swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! LearnModel - 学習の統合モデル
//!
//! ## 設計思想
//!
//! LearnModel は「何を学習するか」を統合的に定義する。
//!
//! - **何を目的とするか** (objective)
//! - **何を Episode として切り出すか** (build_episodes)
//! - **何を Success/Failure とするか** (evaluate)
//! - **どう TrainingData に変換するか** (convert)
//!
//! ## Learn の価値
//!
//! Core(Swarm本体)は性能制約で 3-gram までしか取れない。
//! しかし Learn は非同期/オフラインなので、5-gram や 10-gram など
//! 自由に分析できる。これが Learn モジュールの価値。
//!
//! ## モジュール構造
//!
//! ```text
//! learn_model/
//! ├── mod.rs              # LearnModel trait 定義
//! ├── dpo.rs              # 汎用 DPO 基盤 (DpoLearnModel<F>)
//! ├── dependency_graph.rs # DependencyGraph 推論の学習
//! ├── worker_task.rs      # Worker の Task 完了パターン学習
//! ├── worker_decision.rs  # Worker の意思決定シーケンス学習
//! └── error.rs            # エラー型
//! ```
//!
//! ## 実装一覧
//!
//! | 実装 | 目的 | グルーピング |
//! |------|------|-------------|
//! | [`DpoLearnModel`] | 汎用 DPO 学習 | `group_id` |
//! | [`DependencyGraphLearnModel`] | DependencyGraph 推論 | - |
//! | [`WorkerTaskLearn`] | Task 完了パターン | `task_id` |
//! | [`WorkerDecisionSequenceLearn`] | 意思決定シーケンス | `worker_id` |
//!
//! ## DPO (Direct Preference Optimization)
//!
//! DPO は成功/失敗 Episode のペアから学習する手法。
//! `group_id` でグルーピングし、カスタム `extractor` で prompt/response を抽出。
//!
//! [`DependencyGraphLearnModel::extractor()`] を使って DPO を行う例:
//!
//! ```ignore
//! let dpo = DpoLearnModel::new(
//!     DependencyGraphLearnModel::default_system_prompt(),
//!     DependencyGraphLearnModel::extractor(),
//! );
//! let pairs = dpo.build_pairs(&episodes);
//! ```
//!
//! ## Record による抽象化
//!
//! ActionEvent と LlmDebugEvent を `Record` enum で統一的に扱う。
//! LearnModel は Record のストリームから Episode を構築する。
//!
//! ```text
//! ActionEvent ──┐
//!               ├──▶ Vec<Record> ──▶ LearnModel.build_episodes()
//! LlmDebugEvent ┘                         ↓
//!                                    Vec<Episode>
//!//!                                    LearnModel.convert()
//!//!                                    TrainingData
//! ```

mod dependency_graph;
mod dpo;
mod error;
mod worker_decision;
mod worker_task;

pub use dependency_graph::DependencyGraphLearnModel;
pub use dpo::{DpoConfig, DpoLearnModel, DpoPair};
pub use error::LearnError;
pub use worker_decision::WorkerDecisionSequenceLearn;
pub use worker_task::WorkerTaskLearn;

use crate::events::ActionEvent;

use super::episode::{Episode, EpisodeContext, Outcome};
use super::record::Record;
use super::training::TrainingData;

// ============================================================================
// System Event Constants
// ============================================================================

/// システムイベント定数
pub mod system_events {
    /// Tick 開始イベント
    pub const TICK_START: &str = "tick_start";
    /// Tick 終了イベント
    pub const TICK_END: &str = "tick_end";
    /// タスク完了イベント
    pub const DONE: &str = "done";

    /// デフォルトのシステムイベント一覧
    pub const DEFAULT_SYSTEM_EVENTS: &[&str] = &[TICK_START, TICK_END, DONE];
}

// ============================================================================
// LearnModel Trait
// ============================================================================

/// 学習の統合モデル
///
/// 何を学習対象とし、何を成功とするかを統合的に定義する。
/// Record[] から Episode を構築し、TrainingData に変換するまでの全責務を担う。
///
/// ## Record による統一インターフェース
///
/// ActionEvent も LlmDebugEvent も `Record` として統一的に扱う。
/// これにより:
/// - ActionEvent ベースの Learn
/// - LlmDebugEvent ベースの Learn
/// - 両方を混ぜた Learn
///
/// 全て同じインターフェースで実装可能。
pub trait LearnModel: Send + Sync {
    /// 名前
    fn name(&self) -> &str;

    /// 目的を表す説明
    fn objective(&self) -> &str;

    /// Record のストリームから Episode を構築
    ///
    /// N-gram、Worker単位、任意のグルーピングが可能。
    /// Core が 3-gram までしか取れなくても、Learn は 5-gram や 10-gram を
    /// 自由に構築できる。
    fn build_episodes(&self, records: &[Record]) -> Vec<Episode>;

    /// Records から Success/Failure を判定
    ///
    /// 純粋なロジック: EpisodeContext (Records) → Outcome
    /// build_episodes() 内でこれを呼んで Episode.outcome を設定する。
    fn evaluate(&self, context: &EpisodeContext) -> Outcome;

    /// Episode を TrainingData に変換
    fn convert(&self, episode: &Episode) -> Result<TrainingData, LearnError>;

    /// 複数 Episode を一括変換(デフォルト実装)
    fn convert_batch(&self, episodes: &[Episode]) -> Vec<TrainingData> {
        episodes
            .iter()
            .filter_map(|ep| self.convert(ep).ok())
            .collect()
    }

    /// 便利メソッド: ActionEvent[] から直接変換
    fn build_episodes_from_actions(&self, actions: &[ActionEvent]) -> Vec<Episode> {
        let records: Vec<Record> = actions.iter().map(Record::from).collect();
        self.build_episodes(&records)
    }
}

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
mod tests {
    use super::*;
    use crate::events::{ActionContext, ActionEventBuilder, ActionEventResult};
    use crate::learn::record::RecordStream;
    use crate::types::WorkerId;
    use std::time::Duration;

    fn make_action(tick: u64, worker_id: usize, action: &str, success: bool) -> ActionEvent {
        let result = if success {
            ActionEventResult::success()
        } else {
            ActionEventResult::failure("error")
        };

        ActionEventBuilder::new(tick, WorkerId(worker_id), action)
            .result(result)
            .duration(Duration::from_millis(10))
            .context(ActionContext::new())
            .build()
    }

    fn make_records(actions: &[ActionEvent]) -> Vec<Record> {
        actions.iter().map(Record::from).collect()
    }

    #[test]
    fn test_record_accessors() {
        let action = make_action(1, 5, "CheckStatus", true);
        let record = Record::from(&action);

        assert!(record.is_action());
        assert!(!record.is_llm());
        assert_eq!(record.worker_id(), Some(5));
        assert!(record.as_action().is_some());
        assert!(record.as_llm().is_none());
    }

    #[test]
    fn test_record_stream_group_by_worker() {
        let actions = vec![
            make_action(1, 0, "A", true),
            make_action(2, 1, "B", true),
            make_action(3, 0, "C", true),
            make_action(4, 1, "D", true),
        ];
        let records = make_records(&actions);
        let stream = RecordStream::new(&records);

        let groups = stream.group_by_worker();
        assert_eq!(groups.len(), 2);
        assert_eq!(groups.get(&0).map(|v| v.len()), Some(2));
        assert_eq!(groups.get(&1).map(|v| v.len()), Some(2));
    }
}