drasi_lib/channels/
dispatcher.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::sync::Arc;
6use tokio::sync::{broadcast, mpsc};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
158#[serde(rename_all = "lowercase")]
159pub enum DispatchMode {
160 Broadcast,
162 #[default]
164 Channel,
165}
166
167#[async_trait]
169pub trait ChangeDispatcher<T>: Send + Sync
170where
171 T: Clone + Send + Sync + 'static,
172{
173 async fn dispatch_change(&self, change: Arc<T>) -> Result<()>;
175
176 async fn dispatch_changes(&self, changes: Vec<Arc<T>>) -> Result<()> {
178 for change in changes {
179 self.dispatch_change(change).await?;
180 }
181 Ok(())
182 }
183
184 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>>;
186}
187
188#[async_trait]
190pub trait ChangeReceiver<T>: Send + Sync
191where
192 T: Clone + Send + Sync + 'static,
193{
194 async fn recv(&mut self) -> Result<Arc<T>>;
196}
197
198pub struct BroadcastChangeDispatcher<T>
200where
201 T: Clone + Send + Sync + 'static,
202{
203 tx: broadcast::Sender<Arc<T>>,
204 _capacity: usize,
205}
206
207impl<T> BroadcastChangeDispatcher<T>
208where
209 T: Clone + Send + Sync + 'static,
210{
211 pub fn new(capacity: usize) -> Self {
213 let (tx, _) = broadcast::channel(capacity);
214 Self {
215 tx,
216 _capacity: capacity,
217 }
218 }
219}
220
221#[async_trait]
222impl<T> ChangeDispatcher<T> for BroadcastChangeDispatcher<T>
223where
224 T: Clone + Send + Sync + 'static,
225{
226 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
227 let _ = self.tx.send(change);
229 Ok(())
230 }
231
232 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
233 let rx = self.tx.subscribe();
234 Ok(Box::new(BroadcastChangeReceiver { rx }))
235 }
236}
237
238pub struct BroadcastChangeReceiver<T>
240where
241 T: Clone + Send + Sync + 'static,
242{
243 rx: broadcast::Receiver<Arc<T>>,
244}
245
246#[async_trait]
247impl<T> ChangeReceiver<T> for BroadcastChangeReceiver<T>
248where
249 T: Clone + Send + Sync + 'static,
250{
251 async fn recv(&mut self) -> Result<Arc<T>> {
252 match self.rx.recv().await {
253 Ok(change) => Ok(change),
254 Err(broadcast::error::RecvError::Closed) => {
255 Err(anyhow::anyhow!("Broadcast channel closed"))
256 }
257 Err(broadcast::error::RecvError::Lagged(n)) => {
258 log::warn!("Broadcast receiver lagged by {n} messages");
260 self.recv().await
262 }
263 }
264 }
265}
266
267pub struct ChannelChangeDispatcher<T>
269where
270 T: Clone + Send + Sync + 'static,
271{
272 tx: mpsc::Sender<Arc<T>>,
273 rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<Arc<T>>>>>,
274 _capacity: usize,
275}
276
277impl<T> ChannelChangeDispatcher<T>
278where
279 T: Clone + Send + Sync + 'static,
280{
281 pub fn new(capacity: usize) -> Self {
283 let (tx, rx) = mpsc::channel(capacity);
284 Self {
285 tx,
286 rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
287 _capacity: capacity,
288 }
289 }
290}
291
292#[async_trait]
293impl<T> ChangeDispatcher<T> for ChannelChangeDispatcher<T>
294where
295 T: Clone + Send + Sync + 'static,
296{
297 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
298 self.tx
299 .send(change)
300 .await
301 .map_err(|_| anyhow::anyhow!("Failed to send on channel"))?;
302 Ok(())
303 }
304
305 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
306 let mut rx_opt = self.rx.lock().await;
309 let rx = rx_opt.take().ok_or_else(|| {
310 anyhow::anyhow!("Receiver already created for this channel dispatcher")
311 })?;
312 Ok(Box::new(ChannelChangeReceiver { rx }))
313 }
314}
315
316pub struct ChannelChangeReceiver<T>
318where
319 T: Clone + Send + Sync + 'static,
320{
321 rx: mpsc::Receiver<Arc<T>>,
322}
323
324#[async_trait]
325impl<T> ChangeReceiver<T> for ChannelChangeReceiver<T>
326where
327 T: Clone + Send + Sync + 'static,
328{
329 async fn recv(&mut self) -> Result<Arc<T>> {
330 self.rx
331 .recv()
332 .await
333 .ok_or_else(|| anyhow::anyhow!("Channel closed"))
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use tokio::time::{sleep, Duration};
341
342 #[derive(Clone, Debug, PartialEq)]
343 struct TestMessage {
344 id: u32,
345 content: String,
346 }
347
348 #[tokio::test]
349 async fn test_broadcast_dispatcher_single_receiver() {
350 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
351 let mut receiver = dispatcher.create_receiver().await.unwrap();
352
353 let msg = Arc::new(TestMessage {
354 id: 1,
355 content: "test".to_string(),
356 });
357
358 dispatcher.dispatch_change(msg.clone()).await.unwrap();
359
360 let received = receiver.recv().await.unwrap();
361 assert_eq!(*received, *msg);
362 }
363
364 #[tokio::test]
365 async fn test_broadcast_dispatcher_multiple_receivers() {
366 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
367 let mut receiver1 = dispatcher.create_receiver().await.unwrap();
368 let mut receiver2 = dispatcher.create_receiver().await.unwrap();
369
370 let msg = Arc::new(TestMessage {
371 id: 1,
372 content: "broadcast".to_string(),
373 });
374
375 dispatcher.dispatch_change(msg.clone()).await.unwrap();
376
377 let received1 = receiver1.recv().await.unwrap();
378 let received2 = receiver2.recv().await.unwrap();
379
380 assert_eq!(*received1, *msg);
381 assert_eq!(*received2, *msg);
382 }
383
384 #[tokio::test]
385 async fn test_broadcast_dispatcher_dispatch_changes() {
386 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
387 let mut receiver = dispatcher.create_receiver().await.unwrap();
388
389 let messages = vec![
390 Arc::new(TestMessage {
391 id: 1,
392 content: "first".to_string(),
393 }),
394 Arc::new(TestMessage {
395 id: 2,
396 content: "second".to_string(),
397 }),
398 Arc::new(TestMessage {
399 id: 3,
400 content: "third".to_string(),
401 }),
402 ];
403
404 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
405
406 for expected in messages {
407 let received = receiver.recv().await.unwrap();
408 assert_eq!(*received, *expected);
409 }
410 }
411
412 #[tokio::test]
413 async fn test_channel_dispatcher_single_receiver() {
414 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
415 let mut receiver = dispatcher.create_receiver().await.unwrap();
416
417 let msg = Arc::new(TestMessage {
418 id: 1,
419 content: "channel".to_string(),
420 });
421
422 dispatcher.dispatch_change(msg.clone()).await.unwrap();
423
424 let received = receiver.recv().await.unwrap();
425 assert_eq!(*received, *msg);
426 }
427
428 #[tokio::test]
429 async fn test_channel_dispatcher_only_one_receiver() {
430 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
431 let _receiver1 = dispatcher.create_receiver().await.unwrap();
432
433 let result = dispatcher.create_receiver().await;
435 assert!(result.is_err());
436 if let Err(e) = result {
437 assert!(e.to_string().contains("Receiver already created"));
438 }
439 }
440
441 #[tokio::test]
442 async fn test_channel_dispatcher_dispatch_changes() {
443 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
444 let mut receiver = dispatcher.create_receiver().await.unwrap();
445
446 let messages = vec![
447 Arc::new(TestMessage {
448 id: 1,
449 content: "first".to_string(),
450 }),
451 Arc::new(TestMessage {
452 id: 2,
453 content: "second".to_string(),
454 }),
455 ];
456
457 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
458
459 for expected in messages {
460 let received = receiver.recv().await.unwrap();
461 assert_eq!(*received, *expected);
462 }
463 }
464
465 #[tokio::test]
466 async fn test_broadcast_receiver_handles_lag() {
467 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(2);
469 let mut receiver = dispatcher.create_receiver().await.unwrap();
470
471 for i in 0..5 {
473 let msg = Arc::new(TestMessage {
474 id: i,
475 content: format!("msg{i}"),
476 });
477 dispatcher.dispatch_change(msg).await.unwrap();
478 }
479
480 sleep(Duration::from_millis(10)).await;
482
483 let result = receiver.recv().await;
485 assert!(result.is_ok());
486 }
487
488 #[tokio::test]
489 async fn test_dispatch_mode_default() {
490 assert_eq!(DispatchMode::default(), DispatchMode::Channel);
491 }
492
493 #[tokio::test]
494 async fn test_dispatch_mode_serialization() {
495 let mode = DispatchMode::Broadcast;
496 let json = serde_json::to_string(&mode).unwrap();
497 assert_eq!(json, "\"broadcast\"");
498
499 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
500 assert_eq!(deserialized, DispatchMode::Broadcast);
501
502 let mode = DispatchMode::Channel;
503 let json = serde_json::to_string(&mode).unwrap();
504 assert_eq!(json, "\"channel\"");
505
506 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
507 assert_eq!(deserialized, DispatchMode::Channel);
508 }
509}