1use std::collections::{HashMap, HashSet};
21use std::future::Future;
22use std::sync::{Arc, Mutex};
23
24use serde::Serialize;
25use thiserror::Error;
26use tokio::sync::broadcast;
27
28#[cfg(feature = "redis")]
29const REDIS_PUBLISH_QUEUE_CAPACITY: usize = 1024;
30
31#[derive(Clone)]
33pub struct Channels {
34 backend: Arc<dyn ChannelsBackend>,
35}
36
37pub trait ChannelsBackend: Send + Sync + 'static {
39 fn publish(&self, topic: &str, msg: ChannelMessage) -> Result<usize, ChannelPublishError>;
46
47 fn ensure_topic(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>>;
49
50 fn subscribe(&self, topic: &str) -> Subscriber;
52
53 fn channel_count(&self) -> usize;
55
56 fn gc(&self);
58
59 fn snapshot(&self) -> HashMap<String, ChannelStats>;
61}
62
63#[derive(Clone)]
65pub struct LocalChannelsBackend {
66 inner: Arc<LocalChannelsInner>,
67}
68
69struct LocalChannelsInner {
70 capacity: usize,
71 registry: Mutex<HashMap<String, Arc<broadcast::Sender<ChannelMessage>>>>,
72 metrics: Arc<ChannelMetrics>,
73}
74
75#[derive(Clone, Debug, PartialEq, Eq)]
77pub struct ChannelMessage(pub String);
78
79impl From<String> for ChannelMessage {
80 fn from(s: String) -> Self {
81 Self(s)
82 }
83}
84
85impl From<&str> for ChannelMessage {
86 fn from(s: &str) -> Self {
87 Self(s.to_owned())
88 }
89}
90
91impl ChannelMessage {
92 #[must_use]
94 pub fn as_str(&self) -> &str {
95 &self.0
96 }
97
98 #[must_use]
100 pub fn into_string(self) -> String {
101 self.0
102 }
103}
104
105impl std::fmt::Display for ChannelMessage {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.write_str(&self.0)
108 }
109}
110
111#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)]
113pub struct ChannelStats {
114 pub subscriber_count: usize,
116 pub lifetime_publish_count: u64,
118 pub dropped_count: u64,
120 pub lagged_count: u64,
122}
123
124#[derive(Default)]
125struct ChannelMetrics {
126 counters: Mutex<HashMap<String, ChannelMetricCounters>>,
127}
128
129#[derive(Clone, Default)]
130struct ChannelMetricCounters {
131 publishes: u64,
132 drops: u64,
133 lags: u64,
134}
135
136impl ChannelMetrics {
137 fn ensure_topic(&self, topic: &str) {
138 let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
139 counters.entry(topic.to_owned()).or_default();
140 }
141
142 fn record_publish(&self, topic: &str) {
143 let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
144 let stats = counters.entry(topic.to_owned()).or_default();
145 stats.publishes = stats.publishes.saturating_add(1);
146 drop(counters);
147 }
148
149 fn record_dropped(&self, topic: &str, count: u64) {
150 let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
151 let stats = counters.entry(topic.to_owned()).or_default();
152 stats.drops = stats.drops.saturating_add(count);
153 drop(counters);
154 }
155
156 fn record_lagged(&self, topic: &str, count: u64) {
157 let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
158 let stats = counters.entry(topic.to_owned()).or_default();
159 stats.lags = stats.lags.saturating_add(count);
160 drop(counters);
161 }
162
163 fn snapshot(&self) -> HashMap<String, ChannelMetricCounters> {
164 self.counters
165 .lock()
166 .expect("channel metrics lock poisoned")
167 .clone()
168 }
169
170 fn remove_topics(&self, topics: &HashSet<String>) {
171 if topics.is_empty() {
172 return;
173 }
174
175 let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
176 counters.retain(|topic, _| !topics.contains(topic));
177 drop(counters);
178 }
179}
180
181#[derive(Debug, Clone, Error, PartialEq, Eq)]
183pub enum ChannelPublishError {
184 #[error("channel backend is closed")]
186 BackendClosed,
187 #[error("channel backend publish queue is full")]
189 QueueFull,
190}
191
192#[derive(Debug, Error)]
194pub enum BroadcastError {
195 #[error("broadcast payload is not valid UTF-8: {0}")]
198 InvalidUtf8(#[from] std::string::FromUtf8Error),
199
200 #[error(transparent)]
202 Publish(#[from] ChannelPublishError),
203}
204
205pub enum BroadcastPayload {
207 Text(String),
209 Bytes(Vec<u8>),
211}
212
213impl From<&str> for BroadcastPayload {
214 fn from(value: &str) -> Self {
215 Self::Text(value.to_owned())
216 }
217}
218
219impl From<String> for BroadcastPayload {
220 fn from(value: String) -> Self {
221 Self::Text(value)
222 }
223}
224
225impl From<&String> for BroadcastPayload {
226 fn from(value: &String) -> Self {
227 Self::Text(value.clone())
228 }
229}
230
231impl From<Vec<u8>> for BroadcastPayload {
232 fn from(value: Vec<u8>) -> Self {
233 Self::Bytes(value)
234 }
235}
236
237impl From<&[u8]> for BroadcastPayload {
238 fn from(value: &[u8]) -> Self {
239 Self::Bytes(value.to_vec())
240 }
241}
242
243impl<const N: usize> From<&[u8; N]> for BroadcastPayload {
244 fn from(value: &[u8; N]) -> Self {
245 Self::Bytes(value.to_vec())
246 }
247}
248
249#[derive(Clone)]
251pub struct Broadcast {
252 channels: Channels,
253}
254
255impl Broadcast {
256 #[must_use]
258 pub const fn new(channels: Channels) -> Self {
259 Self { channels }
260 }
261
262 pub fn publish(
279 &self,
280 topic: &str,
281 payload: impl Into<BroadcastPayload>,
282 ) -> Result<usize, BroadcastError> {
283 let message = match payload.into() {
284 BroadcastPayload::Text(text) => ChannelMessage(text),
285 BroadcastPayload::Bytes(bytes) => ChannelMessage(String::from_utf8(bytes)?),
286 };
287 Ok(self.channels.publish(topic, message)?)
288 }
289
290 #[cfg(feature = "maud")]
308 pub fn publish_html(
309 &self,
310 topic: &str,
311 fragment: &maud::Markup,
312 ) -> Result<usize, BroadcastError> {
313 self.publish(topic, htmx_oob_envelope(fragment))
314 }
315}
316
317#[cfg(feature = "maud")]
318fn htmx_oob_envelope(fragment: &maud::Markup) -> String {
319 maud::html! {
320 template hx-swap-oob="true" {
321 (fragment)
322 }
323 }
324 .into_string()
325}
326
327#[derive(Clone)]
329pub struct Sender {
330 topic: String,
331 backend: Arc<dyn ChannelsBackend>,
332 keepalive: Arc<broadcast::Sender<ChannelMessage>>,
333}
334
335impl Sender {
336 pub fn send(&self, msg: impl Into<ChannelMessage>) -> Result<usize, ChannelPublishError> {
345 self.backend.publish(&self.topic, msg.into())
346 }
347
348 #[must_use]
350 pub fn receiver_count(&self) -> usize {
351 self.keepalive.receiver_count()
352 }
353}
354
355pub struct Subscriber {
357 topic: String,
358 inner: broadcast::Receiver<ChannelMessage>,
359 metrics: Arc<ChannelMetrics>,
360}
361
362impl Subscriber {
363 pub async fn recv(&mut self) -> Result<ChannelMessage, broadcast::error::RecvError> {
370 match self.inner.recv().await {
371 Err(broadcast::error::RecvError::Lagged(count)) => {
372 self.metrics.record_lagged(&self.topic, count);
373 Err(broadcast::error::RecvError::Lagged(count))
374 }
375 result => result,
376 }
377 }
378
379 pub fn try_recv(&mut self) -> Result<ChannelMessage, broadcast::error::TryRecvError> {
385 match self.inner.try_recv() {
386 Err(broadcast::error::TryRecvError::Lagged(count)) => {
387 self.metrics.record_lagged(&self.topic, count);
388 Err(broadcast::error::TryRecvError::Lagged(count))
389 }
390 result => result,
391 }
392 }
393
394 #[cfg(feature = "ws")]
396 pub fn into_stream(self) -> impl tokio_stream::Stream<Item = ChannelMessage> {
397 use tokio_stream::StreamExt;
398 let topic = self.topic;
399 let metrics = self.metrics;
400 tokio_stream::wrappers::BroadcastStream::new(self.inner).filter_map(move |result| {
401 if let Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(count)) =
402 &result
403 {
404 metrics.record_lagged(&topic, *count);
405 }
406 result.ok()
407 })
408 }
409}
410
411impl LocalChannelsBackend {
412 #[must_use]
414 pub fn new(capacity: usize) -> Self {
415 Self {
416 inner: Arc::new(LocalChannelsInner {
417 capacity: capacity.clamp(1, 16_384),
418 registry: Mutex::new(HashMap::new()),
419 metrics: Arc::new(ChannelMetrics::default()),
420 }),
421 }
422 }
423
424 fn get_or_create_sender(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>> {
425 let mut registry = self.inner.registry.lock().expect("channels lock poisoned");
426
427 #[allow(clippy::option_if_let_else)]
428 if let Some(tx) = registry.get(topic) {
429 Arc::clone(tx)
430 } else {
431 let tx = Arc::new(broadcast::channel(self.inner.capacity).0);
432 registry.insert(topic.to_owned(), Arc::clone(&tx));
433 tx
434 }
435 }
436
437 fn publish_local(&self, topic: &str, msg: ChannelMessage) -> usize {
438 let count = self.send_without_publish_metric(topic, msg);
439 if count > 0 {
440 self.inner.metrics.record_publish(topic);
441 }
442 count
443 }
444
445 fn send_without_publish_metric(&self, topic: &str, msg: ChannelMessage) -> usize {
446 let tx = self.get_or_create_sender(topic);
447 match tx.send(msg) {
448 Ok(count) => count,
449 Err(_error) => {
450 self.inner.metrics.record_dropped(topic, 1);
451 0
452 }
453 }
454 }
455}
456
457impl ChannelsBackend for LocalChannelsBackend {
458 fn publish(&self, topic: &str, msg: ChannelMessage) -> Result<usize, ChannelPublishError> {
459 Ok(self.publish_local(topic, msg))
460 }
461
462 fn ensure_topic(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>> {
463 self.inner.metrics.ensure_topic(topic);
464 self.get_or_create_sender(topic)
465 }
466
467 fn subscribe(&self, topic: &str) -> Subscriber {
468 let tx = self.ensure_topic(topic);
469 Subscriber {
470 topic: topic.to_owned(),
471 inner: tx.subscribe(),
472 metrics: Arc::clone(&self.inner.metrics),
473 }
474 }
475
476 fn channel_count(&self) -> usize {
477 let registry = self.inner.registry.lock().expect("channels lock poisoned");
478 registry.len()
479 }
480
481 fn gc(&self) {
482 let mut registry = self.inner.registry.lock().expect("channels lock poisoned");
483 let mut removed_topics = HashSet::new();
484 registry.retain(|topic, tx| {
485 let keep = tx.receiver_count() > 0 || Arc::strong_count(tx) > 1;
486 if !keep {
487 removed_topics.insert(topic.clone());
488 }
489 keep
490 });
491 drop(registry);
492
493 self.inner.metrics.remove_topics(&removed_topics);
494 }
495
496 fn snapshot(&self) -> HashMap<String, ChannelStats> {
497 let subscriber_counts: HashMap<String, usize> = {
501 let registry = self.inner.registry.lock().expect("channels lock poisoned");
502 registry
503 .iter()
504 .map(|(topic, sender)| (topic.clone(), sender.receiver_count()))
505 .collect()
506 };
507 let metric_counters = self.inner.metrics.snapshot();
508
509 let mut topics: HashSet<String> = metric_counters.keys().cloned().collect();
510 topics.extend(subscriber_counts.keys().cloned());
511
512 topics
513 .into_iter()
514 .map(|topic| {
515 let subscriber_count = subscriber_counts.get(&topic).copied().unwrap_or(0);
516 let counters = metric_counters.get(&topic).cloned().unwrap_or_default();
517 (
518 topic,
519 ChannelStats {
520 subscriber_count,
521 lifetime_publish_count: counters.publishes,
522 dropped_count: counters.drops,
523 lagged_count: counters.lags,
524 },
525 )
526 })
527 .collect()
528 }
529}
530
531#[cfg(feature = "redis")]
532#[derive(Clone)]
533struct RedisChannelsBackend {
534 local: LocalChannelsBackend,
535 publisher: tokio::sync::mpsc::Sender<RedisPublishCommand>,
536 origin_id: String,
537 key_prefix: String,
538}
539
540#[cfg(feature = "redis")]
541struct RedisPublishCommand {
542 redis_channel: String,
543 envelope: RedisEnvelope,
544}
545
546#[cfg(feature = "redis")]
547#[derive(serde::Deserialize, serde::Serialize)]
548struct RedisEnvelope {
549 origin: String,
550 topic: String,
551 payload: String,
552}
553
554#[derive(Debug, Error)]
556pub enum ChannelBackendConfigError {
557 #[error("channels.redis.url is required when channels.backend = \"redis\"")]
559 MissingRedisUrl,
560 #[error("invalid channels.redis.url: {0}")]
562 InvalidRedisUrl(String),
563 #[error("channels.backend = \"redis\" requires the redis cargo feature")]
565 RedisFeatureDisabled,
566}
567
568#[cfg(feature = "redis")]
569impl RedisChannelsBackend {
570 fn from_config(
571 config: &crate::config::ChannelConfig,
572 shutdown: tokio_util::sync::CancellationToken,
573 ) -> Result<Self, ChannelBackendConfigError> {
574 let url = config
575 .redis
576 .url
577 .clone()
578 .filter(|url| !url.trim().is_empty())
579 .ok_or(ChannelBackendConfigError::MissingRedisUrl)?;
580 let client = redis::Client::open(url)
581 .map_err(|error| ChannelBackendConfigError::InvalidRedisUrl(error.to_string()))?;
582 let local = LocalChannelsBackend::new(config.capacity);
583 let (publisher, receiver) = tokio::sync::mpsc::channel(REDIS_PUBLISH_QUEUE_CAPACITY);
584 let origin_id = uuid::Uuid::new_v4().to_string();
585 let backend = Self {
586 local: local.clone(),
587 publisher,
588 origin_id: origin_id.clone(),
589 key_prefix: config.redis.key_prefix.clone(),
590 };
591 spawn_redis_publisher(client.clone(), receiver, shutdown.clone());
592 spawn_redis_listener(
593 client,
594 local,
595 origin_id,
596 config.redis.key_prefix.clone(),
597 shutdown,
598 );
599 Ok(backend)
600 }
601
602 fn redis_channel(&self, topic: &str) -> String {
603 redis_channel_name(&self.key_prefix, topic)
604 }
605}
606
607#[cfg(feature = "redis")]
608fn redis_channel_name(prefix: &str, topic: &str) -> String {
609 format!("{prefix}:{topic}")
610}
611
612#[cfg(feature = "redis")]
613fn redis_channel_topic<'a>(channel_prefix: &str, channel: &'a str) -> Option<&'a str> {
614 channel.strip_prefix(channel_prefix)
615}
616
617#[cfg(feature = "redis")]
618fn redis_channel_pattern(prefix: &str) -> String {
619 format!("{prefix}:*")
620}
621
622#[cfg(feature = "redis")]
623fn spawn_redis_publisher(
624 client: redis::Client,
625 mut receiver: tokio::sync::mpsc::Receiver<RedisPublishCommand>,
626 shutdown: tokio_util::sync::CancellationToken,
627) {
628 tokio::spawn(async move {
629 use redis::AsyncCommands as _;
630 use redis::aio::{ConnectionManager, ConnectionManagerConfig};
631
632 let mut connection =
633 match ConnectionManager::new_lazy_with_config(client, ConnectionManagerConfig::new()) {
634 Ok(connection) => connection,
635 Err(error) => {
636 tracing::warn!(error = %error, "failed to create Redis channels publisher");
637 return;
638 }
639 };
640
641 loop {
642 tokio::select! {
643 () = shutdown.cancelled() => break,
644 Some(command) = receiver.recv() => {
645 let Ok(payload) = serde_json::to_string(&command.envelope) else {
646 tracing::warn!("failed to serialize Redis channel envelope");
647 continue;
648 };
649 if let Err(error) = connection
650 .publish::<_, _, usize>(&command.redis_channel, payload)
651 .await
652 {
653 tracing::warn!(error = %error, channel = %command.redis_channel, "Redis channel publish failed");
654 }
655 }
656 else => break,
657 }
658 }
659 });
660}
661
662#[cfg(feature = "redis")]
663fn spawn_redis_listener(
664 client: redis::Client,
665 local: LocalChannelsBackend,
666 origin_id: String,
667 key_prefix: String,
668 shutdown: tokio_util::sync::CancellationToken,
669) {
670 tokio::spawn(async move {
671 use futures::StreamExt as _;
672
673 let channel_prefix = redis_channel_name(&key_prefix, "");
674 let pattern = redis_channel_pattern(&key_prefix);
675 loop {
676 if shutdown.is_cancelled() {
677 break;
678 }
679
680 let mut pubsub = match client.get_async_pubsub().await {
681 Ok(pubsub) => pubsub,
682 Err(error) => {
683 tracing::warn!(error = %error, "failed to connect Redis channels listener");
684 tokio::time::sleep(std::time::Duration::from_millis(250)).await;
685 continue;
686 }
687 };
688
689 if let Err(error) = pubsub.psubscribe(&pattern).await {
690 tracing::warn!(error = %error, pattern = %pattern, "failed to subscribe Redis channels listener");
691 tokio::time::sleep(std::time::Duration::from_millis(250)).await;
692 continue;
693 }
694
695 let mut stream = pubsub.on_message();
696 loop {
697 tokio::select! {
698 () = shutdown.cancelled() => return,
699 message = stream.next() => {
700 let Some(message) = message else {
701 break;
702 };
703 let redis_channel = message.get_channel_name();
704 let payload: String = match message.get_payload() {
705 Ok(payload) => payload,
706 Err(error) => {
707 tracing::warn!(error = %error, "failed to decode Redis channel payload");
708 continue;
709 }
710 };
711 let envelope: RedisEnvelope = match serde_json::from_str(&payload) {
712 Ok(envelope) => envelope,
713 Err(error) => {
714 tracing::warn!(error = %error, "failed to parse Redis channel envelope");
715 continue;
716 }
717 };
718 deliver_redis_envelope(
719 &local,
720 &origin_id,
721 &channel_prefix,
722 redis_channel,
723 envelope,
724 );
725 }
726 }
727 }
728 }
729 });
730}
731
732#[cfg(feature = "redis")]
733fn deliver_redis_envelope(
734 local: &LocalChannelsBackend,
735 origin_id: &str,
736 channel_prefix: &str,
737 redis_channel: &str,
738 envelope: RedisEnvelope,
739) {
740 let Some(topic) = redis_channel_topic(channel_prefix, redis_channel) else {
741 tracing::warn!(channel = %redis_channel, "Redis channel name did not match channel prefix");
742 return;
743 };
744
745 if envelope.topic != topic {
746 tracing::warn!(
747 channel = %redis_channel,
748 channel_topic = %topic,
749 envelope_topic = %envelope.topic,
750 "Redis channel envelope topic mismatch"
751 );
752 return;
753 }
754
755 if envelope.origin == origin_id {
756 return;
757 }
758
759 local.publish_local(topic, ChannelMessage(envelope.payload));
760}
761
762#[cfg(feature = "redis")]
763impl ChannelsBackend for RedisChannelsBackend {
764 fn publish(&self, topic: &str, msg: ChannelMessage) -> Result<usize, ChannelPublishError> {
765 let command = RedisPublishCommand {
766 redis_channel: self.redis_channel(topic),
767 envelope: RedisEnvelope {
768 origin: self.origin_id.clone(),
769 topic: topic.to_owned(),
770 payload: msg.as_str().to_owned(),
771 },
772 };
773 self.publisher
774 .try_send(command)
775 .map_err(|error| match error {
776 tokio::sync::mpsc::error::TrySendError::Full(_) => ChannelPublishError::QueueFull,
777 tokio::sync::mpsc::error::TrySendError::Closed(_) => {
778 ChannelPublishError::BackendClosed
779 }
780 })?;
781 Ok(self.local.publish_local(topic, msg))
782 }
783
784 fn ensure_topic(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>> {
785 self.local.ensure_topic(topic)
786 }
787
788 fn subscribe(&self, topic: &str) -> Subscriber {
789 self.local.subscribe(topic)
790 }
791
792 fn channel_count(&self) -> usize {
793 self.local.channel_count()
794 }
795
796 fn gc(&self) {
797 self.local.gc();
798 }
799
800 fn snapshot(&self) -> HashMap<String, ChannelStats> {
801 self.local.snapshot()
802 }
803}
804
805impl Channels {
806 #[must_use]
808 pub fn new(capacity: usize) -> Self {
809 Self::with_backend(LocalChannelsBackend::new(capacity))
810 }
811
812 #[must_use]
814 pub fn with_backend(backend: impl ChannelsBackend) -> Self {
815 Self {
816 backend: Arc::new(backend),
817 }
818 }
819
820 #[must_use]
822 pub fn with_shared_backend(backend: Arc<dyn ChannelsBackend>) -> Self {
823 Self { backend }
824 }
825
826 pub fn from_config(
833 config: &crate::config::ChannelConfig,
834 shutdown: tokio_util::sync::CancellationToken,
835 ) -> Result<Self, ChannelBackendConfigError> {
836 match config.backend {
837 crate::config::ChannelBackend::InProcess => Ok(Self::new(config.capacity)),
838 crate::config::ChannelBackend::Redis => Self::redis_from_config(config, shutdown),
839 }
840 }
841
842 #[cfg(feature = "redis")]
843 fn redis_from_config(
844 config: &crate::config::ChannelConfig,
845 shutdown: tokio_util::sync::CancellationToken,
846 ) -> Result<Self, ChannelBackendConfigError> {
847 Ok(Self::with_backend(RedisChannelsBackend::from_config(
848 config, shutdown,
849 )?))
850 }
851
852 #[cfg(not(feature = "redis"))]
853 fn redis_from_config(
854 _config: &crate::config::ChannelConfig,
855 _shutdown: tokio_util::sync::CancellationToken,
856 ) -> Result<Self, ChannelBackendConfigError> {
857 Err(ChannelBackendConfigError::RedisFeatureDisabled)
858 }
859
860 #[must_use]
862 pub fn broadcast(&self) -> Broadcast {
863 Broadcast::new(self.clone())
864 }
865
866 pub fn publish(
872 &self,
873 topic: &str,
874 msg: impl Into<ChannelMessage>,
875 ) -> Result<usize, ChannelPublishError> {
876 self.backend.publish(topic, msg.into())
877 }
878
879 #[must_use]
881 pub fn sender(&self, name: &str) -> Sender {
882 let keepalive = self.backend.ensure_topic(name);
883 Sender {
884 topic: name.to_owned(),
885 backend: Arc::clone(&self.backend),
886 keepalive,
887 }
888 }
889
890 #[must_use]
892 pub fn subscribe(&self, name: &str) -> Subscriber {
893 self.backend.subscribe(name)
894 }
895
896 pub async fn subscribe_authorized<E, Fut>(
923 &self,
924 name: &str,
925 authorize: impl FnOnce(String) -> Fut,
926 ) -> Result<Subscriber, E>
927 where
928 Fut: Future<Output = Result<(), E>>,
929 {
930 authorize(name.to_owned()).await?;
931 Ok(self.subscribe(name))
932 }
933
934 #[must_use]
936 pub fn channel_count(&self) -> usize {
937 self.backend.channel_count()
938 }
939
940 pub fn gc(&self) {
942 self.backend.gc();
943 }
944
945 #[must_use]
947 pub fn snapshot(&self) -> HashMap<String, ChannelStats> {
948 self.backend.snapshot()
949 }
950
951 #[cfg(feature = "ws")]
953 pub fn sse_stream(
954 &self,
955 name: &str,
956 ) -> axum::response::sse::Sse<
957 impl tokio_stream::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>
958 + use<>,
959 > {
960 crate::sse::from_subscriber(self.subscribe(name))
961 }
962}
963
964#[cfg(test)]
965mod tests {
966 use super::*;
967
968 #[test]
969 fn create_channels() {
970 let channels = Channels::new(16);
971 assert_eq!(channels.channel_count(), 0);
972 }
973
974 #[test]
975 fn sender_creates_channel_lazily() {
976 let channels = Channels::new(16);
977 let _tx = channels.sender("test");
978 assert_eq!(channels.channel_count(), 1);
979 }
980
981 #[test]
982 fn subscribe_creates_channel_lazily() {
983 let channels = Channels::new(16);
984 let _rx = channels.subscribe("test");
985 assert_eq!(channels.channel_count(), 1);
986 }
987
988 #[tokio::test]
989 async fn send_and_receive() -> Result<(), broadcast::error::RecvError> {
990 let channels = Channels::new(16);
991 let tx = channels.sender("chat");
992 let mut rx = channels.subscribe("chat");
993
994 tx.send("hello").expect("should send");
995 let msg = rx.recv().await?;
996 assert_eq!(msg.as_str(), "hello");
997 Ok(())
998 }
999
1000 #[tokio::test]
1001 async fn multiple_subscribers() -> Result<(), broadcast::error::RecvError> {
1002 let channels = Channels::new(16);
1003 let tx = channels.sender("chat");
1004 let mut rx1 = channels.subscribe("chat");
1005 let mut rx2 = channels.subscribe("chat");
1006
1007 tx.send("broadcast").expect("should send");
1008
1009 let msg1 = rx1.recv().await?;
1010 let msg2 = rx2.recv().await?;
1011 assert_eq!(msg1.as_str(), "broadcast");
1012 assert_eq!(msg2.as_str(), "broadcast");
1013 Ok(())
1014 }
1015
1016 #[test]
1017 fn sender_receiver_count() {
1018 let channels = Channels::new(16);
1019 let tx = channels.sender("chat");
1020 assert_eq!(tx.receiver_count(), 0);
1021
1022 let _rx1 = channels.subscribe("chat");
1023 assert_eq!(tx.receiver_count(), 1);
1024
1025 let _rx2 = channels.subscribe("chat");
1026 assert_eq!(tx.receiver_count(), 2);
1027 }
1028
1029 #[test]
1030 fn channel_message_conversions() {
1031 let msg: ChannelMessage = "hello".into();
1032 assert_eq!(msg.as_str(), "hello");
1033 assert_eq!(msg.to_string(), "hello");
1034
1035 let msg2: ChannelMessage = String::from("world").into();
1036 assert_eq!(msg2.into_string(), "world");
1037 }
1038
1039 #[test]
1040 #[allow(clippy::redundant_clone)]
1041 fn channels_is_clone() {
1042 let channels = Channels::new(16);
1043 let _cloned = channels.clone();
1044 }
1045
1046 #[test]
1047 fn snapshot_returns_counts() {
1048 let channels = Channels::new(16);
1049 let _tx = channels.sender("empty");
1050
1051 let _tx2 = channels.sender("one");
1052 let _rx_one = channels.subscribe("one");
1053
1054 let _tx3 = channels.sender("two");
1055 let _rx_two_1 = channels.subscribe("two");
1056 let _rx_two_2 = channels.subscribe("two");
1057
1058 let snap = channels.snapshot();
1059 assert_eq!(
1060 snap.get("empty").map(|stats| stats.subscriber_count),
1061 Some(0)
1062 );
1063 assert_eq!(snap.get("one").map(|stats| stats.subscriber_count), Some(1));
1064 assert_eq!(snap.get("two").map(|stats| stats.subscriber_count), Some(2));
1065 assert_eq!(snap.len(), 3);
1066 }
1067
1068 #[cfg(all(feature = "ws", feature = "maud"))]
1069 #[tokio::test]
1070 async fn broadcast_publish_html_wraps_fragment_in_hx_swap_oob_envelope()
1071 -> Result<(), broadcast::error::RecvError> {
1072 let channels = Channels::new(16);
1073 let broadcast = Broadcast::new(channels.clone());
1074 let mut rx = channels.subscribe("feed");
1075
1076 let sent = broadcast
1077 .publish_html(
1078 "feed",
1079 &maud::html! {
1080 li id="item-1" { "one" }
1081 },
1082 )
1083 .expect("html publish should succeed");
1084
1085 assert_eq!(sent, 1);
1086 let msg = rx.recv().await?;
1087 assert!(msg.as_str().contains("hx-swap-oob"));
1088 assert!(msg.as_str().contains("<li id=\"item-1\">one</li>"));
1089 Ok(())
1090 }
1091
1092 #[cfg(feature = "ws")]
1093 #[tokio::test]
1094 async fn broadcast_publish_raw_bytes_delivers_text_payload()
1095 -> Result<(), broadcast::error::RecvError> {
1096 let channels = Channels::new(16);
1097 let broadcast = Broadcast::new(channels.clone());
1098 let mut rx = channels.subscribe("raw");
1099
1100 let sent = broadcast
1101 .publish("raw", b"hello".as_slice())
1102 .expect("raw publish should succeed");
1103
1104 assert_eq!(sent, 1);
1105 assert_eq!(rx.recv().await?.as_str(), "hello");
1106 Ok(())
1107 }
1108
1109 #[cfg(feature = "ws")]
1110 #[test]
1111 fn broadcast_publish_rejects_invalid_utf8_bytes() {
1112 let channels = Channels::new(16);
1113 let broadcast = Broadcast::new(channels);
1114
1115 let error = broadcast
1116 .publish("raw", vec![0xff, 0xfe])
1117 .expect_err("invalid UTF-8 should be rejected before publishing");
1118
1119 assert!(matches!(error, BroadcastError::InvalidUtf8(_)));
1120 }
1121
1122 #[cfg(feature = "ws")]
1123 #[tokio::test]
1124 async fn snapshot_returns_channel_metrics() -> Result<(), broadcast::error::RecvError> {
1125 let channels = Channels::new(16);
1126 let broadcast = Broadcast::new(channels.clone());
1127 let mut rx = channels.subscribe("metrics");
1128
1129 broadcast
1130 .publish("metrics", "one")
1131 .expect("publish should succeed");
1132 let _ = rx.recv().await?;
1133
1134 let snap = channels.snapshot();
1135 let stats = snap.get("metrics").expect("topic should be tracked");
1136 assert_eq!(stats.subscriber_count, 1);
1137 assert_eq!(stats.lifetime_publish_count, 1);
1138 assert_eq!(stats.dropped_count, 0);
1139 assert_eq!(stats.lagged_count, 0);
1140 Ok(())
1141 }
1142
1143 #[cfg(feature = "ws")]
1144 #[test]
1145 fn snapshot_counts_dropped_publish_without_successful_delivery() {
1146 let channels = Channels::new(16);
1147 let sent = channels
1148 .broadcast()
1149 .publish("metrics", "one")
1150 .expect("publish with no subscribers should not fail");
1151
1152 assert_eq!(sent, 0);
1153 let snap = channels.snapshot();
1154 let stats = snap.get("metrics").expect("topic should be tracked");
1155 assert_eq!(stats.subscriber_count, 0);
1156 assert_eq!(stats.lifetime_publish_count, 0);
1157 assert_eq!(stats.dropped_count, 1);
1158 assert_eq!(stats.lagged_count, 0);
1159 }
1160
1161 #[test]
1162 fn gc_prunes_metrics_for_removed_idle_topics() {
1163 let channels = Channels::new(16);
1164 channels
1165 .publish("tenant:gone", "one")
1166 .expect("publish with no subscribers should only record a drop");
1167
1168 let before_gc = channels.snapshot();
1169 assert!(before_gc.contains_key("tenant:gone"));
1170
1171 channels.gc();
1172
1173 let after_gc = channels.snapshot();
1174 assert!(!after_gc.contains_key("tenant:gone"));
1175 assert_eq!(channels.channel_count(), 0);
1176 }
1177
1178 #[cfg(feature = "redis")]
1179 #[test]
1180 fn redis_listener_rejects_envelope_topic_that_mismatches_channel() {
1181 let local = LocalChannelsBackend::new(16);
1182 let mut private_rx = local.subscribe("private");
1183 let channel_prefix = redis_channel_name("autumn:channels", "");
1184
1185 deliver_redis_envelope(
1186 &local,
1187 "local-origin",
1188 &channel_prefix,
1189 "autumn:channels:public",
1190 RedisEnvelope {
1191 origin: "remote-origin".to_owned(),
1192 topic: "private".to_owned(),
1193 payload: "secret".to_owned(),
1194 },
1195 );
1196
1197 assert!(matches!(
1198 private_rx.try_recv(),
1199 Err(broadcast::error::TryRecvError::Empty)
1200 ));
1201 assert!(!local.snapshot().contains_key("public"));
1202 }
1203
1204 #[cfg(feature = "redis")]
1205 #[test]
1206 fn redis_listener_counts_successful_remote_deliveries() {
1207 let local = LocalChannelsBackend::new(16);
1208 let mut rx = local.subscribe("public");
1209 let channel_prefix = redis_channel_name("autumn:channels", "");
1210
1211 deliver_redis_envelope(
1212 &local,
1213 "local-origin",
1214 &channel_prefix,
1215 "autumn:channels:public",
1216 RedisEnvelope {
1217 origin: "remote-origin".to_owned(),
1218 topic: "public".to_owned(),
1219 payload: "hello".to_owned(),
1220 },
1221 );
1222
1223 assert_eq!(
1224 rx.try_recv()
1225 .expect("remote message should fan out")
1226 .as_str(),
1227 "hello"
1228 );
1229 let snapshot = local.snapshot();
1230 let stats = snapshot.get("public").expect("topic should be tracked");
1231 assert_eq!(stats.lifetime_publish_count, 1);
1232 assert_eq!(stats.dropped_count, 0);
1233 }
1234
1235 #[cfg(feature = "redis")]
1236 #[test]
1237 fn redis_publish_rejects_when_bounded_queue_is_full() {
1238 let local = LocalChannelsBackend::new(16);
1239 let mut rx = local.subscribe("queue");
1240 let (publisher, _receiver) = tokio::sync::mpsc::channel(1);
1241 publisher
1242 .try_send(RedisPublishCommand {
1243 redis_channel: "autumn:channels:queue".to_owned(),
1244 envelope: RedisEnvelope {
1245 origin: "origin".to_owned(),
1246 topic: "queue".to_owned(),
1247 payload: "already queued".to_owned(),
1248 },
1249 })
1250 .expect("first command should fill the queue");
1251
1252 let backend = RedisChannelsBackend {
1253 local,
1254 publisher,
1255 origin_id: "origin".to_owned(),
1256 key_prefix: "autumn:channels".to_owned(),
1257 };
1258
1259 let error = backend
1260 .publish("queue", ChannelMessage::from("second"))
1261 .expect_err("full Redis queue should reject the publish");
1262
1263 assert_eq!(error, ChannelPublishError::QueueFull);
1264 assert!(matches!(
1265 rx.try_recv(),
1266 Err(broadcast::error::TryRecvError::Empty)
1267 ));
1268 }
1269
1270 #[test]
1271 fn snapshot_releases_registry_before_waiting_on_metrics() {
1272 let backend = LocalChannelsBackend::new(16);
1273 backend.ensure_topic("race");
1274
1275 let metrics_guard = backend
1276 .inner
1277 .metrics
1278 .counters
1279 .lock()
1280 .expect("channel metrics lock should not be poisoned");
1281 let registry_guard = backend
1282 .inner
1283 .registry
1284 .lock()
1285 .expect("channel registry lock should not be poisoned");
1286 let snapshot_backend = backend.clone();
1287
1288 let handle = std::thread::spawn(move || {
1289 let snapshot = snapshot_backend.snapshot();
1290 assert!(snapshot.contains_key("race"));
1291 });
1292
1293 std::thread::sleep(std::time::Duration::from_millis(25));
1294 drop(registry_guard);
1295 std::thread::sleep(std::time::Duration::from_millis(25));
1296
1297 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
1298 let registry_released_before_metrics = loop {
1299 match backend.inner.registry.try_lock() {
1300 Ok(registry) => {
1301 drop(registry);
1302 break true;
1303 }
1304 Err(std::sync::TryLockError::WouldBlock)
1305 if std::time::Instant::now() < deadline =>
1306 {
1307 std::thread::yield_now();
1308 }
1309 Err(std::sync::TryLockError::WouldBlock) => break false,
1310 Err(std::sync::TryLockError::Poisoned(error)) => {
1311 panic!("channel registry lock should not be poisoned: {error}");
1312 }
1313 }
1314 };
1315
1316 drop(metrics_guard);
1317 handle.join().expect("snapshot thread should finish");
1318 assert!(
1319 registry_released_before_metrics,
1320 "snapshot held the registry mutex while waiting on metrics"
1321 );
1322 }
1323
1324 #[cfg(feature = "ws")]
1325 #[tokio::test]
1326 async fn app_state_broadcast_uses_state_channels() -> Result<(), broadcast::error::RecvError> {
1327 let state = crate::AppState::for_test();
1328 let mut rx = state.channels().subscribe("state-topic");
1329
1330 state
1331 .broadcast()
1332 .publish("state-topic", "from-state")
1333 .expect("publish should succeed");
1334
1335 assert_eq!(rx.recv().await?.as_str(), "from-state");
1336 Ok(())
1337 }
1338
1339 #[cfg(feature = "ws")]
1340 #[tokio::test]
1341 async fn subscribe_authorized_rejects_before_creating_subscriber() {
1342 let channels = Channels::new(16);
1343
1344 let result: Result<Subscriber, &'static str> = channels
1345 .subscribe_authorized("private", |topic| async move {
1346 assert_eq!(topic, "private");
1347 Err("denied")
1348 })
1349 .await;
1350
1351 assert!(matches!(result, Err("denied")));
1352 assert!(!channels.snapshot().contains_key("private"));
1353 }
1354
1355 #[cfg(feature = "ws")]
1356 #[tokio::test]
1357 async fn subscribe_authorized_allows_after_hook_passes()
1358 -> Result<(), broadcast::error::RecvError> {
1359 let channels = Channels::new(16);
1360 let mut rx = channels
1361 .subscribe_authorized("private", |topic| async move {
1362 assert_eq!(topic, "private");
1363 Ok::<(), std::convert::Infallible>(())
1364 })
1365 .await
1366 .expect("authorization should pass");
1367
1368 channels
1369 .broadcast()
1370 .publish("private", "secret")
1371 .expect("publish should succeed");
1372
1373 assert_eq!(rx.recv().await?.as_str(), "secret");
1374 Ok(())
1375 }
1376
1377 #[test]
1378 fn gc_removes_dead_channels() {
1379 let channels = Channels::new(16);
1380 let _tx = channels.sender("alive");
1381 {
1382 let _tx = channels.sender("dead");
1383 }
1384 assert_eq!(channels.channel_count(), 2);
1385 channels.gc();
1386 assert_eq!(channels.channel_count(), 1);
1387 }
1388
1389 #[cfg(feature = "ws")]
1390 #[tokio::test]
1391 async fn subscriber_into_stream() {
1392 use tokio_stream::StreamExt;
1393 let channels = Channels::new(16);
1394 let tx = channels.sender("test_stream");
1395 let rx = channels.subscribe("test_stream");
1396
1397 tx.send("message 1").unwrap();
1398 tx.send("message 2").unwrap();
1399
1400 let mut stream = rx.into_stream();
1401 let msg1 = stream.next().await.unwrap();
1402 assert_eq!(msg1.as_str(), "message 1");
1403
1404 let msg2 = stream.next().await.unwrap();
1405 assert_eq!(msg2.as_str(), "message 2");
1406 }
1407
1408 #[cfg(feature = "ws")]
1409 #[tokio::test]
1410 async fn channels_sse_stream() {
1411 let channels = Channels::new(16);
1412 let tx = channels.sender("test_sse");
1413
1414 let sse = channels.sse_stream("test_sse");
1415
1416 tx.send("sse message").unwrap();
1417 let _stream = sse;
1418 }
1419}