use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tracing::{debug, info, warn};
use crate::subscription::change_tracker::ChangeEvent;
use crate::subscription::subscription_manager::SubscriptionManager;
#[derive(Debug, Clone)]
pub struct BroadcasterStats {
pub received: u64,
pub forwarded: u64,
pub lagged: u64,
}
#[derive(Debug)]
struct BroadcasterState {
received: std::sync::atomic::AtomicU64,
forwarded: std::sync::atomic::AtomicU64,
lagged: std::sync::atomic::AtomicU64,
}
impl BroadcasterState {
fn new() -> Arc<Self> {
Arc::new(Self {
received: std::sync::atomic::AtomicU64::new(0),
forwarded: std::sync::atomic::AtomicU64::new(0),
lagged: std::sync::atomic::AtomicU64::new(0),
})
}
fn snapshot(&self) -> BroadcasterStats {
BroadcasterStats {
received: self.received.load(std::sync::atomic::Ordering::Relaxed),
forwarded: self.forwarded.load(std::sync::atomic::Ordering::Relaxed),
lagged: self.lagged.load(std::sync::atomic::Ordering::Relaxed),
}
}
}
pub struct Broadcaster {
state: Arc<BroadcasterState>,
shutdown_tx: tokio::sync::oneshot::Sender<()>,
task: JoinHandle<()>,
}
impl std::fmt::Debug for Broadcaster {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Broadcaster")
.field("stats", &self.state.snapshot())
.finish()
}
}
impl Broadcaster {
pub fn spawn(
source: broadcast::Receiver<Arc<ChangeEvent>>,
manager: Arc<SubscriptionManager>,
) -> Self {
let state = BroadcasterState::new();
let state_clone = Arc::clone(&state);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let task = tokio::spawn(run_broadcaster(source, manager, state_clone, shutdown_rx));
Self {
state,
shutdown_tx,
task,
}
}
pub fn shutdown(self) {
let _ = self.shutdown_tx.send(());
}
pub fn into_join_handle(self) -> JoinHandle<()> {
self.task
}
pub fn stats(&self) -> BroadcasterStats {
self.state.snapshot()
}
}
async fn run_broadcaster(
mut source: broadcast::Receiver<Arc<ChangeEvent>>,
manager: Arc<SubscriptionManager>,
state: Arc<BroadcasterState>,
mut shutdown_rx: tokio::sync::oneshot::Receiver<()>,
) {
info!("Broadcaster task started");
loop {
tokio::select! {
_ = &mut shutdown_rx => {
info!("Broadcaster received shutdown signal");
break;
}
result = source.recv() => {
match result {
Ok(event) => {
state.received.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
debug!(sequence = event.sequence, event_type = %event.event_type, "Broadcaster forwarding event");
manager.notify(event).await;
state.forwarded.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
Err(broadcast::error::RecvError::Lagged(count)) => {
warn!(count, "Broadcaster lagged; {} events were missed", count);
state.lagged.fetch_add(count, std::sync::atomic::Ordering::Relaxed);
}
Err(broadcast::error::RecvError::Closed) => {
info!("Source broadcast channel closed; broadcaster exiting");
break;
}
}
}
}
}
info!("Broadcaster task stopped");
}
pub struct BroadcasterBuilder {
tracker: Option<Arc<crate::subscription::change_tracker::ChangeTracker>>,
manager: Option<Arc<SubscriptionManager>>,
}
impl BroadcasterBuilder {
pub fn new() -> Self {
Self {
tracker: None,
manager: None,
}
}
pub fn tracker(
mut self,
tracker: Arc<crate::subscription::change_tracker::ChangeTracker>,
) -> Self {
self.tracker = Some(tracker);
self
}
pub fn manager(mut self, manager: Arc<SubscriptionManager>) -> Self {
self.manager = Some(manager);
self
}
pub fn build(self) -> Broadcaster {
let tracker = self
.tracker
.expect("BroadcasterBuilder: tracker must be set before build()");
let manager = self
.manager
.expect("BroadcasterBuilder: manager must be set before build()");
let receiver = tracker.subscribe();
Broadcaster::spawn(receiver, manager)
}
}
impl Default for BroadcasterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::subscription::change_tracker::ChangeTracker;
use crate::subscription::filter::SubscriptionFilter;
use std::time::Duration;
use tokio::time::{sleep, timeout};
fn make_stack() -> (Arc<ChangeTracker>, Arc<SubscriptionManager>, Broadcaster) {
let tracker = Arc::new(ChangeTracker::new(128));
let manager = Arc::new(SubscriptionManager::with_defaults());
let broadcaster = BroadcasterBuilder::new()
.tracker(Arc::clone(&tracker))
.manager(Arc::clone(&manager))
.build();
(tracker, manager, broadcaster)
}
#[tokio::test]
async fn test_broadcaster_forwards_event_end_to_end() {
let (tracker, manager, _broadcaster) = make_stack();
let (_id, mut rx) = manager.subscribe(SubscriptionFilter::all()).await;
sleep(Duration::from_millis(5)).await;
tracker.record_insert("http://ex.org/s", "http://ex.org/p", "obj", None);
let received = timeout(Duration::from_millis(200), rx.recv())
.await
.expect("no timeout")
.expect("received event");
assert_eq!(received.subject, "http://ex.org/s");
}
#[tokio::test]
async fn test_broadcaster_stats_increment() {
let (tracker, manager, broadcaster) = make_stack();
let (_id, _rx) = manager.subscribe(SubscriptionFilter::all()).await;
sleep(Duration::from_millis(5)).await;
tracker.record_insert("s", "p", "o", None);
tracker.record_delete("s", "p", "o", None);
sleep(Duration::from_millis(50)).await;
let stats = broadcaster.stats();
assert!(stats.received >= 2, "Expected at least 2 received events");
assert!(stats.forwarded >= 2, "Expected at least 2 forwarded events");
}
#[tokio::test]
async fn test_broadcaster_fanout_to_multiple_subscribers() {
let (tracker, manager, _broadcaster) = make_stack();
let (_id1, mut rx1) = manager.subscribe(SubscriptionFilter::all()).await;
let (_id2, mut rx2) = manager.subscribe(SubscriptionFilter::all()).await;
sleep(Duration::from_millis(5)).await;
tracker.record_insert("s", "p", "o", None);
let r1 = timeout(Duration::from_millis(200), rx1.recv()).await;
let r2 = timeout(Duration::from_millis(200), rx2.recv()).await;
assert!(
r1.is_ok() && r1.expect("should succeed").is_some(),
"rx1 should receive"
);
assert!(
r2.is_ok() && r2.expect("should succeed").is_some(),
"rx2 should receive"
);
}
#[tokio::test]
async fn test_broadcaster_filtered_fanout() {
let (tracker, manager, _broadcaster) = make_stack();
let insert_filter = SubscriptionFilter::inserts_only();
let (_id, mut rx) = manager.subscribe(insert_filter).await;
sleep(Duration::from_millis(5)).await;
tracker.record_delete("s", "p", "o", None);
tracker.record_insert("s2", "p", "o", None);
let received = timeout(Duration::from_millis(200), rx.recv())
.await
.expect("no timeout")
.expect("received");
assert_eq!(
received.event_type,
crate::subscription::change_tracker::ChangeType::Insert
);
}
#[tokio::test]
async fn test_broadcaster_shutdown_stops_task() {
let tracker = Arc::new(ChangeTracker::new(128));
let manager = Arc::new(SubscriptionManager::with_defaults());
let receiver = tracker.subscribe();
let broadcaster = Broadcaster::spawn(receiver, Arc::clone(&manager));
sleep(Duration::from_millis(5)).await;
let handle = broadcaster.into_join_handle();
assert!(!handle.is_finished());
}
#[tokio::test]
async fn test_broadcaster_builder_defaults() {
let tracker = Arc::new(ChangeTracker::new(64));
let manager = Arc::new(SubscriptionManager::with_defaults());
let _broadcaster = BroadcasterBuilder::default()
.tracker(Arc::clone(&tracker))
.manager(Arc::clone(&manager))
.build();
let stats = _broadcaster.stats();
assert_eq!(stats.received, 0);
assert_eq!(stats.forwarded, 0);
assert_eq!(stats.lagged, 0);
}
}