use std::fmt::Debug;
use super::map::{ExplorationMap, GraphMap, MapNodeState, MapState};
use super::mutation::ActionNodeData;
use super::node_rules::Rules;
use super::operator::{ConfigurableOperator, Operator, RulesBasedMutation};
use super::selection::{AnySelection, SelectionKind};
use crate::online_stats::SwarmStats;
#[derive(Debug, Clone)]
pub struct OperatorConfig {
pub selection: SelectionKind,
pub ucb1_c: f64,
}
impl Default for OperatorConfig {
fn default() -> Self {
Self {
selection: SelectionKind::Fifo,
ucb1_c: std::f64::consts::SQRT_2,
}
}
}
impl OperatorConfig {
pub fn ucb1(c: f64) -> Self {
Self {
selection: SelectionKind::Ucb1,
ucb1_c: c,
}
}
pub fn greedy() -> Self {
Self {
selection: SelectionKind::Greedy,
..Default::default()
}
}
pub fn thompson() -> Self {
Self {
selection: SelectionKind::Thompson,
..Default::default()
}
}
}
pub struct ProviderContext<'a, N, E, S>
where
N: Debug + Clone,
E: Debug + Clone,
S: MapState,
{
pub map: &'a GraphMap<N, E, S>,
pub stats: &'a SwarmStats,
}
impl<'a, N, E, S> ProviderContext<'a, N, E, S>
where
N: Debug + Clone,
E: Debug + Clone,
S: MapState,
{
pub fn new(map: &'a GraphMap<N, E, S>, stats: &'a SwarmStats) -> Self {
Self { map, stats }
}
pub fn frontier_count(&self) -> usize {
self.map.frontiers().len()
}
pub fn total_visits(&self) -> u32 {
self.stats.total_visits()
}
pub fn is_exploration_mature(&self, threshold: u32) -> bool {
self.stats.total_visits() >= threshold
}
}
pub trait OperatorProvider<R>: Send + Sync
where
R: Rules,
{
fn provide(
&self,
rules: R,
context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
) -> ConfigurableOperator<R>;
fn reevaluate(
&self,
_operator: &mut ConfigurableOperator<R>,
_context: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
) {
}
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct ConfigBasedOperatorProvider {
config: OperatorConfig,
}
impl ConfigBasedOperatorProvider {
pub fn new(config: OperatorConfig) -> Self {
Self { config }
}
pub fn fifo() -> Self {
Self::new(OperatorConfig::default())
}
pub fn ucb1(c: f64) -> Self {
Self::new(OperatorConfig::ucb1(c))
}
pub fn config(&self) -> &OperatorConfig {
&self.config
}
}
impl<R> OperatorProvider<R> for ConfigBasedOperatorProvider
where
R: Rules + 'static,
{
fn provide(
&self,
rules: R,
_context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
) -> ConfigurableOperator<R> {
let selection = AnySelection::from_kind(self.config.selection, self.config.ucb1_c);
Operator::new(RulesBasedMutation::new(), selection, rules)
}
fn name(&self) -> &str {
"ConfigBased"
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveOperatorProvider {
maturity_threshold: u32,
error_rate_threshold: f64,
ucb1_c: f64,
}
impl Default for AdaptiveOperatorProvider {
fn default() -> Self {
Self {
maturity_threshold: 10,
error_rate_threshold: 0.3,
ucb1_c: std::f64::consts::SQRT_2,
}
}
}
impl AdaptiveOperatorProvider {
pub fn new(maturity_threshold: u32, error_rate_threshold: f64, ucb1_c: f64) -> Self {
Self {
maturity_threshold,
error_rate_threshold,
ucb1_c,
}
}
pub fn with_maturity_threshold(mut self, threshold: u32) -> Self {
self.maturity_threshold = threshold;
self
}
pub fn with_error_rate_threshold(mut self, threshold: f64) -> Self {
self.error_rate_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn with_ucb1_c(mut self, c: f64) -> Self {
self.ucb1_c = c;
self
}
fn select_strategy(&self, stats: &SwarmStats) -> SelectionKind {
let visits = stats.total_visits();
let error_rate = stats.failure_rate();
if visits < self.maturity_threshold {
SelectionKind::Ucb1
} else if error_rate > self.error_rate_threshold {
SelectionKind::Thompson
} else {
SelectionKind::Greedy
}
}
pub fn current_selection(&self, stats: &SwarmStats) -> SelectionKind {
self.select_strategy(stats)
}
}
impl<R> OperatorProvider<R> for AdaptiveOperatorProvider
where
R: Rules + 'static,
{
fn provide(
&self,
rules: R,
context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
) -> ConfigurableOperator<R> {
let selection_kind = context
.map(|ctx| self.select_strategy(ctx.stats))
.unwrap_or(SelectionKind::Ucb1);
let selection = AnySelection::from_kind(selection_kind, self.ucb1_c);
Operator::new(RulesBasedMutation::new(), selection, rules)
}
fn reevaluate(
&self,
operator: &mut ConfigurableOperator<R>,
context: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
) {
let current_kind = operator.selection.kind();
let new_kind = self.select_strategy(context.stats);
if current_kind != new_kind {
operator.selection = AnySelection::from_kind(new_kind, self.ucb1_c);
}
}
fn name(&self) -> &str {
"Adaptive"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionEventBuilder, ActionEventResult};
use crate::exploration::NodeRules;
use crate::types::WorkerId;
fn record_success(stats: &mut SwarmStats, action: &str) {
let event = ActionEventBuilder::new(0, WorkerId(0), action)
.result(ActionEventResult::success())
.build();
stats.record(&event);
}
fn record_failure(stats: &mut SwarmStats, action: &str) {
let event = ActionEventBuilder::new(0, WorkerId(0), action)
.result(ActionEventResult::failure("error"))
.build();
stats.record(&event);
}
#[test]
fn test_operator_config_default() {
let config = OperatorConfig::default();
assert_eq!(config.selection, SelectionKind::Fifo);
assert!((config.ucb1_c - std::f64::consts::SQRT_2).abs() < 1e-10);
}
#[test]
fn test_operator_config_ucb1() {
let config = OperatorConfig::ucb1(2.0);
assert_eq!(config.selection, SelectionKind::Ucb1);
assert_eq!(config.ucb1_c, 2.0);
}
#[test]
fn test_operator_config_greedy() {
let config = OperatorConfig::greedy();
assert_eq!(config.selection, SelectionKind::Greedy);
}
#[test]
fn test_operator_config_thompson() {
let config = OperatorConfig::thompson();
assert_eq!(config.selection, SelectionKind::Thompson);
}
#[test]
fn test_config_based_provider_fifo() {
let provider = ConfigBasedOperatorProvider::fifo();
let rules = NodeRules::for_testing();
let operator = provider.provide(rules, None);
assert_eq!(operator.name(), "RulesBased+FIFO");
}
#[test]
fn test_config_based_provider_ucb1() {
let provider = ConfigBasedOperatorProvider::ucb1(1.41);
let rules = NodeRules::for_testing();
let operator = provider.provide(rules, None);
assert_eq!(operator.name(), "RulesBased+UCB1");
}
#[test]
fn test_config_based_provider_with_context() {
let provider = ConfigBasedOperatorProvider::new(OperatorConfig::greedy());
let rules = NodeRules::for_testing();
let operator1 = provider.provide(rules.clone(), None);
assert_eq!(operator1.name(), "RulesBased+Greedy");
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let stats = SwarmStats::new();
let ctx = ProviderContext::new(&map, &stats);
let operator2 = provider.provide(rules, Some(&ctx));
assert_eq!(operator2.name(), "RulesBased+Greedy");
}
#[test]
fn test_provider_context_queries() {
let mut map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let _root = map.create_root(ActionNodeData::new("root"), MapNodeState::Open);
let stats = SwarmStats::new();
let ctx = ProviderContext::new(&map, &stats);
assert_eq!(ctx.frontier_count(), 1);
assert_eq!(ctx.total_visits(), 0);
assert!(!ctx.is_exploration_mature(10));
}
#[test]
fn test_adaptive_provider_initial_ucb1() {
let provider = AdaptiveOperatorProvider::default();
let stats = SwarmStats::new();
assert_eq!(provider.current_selection(&stats), SelectionKind::Ucb1);
}
#[test]
fn test_adaptive_provider_mature_low_error_greedy() {
let provider = AdaptiveOperatorProvider::default().with_maturity_threshold(5);
let mut stats = SwarmStats::new();
for _ in 0..10 {
record_success(&mut stats, "grep");
}
assert_eq!(stats.failure_rate(), 0.0);
assert_eq!(provider.current_selection(&stats), SelectionKind::Greedy);
}
#[test]
fn test_adaptive_provider_mature_high_error_thompson() {
let provider = AdaptiveOperatorProvider::default()
.with_maturity_threshold(5)
.with_error_rate_threshold(0.3);
let mut stats = SwarmStats::new();
for _ in 0..5 {
record_success(&mut stats, "grep");
}
for _ in 0..5 {
record_failure(&mut stats, "grep");
}
assert_eq!(stats.failure_rate(), 0.5);
assert_eq!(provider.current_selection(&stats), SelectionKind::Thompson);
}
#[test]
fn test_adaptive_provider_provide() {
let provider = AdaptiveOperatorProvider::default();
let rules = NodeRules::for_testing();
let operator = provider.provide(rules.clone(), None);
assert_eq!(operator.name(), "RulesBased+UCB1");
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let mut stats = SwarmStats::new();
for _ in 0..20 {
record_success(&mut stats, "grep");
}
let ctx = ProviderContext::new(&map, &stats);
let operator2 = provider.provide(rules, Some(&ctx));
assert_eq!(operator2.name(), "RulesBased+Greedy");
}
#[test]
fn test_adaptive_provider_reevaluate() {
let provider = AdaptiveOperatorProvider::default().with_maturity_threshold(5);
let rules = NodeRules::for_testing();
let mut operator = provider.provide(rules, None);
assert_eq!(operator.selection.kind(), SelectionKind::Ucb1);
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let mut stats = SwarmStats::new();
for _ in 0..10 {
record_success(&mut stats, "grep");
}
let ctx = ProviderContext::new(&map, &stats);
provider.reevaluate(&mut operator, &ctx);
assert_eq!(operator.selection.kind(), SelectionKind::Greedy);
}
#[test]
fn test_adaptive_provider_reevaluate_to_thompson() {
let provider = AdaptiveOperatorProvider::default()
.with_maturity_threshold(5)
.with_error_rate_threshold(0.3);
let rules = NodeRules::for_testing();
let mut operator = provider.provide(rules, None);
assert_eq!(operator.selection.kind(), SelectionKind::Ucb1);
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let mut stats = SwarmStats::new();
for _ in 0..3 {
record_success(&mut stats, "grep");
}
for _ in 0..7 {
record_failure(&mut stats, "grep");
}
let ctx = ProviderContext::new(&map, &stats);
provider.reevaluate(&mut operator, &ctx);
assert_eq!(operator.selection.kind(), SelectionKind::Thompson);
}
}