swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Thompson Selection - Thompson Sampling 選択
//!
//! Beta 分布の期待値に基づいて選択する bandit アルゴリズム。
//! 学習ボーナスは Provider から取得(コンテキスト付き)。

use std::fmt::Debug;

use super::SelectionLogic;
use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
use crate::exploration::mutation::ActionExtractor;
use crate::learn::{LearnedProvider, LearningQuery};
use crate::online_stats::SwarmStats;

/// Thompson Sampling 選択
///
/// Beta(successes + 1, failures + 1) の期待値 + 学習ボーナス でスコアリング。
///
/// 本来の Thompson Sampling は Beta 分布からサンプリングするが、
/// 決定論的にするため期待値を使用。
#[derive(Debug, Clone)]
pub struct Thompson {
    /// 学習ボーナス係数
    learning_weight: f64,
}

impl Default for Thompson {
    fn default() -> Self {
        Self {
            learning_weight: 0.3,
        }
    }
}

impl Thompson {
    pub fn new() -> Self {
        Self::default()
    }

    /// 学習ボーナス係数を指定して生成
    pub fn with_learning_weight(learning_weight: f64) -> Self {
        Self { learning_weight }
    }

    /// 学習ボーナス係数を取得
    pub fn learning_weight(&self) -> f64 {
        self.learning_weight
    }

    /// Thompson スコアを計算(コンテキストなし、後方互換用)
    pub fn compute_score(
        &self,
        stats: &SwarmStats,
        action: &str,
        target: Option<&str>,
        provider: &dyn LearnedProvider,
    ) -> f64 {
        self.compute_score_with_context(stats, action, target, provider, None, None)
    }

    /// Thompson スコアを計算(コンテキスト付き)
    ///
    /// Provider に confidence_with_context クエリを投げて学習ボーナスを取得。
    pub fn compute_score_with_context(
        &self,
        stats: &SwarmStats,
        action: &str,
        target: Option<&str>,
        provider: &dyn LearnedProvider,
        prev_action: Option<&str>,
        prev_prev_action: Option<&str>,
    ) -> f64 {
        let action_stats = match target {
            Some(t) => stats.get_action_target_stats(action, t),
            None => stats.get_action_stats(action),
        };

        // Beta 分布の期待値: α / (α + β)
        let alpha = action_stats.successes as f64 + 1.0;
        let beta = action_stats.failures as f64 + 1.0;
        let base_score = alpha / (alpha + beta);

        // Provider から学習ボーナスを取得(コンテキスト付き)
        let learning_bonus = provider
            .query(LearningQuery::confidence_with_context(
                action,
                target,
                prev_action,
                prev_prev_action,
            ))
            .score();

        base_score + learning_bonus * self.learning_weight
    }
}

impl<N, E, S> SelectionLogic<N, E, S> for Thompson
where
    N: Debug + Clone + ActionExtractor,
    E: Debug + Clone,
    S: MapState,
{
    fn next(
        &self,
        map: &GraphMap<N, E, S>,
        stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> Option<MapNodeId> {
        self.select(map, 1, stats, provider).into_iter().next()
    }

    fn select(
        &self,
        map: &GraphMap<N, E, S>,
        count: usize,
        stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> Vec<MapNodeId> {
        let frontiers = map.frontiers();
        if frontiers.is_empty() || count == 0 {
            return vec![];
        }

        // Thompson スコアでソート(親ノードの情報を活用)
        let mut scored: Vec<_> = frontiers
            .iter()
            .filter_map(|&id| {
                map.get(id).map(|node| {
                    let (action, target) = node.data.extract();

                    // 親ノードのアクションを取得(学習ボーナス用)
                    let parent_id = map.parent(id);
                    let prev_action = parent_id.and_then(|pid| {
                        map.get(pid)
                            .map(|parent_node| parent_node.data.action_name().to_string())
                    });

                    // 親の親のアクションを取得(N-gram ボーナス用)
                    let prev_prev_action =
                        parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
                            map.get(ppid).map(|grandparent_node| {
                                grandparent_node.data.action_name().to_string()
                            })
                        });

                    let score = self.compute_score_with_context(
                        stats,
                        action,
                        target,
                        provider,
                        prev_action.as_deref(),
                        prev_prev_action.as_deref(),
                    );
                    (id, score)
                })
            })
            .collect();

        // スコア降順でソート(高スコアを優先)
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        scored.into_iter().take(count).map(|(id, _)| id).collect()
    }

    fn score(
        &self,
        action: &str,
        target: Option<&str>,
        stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> f64 {
        self.compute_score(stats, action, target, provider)
    }

    fn name(&self) -> &str {
        "Thompson"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::events::{ActionEventBuilder, ActionEventResult};
    use crate::learn::NullProvider;
    use crate::types::WorkerId;

    fn record_success(stats: &mut SwarmStats, action: &str) {
        let event = ActionEventBuilder::new(0, WorkerId(0), action)
            .result(ActionEventResult::success())
            .build();
        stats.record(&event);
    }

    fn record_failure(stats: &mut SwarmStats, action: &str) {
        let event = ActionEventBuilder::new(0, WorkerId(0), action)
            .result(ActionEventResult::failure("error"))
            .build();
        stats.record(&event);
    }

    #[test]
    fn test_thompson_initial_score() {
        let thompson = Thompson::new();
        let stats = SwarmStats::new();
        let provider = NullProvider;

        // Beta(1, 1) の期待値 = 0.5, bonus = 0
        assert_eq!(thompson.compute_score(&stats, "grep", None, &provider), 0.5);
    }

    #[test]
    fn test_thompson_score_increases_with_success() {
        let thompson = Thompson::new();
        let mut stats = SwarmStats::new();
        let provider = NullProvider;

        let initial = thompson.compute_score(&stats, "grep", None, &provider);

        record_success(&mut stats, "grep");
        let after_success = thompson.compute_score(&stats, "grep", None, &provider);

        // 成功するとスコアが上がる
        assert!(after_success > initial);
    }

    #[test]
    fn test_thompson_score_decreases_with_failure() {
        let thompson = Thompson::new();
        let mut stats = SwarmStats::new();
        let provider = NullProvider;

        let initial = thompson.compute_score(&stats, "grep", None, &provider);

        record_failure(&mut stats, "grep");
        let after_failure = thompson.compute_score(&stats, "grep", None, &provider);

        // 失敗するとスコアが下がる
        assert!(after_failure < initial);
    }

    #[test]
    fn test_thompson_with_learning_weight() {
        let thompson = Thompson::with_learning_weight(0.5);
        assert!((thompson.learning_weight() - 0.5).abs() < 1e-10);
    }
}