drasi_lib/channels/
dispatcher.rs1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::fmt::Debug;
5use std::sync::Arc;
6use tokio::sync::{broadcast, mpsc};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
158#[serde(rename_all = "lowercase")]
159pub enum DispatchMode {
160 Broadcast,
162 #[default]
164 Channel,
165}
166
167#[async_trait]
169pub trait ChangeDispatcher<T>: Send + Sync
170where
171 T: Clone + Send + Sync + 'static,
172{
173 async fn dispatch_change(&self, change: Arc<T>) -> Result<()>;
175
176 async fn dispatch_changes(&self, changes: Vec<Arc<T>>) -> Result<()> {
178 for change in changes {
179 self.dispatch_change(change).await?;
180 }
181 Ok(())
182 }
183
184 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>>;
186}
187
188#[async_trait]
190pub trait ChangeReceiver<T>: Send + Sync
191where
192 T: Clone + Send + Sync + 'static,
193{
194 async fn recv(&mut self) -> Result<Arc<T>>;
196}
197
198pub struct BroadcastChangeDispatcher<T>
200where
201 T: Clone + Send + Sync + 'static,
202{
203 tx: broadcast::Sender<Arc<T>>,
204 _capacity: usize,
205}
206
207impl<T> BroadcastChangeDispatcher<T>
208where
209 T: Clone + Send + Sync + 'static,
210{
211 pub fn new(capacity: usize) -> Self {
213 let (tx, _) = broadcast::channel(capacity);
214 Self {
215 tx,
216 _capacity: capacity,
217 }
218 }
219}
220
221#[async_trait]
222impl<T> ChangeDispatcher<T> for BroadcastChangeDispatcher<T>
223where
224 T: Clone + Send + Sync + 'static,
225{
226 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
227 let _ = self.tx.send(change);
229 Ok(())
230 }
231
232 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
233 let rx = self.tx.subscribe();
234 Ok(Box::new(BroadcastChangeReceiver { rx }))
235 }
236}
237
238pub struct BroadcastChangeReceiver<T>
240where
241 T: Clone + Send + Sync + 'static,
242{
243 rx: broadcast::Receiver<Arc<T>>,
244}
245
246#[async_trait]
247impl<T> ChangeReceiver<T> for BroadcastChangeReceiver<T>
248where
249 T: Clone + Send + Sync + 'static,
250{
251 async fn recv(&mut self) -> Result<Arc<T>> {
252 loop {
253 match self.rx.recv().await {
254 Ok(change) => return Ok(change),
255 Err(broadcast::error::RecvError::Closed) => {
256 return Err(anyhow::anyhow!("Broadcast channel closed"));
257 }
258 Err(broadcast::error::RecvError::Lagged(n)) => {
259 log::warn!("Broadcast receiver lagged by {n} messages");
260 }
262 }
263 }
264 }
265}
266
267pub struct ChannelChangeDispatcher<T>
269where
270 T: Clone + Send + Sync + 'static,
271{
272 tx: mpsc::Sender<Arc<T>>,
273 rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<Arc<T>>>>>,
274 _capacity: usize,
275}
276
277impl<T> ChannelChangeDispatcher<T>
278where
279 T: Clone + Send + Sync + 'static,
280{
281 pub fn new(capacity: usize) -> Self {
283 let (tx, rx) = mpsc::channel(capacity);
284 Self {
285 tx,
286 rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
287 _capacity: capacity,
288 }
289 }
290}
291
292#[async_trait]
293impl<T> ChangeDispatcher<T> for ChannelChangeDispatcher<T>
294where
295 T: Clone + Send + Sync + 'static,
296{
297 async fn dispatch_change(&self, change: Arc<T>) -> Result<()> {
298 self.tx
299 .send(change)
300 .await
301 .map_err(|_| anyhow::anyhow!("Failed to send on channel"))?;
302 Ok(())
303 }
304
305 async fn create_receiver(&self) -> Result<Box<dyn ChangeReceiver<T>>> {
306 let mut rx_opt = self.rx.lock().await;
309 let rx = rx_opt.take().ok_or_else(|| {
310 anyhow::anyhow!("Receiver already created for this channel dispatcher")
311 })?;
312 Ok(Box::new(ChannelChangeReceiver { rx }))
313 }
314}
315
316pub struct ChannelChangeReceiver<T>
318where
319 T: Clone + Send + Sync + 'static,
320{
321 rx: mpsc::Receiver<Arc<T>>,
322}
323
324#[async_trait]
325impl<T> ChangeReceiver<T> for ChannelChangeReceiver<T>
326where
327 T: Clone + Send + Sync + 'static,
328{
329 async fn recv(&mut self) -> Result<Arc<T>> {
330 self.rx
331 .recv()
332 .await
333 .ok_or_else(|| anyhow::anyhow!("Channel closed"))
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 #[derive(Clone, Debug, PartialEq)]
341 struct TestMessage {
342 id: u32,
343 content: String,
344 }
345
346 #[tokio::test]
347 async fn test_broadcast_dispatcher_single_receiver() {
348 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
349 let mut receiver = dispatcher.create_receiver().await.unwrap();
350
351 let msg = Arc::new(TestMessage {
352 id: 1,
353 content: "test".to_string(),
354 });
355
356 dispatcher.dispatch_change(msg.clone()).await.unwrap();
357
358 let received = receiver.recv().await.unwrap();
359 assert_eq!(*received, *msg);
360 }
361
362 #[tokio::test]
363 async fn test_broadcast_dispatcher_multiple_receivers() {
364 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
365 let mut receiver1 = dispatcher.create_receiver().await.unwrap();
366 let mut receiver2 = dispatcher.create_receiver().await.unwrap();
367
368 let msg = Arc::new(TestMessage {
369 id: 1,
370 content: "broadcast".to_string(),
371 });
372
373 dispatcher.dispatch_change(msg.clone()).await.unwrap();
374
375 let received1 = receiver1.recv().await.unwrap();
376 let received2 = receiver2.recv().await.unwrap();
377
378 assert_eq!(*received1, *msg);
379 assert_eq!(*received2, *msg);
380 }
381
382 #[tokio::test]
383 async fn test_broadcast_dispatcher_dispatch_changes() {
384 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(100);
385 let mut receiver = dispatcher.create_receiver().await.unwrap();
386
387 let messages = vec![
388 Arc::new(TestMessage {
389 id: 1,
390 content: "first".to_string(),
391 }),
392 Arc::new(TestMessage {
393 id: 2,
394 content: "second".to_string(),
395 }),
396 Arc::new(TestMessage {
397 id: 3,
398 content: "third".to_string(),
399 }),
400 ];
401
402 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
403
404 for expected in messages {
405 let received = receiver.recv().await.unwrap();
406 assert_eq!(*received, *expected);
407 }
408 }
409
410 #[tokio::test]
411 async fn test_channel_dispatcher_single_receiver() {
412 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
413 let mut receiver = dispatcher.create_receiver().await.unwrap();
414
415 let msg = Arc::new(TestMessage {
416 id: 1,
417 content: "channel".to_string(),
418 });
419
420 dispatcher.dispatch_change(msg.clone()).await.unwrap();
421
422 let received = receiver.recv().await.unwrap();
423 assert_eq!(*received, *msg);
424 }
425
426 #[tokio::test]
427 async fn test_channel_dispatcher_only_one_receiver() {
428 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
429 let _receiver1 = dispatcher.create_receiver().await.unwrap();
430
431 let result = dispatcher.create_receiver().await;
433 assert!(result.is_err());
434 if let Err(e) = result {
435 assert!(e.to_string().contains("Receiver already created"));
436 }
437 }
438
439 #[tokio::test]
440 async fn test_channel_dispatcher_dispatch_changes() {
441 let dispatcher = ChannelChangeDispatcher::<TestMessage>::new(100);
442 let mut receiver = dispatcher.create_receiver().await.unwrap();
443
444 let messages = vec![
445 Arc::new(TestMessage {
446 id: 1,
447 content: "first".to_string(),
448 }),
449 Arc::new(TestMessage {
450 id: 2,
451 content: "second".to_string(),
452 }),
453 ];
454
455 dispatcher.dispatch_changes(messages.clone()).await.unwrap();
456
457 for expected in messages {
458 let received = receiver.recv().await.unwrap();
459 assert_eq!(*received, *expected);
460 }
461 }
462
463 #[tokio::test]
464 async fn test_broadcast_receiver_handles_lag() {
465 let dispatcher = BroadcastChangeDispatcher::<TestMessage>::new(2);
467 let mut receiver = dispatcher.create_receiver().await.unwrap();
468
469 for i in 0..5 {
471 let msg = Arc::new(TestMessage {
472 id: i,
473 content: format!("msg{i}"),
474 });
475 dispatcher.dispatch_change(msg).await.unwrap();
476 }
477
478 tokio::task::yield_now().await;
480
481 let result = receiver.recv().await;
483 assert!(result.is_ok());
484 }
485
486 #[tokio::test]
487 async fn test_dispatch_mode_default() {
488 assert_eq!(DispatchMode::default(), DispatchMode::Channel);
489 }
490
491 #[tokio::test]
492 async fn test_dispatch_mode_serialization() {
493 let mode = DispatchMode::Broadcast;
494 let json = serde_json::to_string(&mode).unwrap();
495 assert_eq!(json, "\"broadcast\"");
496
497 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
498 assert_eq!(deserialized, DispatchMode::Broadcast);
499
500 let mode = DispatchMode::Channel;
501 let json = serde_json::to_string(&mode).unwrap();
502 assert_eq!(json, "\"channel\"");
503
504 let deserialized: DispatchMode = serde_json::from_str(&json).unwrap();
505 assert_eq!(deserialized, DispatchMode::Channel);
506 }
507}