oxigdal_websocket/broadcast/
channel.rs1use crate::error::{Error, Result};
4use crate::protocol::message::Message;
5use crate::server::connection::ConnectionId;
6use dashmap::DashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use tokio::sync::broadcast;
10
11#[derive(Debug, Clone)]
13pub struct ChannelConfig {
14 pub max_subscribers: usize,
16 pub buffer_size: usize,
18}
19
20impl Default for ChannelConfig {
21 fn default() -> Self {
22 Self {
23 max_subscribers: 10_000,
24 buffer_size: 1000,
25 }
26 }
27}
28
29pub trait Channel: Send + Sync {
31 fn subscribe(
33 &self,
34 subscriber: ConnectionId,
35 ) -> impl std::future::Future<Output = Result<()>> + Send;
36
37 fn unsubscribe(
39 &self,
40 subscriber: &ConnectionId,
41 ) -> impl std::future::Future<Output = Result<()>> + Send;
42
43 fn publish(&self, message: Message) -> impl std::future::Future<Output = Result<usize>> + Send;
45
46 fn subscriber_count(&self) -> impl std::future::Future<Output = usize> + Send;
48}
49
50pub struct TopicChannel {
52 topic: String,
53 config: ChannelConfig,
54 subscribers: Arc<DashMap<ConnectionId, broadcast::Sender<Message>>>,
55 tx: broadcast::Sender<Message>,
56 stats: Arc<ChannelStatistics>,
57}
58
59struct ChannelStatistics {
61 messages_published: AtomicU64,
62 messages_delivered: AtomicU64,
63 messages_dropped: AtomicU64,
64}
65
66impl TopicChannel {
67 pub fn new(topic: String, config: ChannelConfig) -> Self {
69 let (tx, _) = broadcast::channel(config.buffer_size);
70
71 Self {
72 topic,
73 config,
74 subscribers: Arc::new(DashMap::new()),
75 tx,
76 stats: Arc::new(ChannelStatistics {
77 messages_published: AtomicU64::new(0),
78 messages_delivered: AtomicU64::new(0),
79 messages_dropped: AtomicU64::new(0),
80 }),
81 }
82 }
83
84 pub fn topic(&self) -> &str {
86 &self.topic
87 }
88
89 pub async fn stats(&self) -> ChannelStats {
91 ChannelStats {
92 topic: self.topic.clone(),
93 subscriber_count: self.subscribers.len(),
94 messages_published: self.stats.messages_published.load(Ordering::Relaxed),
95 messages_delivered: self.stats.messages_delivered.load(Ordering::Relaxed),
96 messages_dropped: self.stats.messages_dropped.load(Ordering::Relaxed),
97 }
98 }
99}
100
101impl Channel for TopicChannel {
102 async fn subscribe(&self, subscriber: ConnectionId) -> Result<()> {
103 if self.subscribers.len() >= self.config.max_subscribers {
104 return Err(Error::ResourceExhausted(format!(
105 "Topic {} has reached maximum subscribers ({})",
106 self.topic, self.config.max_subscribers
107 )));
108 }
109
110 self.subscribers.insert(subscriber, self.tx.clone());
111 tracing::debug!("Subscriber {} joined topic {}", subscriber, self.topic);
112 Ok(())
113 }
114
115 async fn unsubscribe(&self, subscriber: &ConnectionId) -> Result<()> {
116 self.subscribers.remove(subscriber);
117 tracing::debug!("Subscriber {} left topic {}", subscriber, self.topic);
118 Ok(())
119 }
120
121 async fn publish(&self, message: Message) -> Result<usize> {
122 self.stats
123 .messages_published
124 .fetch_add(1, Ordering::Relaxed);
125
126 match self.tx.send(message) {
127 Ok(count) => {
128 self.stats
129 .messages_delivered
130 .fetch_add(count as u64, Ordering::Relaxed);
131 Ok(count)
132 }
133 Err(_) => {
134 self.stats.messages_dropped.fetch_add(1, Ordering::Relaxed);
135 Ok(0)
136 }
137 }
138 }
139
140 async fn subscriber_count(&self) -> usize {
141 self.subscribers.len()
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct ChannelStats {
148 pub topic: String,
150 pub subscriber_count: usize,
152 pub messages_published: u64,
154 pub messages_delivered: u64,
156 pub messages_dropped: u64,
158}
159
160pub struct MultiChannelManager {
162 channels: Arc<DashMap<String, Arc<TopicChannel>>>,
163 default_config: ChannelConfig,
164}
165
166impl MultiChannelManager {
167 pub fn new(default_config: ChannelConfig) -> Self {
169 Self {
170 channels: Arc::new(DashMap::new()),
171 default_config,
172 }
173 }
174
175 pub fn get_or_create(&self, topic: &str) -> Arc<TopicChannel> {
177 self.channels
178 .entry(topic.to_string())
179 .or_insert_with(|| {
180 Arc::new(TopicChannel::new(
181 topic.to_string(),
182 self.default_config.clone(),
183 ))
184 })
185 .clone()
186 }
187
188 pub async fn subscribe(&self, topic: &str, subscriber: ConnectionId) -> Result<()> {
190 let channel = self.get_or_create(topic);
191 channel.subscribe(subscriber).await
192 }
193
194 pub async fn unsubscribe(&self, topic: &str, subscriber: &ConnectionId) -> Result<()> {
196 if let Some(channel) = self.channels.get(topic) {
197 channel.unsubscribe(subscriber).await?;
198 }
199 Ok(())
200 }
201
202 pub async fn publish(&self, topic: &str, message: Message) -> Result<usize> {
204 if let Some(channel) = self.channels.get(topic) {
205 channel.publish(message).await
206 } else {
207 Ok(0)
208 }
209 }
210
211 pub fn topics(&self) -> Vec<String> {
213 self.channels.iter().map(|r| r.key().clone()).collect()
214 }
215
216 pub fn channel_count(&self) -> usize {
218 self.channels.len()
219 }
220
221 pub fn remove_channel(&self, topic: &str) -> Option<Arc<TopicChannel>> {
223 self.channels.remove(topic).map(|(_, v)| v)
224 }
225
226 pub fn clear(&self) {
228 self.channels.clear();
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[tokio::test]
237 async fn test_topic_channel() {
238 let config = ChannelConfig::default();
239 let channel = TopicChannel::new("test".to_string(), config);
240
241 assert_eq!(channel.topic(), "test");
242 assert_eq!(channel.subscriber_count().await, 0);
243 }
244
245 #[tokio::test]
246 async fn test_channel_subscribe() -> Result<()> {
247 let config = ChannelConfig::default();
248 let channel = TopicChannel::new("test".to_string(), config);
249
250 let subscriber = ConnectionId::new_v4();
251 channel.subscribe(subscriber).await?;
252
253 assert_eq!(channel.subscriber_count().await, 1);
254 Ok(())
255 }
256
257 #[tokio::test]
258 async fn test_channel_unsubscribe() -> Result<()> {
259 let config = ChannelConfig::default();
260 let channel = TopicChannel::new("test".to_string(), config);
261
262 let subscriber = ConnectionId::new_v4();
263 channel.subscribe(subscriber).await?;
264 channel.unsubscribe(&subscriber).await?;
265
266 assert_eq!(channel.subscriber_count().await, 0);
267 Ok(())
268 }
269
270 #[tokio::test]
271 async fn test_channel_max_subscribers() {
272 let config = ChannelConfig {
273 max_subscribers: 2,
274 buffer_size: 10,
275 };
276 let channel = TopicChannel::new("test".to_string(), config);
277
278 let sub1 = ConnectionId::new_v4();
279 let sub2 = ConnectionId::new_v4();
280 let sub3 = ConnectionId::new_v4();
281
282 assert!(channel.subscribe(sub1).await.is_ok());
283 assert!(channel.subscribe(sub2).await.is_ok());
284 assert!(channel.subscribe(sub3).await.is_err());
285 }
286
287 #[tokio::test]
288 async fn test_multi_channel_manager() {
289 let config = ChannelConfig::default();
290 let manager = MultiChannelManager::new(config);
291
292 assert_eq!(manager.channel_count(), 0);
293
294 let channel = manager.get_or_create("test");
295 assert_eq!(manager.channel_count(), 1);
296 assert_eq!(channel.topic(), "test");
297 }
298
299 #[tokio::test]
300 async fn test_multi_channel_publish() -> Result<()> {
301 let config = ChannelConfig::default();
302 let manager = MultiChannelManager::new(config);
303
304 let subscriber = ConnectionId::new_v4();
305 manager.subscribe("test", subscriber).await?;
306
307 let channel = manager.get_or_create("test");
309 let mut _rx = channel.tx.subscribe();
310
311 let message = Message::ping();
312 let count = manager.publish("test", message).await?;
313
314 assert_eq!(count, 1);
316 Ok(())
317 }
318}