drasi_lib/channels/
dispatcher.rs1use anyhow::Result;
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18use std::fmt::Debug;
19use std::sync::Arc;
20use tokio::sync::{broadcast, mpsc};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
172#[serde(rename_all = "lowercase")]
173pub enum DispatchMode {
174 Broadcast,
176 #[default]
178 Channel,
179}
180
181#[async_trait]
183pub trait ChangeDispatcher<T>: Send + Sync
184where
185 T: Clone + Send + Sync + 'static,
186{
187 async fn dispatch_change(&self, change: Arc<T>) -> Result<()>;
189
190 async fn dispatch_changes(&self, changes: Vec<Arc<T>>) -> Result<()> {
192 for change in changes {
193 self.dispatch_change(change).await?;
194 }
195 Ok(())
196 }
197
198 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>>;
200}
201
202#[async_trait]
204pub trait ChangeReceiver<T>: Send + Sync
205where
206 T: Clone + Send + Sync + 'static,
207{
208 async fn recv(&mut self) -> Result<Arc<T>>;
210}
211
212pub struct BroadcastChangeDispatcher<T>
214where
215 T: Clone + Send + Sync + 'static,
216{
217 tx: broadcast::Sender<Arc<T>>,
218 _capacity: usize,
219}
220
221impl<T> BroadcastChangeDispatcher<T>
222where
223 T: Clone + Send + Sync + 'static,
224{
225 pub fn new(capacity: usize) -> Self {
227 let (tx, _) = broadcast::channel(capacity);
228 Self {
229 tx,
230 _capacity: capacity,
231 }
232 }
233}
234
235#[async_trait]
236impl<T> ChangeDispatcher<T> for BroadcastChangeDispatcher<T>
237where
238 T: Clone + Send + Sync + 'static,
239{
240 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
241 let _ = self.tx.send(change);
243 Ok(())
244 }
245
246 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
247 let rx = self.tx.subscribe();
248 Ok(Box::new(BroadcastChangeReceiver { rx }))
249 }
250}
251
252pub struct BroadcastChangeReceiver<T>
254where
255 T: Clone + Send + Sync + 'static,
256{
257 rx: broadcast::Receiver<Arc<T>>,
258}
259
260#[async_trait]
261impl<T> ChangeReceiver<T> for BroadcastChangeReceiver<T>
262where
263 T: Clone + Send + Sync + 'static,
264{
265 async fn recv(&mut self) -> Result<Arc<T>> {
266 loop {
267 match self.rx.recv().await {
268 Ok(change) => return Ok(change),
269 Err(broadcast::error::RecvError::Closed) => {
270 return Err(anyhow::anyhow!("Broadcast channel closed"));
271 }
272 Err(broadcast::error::RecvError::Lagged(n)) => {
273 log::warn!("Broadcast receiver lagged by {n} messages");
274 }
276 }
277 }
278 }
279}
280
281pub struct ChannelChangeDispatcher<T>
283where
284 T: Clone + Send + Sync + 'static,
285{
286 tx: mpsc::Sender<Arc<T>>,
287 rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<Arc<T>>>>>,
288 _capacity: usize,
289}
290
291impl<T> ChannelChangeDispatcher<T>
292where
293 T: Clone + Send + Sync + 'static,
294{
295 pub fn new(capacity: usize) -> Self {
297 let (tx, rx) = mpsc::channel(capacity);
298 Self {
299 tx,
300 rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
301 _capacity: capacity,
302 }
303 }
304}
305
306#[async_trait]
307impl<T> ChangeDispatcher<T> for ChannelChangeDispatcher<T>
308where
309 T: Clone + Send + Sync + 'static,
310{
311 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
312 self.tx
313 .send(change)
314 .await
315 .map_err(|_| anyhow::anyhow!("Failed to send on channel"))?;
316 Ok(())
317 }
318
319 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
320 let mut rx_opt = self.rx.lock().await;
323 let rx = rx_opt.take().ok_or_else(|| {
324 anyhow::anyhow!("Receiver already created for this channel dispatcher")
325 })?;
326 Ok(Box::new(ChannelChangeReceiver { rx }))
327 }
328}
329
330pub struct ChannelChangeReceiver<T>
332where
333 T: Clone + Send + Sync + 'static,
334{
335 rx: mpsc::Receiver<Arc<T>>,
336}
337
338impl<T> ChannelChangeReceiver<T>
339where
340 T: Clone + Send + Sync + 'static,
341{
342 pub fn new(rx: mpsc::Receiver<Arc<T>>) -> Self {
344 Self { rx }
345 }
346}
347
348#[async_trait]
349impl<T> ChangeReceiver<T> for ChannelChangeReceiver<T>
350where
351 T: Clone + Send + Sync + 'static,
352{
353 async fn recv(&mut self) -> Result<Arc<T>> {
354 self.rx
355 .recv()
356 .await
357 .ok_or_else(|| anyhow::anyhow!("Channel closed"))
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 #[derive(Clone, Debug, PartialEq)]
365 struct TestMessage {
366 id: u32,
367 content: String,
368 }
369
370 #[tokio::test]
371 async fn test_broadcast_dispatcher_single_receiver() {
372 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
373 let mut receiver = dispatcher.create_receiver().await.unwrap();
374
375 let msg = Arc::new(TestMessage {
376 id: 1,
377 content: "test".to_string(),
378 });
379
380 dispatcher.dispatch_change(msg.clone()).await.unwrap();
381
382 let received = receiver.recv().await.unwrap();
383 assert_eq!(*received, *msg);
384 }
385
386 #[tokio::test]
387 async fn test_broadcast_dispatcher_multiple_receivers() {
388 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
389 let mut receiver1 = dispatcher.create_receiver().await.unwrap();
390 let mut receiver2 = dispatcher.create_receiver().await.unwrap();
391
392 let msg = Arc::new(TestMessage {
393 id: 1,
394 content: "broadcast".to_string(),
395 });
396
397 dispatcher.dispatch_change(msg.clone()).await.unwrap();
398
399 let received1 = receiver1.recv().await.unwrap();
400 let received2 = receiver2.recv().await.unwrap();
401
402 assert_eq!(*received1, *msg);
403 assert_eq!(*received2, *msg);
404 }
405
406 #[tokio::test]
407 async fn test_broadcast_dispatcher_dispatch_changes() {
408 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
409 let mut receiver = dispatcher.create_receiver().await.unwrap();
410
411 let messages = vec![
412 Arc::new(TestMessage {
413 id: 1,
414 content: "first".to_string(),
415 }),
416 Arc::new(TestMessage {
417 id: 2,
418 content: "second".to_string(),
419 }),
420 Arc::new(TestMessage {
421 id: 3,
422 content: "third".to_string(),
423 }),
424 ];
425
426 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
427
428 for expected in messages {
429 let received = receiver.recv().await.unwrap();
430 assert_eq!(*received, *expected);
431 }
432 }
433
434 #[tokio::test]
435 async fn test_channel_dispatcher_single_receiver() {
436 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
437 let mut receiver = dispatcher.create_receiver().await.unwrap();
438
439 let msg = Arc::new(TestMessage {
440 id: 1,
441 content: "channel".to_string(),
442 });
443
444 dispatcher.dispatch_change(msg.clone()).await.unwrap();
445
446 let received = receiver.recv().await.unwrap();
447 assert_eq!(*received, *msg);
448 }
449
450 #[tokio::test]
451 async fn test_channel_dispatcher_only_one_receiver() {
452 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
453 let _receiver1 = dispatcher.create_receiver().await.unwrap();
454
455 let result = dispatcher.create_receiver().await;
457 assert!(result.is_err());
458 if let Err(e) = result {
459 assert!(e.to_string().contains("Receiver already created"));
460 }
461 }
462
463 #[tokio::test]
464 async fn test_channel_dispatcher_dispatch_changes() {
465 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
466 let mut receiver = dispatcher.create_receiver().await.unwrap();
467
468 let messages = vec![
469 Arc::new(TestMessage {
470 id: 1,
471 content: "first".to_string(),
472 }),
473 Arc::new(TestMessage {
474 id: 2,
475 content: "second".to_string(),
476 }),
477 ];
478
479 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
480
481 for expected in messages {
482 let received = receiver.recv().await.unwrap();
483 assert_eq!(*received, *expected);
484 }
485 }
486
487 #[tokio::test]
488 async fn test_broadcast_receiver_handles_lag() {
489 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(2);
491 let mut receiver = dispatcher.create_receiver().await.unwrap();
492
493 for i in 0..5 {
495 let msg = Arc::new(TestMessage {
496 id: i,
497 content: format!("msg{i}"),
498 });
499 dispatcher.dispatch_change(msg).await.unwrap();
500 }
501
502 tokio::task::yield_now().await;
504
505 let result = receiver.recv().await;
507 assert!(result.is_ok());
508 }
509
510 #[tokio::test]
511 async fn test_dispatch_mode_default() {
512 assert_eq!(DispatchMode::default(), DispatchMode::Channel);
513 }
514
515 #[tokio::test]
516 async fn test_dispatch_mode_serialization() {
517 let mode = DispatchMode::Broadcast;
518 let json = serde_json::to_string(&mode).unwrap();
519 assert_eq!(json, "\"broadcast\"");
520
521 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
522 assert_eq!(deserialized, DispatchMode::Broadcast);
523
524 let mode = DispatchMode::Channel;
525 let json = serde_json::to_string(&mode).unwrap();
526 assert_eq!(json, "\"channel\"");
527
528 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
529 assert_eq!(deserialized, DispatchMode::Channel);
530 }
531}