swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! UCB1 Selection - Upper Confidence Bound 選択
//!
//! 探索と活用のバランスを取る bandit アルゴリズム。
//! 学習ボーナスは Provider から取得(コンテキスト付き)。

use std::fmt::Debug;

use super::SelectionLogic;
use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
use crate::exploration::mutation::ActionExtractor;
use crate::learn::{LearnedProvider, LearningQuery};
use crate::online_stats::SwarmStats;

/// UCB1 選択
///
/// score = success_rate + c * sqrt(ln(total_visits) / visits) + learned_bonus
///
/// - 成功率が高いノードを優先(活用)
/// - 訪問回数が少ないノードにボーナス(探索)
/// - Provider から学習済みボーナスを加算
#[derive(Debug, Clone)]
pub struct Ucb1 {
    /// 探索係数(大きいほど未探索を優先)
    c: f64,
    /// 学習ボーナス係数
    learning_weight: f64,
}

impl Default for Ucb1 {
    fn default() -> Self {
        Self {
            c: std::f64::consts::SQRT_2,
            learning_weight: 0.3,
        }
    }
}

impl Ucb1 {
    /// 指定した探索係数で生成
    pub fn new(c: f64) -> Self {
        Self {
            c,
            learning_weight: 0.3,
        }
    }

    /// 探索係数と学習ボーナス係数を指定して生成
    pub fn with_weights(c: f64, learning_weight: f64) -> Self {
        Self { c, learning_weight }
    }

    /// デフォルトの探索係数(√2)で生成
    pub fn with_default_c() -> Self {
        Self::default()
    }

    /// 探索係数を取得
    pub fn c(&self) -> f64 {
        self.c
    }

    /// 学習ボーナス係数を取得
    pub fn learning_weight(&self) -> f64 {
        self.learning_weight
    }

    /// UCB1 スコアを計算(コンテキストなし、後方互換用)
    pub fn compute_score(
        &self,
        stats: &SwarmStats,
        action: &str,
        target: Option<&str>,
        provider: &dyn LearnedProvider,
    ) -> f64 {
        self.compute_score_with_context(stats, action, target, provider, None, None)
    }

    /// UCB1 スコアを計算(コンテキスト付き)
    ///
    /// Provider に confidence_with_context クエリを投げて学習ボーナスを取得。
    pub fn compute_score_with_context(
        &self,
        stats: &SwarmStats,
        action: &str,
        target: Option<&str>,
        provider: &dyn LearnedProvider,
        prev_action: Option<&str>,
        prev_prev_action: Option<&str>,
    ) -> f64 {
        let action_stats = match target {
            Some(t) => stats.get_action_target_stats(action, t),
            None => stats.get_action_stats(action),
        };
        let total = stats.total_visits().max(1);

        if action_stats.visits == 0 {
            return f64::INFINITY; // 未訪問は必ず選択
        }

        let success_rate = action_stats.success_rate();
        let exploration = self.c * ((total as f64).ln() / action_stats.visits as f64).sqrt();

        // Provider から学習ボーナスを取得(コンテキスト付き)
        let learning_bonus = provider
            .query(LearningQuery::confidence_with_context(
                action,
                target,
                prev_action,
                prev_prev_action,
            ))
            .score();

        success_rate + exploration + learning_bonus * self.learning_weight
    }
}

impl<N, E, S> SelectionLogic<N, E, S> for Ucb1
where
    N: Debug + Clone + ActionExtractor,
    E: Debug + Clone,
    S: MapState,
{
    fn next(
        &self,
        map: &GraphMap<N, E, S>,
        stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> Option<MapNodeId> {
        self.select(map, 1, stats, provider).into_iter().next()
    }

    fn select(
        &self,
        map: &GraphMap<N, E, S>,
        count: usize,
        stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> Vec<MapNodeId> {
        let frontiers = map.frontiers();
        if frontiers.is_empty() || count == 0 {
            return vec![];
        }

        // UCB1 スコア計算してソート(親ノードの情報を活用)
        let mut scored: Vec<_> = frontiers
            .iter()
            .filter_map(|&id| {
                map.get(id).map(|node| {
                    let (action, target) = node.data.extract();

                    // 親ノードのアクションを取得(学習ボーナス用)
                    let parent_id = map.parent(id);
                    let prev_action = parent_id.and_then(|pid| {
                        map.get(pid)
                            .map(|parent_node| parent_node.data.action_name().to_string())
                    });

                    // 親の親のアクションを取得(N-gram ボーナス用)
                    let prev_prev_action =
                        parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
                            map.get(ppid).map(|grandparent_node| {
                                grandparent_node.data.action_name().to_string()
                            })
                        });

                    let score = self.compute_score_with_context(
                        stats,
                        action,
                        target,
                        provider,
                        prev_action.as_deref(),
                        prev_prev_action.as_deref(),
                    );
                    (id, score)
                })
            })
            .collect();

        // スコア降順でソート(高スコアを優先)
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        scored.into_iter().take(count).map(|(id, _)| id).collect()
    }

    fn score(
        &self,
        action: &str,
        target: Option<&str>,
        stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> f64 {
        self.compute_score(stats, action, target, provider)
    }

    fn name(&self) -> &str {
        "UCB1"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::events::ActionEventBuilder;
    use crate::learn::NullProvider;

    fn record_success(stats: &mut SwarmStats, action: &str, target: Option<&str>) {
        use crate::events::ActionEventResult;
        use crate::types::WorkerId;

        let mut builder = ActionEventBuilder::new(0, WorkerId(0), action);
        if let Some(t) = target {
            builder = builder.target(t);
        }
        let event = builder.result(ActionEventResult::success()).build();
        stats.record(&event);
    }

    #[test]
    fn test_ucb1_unvisited_is_infinity() {
        let ucb1 = Ucb1::new(1.41);
        let stats = SwarmStats::new();
        let provider = NullProvider;

        // 未訪問ノードは INFINITY
        assert!(ucb1
            .compute_score(&stats, "grep", None, &provider)
            .is_infinite());
    }

    #[test]
    fn test_ucb1_score_changes_with_visits() {
        let ucb1 = Ucb1::new(1.41);
        let mut stats = SwarmStats::new();
        let provider = NullProvider;

        // 最初は INFINITY
        let score1 = ucb1.compute_score(&stats, "grep", None, &provider);
        assert!(score1.is_infinite());

        // 1回訪問(ln(1) = 0 なので探索項は 0)
        record_success(&mut stats, "grep", None);
        let score2 = ucb1.compute_score(&stats, "grep", None, &provider);
        assert!(score2.is_finite());
        // success_rate = 1.0, exploration = c * sqrt(ln(1)/1) = 0, bonus = 0
        assert!((score2 - 1.0).abs() < 0.01);

        // 別のアクションを追加して total を増やす(explore の増加を確認)
        record_success(&mut stats, "glob", None);
        let score3 = ucb1.compute_score(&stats, "grep", None, &provider);
        // total=2, visits=1 なので探索項が増加
        assert!(score3 > score2);
    }

    #[test]
    fn test_ucb1_default_c() {
        let ucb1 = Ucb1::default();
        assert!((ucb1.c() - std::f64::consts::SQRT_2).abs() < 1e-10);
        assert!((ucb1.learning_weight() - 0.3).abs() < 1e-10);
    }

    #[test]
    fn test_ucb1_with_weights() {
        let ucb1 = Ucb1::with_weights(2.0, 0.5);
        assert!((ucb1.c() - 2.0).abs() < 1e-10);
        assert!((ucb1.learning_weight() - 0.5).abs() < 1e-10);
    }
}