drasi_lib/channels/
dispatcher.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::sync::Arc;
6use tokio::sync::{broadcast, mpsc};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
158#[serde(rename_all = "lowercase")]
159pub enum DispatchMode {
160 Broadcast,
162 #[default]
164 Channel,
165}
166
167#[async_trait]
169pub trait ChangeDispatcher<T>: Send + Sync
170where
171 T: Clone + Send + Sync + 'static,
172{
173 async fn dispatch_change(&self, change: Arc<T>) -> Result<()>;
175
176 async fn dispatch_changes(&self, changes: Vec<Arc<T>>) -> Result<()> {
178 for change in changes {
179 self.dispatch_change(change).await?;
180 }
181 Ok(())
182 }
183
184 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>>;
186}
187
188#[async_trait]
190pub trait ChangeReceiver<T>: Send + Sync
191where
192 T: Clone + Send + Sync + 'static,
193{
194 async fn recv(&mut self) -> Result<Arc<T>>;
196}
197
198pub struct BroadcastChangeDispatcher<T>
200where
201 T: Clone + Send + Sync + 'static,
202{
203 tx: broadcast::Sender<Arc<T>>,
204 _capacity: usize,
205}
206
207impl<T> BroadcastChangeDispatcher<T>
208where
209 T: Clone + Send + Sync + 'static,
210{
211 pub fn new(capacity: usize) -> Self {
213 let (tx, _) = broadcast::channel(capacity);
214 Self {
215 tx,
216 _capacity: capacity,
217 }
218 }
219}
220
221#[async_trait]
222impl<T> ChangeDispatcher<T> for BroadcastChangeDispatcher<T>
223where
224 T: Clone + Send + Sync + 'static,
225{
226 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
227 let _ = self.tx.send(change);
229 Ok(())
230 }
231
232 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
233 let rx = self.tx.subscribe();
234 Ok(Box::new(BroadcastChangeReceiver { rx }))
235 }
236}
237
238pub struct BroadcastChangeReceiver<T>
240where
241 T: Clone + Send + Sync + 'static,
242{
243 rx: broadcast::Receiver<Arc<T>>,
244}
245
246#[async_trait]
247impl<T> ChangeReceiver<T> for BroadcastChangeReceiver<T>
248where
249 T: Clone + Send + Sync + 'static,
250{
251 async fn recv(&mut self) -> Result<Arc<T>> {
252 loop {
253 match self.rx.recv().await {
254 Ok(change) => return Ok(change),
255 Err(broadcast::error::RecvError::Closed) => {
256 return Err(anyhow::anyhow!("Broadcast channel closed"));
257 }
258 Err(broadcast::error::RecvError::Lagged(n)) => {
259 log::warn!("Broadcast receiver lagged by {n} messages");
260 }
262 }
263 }
264 }
265}
266
267pub struct ChannelChangeDispatcher<T>
269where
270 T: Clone + Send + Sync + 'static,
271{
272 tx: mpsc::Sender<Arc<T>>,
273 rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<Arc<T>>>>>,
274 _capacity: usize,
275}
276
277impl<T> ChannelChangeDispatcher<T>
278where
279 T: Clone + Send + Sync + 'static,
280{
281 pub fn new(capacity: usize) -> Self {
283 let (tx, rx) = mpsc::channel(capacity);
284 Self {
285 tx,
286 rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
287 _capacity: capacity,
288 }
289 }
290}
291
292#[async_trait]
293impl<T> ChangeDispatcher<T> for ChannelChangeDispatcher<T>
294where
295 T: Clone + Send + Sync + 'static,
296{
297 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
298 self.tx
299 .send(change)
300 .await
301 .map_err(|_| anyhow::anyhow!("Failed to send on channel"))?;
302 Ok(())
303 }
304
305 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
306 let mut rx_opt = self.rx.lock().await;
309 let rx = rx_opt.take().ok_or_else(|| {
310 anyhow::anyhow!("Receiver already created for this channel dispatcher")
311 })?;
312 Ok(Box::new(ChannelChangeReceiver { rx }))
313 }
314}
315
316pub struct ChannelChangeReceiver<T>
318where
319 T: Clone + Send + Sync + 'static,
320{
321 rx: mpsc::Receiver<Arc<T>>,
322}
323
324impl<T> ChannelChangeReceiver<T>
325where
326 T: Clone + Send + Sync + 'static,
327{
328 pub fn new(rx: mpsc::Receiver<Arc<T>>) -> Self {
330 Self { rx }
331 }
332}
333
334#[async_trait]
335impl<T> ChangeReceiver<T> for ChannelChangeReceiver<T>
336where
337 T: Clone + Send + Sync + 'static,
338{
339 async fn recv(&mut self) -> Result<Arc<T>> {
340 self.rx
341 .recv()
342 .await
343 .ok_or_else(|| anyhow::anyhow!("Channel closed"))
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 #[derive(Clone, Debug, PartialEq)]
351 struct TestMessage {
352 id: u32,
353 content: String,
354 }
355
356 #[tokio::test]
357 async fn test_broadcast_dispatcher_single_receiver() {
358 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
359 let mut receiver = dispatcher.create_receiver().await.unwrap();
360
361 let msg = Arc::new(TestMessage {
362 id: 1,
363 content: "test".to_string(),
364 });
365
366 dispatcher.dispatch_change(msg.clone()).await.unwrap();
367
368 let received = receiver.recv().await.unwrap();
369 assert_eq!(*received, *msg);
370 }
371
372 #[tokio::test]
373 async fn test_broadcast_dispatcher_multiple_receivers() {
374 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
375 let mut receiver1 = dispatcher.create_receiver().await.unwrap();
376 let mut receiver2 = dispatcher.create_receiver().await.unwrap();
377
378 let msg = Arc::new(TestMessage {
379 id: 1,
380 content: "broadcast".to_string(),
381 });
382
383 dispatcher.dispatch_change(msg.clone()).await.unwrap();
384
385 let received1 = receiver1.recv().await.unwrap();
386 let received2 = receiver2.recv().await.unwrap();
387
388 assert_eq!(*received1, *msg);
389 assert_eq!(*received2, *msg);
390 }
391
392 #[tokio::test]
393 async fn test_broadcast_dispatcher_dispatch_changes() {
394 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
395 let mut receiver = dispatcher.create_receiver().await.unwrap();
396
397 let messages = vec![
398 Arc::new(TestMessage {
399 id: 1,
400 content: "first".to_string(),
401 }),
402 Arc::new(TestMessage {
403 id: 2,
404 content: "second".to_string(),
405 }),
406 Arc::new(TestMessage {
407 id: 3,
408 content: "third".to_string(),
409 }),
410 ];
411
412 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
413
414 for expected in messages {
415 let received = receiver.recv().await.unwrap();
416 assert_eq!(*received, *expected);
417 }
418 }
419
420 #[tokio::test]
421 async fn test_channel_dispatcher_single_receiver() {
422 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
423 let mut receiver = dispatcher.create_receiver().await.unwrap();
424
425 let msg = Arc::new(TestMessage {
426 id: 1,
427 content: "channel".to_string(),
428 });
429
430 dispatcher.dispatch_change(msg.clone()).await.unwrap();
431
432 let received = receiver.recv().await.unwrap();
433 assert_eq!(*received, *msg);
434 }
435
436 #[tokio::test]
437 async fn test_channel_dispatcher_only_one_receiver() {
438 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
439 let _receiver1 = dispatcher.create_receiver().await.unwrap();
440
441 let result = dispatcher.create_receiver().await;
443 assert!(result.is_err());
444 if let Err(e) = result {
445 assert!(e.to_string().contains("Receiver already created"));
446 }
447 }
448
449 #[tokio::test]
450 async fn test_channel_dispatcher_dispatch_changes() {
451 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
452 let mut receiver = dispatcher.create_receiver().await.unwrap();
453
454 let messages = vec![
455 Arc::new(TestMessage {
456 id: 1,
457 content: "first".to_string(),
458 }),
459 Arc::new(TestMessage {
460 id: 2,
461 content: "second".to_string(),
462 }),
463 ];
464
465 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
466
467 for expected in messages {
468 let received = receiver.recv().await.unwrap();
469 assert_eq!(*received, *expected);
470 }
471 }
472
473 #[tokio::test]
474 async fn test_broadcast_receiver_handles_lag() {
475 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(2);
477 let mut receiver = dispatcher.create_receiver().await.unwrap();
478
479 for i in 0..5 {
481 let msg = Arc::new(TestMessage {
482 id: i,
483 content: format!("msg{i}"),
484 });
485 dispatcher.dispatch_change(msg).await.unwrap();
486 }
487
488 tokio::task::yield_now().await;
490
491 let result = receiver.recv().await;
493 assert!(result.is_ok());
494 }
495
496 #[tokio::test]
497 async fn test_dispatch_mode_default() {
498 assert_eq!(DispatchMode::default(), DispatchMode::Channel);
499 }
500
501 #[tokio::test]
502 async fn test_dispatch_mode_serialization() {
503 let mode = DispatchMode::Broadcast;
504 let json = serde_json::to_string(&mode).unwrap();
505 assert_eq!(json, "\"broadcast\"");
506
507 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
508 assert_eq!(deserialized, DispatchMode::Broadcast);
509
510 let mode = DispatchMode::Channel;
511 let json = serde_json::to_string(&mode).unwrap();
512 assert_eq!(json, "\"channel\"");
513
514 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
515 assert_eq!(deserialized, DispatchMode::Channel);
516 }
517}