use tokio::sync::{broadcast, mpsc};
use crate::events::{ActionEvent, LearningEvent};
use crate::learn::record::Record;
#[derive(Debug, Clone)]
pub struct EventSubscriberConfig {
pub batch_size: usize,
pub flush_interval_ms: Option<u64>,
}
impl Default for EventSubscriberConfig {
fn default() -> Self {
Self {
batch_size: 100,
flush_interval_ms: Some(1000), }
}
}
impl EventSubscriberConfig {
pub fn new() -> Self {
Self::default()
}
pub fn batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn flush_interval_ms(mut self, ms: u64) -> Self {
self.flush_interval_ms = Some(ms);
self
}
pub fn no_flush_interval(mut self) -> Self {
self.flush_interval_ms = None;
self
}
}
pub struct ActionEventSubscriber {
rx: broadcast::Receiver<ActionEvent>,
record_tx: mpsc::Sender<Vec<Record>>,
config: EventSubscriberConfig,
buffer: Vec<Record>,
}
impl ActionEventSubscriber {
pub fn new(rx: broadcast::Receiver<ActionEvent>, record_tx: mpsc::Sender<Vec<Record>>) -> Self {
Self::with_config(rx, record_tx, EventSubscriberConfig::default())
}
pub fn with_config(
rx: broadcast::Receiver<ActionEvent>,
record_tx: mpsc::Sender<Vec<Record>>,
config: EventSubscriberConfig,
) -> Self {
let batch_size = config.batch_size;
Self {
rx,
record_tx,
config,
buffer: Vec::with_capacity(batch_size),
}
}
pub async fn run(mut self) {
tracing::info!(
batch_size = self.config.batch_size,
flush_interval_ms = ?self.config.flush_interval_ms,
"ActionEventSubscriber started"
);
if let Some(interval_ms) = self.config.flush_interval_ms {
self.run_with_flush_interval(interval_ms).await;
} else {
self.run_batch_only().await;
}
self.flush().await;
tracing::info!("ActionEventSubscriber stopped");
}
async fn run_with_flush_interval(&mut self, interval_ms: u64) {
use std::time::Duration;
use tokio::time::{interval, Instant};
let mut flush_interval = interval(Duration::from_millis(interval_ms));
let mut last_flush = Instant::now();
loop {
tokio::select! {
result = self.rx.recv() => {
match result {
Ok(event) => {
self.buffer.push(Record::from(&event));
if self.buffer.len() >= self.config.batch_size {
if !self.flush().await {
return;
}
last_flush = Instant::now();
}
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!("ActionEvent channel closed");
return;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(lagged = n, "ActionEventSubscriber lagged behind");
}
}
}
_ = flush_interval.tick() => {
if !self.buffer.is_empty() && last_flush.elapsed().as_millis() as u64 >= interval_ms {
if !self.flush().await {
return;
}
last_flush = Instant::now();
}
}
}
}
}
async fn run_batch_only(&mut self) {
loop {
match self.rx.recv().await {
Ok(event) => {
self.buffer.push(Record::from(&event));
if self.buffer.len() >= self.config.batch_size && !self.flush().await {
return;
}
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!("ActionEvent channel closed");
return;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(lagged = n, "ActionEventSubscriber lagged behind");
}
}
}
}
async fn flush(&mut self) -> bool {
if self.buffer.is_empty() {
return true;
}
let records = std::mem::take(&mut self.buffer);
let count = records.len();
match self.record_tx.send(records).await {
Ok(()) => {
tracing::debug!(count, "Flushed ActionEvent records to LearningDaemon");
true
}
Err(_) => {
tracing::warn!("LearningDaemon channel closed");
false
}
}
}
}
pub struct LearningEventSubscriber {
rx: broadcast::Receiver<LearningEvent>,
record_tx: mpsc::Sender<Vec<Record>>,
config: EventSubscriberConfig,
buffer: Vec<Record>,
}
impl LearningEventSubscriber {
pub fn new(
rx: broadcast::Receiver<LearningEvent>,
record_tx: mpsc::Sender<Vec<Record>>,
) -> Self {
Self::with_config(rx, record_tx, EventSubscriberConfig::default())
}
pub fn with_config(
rx: broadcast::Receiver<LearningEvent>,
record_tx: mpsc::Sender<Vec<Record>>,
config: EventSubscriberConfig,
) -> Self {
let batch_size = config.batch_size;
Self {
rx,
record_tx,
config,
buffer: Vec::with_capacity(batch_size),
}
}
pub async fn run(mut self) {
tracing::info!(
batch_size = self.config.batch_size,
flush_interval_ms = ?self.config.flush_interval_ms,
"LearningEventSubscriber started"
);
if let Some(interval_ms) = self.config.flush_interval_ms {
self.run_with_flush_interval(interval_ms).await;
} else {
self.run_batch_only().await;
}
self.flush().await;
tracing::info!("LearningEventSubscriber stopped");
}
async fn run_with_flush_interval(&mut self, interval_ms: u64) {
use std::time::Duration;
use tokio::time::{interval, Instant};
let mut flush_interval = interval(Duration::from_millis(interval_ms));
let mut last_flush = Instant::now();
loop {
tokio::select! {
result = self.rx.recv() => {
match result {
Ok(event) => {
self.buffer.push(Record::from(&event));
if self.buffer.len() >= self.config.batch_size {
if !self.flush().await {
return;
}
last_flush = Instant::now();
}
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!("LearningEvent channel closed");
return;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(lagged = n, "LearningEventSubscriber lagged behind");
}
}
}
_ = flush_interval.tick() => {
if !self.buffer.is_empty() && last_flush.elapsed().as_millis() as u64 >= interval_ms {
if !self.flush().await {
return;
}
last_flush = Instant::now();
}
}
}
}
}
async fn run_batch_only(&mut self) {
loop {
match self.rx.recv().await {
Ok(event) => {
self.buffer.push(Record::from(&event));
if self.buffer.len() >= self.config.batch_size && !self.flush().await {
return;
}
}
Err(broadcast::error::RecvError::Closed) => {
tracing::debug!("LearningEvent channel closed");
return;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(lagged = n, "LearningEventSubscriber lagged behind");
}
}
}
}
async fn flush(&mut self) -> bool {
if self.buffer.is_empty() {
return true;
}
let records = std::mem::take(&mut self.buffer);
let count = records.len();
match self.record_tx.send(records).await {
Ok(()) => {
tracing::debug!(count, "Flushed LearningEvent records to LearningDaemon");
true
}
Err(_) => {
tracing::warn!("LearningDaemon channel closed");
false
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::events::{ActionEventBuilder, ActionEventResult, LearningEvent};
use crate::types::WorkerId;
use std::time::Duration;
fn make_action_event(tick: u64, action: &str) -> ActionEvent {
ActionEventBuilder::new(tick, WorkerId(0), action)
.result(ActionEventResult::success())
.duration(Duration::from_millis(10))
.build()
}
fn make_learning_event(model: &str) -> LearningEvent {
LearningEvent::dependency_graph_inference(model)
.prompt("test prompt")
.response("test response")
.discover_order(vec!["A".into(), "B".into()])
.success()
.build()
}
#[tokio::test]
async fn test_action_subscriber_batch() {
let (tx, rx) = broadcast::channel::<ActionEvent>(16);
let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
let config = EventSubscriberConfig::new()
.batch_size(3)
.no_flush_interval();
let subscriber = ActionEventSubscriber::with_config(rx, record_tx, config);
let handle = tokio::spawn(async move {
subscriber.run().await;
});
for i in 0..5 {
tx.send(make_action_event(i, &format!("Action{}", i)))
.unwrap();
}
tokio::time::sleep(Duration::from_millis(50)).await;
let batch = record_rx.try_recv().unwrap();
assert_eq!(batch.len(), 3);
drop(tx);
let _ = handle.await;
let batch = record_rx.try_recv().unwrap();
assert_eq!(batch.len(), 2);
}
#[tokio::test]
async fn test_action_subscriber_flush_interval() {
let (tx, rx) = broadcast::channel::<ActionEvent>(16);
let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
let config = EventSubscriberConfig::new()
.batch_size(100) .flush_interval_ms(50);
let subscriber = ActionEventSubscriber::with_config(rx, record_tx, config);
let handle = tokio::spawn(async move {
subscriber.run().await;
});
tx.send(make_action_event(0, "Action0")).unwrap();
tx.send(make_action_event(1, "Action1")).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let batch = record_rx.try_recv().unwrap();
assert_eq!(batch.len(), 2);
drop(tx);
let _ = handle.await;
}
#[tokio::test]
async fn test_action_subscriber_channel_closed() {
let (tx, rx) = broadcast::channel::<ActionEvent>(16);
let (record_tx, record_rx) = mpsc::channel::<Vec<Record>>(16);
let config = EventSubscriberConfig::new()
.batch_size(100)
.no_flush_interval();
let subscriber = ActionEventSubscriber::with_config(rx, record_tx, config);
let handle = tokio::spawn(async move {
subscriber.run().await;
});
tx.send(make_action_event(0, "Action0")).unwrap();
drop(record_rx);
tx.send(make_action_event(1, "Action1")).unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
drop(tx);
let _ = handle.await;
}
#[tokio::test]
async fn test_learning_subscriber_batch() {
let (tx, rx) = broadcast::channel::<LearningEvent>(16);
let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
let config = EventSubscriberConfig::new()
.batch_size(2)
.no_flush_interval();
let subscriber = LearningEventSubscriber::with_config(rx, record_tx, config);
let handle = tokio::spawn(async move {
subscriber.run().await;
});
for i in 0..3 {
tx.send(make_learning_event(&format!("model{}", i)))
.unwrap();
}
tokio::time::sleep(Duration::from_millis(50)).await;
let batch = record_rx.try_recv().unwrap();
assert_eq!(batch.len(), 2);
for record in &batch {
assert!(record.is_dependency_graph());
}
drop(tx);
let _ = handle.await;
let batch = record_rx.try_recv().unwrap();
assert_eq!(batch.len(), 1);
}
#[tokio::test]
async fn test_learning_subscriber_flush_interval() {
let (tx, rx) = broadcast::channel::<LearningEvent>(16);
let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
let config = EventSubscriberConfig::new()
.batch_size(100) .flush_interval_ms(50);
let subscriber = LearningEventSubscriber::with_config(rx, record_tx, config);
let handle = tokio::spawn(async move {
subscriber.run().await;
});
tx.send(make_learning_event("model")).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let batch = record_rx.try_recv().unwrap();
assert_eq!(batch.len(), 1);
assert!(batch[0].is_dependency_graph());
drop(tx);
let _ = handle.await;
}
#[tokio::test]
async fn test_learning_subscriber_converts_to_dependency_graph_record() {
let (tx, rx) = broadcast::channel::<LearningEvent>(16);
let (record_tx, mut record_rx) = mpsc::channel::<Vec<Record>>(16);
let config = EventSubscriberConfig::new()
.batch_size(1)
.no_flush_interval();
let subscriber = LearningEventSubscriber::with_config(rx, record_tx, config);
let handle = tokio::spawn(async move {
subscriber.run().await;
});
tx.send(make_learning_event("test-model")).unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let batch = record_rx.try_recv().unwrap();
assert_eq!(batch.len(), 1);
let record = batch[0].as_dependency_graph().unwrap();
assert_eq!(record.model, "test-model");
assert_eq!(record.prompt, "test prompt");
assert_eq!(record.discover_order, vec!["A", "B"]);
drop(tx);
let _ = handle.await;
}
#[tokio::test]
async fn test_subscriber_config() {
let config = EventSubscriberConfig::new()
.batch_size(50)
.flush_interval_ms(500);
assert_eq!(config.batch_size, 50);
assert_eq!(config.flush_interval_ms, Some(500));
let config2 = EventSubscriberConfig::new().no_flush_interval();
assert_eq!(config2.flush_interval_ms, None);
}
}