use std::fmt::Debug;
use super::SelectionLogic;
use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
use crate::exploration::mutation::ActionExtractor;
use crate::learn::{LearnedProvider, LearningQuery};
use crate::online_stats::SwarmStats;
#[derive(Debug, Clone)]
pub struct Thompson {
learning_weight: f64,
}
impl Default for Thompson {
fn default() -> Self {
Self {
learning_weight: 0.3,
}
}
}
impl Thompson {
pub fn new() -> Self {
Self::default()
}
pub fn with_learning_weight(learning_weight: f64) -> Self {
Self { learning_weight }
}
pub fn learning_weight(&self) -> f64 {
self.learning_weight
}
pub fn compute_score(
&self,
stats: &SwarmStats,
action: &str,
target: Option<&str>,
provider: &dyn LearnedProvider,
) -> f64 {
self.compute_score_with_context(stats, action, target, provider, None, None)
}
pub fn compute_score_with_context(
&self,
stats: &SwarmStats,
action: &str,
target: Option<&str>,
provider: &dyn LearnedProvider,
prev_action: Option<&str>,
prev_prev_action: Option<&str>,
) -> f64 {
let action_stats = match target {
Some(t) => stats.get_action_target_stats(action, t),
None => stats.get_action_stats(action),
};
let alpha = action_stats.successes as f64 + 1.0;
let beta = action_stats.failures as f64 + 1.0;
let base_score = alpha / (alpha + beta);
let learning_bonus = provider
.query(LearningQuery::confidence_with_context(
action,
target,
prev_action,
prev_prev_action,
))
.score();
base_score + learning_bonus * self.learning_weight
}
}
impl<N, E, S> SelectionLogic<N, E, S> for Thompson
where
N: Debug + Clone + ActionExtractor,
E: Debug + Clone,
S: MapState,
{
fn next(
&self,
map: &GraphMap<N, E, S>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Option<MapNodeId> {
self.select(map, 1, stats, provider).into_iter().next()
}
fn select(
&self,
map: &GraphMap<N, E, S>,
count: usize,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Vec<MapNodeId> {
let frontiers = map.frontiers();
if frontiers.is_empty() || count == 0 {
return vec![];
}
let mut scored: Vec<_> = frontiers
.iter()
.filter_map(|&id| {
map.get(id).map(|node| {
let (action, target) = node.data.extract();
let parent_id = map.parent(id);
let prev_action = parent_id.and_then(|pid| {
map.get(pid)
.map(|parent_node| parent_node.data.action_name().to_string())
});
let prev_prev_action =
parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
map.get(ppid).map(|grandparent_node| {
grandparent_node.data.action_name().to_string()
})
});
let score = self.compute_score_with_context(
stats,
action,
target,
provider,
prev_action.as_deref(),
prev_prev_action.as_deref(),
);
(id, score)
})
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(count).map(|(id, _)| id).collect()
}
fn score(
&self,
action: &str,
target: Option<&str>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> f64 {
self.compute_score(stats, action, target, provider)
}
fn name(&self) -> &str {
"Thompson"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionEventBuilder, ActionEventResult};
use crate::learn::NullProvider;
use crate::types::WorkerId;
fn record_success(stats: &mut SwarmStats, action: &str) {
let event = ActionEventBuilder::new(0, WorkerId(0), action)
.result(ActionEventResult::success())
.build();
stats.record(&event);
}
fn record_failure(stats: &mut SwarmStats, action: &str) {
let event = ActionEventBuilder::new(0, WorkerId(0), action)
.result(ActionEventResult::failure("error"))
.build();
stats.record(&event);
}
#[test]
fn test_thompson_initial_score() {
let thompson = Thompson::new();
let stats = SwarmStats::new();
let provider = NullProvider;
assert_eq!(thompson.compute_score(&stats, "grep", None, &provider), 0.5);
}
#[test]
fn test_thompson_score_increases_with_success() {
let thompson = Thompson::new();
let mut stats = SwarmStats::new();
let provider = NullProvider;
let initial = thompson.compute_score(&stats, "grep", None, &provider);
record_success(&mut stats, "grep");
let after_success = thompson.compute_score(&stats, "grep", None, &provider);
assert!(after_success > initial);
}
#[test]
fn test_thompson_score_decreases_with_failure() {
let thompson = Thompson::new();
let mut stats = SwarmStats::new();
let provider = NullProvider;
let initial = thompson.compute_score(&stats, "grep", None, &provider);
record_failure(&mut stats, "grep");
let after_failure = thompson.compute_score(&stats, "grep", None, &provider);
assert!(after_failure < initial);
}
#[test]
fn test_thompson_with_learning_weight() {
let thompson = Thompson::with_learning_weight(0.5);
assert!((thompson.learning_weight() - 0.5).abs() < 1e-10);
}
}