use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum DispatchMode {
Broadcast,
#[default]
Channel,
}
#[async_trait]
pub trait ChangeDispatcher<T>: Send + Sync
where
T: Clone + Send + Sync + 'static,
{
async fn dispatch_change(&self, change: Arc<T>) -> Result<()>;
async fn dispatch_changes(&self, changes: Vec<Arc<T>>) -> Result<()> {
for change in changes {
self.dispatch_change(change).await?;
}
Ok(())
}
async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>>;
}
#[async_trait]
pub trait ChangeReceiver<T>: Send + Sync
where
T: Clone + Send + Sync + 'static,
{
async fn recv(&mut self) -> Result<Arc<T>>;
}
pub struct BroadcastChangeDispatcher<T>
where
T: Clone + Send + Sync + 'static,
{
tx: broadcast::Sender<Arc<T>>,
_capacity: usize,
}
impl<T> BroadcastChangeDispatcher<T>
where
T: Clone + Send + Sync + 'static,
{
pub fn new(capacity: usize) -> Self {
let (tx, _) = broadcast::channel(capacity);
Self {
tx,
_capacity: capacity,
}
}
}
#[async_trait]
impl<T> ChangeDispatcher<T> for BroadcastChangeDispatcher<T>
where
T: Clone + Send + Sync + 'static,
{
async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
let _ = self.tx.send(change);
Ok(())
}
async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
let rx = self.tx.subscribe();
Ok(Box::new(BroadcastChangeReceiver { rx }))
}
}
pub struct BroadcastChangeReceiver<T>
where
T: Clone + Send + Sync + 'static,
{
rx: broadcast::Receiver<Arc<T>>,
}
#[async_trait]
impl<T> ChangeReceiver<T> for BroadcastChangeReceiver<T>
where
T: Clone + Send + Sync + 'static,
{
async fn recv(&mut self) -> Result<Arc<T>> {
loop {
match self.rx.recv().await {
Ok(change) => return Ok(change),
Err(broadcast::error::RecvError::Closed) => {
return Err(anyhow::anyhow!("Broadcast channel closed"));
}
Err(broadcast::error::RecvError::Lagged(n)) => {
log::warn!("Broadcast receiver lagged by {n} messages");
}
}
}
}
}
pub struct ChannelChangeDispatcher<T>
where
T: Clone + Send + Sync + 'static,
{
tx: mpsc::Sender<Arc<T>>,
rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<Arc<T>>>>>,
_capacity: usize,
}
impl<T> ChannelChangeDispatcher<T>
where
T: Clone + Send + Sync + 'static,
{
pub fn new(capacity: usize) -> Self {
let (tx, rx) = mpsc::channel(capacity);
Self {
tx,
rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
_capacity: capacity,
}
}
}
#[async_trait]
impl<T> ChangeDispatcher<T> for ChannelChangeDispatcher<T>
where
T: Clone + Send + Sync + 'static,
{
async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
self.tx
.send(change)
.await
.map_err(|_| anyhow::anyhow!("Failed to send on channel"))?;
Ok(())
}
async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
let mut rx_opt = self.rx.lock().await;
let rx = rx_opt.take().ok_or_else(|| {
anyhow::anyhow!("Receiver already created for this channel dispatcher")
})?;
Ok(Box::new(ChannelChangeReceiver { rx }))
}
}
pub struct ChannelChangeReceiver<T>
where
T: Clone + Send + Sync + 'static,
{
rx: mpsc::Receiver<Arc<T>>,
}
#[async_trait]
impl<T> ChangeReceiver<T> for ChannelChangeReceiver<T>
where
T: Clone + Send + Sync + 'static,
{
async fn recv(&mut self) -> Result<Arc<T>> {
self.rx
.recv()
.await
.ok_or_else(|| anyhow::anyhow!("Channel closed"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, PartialEq)]
struct TestMessage {
id: u32,
content: String,
}
#[tokio::test]
async fn test_broadcast_dispatcher_single_receiver() {
let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
let mut receiver = dispatcher.create_receiver().await.unwrap();
let msg = Arc::new(TestMessage {
id: 1,
content: "test".to_string(),
});
dispatcher.dispatch_change(msg.clone()).await.unwrap();
let received = receiver.recv().await.unwrap();
assert_eq!(*received, *msg);
}
#[tokio::test]
async fn test_broadcast_dispatcher_multiple_receivers() {
let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
let mut receiver1 = dispatcher.create_receiver().await.unwrap();
let mut receiver2 = dispatcher.create_receiver().await.unwrap();
let msg = Arc::new(TestMessage {
id: 1,
content: "broadcast".to_string(),
});
dispatcher.dispatch_change(msg.clone()).await.unwrap();
let received1 = receiver1.recv().await.unwrap();
let received2 = receiver2.recv().await.unwrap();
assert_eq!(*received1, *msg);
assert_eq!(*received2, *msg);
}
#[tokio::test]
async fn test_broadcast_dispatcher_dispatch_changes() {
let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
let mut receiver = dispatcher.create_receiver().await.unwrap();
let messages = vec![
Arc::new(TestMessage {
id: 1,
content: "first".to_string(),
}),
Arc::new(TestMessage {
id: 2,
content: "second".to_string(),
}),
Arc::new(TestMessage {
id: 3,
content: "third".to_string(),
}),
];
dispatcher.dispatch_changes(messages.clone()).await.unwrap();
for expected in messages {
let received = receiver.recv().await.unwrap();
assert_eq!(*received, *expected);
}
}
#[tokio::test]
async fn test_channel_dispatcher_single_receiver() {
let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
let mut receiver = dispatcher.create_receiver().await.unwrap();
let msg = Arc::new(TestMessage {
id: 1,
content: "channel".to_string(),
});
dispatcher.dispatch_change(msg.clone()).await.unwrap();
let received = receiver.recv().await.unwrap();
assert_eq!(*received, *msg);
}
#[tokio::test]
async fn test_channel_dispatcher_only_one_receiver() {
let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
let _receiver1 = dispatcher.create_receiver().await.unwrap();
let result = dispatcher.create_receiver().await;
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Receiver already created"));
}
}
#[tokio::test]
async fn test_channel_dispatcher_dispatch_changes() {
let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
let mut receiver = dispatcher.create_receiver().await.unwrap();
let messages = vec![
Arc::new(TestMessage {
id: 1,
content: "first".to_string(),
}),
Arc::new(TestMessage {
id: 2,
content: "second".to_string(),
}),
];
dispatcher.dispatch_changes(messages.clone()).await.unwrap();
for expected in messages {
let received = receiver.recv().await.unwrap();
assert_eq!(*received, *expected);
}
}
#[tokio::test]
async fn test_broadcast_receiver_handles_lag() {
let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(2);
let mut receiver = dispatcher.create_receiver().await.unwrap();
for i in 0..5 {
let msg = Arc::new(TestMessage {
id: i,
content: format!("msg{i}"),
});
dispatcher.dispatch_change(msg).await.unwrap();
}
tokio::task::yield_now().await;
let result = receiver.recv().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_dispatch_mode_default() {
assert_eq!(DispatchMode::default(), DispatchMode::Channel);
}
#[tokio::test]
async fn test_dispatch_mode_serialization() {
let mode = DispatchMode::Broadcast;
let json = serde_json::to_string(&mode).unwrap();
assert_eq!(json, "\"broadcast\"");
let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, DispatchMode::Broadcast);
let mode = DispatchMode::Channel;
let json = serde_json::to_string(&mode).unwrap();
assert_eq!(json, "\"channel\"");
let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, DispatchMode::Channel);
}
}