use std::fmt::Debug;
use super::SelectionLogic;
use crate::exploration::map::{ExplorationMap, GraphMap, MapNodeId, MapState};
use crate::exploration::mutation::ActionExtractor;
use crate::learn::{LearnedProvider, LearningQuery};
use crate::online_stats::SwarmStats;
#[derive(Debug, Clone)]
pub struct Ucb1 {
c: f64,
learning_weight: f64,
}
impl Default for Ucb1 {
fn default() -> Self {
Self {
c: std::f64::consts::SQRT_2,
learning_weight: 0.3,
}
}
}
impl Ucb1 {
pub fn new(c: f64) -> Self {
Self {
c,
learning_weight: 0.3,
}
}
pub fn with_weights(c: f64, learning_weight: f64) -> Self {
Self { c, learning_weight }
}
pub fn with_default_c() -> Self {
Self::default()
}
pub fn c(&self) -> f64 {
self.c
}
pub fn learning_weight(&self) -> f64 {
self.learning_weight
}
pub fn compute_score(
&self,
stats: &SwarmStats,
action: &str,
target: Option<&str>,
provider: &dyn LearnedProvider,
) -> f64 {
self.compute_score_with_context(stats, action, target, provider, None, None)
}
pub fn compute_score_with_context(
&self,
stats: &SwarmStats,
action: &str,
target: Option<&str>,
provider: &dyn LearnedProvider,
prev_action: Option<&str>,
prev_prev_action: Option<&str>,
) -> f64 {
let action_stats = match target {
Some(t) => stats.get_action_target_stats(action, t),
None => stats.get_action_stats(action),
};
let total = stats.total_visits().max(1);
if action_stats.visits == 0 {
return f64::INFINITY; }
let success_rate = action_stats.success_rate();
let exploration = self.c * ((total as f64).ln() / action_stats.visits as f64).sqrt();
let learning_bonus = provider
.query(LearningQuery::confidence_with_context(
action,
target,
prev_action,
prev_prev_action,
))
.score();
success_rate + exploration + learning_bonus * self.learning_weight
}
}
impl<N, E, S> SelectionLogic<N, E, S> for Ucb1
where
N: Debug + Clone + ActionExtractor,
E: Debug + Clone,
S: MapState,
{
fn next(
&self,
map: &GraphMap<N, E, S>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Option<MapNodeId> {
self.select(map, 1, stats, provider).into_iter().next()
}
fn select(
&self,
map: &GraphMap<N, E, S>,
count: usize,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Vec<MapNodeId> {
let frontiers = map.frontiers();
if frontiers.is_empty() || count == 0 {
return vec![];
}
let mut scored: Vec<_> = frontiers
.iter()
.filter_map(|&id| {
map.get(id).map(|node| {
let (action, target) = node.data.extract();
let parent_id = map.parent(id);
let prev_action = parent_id.and_then(|pid| {
map.get(pid)
.map(|parent_node| parent_node.data.action_name().to_string())
});
let prev_prev_action =
parent_id.and_then(|pid| map.parent(pid)).and_then(|ppid| {
map.get(ppid).map(|grandparent_node| {
grandparent_node.data.action_name().to_string()
})
});
let score = self.compute_score_with_context(
stats,
action,
target,
provider,
prev_action.as_deref(),
prev_prev_action.as_deref(),
);
(id, score)
})
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().take(count).map(|(id, _)| id).collect()
}
fn score(
&self,
action: &str,
target: Option<&str>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> f64 {
self.compute_score(stats, action, target, provider)
}
fn name(&self) -> &str {
"UCB1"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::ActionEventBuilder;
use crate::learn::NullProvider;
fn record_success(stats: &mut SwarmStats, action: &str, target: Option<&str>) {
use crate::events::ActionEventResult;
use crate::types::WorkerId;
let mut builder = ActionEventBuilder::new(0, WorkerId(0), action);
if let Some(t) = target {
builder = builder.target(t);
}
let event = builder.result(ActionEventResult::success()).build();
stats.record(&event);
}
#[test]
fn test_ucb1_unvisited_is_infinity() {
let ucb1 = Ucb1::new(1.41);
let stats = SwarmStats::new();
let provider = NullProvider;
assert!(ucb1
.compute_score(&stats, "grep", None, &provider)
.is_infinite());
}
#[test]
fn test_ucb1_score_changes_with_visits() {
let ucb1 = Ucb1::new(1.41);
let mut stats = SwarmStats::new();
let provider = NullProvider;
let score1 = ucb1.compute_score(&stats, "grep", None, &provider);
assert!(score1.is_infinite());
record_success(&mut stats, "grep", None);
let score2 = ucb1.compute_score(&stats, "grep", None, &provider);
assert!(score2.is_finite());
assert!((score2 - 1.0).abs() < 0.01);
record_success(&mut stats, "glob", None);
let score3 = ucb1.compute_score(&stats, "grep", None, &provider);
assert!(score3 > score2);
}
#[test]
fn test_ucb1_default_c() {
let ucb1 = Ucb1::default();
assert!((ucb1.c() - std::f64::consts::SQRT_2).abs() < 1e-10);
assert!((ucb1.learning_weight() - 0.3).abs() < 1e-10);
}
#[test]
fn test_ucb1_with_weights() {
let ucb1 = Ucb1::with_weights(2.0, 0.5);
assert!((ucb1.c() - 2.0).abs() < 1e-10);
assert!((ucb1.learning_weight() - 0.5).abs() < 1e-10);
}
}