Skip to main content

oxigdal_websocket/broadcast/
channel.rs

1//! Pub/sub channels for topic-based messaging
2
3use crate::error::{Error, Result};
4use crate::protocol::message::Message;
5use crate::server::connection::ConnectionId;
6use dashmap::DashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use tokio::sync::broadcast;
10
11/// Channel configuration
12#[derive(Debug, Clone)]
13pub struct ChannelConfig {
14    /// Maximum subscribers
15    pub max_subscribers: usize,
16    /// Buffer size for broadcast channel
17    pub buffer_size: usize,
18}
19
20impl Default for ChannelConfig {
21    fn default() -> Self {
22        Self {
23            max_subscribers: 10_000,
24            buffer_size: 1000,
25        }
26    }
27}
28
29/// Generic channel interface
30pub trait Channel: Send + Sync {
31    /// Subscribe to the channel
32    fn subscribe(
33        &self,
34        subscriber: ConnectionId,
35    ) -> impl std::future::Future<Output = Result<()>> + Send;
36
37    /// Unsubscribe from the channel
38    fn unsubscribe(
39        &self,
40        subscriber: &ConnectionId,
41    ) -> impl std::future::Future<Output = Result<()>> + Send;
42
43    /// Publish a message to the channel
44    fn publish(&self, message: Message) -> impl std::future::Future<Output = Result<usize>> + Send;
45
46    /// Get subscriber count
47    fn subscriber_count(&self) -> impl std::future::Future<Output = usize> + Send;
48}
49
50/// Topic-based channel
51pub struct TopicChannel {
52    topic: String,
53    config: ChannelConfig,
54    subscribers: Arc<DashMap<ConnectionId, broadcast::Sender<Message>>>,
55    tx: broadcast::Sender<Message>,
56    stats: Arc<ChannelStatistics>,
57}
58
59/// Channel statistics
60struct ChannelStatistics {
61    messages_published: AtomicU64,
62    messages_delivered: AtomicU64,
63    messages_dropped: AtomicU64,
64}
65
66impl TopicChannel {
67    /// Create a new topic channel
68    pub fn new(topic: String, config: ChannelConfig) -> Self {
69        let (tx, _) = broadcast::channel(config.buffer_size);
70
71        Self {
72            topic,
73            config,
74            subscribers: Arc::new(DashMap::new()),
75            tx,
76            stats: Arc::new(ChannelStatistics {
77                messages_published: AtomicU64::new(0),
78                messages_delivered: AtomicU64::new(0),
79                messages_dropped: AtomicU64::new(0),
80            }),
81        }
82    }
83
84    /// Get topic name
85    pub fn topic(&self) -> &str {
86        &self.topic
87    }
88
89    /// Get statistics
90    pub async fn stats(&self) -> ChannelStats {
91        ChannelStats {
92            topic: self.topic.clone(),
93            subscriber_count: self.subscribers.len(),
94            messages_published: self.stats.messages_published.load(Ordering::Relaxed),
95            messages_delivered: self.stats.messages_delivered.load(Ordering::Relaxed),
96            messages_dropped: self.stats.messages_dropped.load(Ordering::Relaxed),
97        }
98    }
99}
100
101impl Channel for TopicChannel {
102    async fn subscribe(&self, subscriber: ConnectionId) -> Result<()> {
103        if self.subscribers.len() >= self.config.max_subscribers {
104            return Err(Error::ResourceExhausted(format!(
105                "Topic {} has reached maximum subscribers ({})",
106                self.topic, self.config.max_subscribers
107            )));
108        }
109
110        self.subscribers.insert(subscriber, self.tx.clone());
111        tracing::debug!("Subscriber {} joined topic {}", subscriber, self.topic);
112        Ok(())
113    }
114
115    async fn unsubscribe(&self, subscriber: &ConnectionId) -> Result<()> {
116        self.subscribers.remove(subscriber);
117        tracing::debug!("Subscriber {} left topic {}", subscriber, self.topic);
118        Ok(())
119    }
120
121    async fn publish(&self, message: Message) -> Result<usize> {
122        self.stats
123            .messages_published
124            .fetch_add(1, Ordering::Relaxed);
125
126        match self.tx.send(message) {
127            Ok(count) => {
128                self.stats
129                    .messages_delivered
130                    .fetch_add(count as u64, Ordering::Relaxed);
131                Ok(count)
132            }
133            Err(_) => {
134                self.stats.messages_dropped.fetch_add(1, Ordering::Relaxed);
135                Ok(0)
136            }
137        }
138    }
139
140    async fn subscriber_count(&self) -> usize {
141        self.subscribers.len()
142    }
143}
144
145/// Channel statistics snapshot
146#[derive(Debug, Clone)]
147pub struct ChannelStats {
148    /// Topic name
149    pub topic: String,
150    /// Number of subscribers
151    pub subscriber_count: usize,
152    /// Messages published
153    pub messages_published: u64,
154    /// Messages delivered
155    pub messages_delivered: u64,
156    /// Messages dropped
157    pub messages_dropped: u64,
158}
159
160/// Multi-topic channel manager
161pub struct MultiChannelManager {
162    channels: Arc<DashMap<String, Arc<TopicChannel>>>,
163    default_config: ChannelConfig,
164}
165
166impl MultiChannelManager {
167    /// Create a new multi-channel manager
168    pub fn new(default_config: ChannelConfig) -> Self {
169        Self {
170            channels: Arc::new(DashMap::new()),
171            default_config,
172        }
173    }
174
175    /// Get or create a channel
176    pub fn get_or_create(&self, topic: &str) -> Arc<TopicChannel> {
177        self.channels
178            .entry(topic.to_string())
179            .or_insert_with(|| {
180                Arc::new(TopicChannel::new(
181                    topic.to_string(),
182                    self.default_config.clone(),
183                ))
184            })
185            .clone()
186    }
187
188    /// Subscribe to a topic
189    pub async fn subscribe(&self, topic: &str, subscriber: ConnectionId) -> Result<()> {
190        let channel = self.get_or_create(topic);
191        channel.subscribe(subscriber).await
192    }
193
194    /// Unsubscribe from a topic
195    pub async fn unsubscribe(&self, topic: &str, subscriber: &ConnectionId) -> Result<()> {
196        if let Some(channel) = self.channels.get(topic) {
197            channel.unsubscribe(subscriber).await?;
198        }
199        Ok(())
200    }
201
202    /// Publish to a topic
203    pub async fn publish(&self, topic: &str, message: Message) -> Result<usize> {
204        if let Some(channel) = self.channels.get(topic) {
205            channel.publish(message).await
206        } else {
207            Ok(0)
208        }
209    }
210
211    /// Get all topics
212    pub fn topics(&self) -> Vec<String> {
213        self.channels.iter().map(|r| r.key().clone()).collect()
214    }
215
216    /// Get channel count
217    pub fn channel_count(&self) -> usize {
218        self.channels.len()
219    }
220
221    /// Remove a channel
222    pub fn remove_channel(&self, topic: &str) -> Option<Arc<TopicChannel>> {
223        self.channels.remove(topic).map(|(_, v)| v)
224    }
225
226    /// Clear all channels
227    pub fn clear(&self) {
228        self.channels.clear();
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[tokio::test]
237    async fn test_topic_channel() {
238        let config = ChannelConfig::default();
239        let channel = TopicChannel::new("test".to_string(), config);
240
241        assert_eq!(channel.topic(), "test");
242        assert_eq!(channel.subscriber_count().await, 0);
243    }
244
245    #[tokio::test]
246    async fn test_channel_subscribe() -> Result<()> {
247        let config = ChannelConfig::default();
248        let channel = TopicChannel::new("test".to_string(), config);
249
250        let subscriber = ConnectionId::new_v4();
251        channel.subscribe(subscriber).await?;
252
253        assert_eq!(channel.subscriber_count().await, 1);
254        Ok(())
255    }
256
257    #[tokio::test]
258    async fn test_channel_unsubscribe() -> Result<()> {
259        let config = ChannelConfig::default();
260        let channel = TopicChannel::new("test".to_string(), config);
261
262        let subscriber = ConnectionId::new_v4();
263        channel.subscribe(subscriber).await?;
264        channel.unsubscribe(&subscriber).await?;
265
266        assert_eq!(channel.subscriber_count().await, 0);
267        Ok(())
268    }
269
270    #[tokio::test]
271    async fn test_channel_max_subscribers() {
272        let config = ChannelConfig {
273            max_subscribers: 2,
274            buffer_size: 10,
275        };
276        let channel = TopicChannel::new("test".to_string(), config);
277
278        let sub1 = ConnectionId::new_v4();
279        let sub2 = ConnectionId::new_v4();
280        let sub3 = ConnectionId::new_v4();
281
282        assert!(channel.subscribe(sub1).await.is_ok());
283        assert!(channel.subscribe(sub2).await.is_ok());
284        assert!(channel.subscribe(sub3).await.is_err());
285    }
286
287    #[tokio::test]
288    async fn test_multi_channel_manager() {
289        let config = ChannelConfig::default();
290        let manager = MultiChannelManager::new(config);
291
292        assert_eq!(manager.channel_count(), 0);
293
294        let channel = manager.get_or_create("test");
295        assert_eq!(manager.channel_count(), 1);
296        assert_eq!(channel.topic(), "test");
297    }
298
299    #[tokio::test]
300    async fn test_multi_channel_publish() -> Result<()> {
301        let config = ChannelConfig::default();
302        let manager = MultiChannelManager::new(config);
303
304        let subscriber = ConnectionId::new_v4();
305        manager.subscribe("test", subscriber).await?;
306
307        // Get the channel to keep a receiver alive
308        let channel = manager.get_or_create("test");
309        let mut _rx = channel.tx.subscribe();
310
311        let message = Message::ping();
312        let count = manager.publish("test", message).await?;
313
314        // Should deliver to 1 receiver
315        assert_eq!(count, 1);
316        Ok(())
317    }
318}