use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::RwLock;
use std::time::Instant;
use super::map::MapNodeState;
use super::mutation::ActionNodeData;
use super::node_rules::Rules;
use super::operator::{ConfigurableOperator, Operator, RulesBasedMutation};
use super::provider::{AdaptiveOperatorProvider, OperatorProvider, ProviderContext};
use super::selection::{AnySelection, SelectionKind};
use crate::events::{LearningEvent, LearningEventChannel};
use crate::online_stats::SwarmStats;
#[derive(Debug, Clone)]
pub struct StrategyContext {
pub frontier_count: usize,
pub total_visits: u32,
pub failure_rate: f64,
pub success_rate: f64,
pub current_strategy: SelectionKind,
pub avg_depth: Option<f32>,
}
impl StrategyContext {
pub fn new(
frontier_count: usize,
total_visits: u32,
failure_rate: f64,
current_strategy: SelectionKind,
) -> Self {
Self {
frontier_count,
total_visits,
failure_rate,
success_rate: 1.0 - failure_rate,
current_strategy,
avg_depth: None,
}
}
pub fn from_provider_context(
ctx: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
current: SelectionKind,
) -> Self {
Self {
frontier_count: ctx.frontier_count(),
total_visits: ctx.total_visits(),
failure_rate: ctx.stats.failure_rate(),
success_rate: ctx.stats.success_rate(),
current_strategy: current,
avg_depth: None,
}
}
pub fn from_stats(stats: &SwarmStats, frontier_count: usize, current: SelectionKind) -> Self {
Self {
frontier_count,
total_visits: stats.total_visits(),
failure_rate: stats.failure_rate(),
success_rate: stats.success_rate(),
current_strategy: current,
avg_depth: None,
}
}
pub fn with_avg_depth(mut self, depth: f32) -> Self {
self.avg_depth = Some(depth);
self
}
}
#[derive(Debug, Clone)]
pub struct StrategyAdvice {
pub recommended: SelectionKind,
pub should_change: bool,
pub reason: String,
pub confidence: f64,
}
impl StrategyAdvice {
pub fn no_change(current: SelectionKind, reason: impl Into<String>) -> Self {
Self {
recommended: current,
should_change: false,
reason: reason.into(),
confidence: 1.0,
}
}
pub fn change_to(new: SelectionKind, reason: impl Into<String>, confidence: f64) -> Self {
Self {
recommended: new,
should_change: true,
reason: reason.into(),
confidence,
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum StrategyAdviceError {
#[error("LLM call failed: {0}")]
LlmError(String),
#[error("Failed to parse response: {0}")]
ParseError(String),
#[error("Advisor not available")]
Unavailable,
}
pub trait StrategyAdvisor: Send + Sync {
fn advise(&self, context: &StrategyContext) -> Result<StrategyAdvice, StrategyAdviceError>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct ReviewPolicy {
pub interval: u32,
pub min_interval: u32,
pub state_change_threshold: f64,
}
impl Default for ReviewPolicy {
fn default() -> Self {
Self {
interval: 20, min_interval: 5, state_change_threshold: 0.15, }
}
}
impl ReviewPolicy {
pub fn new(interval: u32, min_interval: u32, state_change_threshold: f64) -> Self {
Self {
interval,
min_interval,
state_change_threshold: state_change_threshold.clamp(0.0, 1.0),
}
}
pub fn frequent() -> Self {
Self {
interval: 10,
min_interval: 3,
state_change_threshold: 0.1,
}
}
pub fn conservative() -> Self {
Self {
interval: 50,
min_interval: 20,
state_change_threshold: 0.25,
}
}
}
pub struct AdaptiveLlmOperatorProvider {
adaptive: AdaptiveOperatorProvider,
advisor: Box<dyn StrategyAdvisor>,
policy: ReviewPolicy,
ucb1_c: f64,
last_review_visits: AtomicU32,
last_failure_rate: AtomicU32,
llm_override: RwLock<Option<SelectionKind>>,
}
impl AdaptiveLlmOperatorProvider {
pub fn new(advisor: Box<dyn StrategyAdvisor>) -> Self {
Self {
adaptive: AdaptiveOperatorProvider::default(),
advisor,
policy: ReviewPolicy::default(),
ucb1_c: std::f64::consts::SQRT_2,
last_review_visits: AtomicU32::new(0),
last_failure_rate: AtomicU32::new(0),
llm_override: RwLock::new(None),
}
}
pub fn with_policy(mut self, policy: ReviewPolicy) -> Self {
self.policy = policy;
self
}
pub fn with_adaptive(mut self, adaptive: AdaptiveOperatorProvider) -> Self {
self.adaptive = adaptive;
self
}
pub fn with_ucb1_c(mut self, c: f64) -> Self {
self.ucb1_c = c;
self
}
pub fn llm_override(&self) -> Option<SelectionKind> {
*self.llm_override.read().unwrap()
}
fn should_review(&self, stats: &SwarmStats) -> bool {
let current_visits = stats.total_visits();
let last_visits = self.last_review_visits.load(Ordering::Relaxed);
if current_visits < last_visits + self.policy.min_interval {
return false;
}
if current_visits >= last_visits + self.policy.interval {
return true;
}
let current_rate = (stats.failure_rate() * 1000.0) as u32;
let last_rate = self.last_failure_rate.load(Ordering::Relaxed);
let rate_diff = (current_rate as i32 - last_rate as i32).unsigned_abs() as f64 / 1000.0;
rate_diff >= self.policy.state_change_threshold
}
fn do_review(
&self,
stats: &SwarmStats,
frontier_count: usize,
current: SelectionKind,
) -> Option<SelectionKind> {
let context = StrategyContext::from_stats(stats, frontier_count, current);
let start_time = Instant::now();
let result = self.advisor.advise(&context);
let elapsed = start_time.elapsed();
match result {
Ok(advice) => {
let latency_ms = elapsed.as_millis() as u64;
let tick = LearningEventChannel::global().current_tick();
LearningEventChannel::global().emit(
LearningEvent::strategy_advice(tick, self.advisor.name())
.current_strategy(current.to_string())
.recommended(advice.recommended.to_string())
.should_change(advice.should_change)
.confidence(advice.confidence)
.reason(&advice.reason)
.frontier_count(frontier_count)
.total_visits(stats.total_visits())
.failure_rate(stats.failure_rate())
.latency_ms(latency_ms)
.success()
.build(),
);
tracing::debug!(
target: "swarm_engine::learning",
advisor = %self.advisor.name(),
current_strategy = %current,
recommended = %advice.recommended,
should_change = advice.should_change,
confidence = advice.confidence,
reason = %advice.reason,
latency_ms = latency_ms,
"Strategy advice completed"
);
self.last_review_visits
.store(stats.total_visits(), Ordering::Relaxed);
self.last_failure_rate
.store((stats.failure_rate() * 1000.0) as u32, Ordering::Relaxed);
if advice.should_change {
Some(advice.recommended)
} else {
None
}
}
Err(e) => {
let latency_ms = elapsed.as_millis() as u64;
let tick = LearningEventChannel::global().current_tick();
LearningEventChannel::global().emit(
LearningEvent::strategy_advice(tick, self.advisor.name())
.current_strategy(current.to_string())
.recommended(current.to_string()) .frontier_count(frontier_count)
.total_visits(stats.total_visits())
.failure_rate(stats.failure_rate())
.latency_ms(latency_ms)
.failure(e.to_string())
.build(),
);
tracing::warn!(
advisor = %self.advisor.name(),
error = %e,
latency_ms = latency_ms,
"Strategy advisor failed, falling back to Adaptive"
);
None
}
}
}
fn effective_selection(&self, stats: &SwarmStats) -> SelectionKind {
if let Some(kind) = *self.llm_override.read().unwrap() {
return kind;
}
self.adaptive.current_selection(stats)
}
}
impl<R> OperatorProvider<R> for AdaptiveLlmOperatorProvider
where
R: Rules + 'static,
{
fn provide(
&self,
rules: R,
context: Option<&ProviderContext<'_, ActionNodeData, String, MapNodeState>>,
) -> ConfigurableOperator<R> {
let selection_kind = match context {
Some(ctx) => {
let current = self.effective_selection(ctx.stats);
if self.should_review(ctx.stats) {
if let Some(new_kind) = self.do_review(ctx.stats, ctx.frontier_count(), current)
{
*self.llm_override.write().unwrap() = Some(new_kind);
new_kind
} else {
current
}
} else {
current
}
}
None => SelectionKind::Ucb1, };
let selection = AnySelection::from_kind(selection_kind, self.ucb1_c);
Operator::new(RulesBasedMutation::new(), selection, rules)
}
fn reevaluate(
&self,
operator: &mut ConfigurableOperator<R>,
context: &ProviderContext<'_, ActionNodeData, String, MapNodeState>,
) {
let current = operator.selection().kind();
if self.should_review(context.stats) {
if let Some(new_kind) = self.do_review(context.stats, context.frontier_count(), current)
{
if new_kind != current {
tracing::info!(
from = %current,
to = %new_kind,
"Strategy changed by LLM advisor"
);
operator.set_selection(AnySelection::from_kind(new_kind, self.ucb1_c));
*self.llm_override.write().unwrap() = Some(new_kind);
}
return;
}
}
if self.llm_override.read().unwrap().is_none() {
self.adaptive.reevaluate(operator, context);
}
}
fn name(&self) -> &str {
"HybridLlm"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionEventBuilder, ActionEventResult};
use crate::exploration::{GraphMap, NodeRules};
use crate::types::WorkerId;
fn record_success(stats: &mut SwarmStats, action: &str) {
let event = ActionEventBuilder::new(0, WorkerId(0), action)
.result(ActionEventResult::success())
.build();
stats.record(&event);
}
fn record_failure(stats: &mut SwarmStats, action: &str) {
let event = ActionEventBuilder::new(0, WorkerId(0), action)
.result(ActionEventResult::failure("error"))
.build();
stats.record(&event);
}
struct MockAdvisor {
advice: StrategyAdvice,
call_count: std::sync::atomic::AtomicUsize,
}
impl MockAdvisor {
fn new(advice: StrategyAdvice) -> Self {
Self {
advice,
call_count: std::sync::atomic::AtomicUsize::new(0),
}
}
fn call_count(&self) -> usize {
self.call_count.load(Ordering::Relaxed)
}
}
impl StrategyAdvisor for MockAdvisor {
fn advise(
&self,
_context: &StrategyContext,
) -> Result<StrategyAdvice, StrategyAdviceError> {
self.call_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(self.advice.clone())
}
fn name(&self) -> &str {
"MockAdvisor"
}
}
struct FailingAdvisor;
impl StrategyAdvisor for FailingAdvisor {
fn advise(
&self,
_context: &StrategyContext,
) -> Result<StrategyAdvice, StrategyAdviceError> {
Err(StrategyAdviceError::LlmError("Mock error".into()))
}
fn name(&self) -> &str {
"FailingAdvisor"
}
}
#[test]
fn test_strategy_context_new() {
let ctx = StrategyContext::new(15, 47, 0.23, SelectionKind::Ucb1);
assert_eq!(ctx.frontier_count, 15);
assert_eq!(ctx.total_visits, 47);
assert!((ctx.failure_rate - 0.23).abs() < 0.001);
assert!((ctx.success_rate - 0.77).abs() < 0.001);
assert_eq!(ctx.current_strategy, SelectionKind::Ucb1);
}
#[test]
fn test_strategy_context_from_stats() {
let mut stats = SwarmStats::new();
for _ in 0..7 {
record_success(&mut stats, "action");
}
for _ in 0..3 {
record_failure(&mut stats, "action");
}
let ctx = StrategyContext::from_stats(&stats, 10, SelectionKind::Greedy);
assert_eq!(ctx.frontier_count, 10);
assert_eq!(ctx.total_visits, 10);
assert!((ctx.failure_rate - 0.3).abs() < 0.01);
}
#[test]
fn test_review_policy_default() {
let policy = ReviewPolicy::default();
assert_eq!(policy.interval, 20);
assert_eq!(policy.min_interval, 5);
assert!((policy.state_change_threshold - 0.15).abs() < 0.001);
}
#[test]
fn test_review_policy_frequent() {
let policy = ReviewPolicy::frequent();
assert_eq!(policy.interval, 10);
assert_eq!(policy.min_interval, 3);
}
#[test]
fn test_review_policy_conservative() {
let policy = ReviewPolicy::conservative();
assert_eq!(policy.interval, 50);
assert_eq!(policy.min_interval, 20);
}
#[test]
fn test_hybrid_provider_initial_ucb1() {
let advice = StrategyAdvice::no_change(SelectionKind::Ucb1, "test");
let advisor = MockAdvisor::new(advice);
let provider = AdaptiveLlmOperatorProvider::new(Box::new(advisor));
let rules = NodeRules::for_testing();
let operator = provider.provide(rules, None);
assert_eq!(operator.name(), "RulesBased+UCB1");
}
#[test]
fn test_hybrid_provider_review_at_interval() {
let advice = StrategyAdvice::change_to(SelectionKind::Greedy, "test", 0.9);
let advisor = MockAdvisor::new(advice);
let provider =
AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
interval: 10,
min_interval: 5,
state_change_threshold: 0.5,
});
let rules = NodeRules::for_testing();
let mut stats = SwarmStats::new();
for _ in 0..20 {
record_success(&mut stats, "action");
}
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let ctx = ProviderContext::new(&map, &stats);
let operator = provider.provide(rules, Some(&ctx));
assert_eq!(operator.name(), "RulesBased+Greedy");
let advisor_ref = provider.advisor.as_ref();
let mock = unsafe { &*(advisor_ref as *const dyn StrategyAdvisor as *const MockAdvisor) };
assert_eq!(mock.call_count(), 1);
}
#[test]
fn test_hybrid_provider_no_review_before_min_interval() {
let advice = StrategyAdvice::change_to(SelectionKind::Greedy, "test", 0.9);
let advisor = MockAdvisor::new(advice);
let provider =
AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
interval: 10,
min_interval: 5,
state_change_threshold: 0.5,
});
let rules = NodeRules::for_testing();
let mut stats = SwarmStats::new();
for _ in 0..3 {
record_success(&mut stats, "action");
}
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let ctx = ProviderContext::new(&map, &stats);
let operator = provider.provide(rules, Some(&ctx));
assert_eq!(operator.name(), "RulesBased+UCB1");
}
#[test]
fn test_hybrid_provider_fallback_on_error() {
let provider =
AdaptiveLlmOperatorProvider::new(Box::new(FailingAdvisor)).with_policy(ReviewPolicy {
interval: 1,
min_interval: 1,
state_change_threshold: 0.0,
});
let rules = NodeRules::for_testing();
let mut stats = SwarmStats::new();
for _ in 0..10 {
record_success(&mut stats, "action");
}
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let ctx = ProviderContext::new(&map, &stats);
let operator = provider.provide(rules, Some(&ctx));
assert!(operator.name().contains("RulesBased"));
}
#[test]
fn test_hybrid_provider_reevaluate() {
let advice = StrategyAdvice::change_to(SelectionKind::Thompson, "high variance", 0.85);
let advisor = MockAdvisor::new(advice);
let provider =
AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
interval: 5,
min_interval: 1,
state_change_threshold: 0.5,
});
let rules = NodeRules::for_testing();
let mut operator = provider.provide(rules, None);
assert_eq!(operator.selection().kind(), SelectionKind::Ucb1);
let mut stats = SwarmStats::new();
for _ in 0..10 {
record_success(&mut stats, "action");
}
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let ctx = ProviderContext::new(&map, &stats);
provider.reevaluate(&mut operator, &ctx);
assert_eq!(operator.selection().kind(), SelectionKind::Thompson);
}
#[test]
fn test_hybrid_provider_state_change_trigger() {
let advice = StrategyAdvice::change_to(SelectionKind::Thompson, "high variance", 0.8);
let advisor = MockAdvisor::new(advice);
let provider =
AdaptiveLlmOperatorProvider::new(Box::new(advisor)).with_policy(ReviewPolicy {
interval: 100, min_interval: 1,
state_change_threshold: 0.1, });
let rules = NodeRules::for_testing();
let mut stats = SwarmStats::new();
for _ in 0..5 {
record_success(&mut stats, "action");
}
let map: GraphMap<ActionNodeData, String, MapNodeState> = GraphMap::new();
let ctx = ProviderContext::new(&map, &stats);
let _ = provider.provide(rules.clone(), Some(&ctx));
for _ in 0..5 {
record_failure(&mut stats, "action");
}
let ctx2 = ProviderContext::new(&map, &stats);
let operator = provider.provide(rules, Some(&ctx2));
assert_eq!(operator.selection().kind(), SelectionKind::Thompson);
}
}