Skip to main content

autumn_web/
channels.rs

1//! Named broadcast channel registry for real-time messaging.
2//!
3//! [`Channels`] provides a lightweight pub-sub primitive with a local
4//! in-process backend by default and an optional Redis pub/sub backend for
5//! multi-replica fan-out.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use autumn_web::channels::Channels;
11//!
12//! let channels = Channels::new(32);
13//! let tx = channels.sender("lobby");
14//! let mut rx = channels.subscribe("lobby");
15//!
16//! tx.send("hello").ok();
17//! # // In async context: let msg = rx.recv().await.expect("should receive");
18//! ```
19
20use std::collections::{HashMap, HashSet};
21use std::future::Future;
22use std::sync::{Arc, Mutex};
23
24use serde::Serialize;
25use thiserror::Error;
26use tokio::sync::broadcast;
27
28#[cfg(feature = "redis")]
29const REDIS_PUBLISH_QUEUE_CAPACITY: usize = 1024;
30
31/// A registry of named broadcast channels.
32#[derive(Clone)]
33pub struct Channels {
34    backend: Arc<dyn ChannelsBackend>,
35}
36
37/// Backend abstraction for channel fan-out.
38pub trait ChannelsBackend: Send + Sync + 'static {
39    /// Publish one message to a topic.
40    ///
41    /// # Errors
42    ///
43    /// Returns [`ChannelPublishError`] if the backend cannot accept the
44    /// publish request.
45    fn publish(&self, topic: &str, msg: ChannelMessage) -> Result<usize, ChannelPublishError>;
46
47    /// Ensure a local topic exists and return a keepalive sender handle.
48    fn ensure_topic(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>>;
49
50    /// Subscribe to future messages on a topic.
51    fn subscribe(&self, topic: &str) -> Subscriber;
52
53    /// Return the number of topics known to this backend.
54    fn channel_count(&self) -> usize;
55
56    /// Remove idle local topic registries when supported.
57    fn gc(&self);
58
59    /// Return per-topic subscriber and delivery metrics.
60    fn snapshot(&self) -> HashMap<String, ChannelStats>;
61}
62
63/// Local in-process [`tokio::sync::broadcast`] channel backend.
64#[derive(Clone)]
65pub struct LocalChannelsBackend {
66    inner: Arc<LocalChannelsInner>,
67}
68
69struct LocalChannelsInner {
70    capacity: usize,
71    registry: Mutex<HashMap<String, Arc<broadcast::Sender<ChannelMessage>>>>,
72    metrics: Arc<ChannelMetrics>,
73}
74
75/// A message sent through a broadcast channel.
76#[derive(Clone, Debug, PartialEq, Eq)]
77pub struct ChannelMessage(pub String);
78
79impl From<String> for ChannelMessage {
80    fn from(s: String) -> Self {
81        Self(s)
82    }
83}
84
85impl From<&str> for ChannelMessage {
86    fn from(s: &str) -> Self {
87        Self(s.to_owned())
88    }
89}
90
91impl ChannelMessage {
92    /// Get the message content as a string slice.
93    #[must_use]
94    pub fn as_str(&self) -> &str {
95        &self.0
96    }
97
98    /// Consume the message, returning the inner `String`.
99    #[must_use]
100    pub fn into_string(self) -> String {
101        self.0
102    }
103}
104
105impl std::fmt::Display for ChannelMessage {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.write_str(&self.0)
108    }
109}
110
111/// Per-topic channel metrics exposed by `/actuator/channels`.
112#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)]
113pub struct ChannelStats {
114    /// Current active subscriber count.
115    pub subscriber_count: usize,
116    /// Successful local deliveries for this topic over this process lifetime.
117    pub lifetime_publish_count: u64,
118    /// Messages dropped because no local receiver accepted them.
119    pub dropped_count: u64,
120    /// Messages skipped by slow subscribers.
121    pub lagged_count: u64,
122}
123
124#[derive(Default)]
125struct ChannelMetrics {
126    counters: Mutex<HashMap<String, ChannelMetricCounters>>,
127}
128
129#[derive(Clone, Default)]
130struct ChannelMetricCounters {
131    publishes: u64,
132    drops: u64,
133    lags: u64,
134}
135
136impl ChannelMetrics {
137    fn ensure_topic(&self, topic: &str) {
138        let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
139        counters.entry(topic.to_owned()).or_default();
140    }
141
142    fn record_publish(&self, topic: &str) {
143        let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
144        let stats = counters.entry(topic.to_owned()).or_default();
145        stats.publishes = stats.publishes.saturating_add(1);
146        drop(counters);
147    }
148
149    fn record_dropped(&self, topic: &str, count: u64) {
150        let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
151        let stats = counters.entry(topic.to_owned()).or_default();
152        stats.drops = stats.drops.saturating_add(count);
153        drop(counters);
154    }
155
156    fn record_lagged(&self, topic: &str, count: u64) {
157        let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
158        let stats = counters.entry(topic.to_owned()).or_default();
159        stats.lags = stats.lags.saturating_add(count);
160        drop(counters);
161    }
162
163    fn snapshot(&self) -> HashMap<String, ChannelMetricCounters> {
164        self.counters
165            .lock()
166            .expect("channel metrics lock poisoned")
167            .clone()
168    }
169
170    fn remove_topics(&self, topics: &HashSet<String>) {
171        if topics.is_empty() {
172            return;
173        }
174
175        let mut counters = self.counters.lock().expect("channel metrics lock poisoned");
176        counters.retain(|topic, _| !topics.contains(topic));
177        drop(counters);
178    }
179}
180
181/// Error returned when a channel backend cannot accept a publish request.
182#[derive(Debug, Clone, Error, PartialEq, Eq)]
183pub enum ChannelPublishError {
184    /// The backend has shut down and can no longer accept publish requests.
185    #[error("channel backend is closed")]
186    BackendClosed,
187    /// The backend's bounded publish queue is full.
188    #[error("channel backend publish queue is full")]
189    QueueFull,
190}
191
192/// Error returned by the htmx/raw broadcast facade.
193#[derive(Debug, Error)]
194pub enum BroadcastError {
195    /// Raw byte payloads must be UTF-8 because htmx SSE and WebSocket text
196    /// transports consume text frames.
197    #[error("broadcast payload is not valid UTF-8: {0}")]
198    InvalidUtf8(#[from] std::string::FromUtf8Error),
199
200    /// The selected channel backend rejected the publish request.
201    #[error(transparent)]
202    Publish(#[from] ChannelPublishError),
203}
204
205/// Raw broadcast payload accepted by [`Broadcast::publish`].
206pub enum BroadcastPayload {
207    /// Text payload.
208    Text(String),
209    /// Byte payload, decoded as UTF-8 before publishing.
210    Bytes(Vec<u8>),
211}
212
213impl From<&str> for BroadcastPayload {
214    fn from(value: &str) -> Self {
215        Self::Text(value.to_owned())
216    }
217}
218
219impl From<String> for BroadcastPayload {
220    fn from(value: String) -> Self {
221        Self::Text(value)
222    }
223}
224
225impl From<&String> for BroadcastPayload {
226    fn from(value: &String) -> Self {
227        Self::Text(value.clone())
228    }
229}
230
231impl From<Vec<u8>> for BroadcastPayload {
232    fn from(value: Vec<u8>) -> Self {
233        Self::Bytes(value)
234    }
235}
236
237impl From<&[u8]> for BroadcastPayload {
238    fn from(value: &[u8]) -> Self {
239        Self::Bytes(value.to_vec())
240    }
241}
242
243impl<const N: usize> From<&[u8; N]> for BroadcastPayload {
244    fn from(value: &[u8; N]) -> Self {
245        Self::Bytes(value.to_vec())
246    }
247}
248
249/// Productive publishing facade for htmx-oriented applications.
250#[derive(Clone)]
251pub struct Broadcast {
252    channels: Channels,
253}
254
255impl Broadcast {
256    /// Create a broadcast facade from a channel registry.
257    #[must_use]
258    pub const fn new(channels: Channels) -> Self {
259        Self { channels }
260    }
261
262    /// Publish a raw UTF-8 payload to a topic.
263    ///
264    /// ```
265    /// use autumn_web::channels::Channels;
266    ///
267    /// let channels = Channels::new(16);
268    /// channels
269    ///     .broadcast()
270    ///     .publish("feed", b"raw fragment".as_slice())
271    ///     .expect("raw publish should succeed");
272    /// ```
273    ///
274    /// # Errors
275    ///
276    /// Returns [`BroadcastError::InvalidUtf8`] for invalid byte payloads or
277    /// [`BroadcastError::Publish`] when the backend rejects the publish.
278    pub fn publish(
279        &self,
280        topic: &str,
281        payload: impl Into<BroadcastPayload>,
282    ) -> Result<usize, BroadcastError> {
283        let message = match payload.into() {
284            BroadcastPayload::Text(text) => ChannelMessage(text),
285            BroadcastPayload::Bytes(bytes) => ChannelMessage(String::from_utf8(bytes)?),
286        };
287        Ok(self.channels.publish(topic, message)?)
288    }
289
290    /// Publish a Maud fragment wrapped in an htmx out-of-band envelope.
291    ///
292    /// ```
293    /// use autumn_web::channels::Channels;
294    /// use maud::html;
295    ///
296    /// let channels = Channels::new(16);
297    /// channels
298    ///     .broadcast()
299    ///     .publish_html("feed", &html! { div id="notice" { "Saved" } })
300    ///     .expect("html publish should succeed");
301    /// ```
302    ///
303    /// # Errors
304    ///
305    /// Returns [`BroadcastError::Publish`] when the selected backend rejects
306    /// the publish request.
307    #[cfg(feature = "maud")]
308    pub fn publish_html(
309        &self,
310        topic: &str,
311        fragment: &maud::Markup,
312    ) -> Result<usize, BroadcastError> {
313        self.publish(topic, htmx_oob_envelope(fragment))
314    }
315}
316
317#[cfg(feature = "maud")]
318fn htmx_oob_envelope(fragment: &maud::Markup) -> String {
319    maud::html! {
320        template hx-swap-oob="true" {
321            (fragment)
322        }
323    }
324    .into_string()
325}
326
327/// A sender handle for a broadcast channel.
328#[derive(Clone)]
329pub struct Sender {
330    topic: String,
331    backend: Arc<dyn ChannelsBackend>,
332    keepalive: Arc<broadcast::Sender<ChannelMessage>>,
333}
334
335impl Sender {
336    /// Broadcast a message to all current subscribers of this channel.
337    ///
338    /// Publishing to a topic with no subscribers is not fatal; the backend
339    /// records a drop metric and returns `Ok(0)`.
340    ///
341    /// # Errors
342    ///
343    /// Returns [`ChannelPublishError`] if the backend is closed.
344    pub fn send(&self, msg: impl Into<ChannelMessage>) -> Result<usize, ChannelPublishError> {
345        self.backend.publish(&self.topic, msg.into())
346    }
347
348    /// Returns the current number of active subscribers.
349    #[must_use]
350    pub fn receiver_count(&self) -> usize {
351        self.keepalive.receiver_count()
352    }
353}
354
355/// A subscriber handle for a broadcast channel.
356pub struct Subscriber {
357    topic: String,
358    inner: broadcast::Receiver<ChannelMessage>,
359    metrics: Arc<ChannelMetrics>,
360}
361
362impl Subscriber {
363    /// Receive the next message from the channel.
364    ///
365    /// # Errors
366    ///
367    /// Returns `RecvError::Closed` if all senders have been dropped, or
368    /// `RecvError::Lagged(n)` if messages were skipped.
369    pub async fn recv(&mut self) -> Result<ChannelMessage, broadcast::error::RecvError> {
370        match self.inner.recv().await {
371            Err(broadcast::error::RecvError::Lagged(count)) => {
372                self.metrics.record_lagged(&self.topic, count);
373                Err(broadcast::error::RecvError::Lagged(count))
374            }
375            result => result,
376        }
377    }
378
379    /// Try to receive a message without waiting.
380    ///
381    /// # Errors
382    ///
383    /// Returns the same errors as [`broadcast::Receiver::try_recv`].
384    pub fn try_recv(&mut self) -> Result<ChannelMessage, broadcast::error::TryRecvError> {
385        match self.inner.try_recv() {
386            Err(broadcast::error::TryRecvError::Lagged(count)) => {
387                self.metrics.record_lagged(&self.topic, count);
388                Err(broadcast::error::TryRecvError::Lagged(count))
389            }
390            result => result,
391        }
392    }
393
394    /// Convert this subscriber into a stream of channel messages.
395    #[cfg(feature = "ws")]
396    pub fn into_stream(self) -> impl tokio_stream::Stream<Item = ChannelMessage> {
397        use tokio_stream::StreamExt;
398        let topic = self.topic;
399        let metrics = self.metrics;
400        tokio_stream::wrappers::BroadcastStream::new(self.inner).filter_map(move |result| {
401            if let Err(tokio_stream::wrappers::errors::BroadcastStreamRecvError::Lagged(count)) =
402                &result
403            {
404                metrics.record_lagged(&topic, *count);
405            }
406            result.ok()
407        })
408    }
409}
410
411impl LocalChannelsBackend {
412    /// Create a local backend with the given per-topic buffer capacity.
413    #[must_use]
414    pub fn new(capacity: usize) -> Self {
415        Self {
416            inner: Arc::new(LocalChannelsInner {
417                capacity: capacity.clamp(1, 16_384),
418                registry: Mutex::new(HashMap::new()),
419                metrics: Arc::new(ChannelMetrics::default()),
420            }),
421        }
422    }
423
424    fn get_or_create_sender(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>> {
425        let mut registry = self.inner.registry.lock().expect("channels lock poisoned");
426
427        #[allow(clippy::option_if_let_else)]
428        if let Some(tx) = registry.get(topic) {
429            Arc::clone(tx)
430        } else {
431            let tx = Arc::new(broadcast::channel(self.inner.capacity).0);
432            registry.insert(topic.to_owned(), Arc::clone(&tx));
433            tx
434        }
435    }
436
437    fn publish_local(&self, topic: &str, msg: ChannelMessage) -> usize {
438        let count = self.send_without_publish_metric(topic, msg);
439        if count > 0 {
440            self.inner.metrics.record_publish(topic);
441        }
442        count
443    }
444
445    fn send_without_publish_metric(&self, topic: &str, msg: ChannelMessage) -> usize {
446        let tx = self.get_or_create_sender(topic);
447        match tx.send(msg) {
448            Ok(count) => count,
449            Err(_error) => {
450                self.inner.metrics.record_dropped(topic, 1);
451                0
452            }
453        }
454    }
455}
456
457impl ChannelsBackend for LocalChannelsBackend {
458    fn publish(&self, topic: &str, msg: ChannelMessage) -> Result<usize, ChannelPublishError> {
459        Ok(self.publish_local(topic, msg))
460    }
461
462    fn ensure_topic(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>> {
463        self.inner.metrics.ensure_topic(topic);
464        self.get_or_create_sender(topic)
465    }
466
467    fn subscribe(&self, topic: &str) -> Subscriber {
468        let tx = self.ensure_topic(topic);
469        Subscriber {
470            topic: topic.to_owned(),
471            inner: tx.subscribe(),
472            metrics: Arc::clone(&self.inner.metrics),
473        }
474    }
475
476    fn channel_count(&self) -> usize {
477        let registry = self.inner.registry.lock().expect("channels lock poisoned");
478        registry.len()
479    }
480
481    fn gc(&self) {
482        let mut registry = self.inner.registry.lock().expect("channels lock poisoned");
483        let mut removed_topics = HashSet::new();
484        registry.retain(|topic, tx| {
485            let keep = tx.receiver_count() > 0 || Arc::strong_count(tx) > 1;
486            if !keep {
487                removed_topics.insert(topic.clone());
488            }
489            keep
490        });
491        drop(registry);
492
493        self.inner.metrics.remove_topics(&removed_topics);
494    }
495
496    fn snapshot(&self) -> HashMap<String, ChannelStats> {
497        // Keep registry and metrics collection in separate phases. Publish and
498        // subscribe paths touch metrics before registry, so snapshot must never
499        // hold the registry mutex while reading metrics.
500        let subscriber_counts: HashMap<String, usize> = {
501            let registry = self.inner.registry.lock().expect("channels lock poisoned");
502            registry
503                .iter()
504                .map(|(topic, sender)| (topic.clone(), sender.receiver_count()))
505                .collect()
506        };
507        let metric_counters = self.inner.metrics.snapshot();
508
509        let mut topics: HashSet<String> = metric_counters.keys().cloned().collect();
510        topics.extend(subscriber_counts.keys().cloned());
511
512        topics
513            .into_iter()
514            .map(|topic| {
515                let subscriber_count = subscriber_counts.get(&topic).copied().unwrap_or(0);
516                let counters = metric_counters.get(&topic).cloned().unwrap_or_default();
517                (
518                    topic,
519                    ChannelStats {
520                        subscriber_count,
521                        lifetime_publish_count: counters.publishes,
522                        dropped_count: counters.drops,
523                        lagged_count: counters.lags,
524                    },
525                )
526            })
527            .collect()
528    }
529}
530
531#[cfg(feature = "redis")]
532#[derive(Clone)]
533struct RedisChannelsBackend {
534    local: LocalChannelsBackend,
535    publisher: tokio::sync::mpsc::Sender<RedisPublishCommand>,
536    origin_id: String,
537    key_prefix: String,
538}
539
540#[cfg(feature = "redis")]
541struct RedisPublishCommand {
542    redis_channel: String,
543    envelope: RedisEnvelope,
544}
545
546#[cfg(feature = "redis")]
547#[derive(serde::Deserialize, serde::Serialize)]
548struct RedisEnvelope {
549    origin: String,
550    topic: String,
551    payload: String,
552}
553
554/// Channel backend configuration error.
555#[derive(Debug, Error)]
556pub enum ChannelBackendConfigError {
557    /// `channels.backend = "redis"` needs `channels.redis.url`.
558    #[error("channels.redis.url is required when channels.backend = \"redis\"")]
559    MissingRedisUrl,
560    /// Redis URL failed validation by the Redis client.
561    #[error("invalid channels.redis.url: {0}")]
562    InvalidRedisUrl(String),
563    /// The `redis` cargo feature is required for the Redis backend.
564    #[error("channels.backend = \"redis\" requires the redis cargo feature")]
565    RedisFeatureDisabled,
566}
567
568#[cfg(feature = "redis")]
569impl RedisChannelsBackend {
570    fn from_config(
571        config: &crate::config::ChannelConfig,
572        shutdown: tokio_util::sync::CancellationToken,
573    ) -> Result<Self, ChannelBackendConfigError> {
574        let url = config
575            .redis
576            .url
577            .clone()
578            .filter(|url| !url.trim().is_empty())
579            .ok_or(ChannelBackendConfigError::MissingRedisUrl)?;
580        let client = redis::Client::open(url)
581            .map_err(|error| ChannelBackendConfigError::InvalidRedisUrl(error.to_string()))?;
582        let local = LocalChannelsBackend::new(config.capacity);
583        let (publisher, receiver) = tokio::sync::mpsc::channel(REDIS_PUBLISH_QUEUE_CAPACITY);
584        let origin_id = uuid::Uuid::new_v4().to_string();
585        let backend = Self {
586            local: local.clone(),
587            publisher,
588            origin_id: origin_id.clone(),
589            key_prefix: config.redis.key_prefix.clone(),
590        };
591        spawn_redis_publisher(client.clone(), receiver, shutdown.clone());
592        spawn_redis_listener(
593            client,
594            local,
595            origin_id,
596            config.redis.key_prefix.clone(),
597            shutdown,
598        );
599        Ok(backend)
600    }
601
602    fn redis_channel(&self, topic: &str) -> String {
603        redis_channel_name(&self.key_prefix, topic)
604    }
605}
606
607#[cfg(feature = "redis")]
608fn redis_channel_name(prefix: &str, topic: &str) -> String {
609    format!("{prefix}:{topic}")
610}
611
612#[cfg(feature = "redis")]
613fn redis_channel_topic<'a>(channel_prefix: &str, channel: &'a str) -> Option<&'a str> {
614    channel.strip_prefix(channel_prefix)
615}
616
617#[cfg(feature = "redis")]
618fn redis_channel_pattern(prefix: &str) -> String {
619    format!("{prefix}:*")
620}
621
622#[cfg(feature = "redis")]
623fn spawn_redis_publisher(
624    client: redis::Client,
625    mut receiver: tokio::sync::mpsc::Receiver<RedisPublishCommand>,
626    shutdown: tokio_util::sync::CancellationToken,
627) {
628    tokio::spawn(async move {
629        use redis::AsyncCommands as _;
630        use redis::aio::{ConnectionManager, ConnectionManagerConfig};
631
632        let mut connection =
633            match ConnectionManager::new_lazy_with_config(client, ConnectionManagerConfig::new()) {
634                Ok(connection) => connection,
635                Err(error) => {
636                    tracing::warn!(error = %error, "failed to create Redis channels publisher");
637                    return;
638                }
639            };
640
641        loop {
642            tokio::select! {
643                () = shutdown.cancelled() => break,
644                Some(command) = receiver.recv() => {
645                    let Ok(payload) = serde_json::to_string(&command.envelope) else {
646                        tracing::warn!("failed to serialize Redis channel envelope");
647                        continue;
648                    };
649                    if let Err(error) = connection
650                        .publish::<_, _, usize>(&command.redis_channel, payload)
651                        .await
652                    {
653                        tracing::warn!(error = %error, channel = %command.redis_channel, "Redis channel publish failed");
654                    }
655                }
656                else => break,
657            }
658        }
659    });
660}
661
662#[cfg(feature = "redis")]
663fn spawn_redis_listener(
664    client: redis::Client,
665    local: LocalChannelsBackend,
666    origin_id: String,
667    key_prefix: String,
668    shutdown: tokio_util::sync::CancellationToken,
669) {
670    tokio::spawn(async move {
671        use futures::StreamExt as _;
672
673        let channel_prefix = redis_channel_name(&key_prefix, "");
674        let pattern = redis_channel_pattern(&key_prefix);
675        loop {
676            if shutdown.is_cancelled() {
677                break;
678            }
679
680            let mut pubsub = match client.get_async_pubsub().await {
681                Ok(pubsub) => pubsub,
682                Err(error) => {
683                    tracing::warn!(error = %error, "failed to connect Redis channels listener");
684                    tokio::time::sleep(std::time::Duration::from_millis(250)).await;
685                    continue;
686                }
687            };
688
689            if let Err(error) = pubsub.psubscribe(&pattern).await {
690                tracing::warn!(error = %error, pattern = %pattern, "failed to subscribe Redis channels listener");
691                tokio::time::sleep(std::time::Duration::from_millis(250)).await;
692                continue;
693            }
694
695            let mut stream = pubsub.on_message();
696            loop {
697                tokio::select! {
698                    () = shutdown.cancelled() => return,
699                    message = stream.next() => {
700                        let Some(message) = message else {
701                            break;
702                        };
703                        let redis_channel = message.get_channel_name();
704                        let payload: String = match message.get_payload() {
705                            Ok(payload) => payload,
706                            Err(error) => {
707                                tracing::warn!(error = %error, "failed to decode Redis channel payload");
708                                continue;
709                            }
710                        };
711                        let envelope: RedisEnvelope = match serde_json::from_str(&payload) {
712                            Ok(envelope) => envelope,
713                            Err(error) => {
714                                tracing::warn!(error = %error, "failed to parse Redis channel envelope");
715                                continue;
716                            }
717                        };
718                        deliver_redis_envelope(
719                            &local,
720                            &origin_id,
721                            &channel_prefix,
722                            redis_channel,
723                            envelope,
724                        );
725                    }
726                }
727            }
728        }
729    });
730}
731
732#[cfg(feature = "redis")]
733fn deliver_redis_envelope(
734    local: &LocalChannelsBackend,
735    origin_id: &str,
736    channel_prefix: &str,
737    redis_channel: &str,
738    envelope: RedisEnvelope,
739) {
740    let Some(topic) = redis_channel_topic(channel_prefix, redis_channel) else {
741        tracing::warn!(channel = %redis_channel, "Redis channel name did not match channel prefix");
742        return;
743    };
744
745    if envelope.topic != topic {
746        tracing::warn!(
747            channel = %redis_channel,
748            channel_topic = %topic,
749            envelope_topic = %envelope.topic,
750            "Redis channel envelope topic mismatch"
751        );
752        return;
753    }
754
755    if envelope.origin == origin_id {
756        return;
757    }
758
759    local.publish_local(topic, ChannelMessage(envelope.payload));
760}
761
762#[cfg(feature = "redis")]
763impl ChannelsBackend for RedisChannelsBackend {
764    fn publish(&self, topic: &str, msg: ChannelMessage) -> Result<usize, ChannelPublishError> {
765        let command = RedisPublishCommand {
766            redis_channel: self.redis_channel(topic),
767            envelope: RedisEnvelope {
768                origin: self.origin_id.clone(),
769                topic: topic.to_owned(),
770                payload: msg.as_str().to_owned(),
771            },
772        };
773        self.publisher
774            .try_send(command)
775            .map_err(|error| match error {
776                tokio::sync::mpsc::error::TrySendError::Full(_) => ChannelPublishError::QueueFull,
777                tokio::sync::mpsc::error::TrySendError::Closed(_) => {
778                    ChannelPublishError::BackendClosed
779                }
780            })?;
781        Ok(self.local.publish_local(topic, msg))
782    }
783
784    fn ensure_topic(&self, topic: &str) -> Arc<broadcast::Sender<ChannelMessage>> {
785        self.local.ensure_topic(topic)
786    }
787
788    fn subscribe(&self, topic: &str) -> Subscriber {
789        self.local.subscribe(topic)
790    }
791
792    fn channel_count(&self) -> usize {
793        self.local.channel_count()
794    }
795
796    fn gc(&self) {
797        self.local.gc();
798    }
799
800    fn snapshot(&self) -> HashMap<String, ChannelStats> {
801        self.local.snapshot()
802    }
803}
804
805impl Channels {
806    /// Create a new local channel registry with the given buffer capacity.
807    #[must_use]
808    pub fn new(capacity: usize) -> Self {
809        Self::with_backend(LocalChannelsBackend::new(capacity))
810    }
811
812    /// Create a registry from any backend implementation.
813    #[must_use]
814    pub fn with_backend(backend: impl ChannelsBackend) -> Self {
815        Self {
816            backend: Arc::new(backend),
817        }
818    }
819
820    /// Create a registry from a shared backend implementation.
821    #[must_use]
822    pub fn with_shared_backend(backend: Arc<dyn ChannelsBackend>) -> Self {
823        Self { backend }
824    }
825
826    /// Create a channel registry from resolved framework config.
827    ///
828    /// # Errors
829    ///
830    /// Returns [`ChannelBackendConfigError`] when a Redis backend is requested
831    /// without usable Redis configuration or without the `redis` feature.
832    pub fn from_config(
833        config: &crate::config::ChannelConfig,
834        shutdown: tokio_util::sync::CancellationToken,
835    ) -> Result<Self, ChannelBackendConfigError> {
836        match config.backend {
837            crate::config::ChannelBackend::InProcess => Ok(Self::new(config.capacity)),
838            crate::config::ChannelBackend::Redis => Self::redis_from_config(config, shutdown),
839        }
840    }
841
842    #[cfg(feature = "redis")]
843    fn redis_from_config(
844        config: &crate::config::ChannelConfig,
845        shutdown: tokio_util::sync::CancellationToken,
846    ) -> Result<Self, ChannelBackendConfigError> {
847        Ok(Self::with_backend(RedisChannelsBackend::from_config(
848            config, shutdown,
849        )?))
850    }
851
852    #[cfg(not(feature = "redis"))]
853    fn redis_from_config(
854        _config: &crate::config::ChannelConfig,
855        _shutdown: tokio_util::sync::CancellationToken,
856    ) -> Result<Self, ChannelBackendConfigError> {
857        Err(ChannelBackendConfigError::RedisFeatureDisabled)
858    }
859
860    /// Return a htmx-friendly broadcast facade.
861    #[must_use]
862    pub fn broadcast(&self) -> Broadcast {
863        Broadcast::new(self.clone())
864    }
865
866    /// Publish a raw channel message through the selected backend.
867    ///
868    /// # Errors
869    ///
870    /// Returns [`ChannelPublishError`] if the backend is closed.
871    pub fn publish(
872        &self,
873        topic: &str,
874        msg: impl Into<ChannelMessage>,
875    ) -> Result<usize, ChannelPublishError> {
876        self.backend.publish(topic, msg.into())
877    }
878
879    /// Get or create a sender for the named channel.
880    #[must_use]
881    pub fn sender(&self, name: &str) -> Sender {
882        let keepalive = self.backend.ensure_topic(name);
883        Sender {
884            topic: name.to_owned(),
885            backend: Arc::clone(&self.backend),
886            keepalive,
887        }
888    }
889
890    /// Subscribe to the named channel.
891    #[must_use]
892    pub fn subscribe(&self, name: &str) -> Subscriber {
893        self.backend.subscribe(name)
894    }
895
896    /// Authorize a channel subscription before allocating the subscriber.
897    ///
898    /// The hook receives the requested topic name. If it returns an error,
899    /// no subscriber is created and the error is returned unchanged.
900    ///
901    /// ```rust,no_run
902    /// use autumn_web::channels::Channels;
903    ///
904    /// # async fn example(channels: Channels) -> autumn_web::AutumnResult<()> {
905    /// let mut rx = channels
906    ///     .subscribe_authorized("private-feed", |topic| async move {
907    ///         if topic == "private-feed" {
908    ///             Ok(())
909    ///         } else {
910    ///             Err(autumn_web::AutumnError::forbidden_msg("not your feed"))
911    ///         }
912    ///     })
913    ///     .await?;
914    /// # let _ = &mut rx;
915    /// # Ok(())
916    /// # }
917    /// ```
918    ///
919    /// # Errors
920    ///
921    /// Returns the error produced by the authorization hook.
922    pub async fn subscribe_authorized<E, Fut>(
923        &self,
924        name: &str,
925        authorize: impl FnOnce(String) -> Fut,
926    ) -> Result<Subscriber, E>
927    where
928        Fut: Future<Output = Result<(), E>>,
929    {
930        authorize(name.to_owned()).await?;
931        Ok(self.subscribe(name))
932    }
933
934    /// Returns the number of active topics in the registry.
935    #[must_use]
936    pub fn channel_count(&self) -> usize {
937        self.backend.channel_count()
938    }
939
940    /// Remove channels with no active senders or receivers.
941    pub fn gc(&self) {
942        self.backend.gc();
943    }
944
945    /// Get a snapshot of all active channels and their metrics.
946    #[must_use]
947    pub fn snapshot(&self) -> HashMap<String, ChannelStats> {
948        self.backend.snapshot()
949    }
950
951    /// Creates an SSE response stream for a channel.
952    #[cfg(feature = "ws")]
953    pub fn sse_stream(
954        &self,
955        name: &str,
956    ) -> axum::response::sse::Sse<
957        impl tokio_stream::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>
958        + use<>,
959    > {
960        crate::sse::from_subscriber(self.subscribe(name))
961    }
962}
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967
968    #[test]
969    fn create_channels() {
970        let channels = Channels::new(16);
971        assert_eq!(channels.channel_count(), 0);
972    }
973
974    #[test]
975    fn sender_creates_channel_lazily() {
976        let channels = Channels::new(16);
977        let _tx = channels.sender("test");
978        assert_eq!(channels.channel_count(), 1);
979    }
980
981    #[test]
982    fn subscribe_creates_channel_lazily() {
983        let channels = Channels::new(16);
984        let _rx = channels.subscribe("test");
985        assert_eq!(channels.channel_count(), 1);
986    }
987
988    #[tokio::test]
989    async fn send_and_receive() -> Result<(), broadcast::error::RecvError> {
990        let channels = Channels::new(16);
991        let tx = channels.sender("chat");
992        let mut rx = channels.subscribe("chat");
993
994        tx.send("hello").expect("should send");
995        let msg = rx.recv().await?;
996        assert_eq!(msg.as_str(), "hello");
997        Ok(())
998    }
999
1000    #[tokio::test]
1001    async fn multiple_subscribers() -> Result<(), broadcast::error::RecvError> {
1002        let channels = Channels::new(16);
1003        let tx = channels.sender("chat");
1004        let mut rx1 = channels.subscribe("chat");
1005        let mut rx2 = channels.subscribe("chat");
1006
1007        tx.send("broadcast").expect("should send");
1008
1009        let msg1 = rx1.recv().await?;
1010        let msg2 = rx2.recv().await?;
1011        assert_eq!(msg1.as_str(), "broadcast");
1012        assert_eq!(msg2.as_str(), "broadcast");
1013        Ok(())
1014    }
1015
1016    #[test]
1017    fn sender_receiver_count() {
1018        let channels = Channels::new(16);
1019        let tx = channels.sender("chat");
1020        assert_eq!(tx.receiver_count(), 0);
1021
1022        let _rx1 = channels.subscribe("chat");
1023        assert_eq!(tx.receiver_count(), 1);
1024
1025        let _rx2 = channels.subscribe("chat");
1026        assert_eq!(tx.receiver_count(), 2);
1027    }
1028
1029    #[test]
1030    fn channel_message_conversions() {
1031        let msg: ChannelMessage = "hello".into();
1032        assert_eq!(msg.as_str(), "hello");
1033        assert_eq!(msg.to_string(), "hello");
1034
1035        let msg2: ChannelMessage = String::from("world").into();
1036        assert_eq!(msg2.into_string(), "world");
1037    }
1038
1039    #[test]
1040    #[allow(clippy::redundant_clone)]
1041    fn channels_is_clone() {
1042        let channels = Channels::new(16);
1043        let _cloned = channels.clone();
1044    }
1045
1046    #[test]
1047    fn snapshot_returns_counts() {
1048        let channels = Channels::new(16);
1049        let _tx = channels.sender("empty");
1050
1051        let _tx2 = channels.sender("one");
1052        let _rx_one = channels.subscribe("one");
1053
1054        let _tx3 = channels.sender("two");
1055        let _rx_two_1 = channels.subscribe("two");
1056        let _rx_two_2 = channels.subscribe("two");
1057
1058        let snap = channels.snapshot();
1059        assert_eq!(
1060            snap.get("empty").map(|stats| stats.subscriber_count),
1061            Some(0)
1062        );
1063        assert_eq!(snap.get("one").map(|stats| stats.subscriber_count), Some(1));
1064        assert_eq!(snap.get("two").map(|stats| stats.subscriber_count), Some(2));
1065        assert_eq!(snap.len(), 3);
1066    }
1067
1068    #[cfg(all(feature = "ws", feature = "maud"))]
1069    #[tokio::test]
1070    async fn broadcast_publish_html_wraps_fragment_in_hx_swap_oob_envelope()
1071    -> Result<(), broadcast::error::RecvError> {
1072        let channels = Channels::new(16);
1073        let broadcast = Broadcast::new(channels.clone());
1074        let mut rx = channels.subscribe("feed");
1075
1076        let sent = broadcast
1077            .publish_html(
1078                "feed",
1079                &maud::html! {
1080                    li id="item-1" { "one" }
1081                },
1082            )
1083            .expect("html publish should succeed");
1084
1085        assert_eq!(sent, 1);
1086        let msg = rx.recv().await?;
1087        assert!(msg.as_str().contains("hx-swap-oob"));
1088        assert!(msg.as_str().contains("<li id=\"item-1\">one</li>"));
1089        Ok(())
1090    }
1091
1092    #[cfg(feature = "ws")]
1093    #[tokio::test]
1094    async fn broadcast_publish_raw_bytes_delivers_text_payload()
1095    -> Result<(), broadcast::error::RecvError> {
1096        let channels = Channels::new(16);
1097        let broadcast = Broadcast::new(channels.clone());
1098        let mut rx = channels.subscribe("raw");
1099
1100        let sent = broadcast
1101            .publish("raw", b"hello".as_slice())
1102            .expect("raw publish should succeed");
1103
1104        assert_eq!(sent, 1);
1105        assert_eq!(rx.recv().await?.as_str(), "hello");
1106        Ok(())
1107    }
1108
1109    #[cfg(feature = "ws")]
1110    #[test]
1111    fn broadcast_publish_rejects_invalid_utf8_bytes() {
1112        let channels = Channels::new(16);
1113        let broadcast = Broadcast::new(channels);
1114
1115        let error = broadcast
1116            .publish("raw", vec![0xff, 0xfe])
1117            .expect_err("invalid UTF-8 should be rejected before publishing");
1118
1119        assert!(matches!(error, BroadcastError::InvalidUtf8(_)));
1120    }
1121
1122    #[cfg(feature = "ws")]
1123    #[tokio::test]
1124    async fn snapshot_returns_channel_metrics() -> Result<(), broadcast::error::RecvError> {
1125        let channels = Channels::new(16);
1126        let broadcast = Broadcast::new(channels.clone());
1127        let mut rx = channels.subscribe("metrics");
1128
1129        broadcast
1130            .publish("metrics", "one")
1131            .expect("publish should succeed");
1132        let _ = rx.recv().await?;
1133
1134        let snap = channels.snapshot();
1135        let stats = snap.get("metrics").expect("topic should be tracked");
1136        assert_eq!(stats.subscriber_count, 1);
1137        assert_eq!(stats.lifetime_publish_count, 1);
1138        assert_eq!(stats.dropped_count, 0);
1139        assert_eq!(stats.lagged_count, 0);
1140        Ok(())
1141    }
1142
1143    #[cfg(feature = "ws")]
1144    #[test]
1145    fn snapshot_counts_dropped_publish_without_successful_delivery() {
1146        let channels = Channels::new(16);
1147        let sent = channels
1148            .broadcast()
1149            .publish("metrics", "one")
1150            .expect("publish with no subscribers should not fail");
1151
1152        assert_eq!(sent, 0);
1153        let snap = channels.snapshot();
1154        let stats = snap.get("metrics").expect("topic should be tracked");
1155        assert_eq!(stats.subscriber_count, 0);
1156        assert_eq!(stats.lifetime_publish_count, 0);
1157        assert_eq!(stats.dropped_count, 1);
1158        assert_eq!(stats.lagged_count, 0);
1159    }
1160
1161    #[test]
1162    fn gc_prunes_metrics_for_removed_idle_topics() {
1163        let channels = Channels::new(16);
1164        channels
1165            .publish("tenant:gone", "one")
1166            .expect("publish with no subscribers should only record a drop");
1167
1168        let before_gc = channels.snapshot();
1169        assert!(before_gc.contains_key("tenant:gone"));
1170
1171        channels.gc();
1172
1173        let after_gc = channels.snapshot();
1174        assert!(!after_gc.contains_key("tenant:gone"));
1175        assert_eq!(channels.channel_count(), 0);
1176    }
1177
1178    #[cfg(feature = "redis")]
1179    #[test]
1180    fn redis_listener_rejects_envelope_topic_that_mismatches_channel() {
1181        let local = LocalChannelsBackend::new(16);
1182        let mut private_rx = local.subscribe("private");
1183        let channel_prefix = redis_channel_name("autumn:channels", "");
1184
1185        deliver_redis_envelope(
1186            &local,
1187            "local-origin",
1188            &channel_prefix,
1189            "autumn:channels:public",
1190            RedisEnvelope {
1191                origin: "remote-origin".to_owned(),
1192                topic: "private".to_owned(),
1193                payload: "secret".to_owned(),
1194            },
1195        );
1196
1197        assert!(matches!(
1198            private_rx.try_recv(),
1199            Err(broadcast::error::TryRecvError::Empty)
1200        ));
1201        assert!(!local.snapshot().contains_key("public"));
1202    }
1203
1204    #[cfg(feature = "redis")]
1205    #[test]
1206    fn redis_listener_counts_successful_remote_deliveries() {
1207        let local = LocalChannelsBackend::new(16);
1208        let mut rx = local.subscribe("public");
1209        let channel_prefix = redis_channel_name("autumn:channels", "");
1210
1211        deliver_redis_envelope(
1212            &local,
1213            "local-origin",
1214            &channel_prefix,
1215            "autumn:channels:public",
1216            RedisEnvelope {
1217                origin: "remote-origin".to_owned(),
1218                topic: "public".to_owned(),
1219                payload: "hello".to_owned(),
1220            },
1221        );
1222
1223        assert_eq!(
1224            rx.try_recv()
1225                .expect("remote message should fan out")
1226                .as_str(),
1227            "hello"
1228        );
1229        let snapshot = local.snapshot();
1230        let stats = snapshot.get("public").expect("topic should be tracked");
1231        assert_eq!(stats.lifetime_publish_count, 1);
1232        assert_eq!(stats.dropped_count, 0);
1233    }
1234
1235    #[cfg(feature = "redis")]
1236    #[test]
1237    fn redis_publish_rejects_when_bounded_queue_is_full() {
1238        let local = LocalChannelsBackend::new(16);
1239        let mut rx = local.subscribe("queue");
1240        let (publisher, _receiver) = tokio::sync::mpsc::channel(1);
1241        publisher
1242            .try_send(RedisPublishCommand {
1243                redis_channel: "autumn:channels:queue".to_owned(),
1244                envelope: RedisEnvelope {
1245                    origin: "origin".to_owned(),
1246                    topic: "queue".to_owned(),
1247                    payload: "already queued".to_owned(),
1248                },
1249            })
1250            .expect("first command should fill the queue");
1251
1252        let backend = RedisChannelsBackend {
1253            local,
1254            publisher,
1255            origin_id: "origin".to_owned(),
1256            key_prefix: "autumn:channels".to_owned(),
1257        };
1258
1259        let error = backend
1260            .publish("queue", ChannelMessage::from("second"))
1261            .expect_err("full Redis queue should reject the publish");
1262
1263        assert_eq!(error, ChannelPublishError::QueueFull);
1264        assert!(matches!(
1265            rx.try_recv(),
1266            Err(broadcast::error::TryRecvError::Empty)
1267        ));
1268    }
1269
1270    #[test]
1271    fn snapshot_releases_registry_before_waiting_on_metrics() {
1272        let backend = LocalChannelsBackend::new(16);
1273        backend.ensure_topic("race");
1274
1275        let metrics_guard = backend
1276            .inner
1277            .metrics
1278            .counters
1279            .lock()
1280            .expect("channel metrics lock should not be poisoned");
1281        let registry_guard = backend
1282            .inner
1283            .registry
1284            .lock()
1285            .expect("channel registry lock should not be poisoned");
1286        let snapshot_backend = backend.clone();
1287
1288        let handle = std::thread::spawn(move || {
1289            let snapshot = snapshot_backend.snapshot();
1290            assert!(snapshot.contains_key("race"));
1291        });
1292
1293        std::thread::sleep(std::time::Duration::from_millis(25));
1294        drop(registry_guard);
1295        std::thread::sleep(std::time::Duration::from_millis(25));
1296
1297        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(1);
1298        let registry_released_before_metrics = loop {
1299            match backend.inner.registry.try_lock() {
1300                Ok(registry) => {
1301                    drop(registry);
1302                    break true;
1303                }
1304                Err(std::sync::TryLockError::WouldBlock)
1305                    if std::time::Instant::now() < deadline =>
1306                {
1307                    std::thread::yield_now();
1308                }
1309                Err(std::sync::TryLockError::WouldBlock) => break false,
1310                Err(std::sync::TryLockError::Poisoned(error)) => {
1311                    panic!("channel registry lock should not be poisoned: {error}");
1312                }
1313            }
1314        };
1315
1316        drop(metrics_guard);
1317        handle.join().expect("snapshot thread should finish");
1318        assert!(
1319            registry_released_before_metrics,
1320            "snapshot held the registry mutex while waiting on metrics"
1321        );
1322    }
1323
1324    #[cfg(feature = "ws")]
1325    #[tokio::test]
1326    async fn app_state_broadcast_uses_state_channels() -> Result<(), broadcast::error::RecvError> {
1327        let state = crate::AppState::for_test();
1328        let mut rx = state.channels().subscribe("state-topic");
1329
1330        state
1331            .broadcast()
1332            .publish("state-topic", "from-state")
1333            .expect("publish should succeed");
1334
1335        assert_eq!(rx.recv().await?.as_str(), "from-state");
1336        Ok(())
1337    }
1338
1339    #[cfg(feature = "ws")]
1340    #[tokio::test]
1341    async fn subscribe_authorized_rejects_before_creating_subscriber() {
1342        let channels = Channels::new(16);
1343
1344        let result: Result<Subscriber, &'static str> = channels
1345            .subscribe_authorized("private", |topic| async move {
1346                assert_eq!(topic, "private");
1347                Err("denied")
1348            })
1349            .await;
1350
1351        assert!(matches!(result, Err("denied")));
1352        assert!(!channels.snapshot().contains_key("private"));
1353    }
1354
1355    #[cfg(feature = "ws")]
1356    #[tokio::test]
1357    async fn subscribe_authorized_allows_after_hook_passes()
1358    -> Result<(), broadcast::error::RecvError> {
1359        let channels = Channels::new(16);
1360        let mut rx = channels
1361            .subscribe_authorized("private", |topic| async move {
1362                assert_eq!(topic, "private");
1363                Ok::<(), std::convert::Infallible>(())
1364            })
1365            .await
1366            .expect("authorization should pass");
1367
1368        channels
1369            .broadcast()
1370            .publish("private", "secret")
1371            .expect("publish should succeed");
1372
1373        assert_eq!(rx.recv().await?.as_str(), "secret");
1374        Ok(())
1375    }
1376
1377    #[test]
1378    fn gc_removes_dead_channels() {
1379        let channels = Channels::new(16);
1380        let _tx = channels.sender("alive");
1381        {
1382            let _tx = channels.sender("dead");
1383        }
1384        assert_eq!(channels.channel_count(), 2);
1385        channels.gc();
1386        assert_eq!(channels.channel_count(), 1);
1387    }
1388
1389    #[cfg(feature = "ws")]
1390    #[tokio::test]
1391    async fn subscriber_into_stream() {
1392        use tokio_stream::StreamExt;
1393        let channels = Channels::new(16);
1394        let tx = channels.sender("test_stream");
1395        let rx = channels.subscribe("test_stream");
1396
1397        tx.send("message 1").unwrap();
1398        tx.send("message 2").unwrap();
1399
1400        let mut stream = rx.into_stream();
1401        let msg1 = stream.next().await.unwrap();
1402        assert_eq!(msg1.as_str(), "message 1");
1403
1404        let msg2 = stream.next().await.unwrap();
1405        assert_eq!(msg2.as_str(), "message 2");
1406    }
1407
1408    #[cfg(feature = "ws")]
1409    #[tokio::test]
1410    async fn channels_sse_stream() {
1411        let channels = Channels::new(16);
1412        let tx = channels.sender("test_sse");
1413
1414        let sse = channels.sse_stream("test_sse");
1415
1416        tx.send("sse message").unwrap();
1417        let _stream = sse;
1418    }
1419}