use std::fmt::Debug;
use super::{Fifo, Greedy, SelectionKind, SelectionLogic, Thompson, Ucb1};
use crate::exploration::map::{GraphMap, MapNodeId, MapState};
use crate::exploration::mutation::ActionExtractor;
use crate::learn::LearnedProvider;
use crate::online_stats::SwarmStats;
#[derive(Debug, Clone)]
pub enum AnySelection {
Fifo(Fifo),
Ucb1(Ucb1),
Greedy(Greedy),
Thompson(Thompson),
}
impl Default for AnySelection {
fn default() -> Self {
Self::Fifo(Fifo::new())
}
}
impl AnySelection {
pub fn from_kind(kind: SelectionKind, ucb1_c: f64) -> Self {
match kind {
SelectionKind::Fifo => Self::Fifo(Fifo::new()),
SelectionKind::Ucb1 => Self::Ucb1(Ucb1::new(ucb1_c)),
SelectionKind::Greedy => Self::Greedy(Greedy::new()),
SelectionKind::Thompson => Self::Thompson(Thompson::new()),
}
}
pub fn kind(&self) -> SelectionKind {
match self {
Self::Fifo(_) => SelectionKind::Fifo,
Self::Ucb1(_) => SelectionKind::Ucb1,
Self::Greedy(_) => SelectionKind::Greedy,
Self::Thompson(_) => SelectionKind::Thompson,
}
}
pub fn selection_name(&self) -> &str {
match self {
Self::Fifo(_) => "FIFO",
Self::Ucb1(_) => "UCB1",
Self::Greedy(_) => "Greedy",
Self::Thompson(_) => "Thompson",
}
}
pub fn compute_score(
&self,
action: &str,
target: Option<&str>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> f64 {
match self {
Self::Fifo(_) => 0.0,
Self::Ucb1(s) => s.compute_score(stats, action, target, provider),
Self::Greedy(s) => s.compute_score(action, target, provider),
Self::Thompson(s) => s.compute_score(stats, action, target, provider),
}
}
}
impl<N, E, S> SelectionLogic<N, E, S> for AnySelection
where
N: Debug + Clone + ActionExtractor,
E: Debug + Clone,
S: MapState,
{
fn next(
&self,
map: &GraphMap<N, E, S>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Option<MapNodeId> {
match self {
Self::Fifo(s) => s.next(map, stats, provider),
Self::Ucb1(s) => s.next(map, stats, provider),
Self::Greedy(s) => s.next(map, stats, provider),
Self::Thompson(s) => s.next(map, stats, provider),
}
}
fn select(
&self,
map: &GraphMap<N, E, S>,
count: usize,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> Vec<MapNodeId> {
match self {
Self::Fifo(s) => s.select(map, count, stats, provider),
Self::Ucb1(s) => s.select(map, count, stats, provider),
Self::Greedy(s) => s.select(map, count, stats, provider),
Self::Thompson(s) => s.select(map, count, stats, provider),
}
}
fn score(
&self,
action: &str,
target: Option<&str>,
stats: &SwarmStats,
provider: &dyn LearnedProvider,
) -> f64 {
self.compute_score(action, target, stats, provider)
}
fn name(&self) -> &str {
self.selection_name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::NullProvider;
#[test]
fn test_any_selection_from_kind() {
let fifo = AnySelection::from_kind(SelectionKind::Fifo, 1.0);
assert_eq!(fifo.kind(), SelectionKind::Fifo);
assert_eq!(fifo.selection_name(), "FIFO");
let ucb1 = AnySelection::from_kind(SelectionKind::Ucb1, 1.41);
assert_eq!(ucb1.kind(), SelectionKind::Ucb1);
assert_eq!(ucb1.selection_name(), "UCB1");
let greedy = AnySelection::from_kind(SelectionKind::Greedy, 1.0);
assert_eq!(greedy.kind(), SelectionKind::Greedy);
assert_eq!(greedy.selection_name(), "Greedy");
let thompson = AnySelection::from_kind(SelectionKind::Thompson, 1.0);
assert_eq!(thompson.kind(), SelectionKind::Thompson);
assert_eq!(thompson.selection_name(), "Thompson");
}
#[test]
fn test_any_selection_score() {
let stats = SwarmStats::new();
let provider = NullProvider;
let fifo = AnySelection::from_kind(SelectionKind::Fifo, 1.0);
assert_eq!(fifo.compute_score("grep", None, &stats, &provider), 0.0);
let ucb1 = AnySelection::from_kind(SelectionKind::Ucb1, 1.41);
assert!(ucb1
.compute_score("grep", None, &stats, &provider)
.is_infinite());
let greedy = AnySelection::from_kind(SelectionKind::Greedy, 1.0);
assert_eq!(greedy.compute_score("grep", None, &stats, &provider), 0.5);
let thompson = AnySelection::from_kind(SelectionKind::Thompson, 1.0);
assert_eq!(thompson.compute_score("grep", None, &stats, &provider), 0.5);
}
#[test]
fn test_any_selection_default() {
let any = AnySelection::default();
assert_eq!(any.kind(), SelectionKind::Fifo);
}
}