use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use crate::util::epoch_millis;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event_type")]
pub enum LearningEvent {
#[serde(rename = "llm_strategy_advice")]
StrategyAdvice {
timestamp_ms: u64,
tick: u64,
advisor: String,
current_strategy: String,
recommended: String,
should_change: bool,
confidence: f64,
reason: String,
frontier_count: usize,
total_visits: u32,
failure_rate: f64,
latency_ms: u64,
success: bool,
error: Option<String>,
},
#[serde(rename = "dependency_graph_inference")]
DependencyGraphInference {
timestamp_ms: u64,
prompt: String,
response: String,
available_actions: Vec<String>,
discover_order: Vec<String>,
not_discover_order: Vec<String>,
model: String,
endpoint: String,
lora: Option<String>,
latency_ms: u64,
success: bool,
error: Option<String>,
},
#[serde(rename = "learn_stats_snapshot")]
LearnStatsSnapshot {
timestamp_ms: u64,
scenario: String,
session_id: String,
stats_json: String,
outcome: LearnStatsOutcome,
total_ticks: u64,
total_actions: u64,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LearnStatsOutcome {
Success { score: f64 },
Failure { reason: String },
Timeout { partial_score: Option<f64> },
}
impl LearningEvent {
pub fn strategy_advice(tick: u64, advisor: impl Into<String>) -> StrategyAdviceBuilder {
StrategyAdviceBuilder {
timestamp_ms: epoch_millis(),
tick,
advisor: advisor.into(),
current_strategy: String::new(),
recommended: String::new(),
should_change: false,
confidence: 0.0,
reason: String::new(),
frontier_count: 0,
total_visits: 0,
failure_rate: 0.0,
latency_ms: 0,
success: true,
error: None,
}
}
pub fn dependency_graph_inference(model: impl Into<String>) -> DependencyGraphInferenceBuilder {
DependencyGraphInferenceBuilder {
timestamp_ms: epoch_millis(),
prompt: String::new(),
response: String::new(),
available_actions: Vec::new(),
discover_order: Vec::new(),
not_discover_order: Vec::new(),
model: model.into(),
endpoint: String::new(),
lora: None,
latency_ms: 0,
success: true,
error: None,
}
}
pub fn learn_stats_snapshot(scenario: impl Into<String>) -> LearnStatsSnapshotBuilder {
LearnStatsSnapshotBuilder {
timestamp_ms: epoch_millis(),
scenario: scenario.into(),
session_id: String::new(),
stats_json: String::new(),
outcome: LearnStatsOutcome::Success { score: 0.0 },
total_ticks: 0,
total_actions: 0,
}
}
}
pub struct StrategyAdviceBuilder {
timestamp_ms: u64,
tick: u64,
advisor: String,
current_strategy: String,
recommended: String,
should_change: bool,
confidence: f64,
reason: String,
frontier_count: usize,
total_visits: u32,
failure_rate: f64,
latency_ms: u64,
success: bool,
error: Option<String>,
}
impl StrategyAdviceBuilder {
pub fn current_strategy(mut self, strategy: impl Into<String>) -> Self {
self.current_strategy = strategy.into();
self
}
pub fn recommended(mut self, strategy: impl Into<String>) -> Self {
self.recommended = strategy.into();
self
}
pub fn should_change(mut self, should: bool) -> Self {
self.should_change = should;
self
}
pub fn confidence(mut self, conf: f64) -> Self {
self.confidence = conf;
self
}
pub fn reason(mut self, reason: impl Into<String>) -> Self {
self.reason = reason.into();
self
}
pub fn frontier_count(mut self, count: usize) -> Self {
self.frontier_count = count;
self
}
pub fn total_visits(mut self, visits: u32) -> Self {
self.total_visits = visits;
self
}
pub fn failure_rate(mut self, rate: f64) -> Self {
self.failure_rate = rate;
self
}
pub fn latency_ms(mut self, ms: u64) -> Self {
self.latency_ms = ms;
self
}
pub fn success(mut self) -> Self {
self.success = true;
self.error = None;
self
}
pub fn failure(mut self, error: impl Into<String>) -> Self {
self.success = false;
self.error = Some(error.into());
self
}
pub fn build(self) -> LearningEvent {
LearningEvent::StrategyAdvice {
timestamp_ms: self.timestamp_ms,
tick: self.tick,
advisor: self.advisor,
current_strategy: self.current_strategy,
recommended: self.recommended,
should_change: self.should_change,
confidence: self.confidence,
reason: self.reason,
frontier_count: self.frontier_count,
total_visits: self.total_visits,
failure_rate: self.failure_rate,
latency_ms: self.latency_ms,
success: self.success,
error: self.error,
}
}
}
pub struct DependencyGraphInferenceBuilder {
timestamp_ms: u64,
prompt: String,
response: String,
available_actions: Vec<String>,
discover_order: Vec<String>,
not_discover_order: Vec<String>,
model: String,
endpoint: String,
lora: Option<String>,
latency_ms: u64,
success: bool,
error: Option<String>,
}
impl DependencyGraphInferenceBuilder {
pub fn prompt(mut self, prompt: impl Into<String>) -> Self {
self.prompt = prompt.into();
self
}
pub fn response(mut self, response: impl Into<String>) -> Self {
self.response = response.into();
self
}
pub fn available_actions(mut self, actions: Vec<String>) -> Self {
self.available_actions = actions;
self
}
pub fn discover_order(mut self, order: Vec<String>) -> Self {
self.discover_order = order;
self
}
pub fn not_discover_order(mut self, order: Vec<String>) -> Self {
self.not_discover_order = order;
self
}
pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = endpoint.into();
self
}
pub fn lora(mut self, lora: impl Into<String>) -> Self {
self.lora = Some(lora.into());
self
}
pub fn latency_ms(mut self, ms: u64) -> Self {
self.latency_ms = ms;
self
}
pub fn success(mut self) -> Self {
self.success = true;
self.error = None;
self
}
pub fn failure(mut self, error: impl Into<String>) -> Self {
self.success = false;
self.error = Some(error.into());
self
}
pub fn build(self) -> LearningEvent {
LearningEvent::DependencyGraphInference {
timestamp_ms: self.timestamp_ms,
prompt: self.prompt,
response: self.response,
available_actions: self.available_actions,
discover_order: self.discover_order,
not_discover_order: self.not_discover_order,
model: self.model,
endpoint: self.endpoint,
lora: self.lora,
latency_ms: self.latency_ms,
success: self.success,
error: self.error,
}
}
}
pub struct LearnStatsSnapshotBuilder {
timestamp_ms: u64,
scenario: String,
session_id: String,
stats_json: String,
outcome: LearnStatsOutcome,
total_ticks: u64,
total_actions: u64,
}
impl LearnStatsSnapshotBuilder {
pub fn session_id(mut self, id: impl Into<String>) -> Self {
self.session_id = id.into();
self
}
pub fn stats_json(mut self, json: impl Into<String>) -> Self {
self.stats_json = json.into();
self
}
pub fn success(mut self, score: f64) -> Self {
self.outcome = LearnStatsOutcome::Success { score };
self
}
pub fn failure(mut self, reason: impl Into<String>) -> Self {
self.outcome = LearnStatsOutcome::Failure {
reason: reason.into(),
};
self
}
pub fn timeout(mut self, partial_score: Option<f64>) -> Self {
self.outcome = LearnStatsOutcome::Timeout { partial_score };
self
}
pub fn total_ticks(mut self, ticks: u64) -> Self {
self.total_ticks = ticks;
self
}
pub fn total_actions(mut self, actions: u64) -> Self {
self.total_actions = actions;
self
}
pub fn build(self) -> LearningEvent {
LearningEvent::LearnStatsSnapshot {
timestamp_ms: self.timestamp_ms,
scenario: self.scenario,
session_id: self.session_id,
stats_json: self.stats_json,
outcome: self.outcome,
total_ticks: self.total_ticks,
total_actions: self.total_actions,
}
}
}
pub struct LearningEventChannel {
tx: broadcast::Sender<LearningEvent>,
enabled: AtomicBool,
current_tick: AtomicU64,
sync_buffer: Mutex<Vec<LearningEvent>>,
}
impl LearningEventChannel {
pub fn new(capacity: usize) -> Self {
let (tx, _) = broadcast::channel(capacity);
Self {
tx,
enabled: AtomicBool::new(false),
current_tick: AtomicU64::new(0),
sync_buffer: Mutex::new(Vec::new()),
}
}
pub fn global() -> &'static Self {
static INSTANCE: OnceLock<LearningEventChannel> = OnceLock::new();
INSTANCE.get_or_init(|| Self::new(256))
}
pub fn enable(&self) {
self.enabled.store(true, Ordering::Relaxed);
}
pub fn disable(&self) {
self.enabled.store(false, Ordering::Relaxed);
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::Relaxed)
}
pub fn set_tick(&self, tick: u64) {
self.current_tick.store(tick, Ordering::Relaxed);
}
pub fn current_tick(&self) -> u64 {
self.current_tick.load(Ordering::Relaxed)
}
pub fn emit(&self, event: LearningEvent) {
if self.enabled.load(Ordering::Relaxed) {
if let Ok(mut buffer) = self.sync_buffer.lock() {
buffer.push(event.clone());
}
let _ = self.tx.send(event);
}
}
pub fn drain_sync(&self) -> Vec<LearningEvent> {
if let Ok(mut buffer) = self.sync_buffer.lock() {
std::mem::take(&mut *buffer)
} else {
Vec::new()
}
}
pub fn subscribe(&self) -> broadcast::Receiver<LearningEvent> {
self.tx.subscribe()
}
pub fn receiver_count(&self) -> usize {
self.tx.receiver_count()
}
}
impl Default for LearningEventChannel {
fn default() -> Self {
Self::new(256)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_channel_disabled_by_default() {
let channel = LearningEventChannel::new(16);
assert!(!channel.is_enabled());
}
#[test]
fn test_channel_enable_disable() {
let channel = LearningEventChannel::new(16);
channel.enable();
assert!(channel.is_enabled());
channel.disable();
assert!(!channel.is_enabled());
}
#[tokio::test]
async fn test_channel_emit_when_enabled() {
let channel = LearningEventChannel::new(16);
channel.enable();
let mut rx = channel.subscribe();
let event = LearningEvent::strategy_advice(42, "TestAdvisor")
.current_strategy("ucb1")
.recommended("greedy")
.should_change(true)
.confidence(0.9)
.reason("test reason")
.frontier_count(10)
.total_visits(100)
.failure_rate(0.1)
.latency_ms(50)
.success()
.build();
channel.emit(event);
let received = rx.recv().await.unwrap();
match received {
LearningEvent::StrategyAdvice {
tick,
advisor,
should_change,
..
} => {
assert_eq!(tick, 42);
assert_eq!(advisor, "TestAdvisor");
assert!(should_change);
}
_ => panic!("Expected StrategyAdvice"),
}
}
#[tokio::test]
async fn test_channel_no_emit_when_disabled() {
let channel = LearningEventChannel::new(16);
let mut rx = channel.subscribe();
let event = LearningEvent::strategy_advice(0, "Test")
.current_strategy("ucb1")
.recommended("ucb1")
.build();
channel.emit(event);
let result = tokio::time::timeout(std::time::Duration::from_millis(10), rx.recv()).await;
assert!(result.is_err());
}
#[test]
fn test_tick_management() {
let channel = LearningEventChannel::new(16);
assert_eq!(channel.current_tick(), 0);
channel.set_tick(42);
assert_eq!(channel.current_tick(), 42);
channel.set_tick(100);
assert_eq!(channel.current_tick(), 100);
}
#[test]
fn test_drain_sync() {
let channel = LearningEventChannel::new(16);
channel.enable();
channel.emit(
LearningEvent::strategy_advice(1, "Advisor1")
.current_strategy("ucb1")
.recommended("greedy")
.build(),
);
channel.emit(
LearningEvent::strategy_advice(2, "Advisor2")
.current_strategy("greedy")
.recommended("thompson")
.build(),
);
let events = channel.drain_sync();
assert_eq!(events.len(), 2);
let t1 = match &events[0] {
LearningEvent::StrategyAdvice { tick, .. } => *tick,
_ => panic!("Expected StrategyAdvice"),
};
let t2 = match &events[1] {
LearningEvent::StrategyAdvice { tick, .. } => *tick,
_ => panic!("Expected StrategyAdvice"),
};
assert_eq!(t1, 1);
assert_eq!(t2, 2);
let events2 = channel.drain_sync();
assert!(events2.is_empty());
}
#[test]
fn test_drain_sync_disabled() {
let channel = LearningEventChannel::new(16);
channel.emit(
LearningEvent::strategy_advice(1, "Advisor")
.current_strategy("ucb1")
.recommended("ucb1")
.build(),
);
let events = channel.drain_sync();
assert!(events.is_empty());
}
}