1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex, RwLock};
5use std::time::Duration;
6
7use tokio::sync::mpsc;
8use tokio::time;
9
10use super::message::Message;
11
12const DEFAULT_CAPACITY: usize = 256;
14
15const DEFAULT_TIMEOUT_MS: i64 = 30_000;
18
19#[derive(Debug)]
21pub enum PublishError {
22 Closed(String),
24 Full(String),
26 Timeout(String),
28}
29
30impl std::fmt::Display for PublishError {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 PublishError::Closed(msg) => write!(f, "channel closed: {msg}"),
34 PublishError::Full(msg) => write!(f, "channel full: {msg}"),
35 PublishError::Timeout(msg) => write!(f, "publish timeout: {msg}"),
36 }
37 }
38}
39
40impl std::error::Error for PublishError {}
41
42#[derive(Debug)]
44pub enum ConsumeError {
45 Timeout(String),
47 Closed(String),
49}
50
51impl std::fmt::Display for ConsumeError {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 match self {
54 ConsumeError::Closed(msg) => write!(f, "channel closed: {msg}"),
55 ConsumeError::Timeout(msg) => write!(f, "consume timeout: {msg}"),
56 }
57 }
58}
59
60impl std::error::Error for ConsumeError {}
61
62#[derive(Debug)]
64#[allow(dead_code)]
65pub enum ReceiptError {
66 Failed(String),
68}
69
70impl std::fmt::Display for ReceiptError {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 ReceiptError::Failed(msg) => write!(f, "receipt error: {msg}"),
74 }
75 }
76}
77
78impl std::error::Error for ReceiptError {}
79
80pub trait ConsumeReceipt: Send {
84 fn ack(self) -> impl Future<Output = Result<(), ReceiptError>> + Send;
85 fn nack(self) -> impl Future<Output = Result<(), ReceiptError>> + Send;
86}
87
88#[allow(dead_code)]
92pub trait PublishReceipt: Send {
93 fn confirm(self) -> impl Future<Output = Result<(), PublishError>> + Send;
94}
95
96impl ConsumeReceipt for () {
97 async fn ack(self) -> Result<(), ReceiptError> {
98 Ok(())
99 }
100 async fn nack(self) -> Result<(), ReceiptError> {
101 Ok(())
102 }
103}
104
105impl PublishReceipt for () {
106 async fn confirm(self) -> Result<(), PublishError> {
107 Ok(())
108 }
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum Overflow {
114 Block,
116 #[allow(dead_code)]
118 DropOldest,
119}
120
121pub trait Channel: Send + Sync {
127 type ConsumeReceipt: ConsumeReceipt;
128 type PublishReceipt: PublishReceipt;
129
130 fn publish(
131 &self,
132 msg: Message,
133 ) -> impl Future<Output = Result<Self::PublishReceipt, PublishError>> + Send;
134
135 fn consume(
136 &self,
137 group: &str,
138 ) -> impl Future<Output = Result<(Message, Self::ConsumeReceipt), ConsumeError>> + Send;
139}
140
141pub struct LocalChannel {
153 capacity: usize,
154 publish_timeout_ms: i64,
155 consume_timeout_ms: i64,
156 inner: ChannelInner,
157}
158
159enum ChannelInner {
160 Block {
161 senders: Mutex<HashMap<String, mpsc::Sender<Message>>>,
163 receivers: Mutex<HashMap<String, Arc<tokio::sync::Mutex<mpsc::Receiver<Message>>>>>,
165 },
166 DropOldest {
167 sender: tokio::sync::broadcast::Sender<Message>,
169 receivers: Mutex<
171 HashMap<String, Arc<tokio::sync::Mutex<tokio::sync::broadcast::Receiver<Message>>>>,
172 >,
173 },
174}
175
176impl LocalChannel {
177 pub fn new(
179 capacity: usize,
180 overflow: Overflow,
181 publish_timeout_ms: i64,
182 consume_timeout_ms: i64,
183 ) -> Self {
184 let inner = match overflow {
185 Overflow::Block => ChannelInner::Block {
186 senders: Mutex::new(HashMap::new()),
187 receivers: Mutex::new(HashMap::new()),
188 },
189 Overflow::DropOldest => {
190 let (sender, _) = tokio::sync::broadcast::channel(capacity);
191 ChannelInner::DropOldest {
192 sender,
193 receivers: Mutex::new(HashMap::new()),
194 }
195 }
196 };
197
198 Self {
199 capacity,
200 publish_timeout_ms,
201 consume_timeout_ms,
202 inner,
203 }
204 }
205
206 pub(crate) fn init_group(&self, group: &str) {
209 match &self.inner {
210 ChannelInner::Block { senders, receivers } => {
211 let mut receiver_map = receivers.lock().unwrap();
212 if receiver_map.contains_key(group) {
213 return;
214 }
215 let (tx, rx) = mpsc::channel(self.capacity);
216 let arc = Arc::new(tokio::sync::Mutex::new(rx));
217 receiver_map.insert(group.to_string(), arc);
218 drop(receiver_map);
219 let mut sender_map = senders.lock().unwrap();
220 sender_map.insert(group.to_string(), tx);
221 }
222 ChannelInner::DropOldest { sender, receivers } => {
223 let mut map = receivers.lock().unwrap();
224 if map.contains_key(group) {
225 return;
226 }
227 let arc = Arc::new(tokio::sync::Mutex::new(sender.subscribe()));
228 map.insert(group.to_string(), arc);
229 }
230 }
231 }
232
233 pub(crate) fn close(&self) {
237 if let ChannelInner::Block { senders, .. } = &self.inner {
238 senders.lock().unwrap().clear();
239 }
240 }
241
242 pub fn with_defaults() -> Self {
244 Self::new(
245 DEFAULT_CAPACITY,
246 Overflow::Block,
247 DEFAULT_TIMEOUT_MS,
248 DEFAULT_TIMEOUT_MS,
249 )
250 }
251}
252
253impl Channel for LocalChannel {
254 type ConsumeReceipt = ();
255 type PublishReceipt = ();
256
257 async fn publish(&self, msg: Message) -> Result<(), PublishError> {
258 match &self.inner {
259 ChannelInner::Block { senders, .. } => {
260 let group_senders: Vec<mpsc::Sender<Message>> = {
261 let map = senders.lock().unwrap();
262 map.values().cloned().collect()
263 };
264
265 if group_senders.is_empty() {
266 return Err(PublishError::Closed("no active receivers".to_string()));
267 }
268
269 for sender in &group_senders {
270 match self.publish_timeout_ms {
271 t if t < 0 => {
272 sender.send(msg.clone()).await.map_err(|_| {
273 PublishError::Closed("no active receivers".to_string())
274 })?;
275 }
276 0 => {
277 sender.try_send(msg.clone()).map_err(|e| match e {
278 mpsc::error::TrySendError::Full(_) => {
279 PublishError::Full("channel at capacity".to_string())
280 }
281 mpsc::error::TrySendError::Closed(_) => {
282 PublishError::Closed("no active receivers".to_string())
283 }
284 })?;
285 }
286 t => {
287 let timeout_dur = Duration::from_millis(t as u64);
288 match time::timeout(timeout_dur, sender.send(msg.clone())).await {
289 Ok(Ok(())) => {}
290 Ok(Err(_)) => {
291 return Err(PublishError::Closed(
292 "no active receivers".to_string(),
293 ));
294 }
295 Err(_) => {
296 return Err(PublishError::Timeout(format!(
297 "no space after {t}ms"
298 )));
299 }
300 }
301 }
302 }
303 }
304
305 Ok(())
306 }
307 ChannelInner::DropOldest { sender, .. } => {
308 sender
309 .send(msg)
310 .map_err(|_| PublishError::Closed("no active receivers".to_string()))?;
311 Ok(())
312 }
313 }
314 }
315
316 async fn consume(&self, group: &str) -> Result<(Message, ()), ConsumeError> {
317 match &self.inner {
318 ChannelInner::Block { senders, receivers } => {
319 let receiver_mutex = {
320 let mut receiver_map = receivers.lock().unwrap();
321
322 if let Some(r) = receiver_map.get(group) {
323 Arc::clone(r)
324 } else {
325 let (tx, rx) = mpsc::channel(self.capacity);
326 let arc = Arc::new(tokio::sync::Mutex::new(rx));
327 receiver_map.insert(group.to_string(), Arc::clone(&arc));
328 drop(receiver_map);
329 let mut sender_map = senders.lock().unwrap();
330 sender_map.insert(group.to_string(), tx);
331 arc
332 }
333 };
334
335 let mut receiver = receiver_mutex.lock().await;
336 let msg = match self.consume_timeout_ms {
337 t if t < 0 => receiver
338 .recv()
339 .await
340 .ok_or_else(|| ConsumeError::Closed("no active senders".to_string())),
341 0 => receiver.try_recv().map_err(|e| match e {
342 mpsc::error::TryRecvError::Empty => {
343 ConsumeError::Timeout("no message available".to_string())
344 }
345 mpsc::error::TryRecvError::Disconnected => {
346 ConsumeError::Closed("no active senders".to_string())
347 }
348 }),
349 t => {
350 let timeout_dur = Duration::from_millis(t as u64);
351 match time::timeout(timeout_dur, receiver.recv()).await {
352 Ok(Some(msg)) => Ok(msg),
353 Ok(None) => Err(ConsumeError::Closed("no active senders".to_string())),
354 Err(_) => Err(ConsumeError::Timeout(format!("no message after {t}ms"))),
355 }
356 }
357 }?;
358 Ok((msg, ()))
359 }
360 ChannelInner::DropOldest { sender, receivers } => {
361 let receiver_mutex = {
362 let mut map = receivers.lock().unwrap();
363
364 if let Some(r) = map.get(group) {
365 Arc::clone(r)
366 } else {
367 let arc = Arc::new(tokio::sync::Mutex::new(sender.subscribe()));
368 map.insert(group.to_string(), Arc::clone(&arc));
369 arc
370 }
371 };
372
373 let mut receiver = receiver_mutex.lock().await;
374 let msg = match self.consume_timeout_ms {
375 0 => match receiver.try_recv() {
376 Ok(msg) => Ok(msg),
377 Err(tokio::sync::broadcast::error::TryRecvError::Empty) => {
378 Err(ConsumeError::Timeout("no message available".to_string()))
379 }
380 Err(tokio::sync::broadcast::error::TryRecvError::Lagged(n)) => {
381 tracing::warn!(
382 group,
383 skipped = n,
384 "consumer group lagged, skipped {n} messages"
385 );
386 Err(ConsumeError::Timeout("no message available".to_string()))
387 }
388 Err(tokio::sync::broadcast::error::TryRecvError::Closed) => {
389 Err(ConsumeError::Closed("no active senders".to_string()))
390 }
391 },
392 t => {
393 let recv_fut = async {
394 loop {
395 match receiver.recv().await {
396 Ok(msg) => return Ok(msg),
397 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
398 tracing::warn!(
399 group,
400 skipped = n,
401 "consumer group lagged, skipped {n} messages"
402 );
403 continue;
404 }
405 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
406 return Err(ConsumeError::Closed(
407 "no active senders".to_string(),
408 ));
409 }
410 }
411 }
412 };
413
414 if t < 0 {
415 recv_fut.await
416 } else {
417 let timeout_dur = Duration::from_millis(t as u64);
418 match time::timeout(timeout_dur, recv_fut).await {
419 Ok(result) => result,
420 Err(_) => {
421 Err(ConsumeError::Timeout(format!("no message after {t}ms")))
422 }
423 }
424 }
425 }
426 }?;
427 Ok((msg, ()))
428 }
429 }
430 }
431}
432
433pub(crate) trait ReplyPublisher: Send + Sync {
435 fn publish<'a>(
436 &'a self,
437 channel: &'a str,
438 msg: Message,
439 ) -> Pin<Box<dyn Future<Output = Result<(), PublishError>> + Send + 'a>>;
440}
441
442impl<C: Channel + 'static> ReplyPublisher for ChannelRegistry<C> {
443 fn publish<'a>(
444 &'a self,
445 channel: &'a str,
446 msg: Message,
447 ) -> Pin<Box<dyn Future<Output = Result<(), PublishError>> + Send + 'a>> {
448 Box::pin(async move {
449 let ch = self
450 .lookup(channel)
451 .ok_or_else(|| PublishError::Closed(format!("channel '{channel}' not found")))?;
452 ch.publish(msg).await?;
453 Ok(())
454 })
455 }
456}
457
458pub struct ChannelRegistry<C: Channel> {
460 channels: RwLock<HashMap<String, Arc<C>>>,
461}
462
463impl<C: Channel> ChannelRegistry<C> {
464 pub fn new() -> Self {
465 Self {
466 channels: RwLock::new(HashMap::new()),
467 }
468 }
469
470 pub fn register(&self, name: impl Into<String>, channel: Arc<C>) {
472 self.channels
473 .write()
474 .expect("channel registry lock poisoned")
475 .insert(name.into(), channel);
476 }
477
478 #[allow(dead_code)]
480 pub fn remove(&self, name: &str) -> Option<Arc<C>> {
481 self.channels
482 .write()
483 .expect("channel registry lock poisoned")
484 .remove(name)
485 }
486
487 pub fn list(&self) -> Vec<Arc<C>> {
488 self.channels
489 .read()
490 .expect("channel registry lock poisoned")
491 .values()
492 .cloned()
493 .collect()
494 }
495
496 pub fn lookup(&self, name: &str) -> Option<Arc<C>> {
498 self.channels
499 .read()
500 .expect("channel registry lock poisoned")
501 .get(name)
502 .cloned()
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use std::sync::Arc;
509
510 use super::*;
511 use crate::messaging::message::MessageBuilder;
512
513 #[tokio::test]
514 async fn publish_no_consumers_returns_error() {
515 let channel = LocalChannel::with_defaults();
516
517 let msg = MessageBuilder::new(b"hello".to_vec()).build();
518 let result = channel.publish(msg).await;
519 assert!(matches!(result, Err(PublishError::Closed(_))));
520 }
521
522 #[tokio::test]
523 async fn publish_no_consumers_drop_oldest_returns_error() {
524 let channel = LocalChannel::new(256, Overflow::DropOldest, -1, -1);
525
526 let msg = MessageBuilder::new(b"hello".to_vec()).build();
527 let result = channel.publish(msg).await;
528 assert!(matches!(result, Err(PublishError::Closed(_))));
529 }
530
531 #[tokio::test]
532 async fn consume_then_publish() {
533 let channel = Arc::new(LocalChannel::with_defaults());
534
535 let consumer = {
537 let ch = channel.clone();
538 tokio::spawn(async move { ch.consume("test").await })
539 };
540
541 tokio::task::yield_now().await;
542
543 let msg = MessageBuilder::new(b"hello".to_vec()).build();
544 channel.publish(msg).await.unwrap();
545
546 let (received, _) = consumer.await.unwrap().unwrap();
547 assert_eq!(received.body(), b"hello");
548 }
549
550 #[tokio::test]
551 async fn competing_consumers_same_group() {
552 let channel = Arc::new(LocalChannel::with_defaults());
553
554 let c1 = {
556 let ch = channel.clone();
557 tokio::spawn(async move { ch.consume("test").await })
558 };
559 let c2 = {
560 let ch = channel.clone();
561 tokio::spawn(async move { ch.consume("test").await })
562 };
563
564 tokio::task::yield_now().await;
565
566 let msg1 = MessageBuilder::new(b"msg1".to_vec()).build();
568 let msg2 = MessageBuilder::new(b"msg2".to_vec()).build();
569 channel.publish(msg1).await.unwrap();
570 channel.publish(msg2).await.unwrap();
571
572 let (r1, _) = c1.await.unwrap().unwrap();
573 let (r2, _) = c2.await.unwrap().unwrap();
574
575 assert_ne!(r1.body(), r2.body());
577 }
578
579 #[tokio::test]
580 async fn independent_groups_each_get_message() {
581 let channel = Arc::new(LocalChannel::with_defaults());
582
583 let c1 = {
585 let ch = channel.clone();
586 tokio::spawn(async move { ch.consume("group-a").await })
587 };
588 let c2 = {
589 let ch = channel.clone();
590 tokio::spawn(async move { ch.consume("group-b").await })
591 };
592
593 tokio::task::yield_now().await;
594
595 let msg = MessageBuilder::new(b"broadcast".to_vec()).build();
596 channel.publish(msg).await.unwrap();
597
598 let (r1, _) = c1.await.unwrap().unwrap();
599 let (r2, _) = c2.await.unwrap().unwrap();
600
601 assert_eq!(r1.body(), b"broadcast");
603 assert_eq!(r2.body(), b"broadcast");
604 }
605
606 #[tokio::test]
607 async fn publish_timeout_zero_returns_full() {
608 let channel = Arc::new(LocalChannel::new(1, Overflow::Block, 0, -1));
609
610 let consumer = {
612 let ch = channel.clone();
613 tokio::spawn(async move { ch.consume("test").await })
614 };
615
616 tokio::task::yield_now().await;
617
618 let msg1 = MessageBuilder::new(b"first".to_vec()).build();
620 channel.publish(msg1).await.unwrap();
621
622 let msg2 = MessageBuilder::new(b"second".to_vec()).build();
624 let result = channel.publish(msg2).await;
625 assert!(matches!(result, Err(PublishError::Full(_))));
626
627 let _ = consumer.await;
629 }
630
631 #[tokio::test]
632 async fn publish_timeout_expires() {
633 let channel = Arc::new(LocalChannel::new(1, Overflow::Block, 50, -1));
634
635 let c1 = {
637 let ch = channel.clone();
638 tokio::spawn(async move { ch.consume("test").await })
639 };
640
641 tokio::task::yield_now().await;
642
643 let msg1 = MessageBuilder::new(b"first".to_vec()).build();
645 channel.publish(msg1).await.unwrap();
646
647 let _ = c1.await;
649
650 let msg2 = MessageBuilder::new(b"second".to_vec()).build();
652 channel.publish(msg2).await.unwrap();
653
654 let msg3 = MessageBuilder::new(b"third".to_vec()).build();
656 let result = channel.publish(msg3).await;
657 assert!(matches!(result, Err(PublishError::Timeout(_))));
658
659 let _ = channel.consume("test").await;
661 }
662
663 #[tokio::test]
664 async fn consume_timeout_expires() {
665 let channel = LocalChannel::new(256, Overflow::Block, -1, 50);
666
667 let result = channel.consume("test").await;
669 assert!(matches!(result, Err(ConsumeError::Timeout(_))));
670 }
671
672 #[tokio::test]
673 async fn drop_oldest_independent_groups() {
674 let channel = Arc::new(LocalChannel::new(256, Overflow::DropOldest, -1, -1));
675
676 let c1 = {
678 let ch = channel.clone();
679 tokio::spawn(async move { ch.consume("group-a").await })
680 };
681 let c2 = {
682 let ch = channel.clone();
683 tokio::spawn(async move { ch.consume("group-b").await })
684 };
685
686 tokio::task::yield_now().await;
687
688 let msg = MessageBuilder::new(b"fanout".to_vec()).build();
689 channel.publish(msg).await.unwrap();
690
691 let (r1, _) = c1.await.unwrap().unwrap();
692 let (r2, _) = c2.await.unwrap().unwrap();
693
694 assert_eq!(r1.body(), b"fanout");
695 assert_eq!(r2.body(), b"fanout");
696 }
697
698 #[tokio::test]
699 async fn new_group_only_sees_messages_after_first_consume() {
700 let channel = Arc::new(LocalChannel::with_defaults());
701
702 let c1 = {
704 let ch = channel.clone();
705 tokio::spawn(async move { ch.consume("group-a").await })
706 };
707
708 tokio::task::yield_now().await;
709
710 let msg1 = MessageBuilder::new(b"before".to_vec()).build();
711 channel.publish(msg1).await.unwrap();
712
713 let (r1, _) = c1.await.unwrap().unwrap();
714 assert_eq!(r1.body(), b"before");
715
716 let c2 = {
718 let ch = channel.clone();
719 tokio::spawn(async move { ch.consume("group-b").await })
720 };
721
722 tokio::task::yield_now().await;
723
724 let c1_again = {
726 let ch = channel.clone();
727 tokio::spawn(async move { ch.consume("group-a").await })
728 };
729
730 let msg2 = MessageBuilder::new(b"after".to_vec()).build();
731 channel.publish(msg2).await.unwrap();
732
733 let (r2, _) = c2.await.unwrap().unwrap();
734 let (r1_again, _) = c1_again.await.unwrap().unwrap();
735 assert_eq!(r2.body(), b"after");
736 assert_eq!(r1_again.body(), b"after");
737 }
738}