swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! Greedy Selection - LearnedProvider ボーナス最大優先選択
//!
//! Provider から得られた学習済みボーナスに基づいて選択する。

use std::fmt::Debug;

use super::SelectionLogic;
use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
use crate::exploration::mutation::ActionExtractor;
use crate::learn::{LearnedProvider, LearningQuery};
use crate::online_stats::SwarmStats;

/// Greedy 選択(ボーナス最大を優先)
///
/// 探索を行わず、Provider からのボーナスが最も高いノードを選択する。
/// 学習済みデータを活用する場合に有効。
#[derive(Debug, Clone, Default)]
pub struct Greedy {
    /// ボーナスの重み(デフォルト: 1.0)
    weight: f64,
}

impl Greedy {
    pub fn new() -> Self {
        Self { weight: 1.0 }
    }

    /// 重みを指定して生成
    pub fn with_weight(weight: f64) -> Self {
        Self { weight }
    }

    /// スコアを計算(コンテキストなし、後方互換用)
    pub fn compute_score(
        &self,
        action: &str,
        target: Option<&str>,
        provider: &dyn LearnedProvider,
    ) -> f64 {
        self.compute_score_with_context(action, target, provider, None, None)
    }

    /// スコアを計算(コンテキスト付き)
    ///
    /// Provider に confidence_with_context クエリを投げてボーナスを取得。
    pub fn compute_score_with_context(
        &self,
        action: &str,
        target: Option<&str>,
        provider: &dyn LearnedProvider,
        prev_action: Option<&str>,
        prev_prev_action: Option<&str>,
    ) -> f64 {
        // Provider からボーナス取得(コンテキスト付き)
        let bonus = provider
            .query(LearningQuery::confidence_with_context(
                action,
                target,
                prev_action,
                prev_prev_action,
            ))
            .score();

        // 基本スコア 0.5 + ボーナス * weight
        0.5 + bonus * self.weight
    }
}

impl<N, E, S> SelectionLogic<N, E, S> for Greedy
where
    N: Debug + Clone + ActionExtractor,
    E: Debug + Clone,
    S: MapState,
{
    fn next(
        &self,
        map: &GraphMap<N, E, S>,
        stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> Option<MapNodeId> {
        self.select(map, 1, stats, provider).into_iter().next()
    }

    fn select(
        &self,
        map: &GraphMap<N, E, S>,
        count: usize,
        _stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> Vec<MapNodeId> {
        let frontiers = map.frontiers();
        if frontiers.is_empty() || count == 0 {
            return vec![];
        }

        // スコア計算してソート(親ノードの情報を活用)
        let mut scored: Vec<_> = frontiers
            .iter()
            .filter_map(|&id| {
                map.get(id).map(|node| {
                    let (action, target) = node.data.extract();

                    // 親ノードのアクションを取得(学習ボーナス用)
                    let parent_id = map.parent(id);
                    let prev_action = parent_id.and_then(|pid| {
                        map.get(pid)
                            .map(|parent_node| parent_node.data.action_name().to_string())
                    });

                    // 親の親のアクションを取得(N-gram ボーナス用)
                    let prev_prev_action =
                        parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
                            map.get(ppid).map(|grandparent_node| {
                                grandparent_node.data.action_name().to_string()
                            })
                        });

                    let score = self.compute_score_with_context(
                        action,
                        target,
                        provider,
                        prev_action.as_deref(),
                        prev_prev_action.as_deref(),
                    );
                    (id, score)
                })
            })
            .collect();

        // スコア降順でソート(高スコアを優先)
        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

        scored.into_iter().take(count).map(|(id, _)| id).collect()
    }

    fn score(
        &self,
        action: &str,
        target: Option<&str>,
        _stats: &SwarmStats,
        provider: &dyn LearnedProvider,
    ) -> f64 {
        self.compute_score(action, target, provider)
    }

    fn name(&self) -> &str {
        "Greedy"
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::learn::NullProvider;

    #[test]
    fn test_greedy_default_score() {
        let greedy = Greedy::new();
        let provider = NullProvider;

        // NullProvider は常に NotAvailable → score() = 0.0 → 基本スコア 0.5
        assert_eq!(greedy.compute_score("unknown", None, &provider), 0.5);
    }

    #[test]
    fn test_greedy_with_null_provider() {
        let greedy = Greedy::new();
        let provider = NullProvider;

        // NullProvider では全て 0.5
        assert_eq!(greedy.compute_score("grep", None, &provider), 0.5);
        assert_eq!(greedy.compute_score("glob", Some("svc1"), &provider), 0.5);
        assert_eq!(greedy.compute_score("other", None, &provider), 0.5);
    }

    #[test]
    fn test_greedy_with_weight() {
        let greedy = Greedy::with_weight(2.0);
        let provider = NullProvider;

        // NullProvider でもスコアは 0.5(ボーナス 0 × weight は 0)
        assert_eq!(greedy.compute_score("action", None, &provider), 0.5);
    }
}