1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{RwLock, mpsc};
9use tokio::time::Duration;
10use chrono::{DateTime, Utc};
11use tracing::{debug, warn};
12
13use crate::{KVResult, Value, PubSubMessage};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum ChannelPattern {
18 Exact(String),
20 Wildcard(String),
22}
23
24impl ChannelPattern {
25 #[must_use]
27 pub fn exact(channel: String) -> Self {
28 Self::Exact(channel)
29 }
30
31 #[must_use]
33 pub fn wildcard(pattern: String) -> Self {
34 Self::Wildcard(pattern)
35 }
36
37 #[must_use]
39 pub fn matches(&self, channel: &str) -> bool {
40 match self {
41 Self::Exact(exact) => exact == channel,
42 Self::Wildcard(pattern) => {
43 if pattern.contains('*') {
45 let parts: Vec<&str> = pattern.split('*').collect();
46 if parts.len() == 2 {
47 channel.starts_with(parts[0]) && channel.ends_with(parts[1])
49 } else {
50 channel.starts_with(parts[0])
52 }
53 } else {
54 pattern == channel
55 }
56 }
57 }
58 }
59}
60
61#[derive(Debug)]
63struct Subscription {
64 pattern: ChannelPattern,
65 sender: mpsc::UnboundedSender<PubSubMessage>,
66 created_at: DateTime<Utc>,
67 last_activity: DateTime<Utc>,
68}
69
70pub struct PubSubManager {
72 subscriptions: Arc<RwLock<HashMap<ChannelPattern, Vec<Subscription>>>>,
74 cleanup_interval: Duration,
76 subscription_timeout: Duration,
78 cleanup_handle: Option<tokio::task::JoinHandle<()>>,
80}
81
82impl PubSubManager {
83 #[must_use]
85 pub fn new(cleanup_interval: Duration, subscription_timeout: Duration) -> Self {
86 Self {
87 subscriptions: Arc::new(RwLock::new(HashMap::new())),
88 cleanup_interval,
89 subscription_timeout,
90 cleanup_handle: None,
91 }
92 }
93
94 pub fn start_cleanup(&mut self) {
96 let subscriptions = Arc::clone(&self.subscriptions);
97 let cleanup_interval = self.cleanup_interval;
98 let subscription_timeout = self.subscription_timeout;
99
100 let handle = tokio::spawn(async move {
101 let mut interval = tokio::time::interval(cleanup_interval);
102
103 loop {
104 interval.tick().await;
105
106 let now = Utc::now();
107 let mut subs = subscriptions.write().await;
108
109 for (pattern, subs_list) in subs.iter_mut() {
111 subs_list.retain(|sub| {
112 let is_active = (now - sub.last_activity).to_std()
113 .map(|d| d < subscription_timeout)
114 .unwrap_or(false);
115
116 if !is_active {
117 debug!("Removing inactive subscription for pattern: {:?}", pattern);
118 }
119
120 is_active
121 });
122 }
123
124 subs.retain(|_, subs_list| !subs_list.is_empty());
126 }
127 });
128
129 self.cleanup_handle = Some(handle);
130 }
131
132 pub fn stop_cleanup(&mut self) {
134 if let Some(handle) = self.cleanup_handle.take() {
135 handle.abort();
136 }
137 }
138
139 pub async fn subscribe(&self, pattern: ChannelPattern) -> KVResult<mpsc::UnboundedReceiver<PubSubMessage>> {
144 let (sender, receiver) = mpsc::unbounded_channel();
145 let pattern_clone = pattern.clone();
146
147 let subscription = Subscription {
148 pattern: pattern.clone(),
149 sender,
150 created_at: Utc::now(),
151 last_activity: Utc::now(),
152 };
153
154 let mut subscriptions = self.subscriptions.write().await;
155 subscriptions.entry(pattern).or_insert_with(Vec::new).push(subscription);
156
157 debug!("New subscription created for pattern: {:?}", pattern_clone);
158 Ok(receiver)
159 }
160
161 pub async fn unsubscribe(&self, pattern: &ChannelPattern) -> KVResult<usize> {
166 let mut subscriptions = self.subscriptions.write().await;
167
168 if let Some(subs_list) = subscriptions.get_mut(pattern) {
169 let count = subs_list.len();
170 subs_list.clear();
171
172 if subs_list.is_empty() {
173 subscriptions.remove(pattern);
174 }
175
176 debug!("Unsubscribed {} subscribers from pattern: {:?}", count, pattern);
177 Ok(count)
178 } else {
179 Ok(0)
180 }
181 }
182
183 pub async fn publish(&self, channel: &str, message: Value) -> KVResult<usize> {
188 let pubsub_message = PubSubMessage {
189 channel: channel.to_string(),
190 message,
191 timestamp: Utc::now(),
192 };
193
194 let subscriptions = self.subscriptions.read().await;
195 let mut delivered_count = 0;
196 let mut failed_deliveries = Vec::new();
197
198 for (pattern, subs_list) in subscriptions.iter() {
200 if pattern.matches(channel) {
201 for (index, subscription) in subs_list.iter().enumerate() {
202 if let Err(e) = subscription.sender.send(pubsub_message.clone()) {
203 warn!("Failed to deliver message to subscriber: {}", e);
204 failed_deliveries.push((pattern.clone(), index));
205 } else {
206 delivered_count += 1;
207 }
208 }
209 }
210 }
211
212 if !failed_deliveries.is_empty() {
214 drop(subscriptions);
215 let mut subs = self.subscriptions.write().await;
216
217 for (pattern, index) in failed_deliveries {
218 if let Some(subs_list) = subs.get_mut(&pattern) {
219 if index < subs_list.len() {
220 subs_list.remove(index);
221 }
222 if subs_list.is_empty() {
223 subs.remove(&pattern);
224 }
225 }
226 }
227 }
228
229 debug!("Published message to channel '{}', delivered to {} subscribers", channel, delivered_count);
230 Ok(delivered_count)
231 }
232
233 pub async fn get_stats(&self) -> KVResult<PubSubStats> {
238 let subscriptions = self.subscriptions.read().await;
239
240 let mut total_subscriptions = 0;
241 let mut pattern_count = 0;
242 let mut exact_patterns = 0;
243 let mut wildcard_patterns = 0;
244
245 for (pattern, subs_list) in subscriptions.iter() {
246 pattern_count += 1;
247 total_subscriptions += subs_list.len();
248
249 match pattern {
250 ChannelPattern::Exact(_) => exact_patterns += 1,
251 ChannelPattern::Wildcard(_) => wildcard_patterns += 1,
252 }
253 }
254
255 Ok(PubSubStats {
256 total_subscriptions,
257 pattern_count,
258 exact_patterns,
259 wildcard_patterns,
260 })
261 }
262
263 pub async fn get_active_patterns(&self) -> KVResult<Vec<ChannelPattern>> {
268 let subscriptions = self.subscriptions.read().await;
269 Ok(subscriptions.keys().cloned().collect())
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct PubSubStats {
276 pub total_subscriptions: usize,
277 pub pattern_count: usize,
278 pub exact_patterns: usize,
279 pub wildcard_patterns: usize,
280}
281
282impl Default for PubSubManager {
283 fn default() -> Self {
284 Self::new(
285 Duration::from_secs(300), Duration::from_secs(3600), )
288 }
289}
290
291impl Drop for PubSubManager {
292 fn drop(&mut self) {
293 self.stop_cleanup();
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use tokio::time::timeout;
301
302 async fn create_test_manager() -> PubSubManager {
303 PubSubManager::new(
304 Duration::from_millis(100),
305 Duration::from_secs(1),
306 )
307 }
308
309 #[tokio::test]
310 async fn test_exact_channel_subscription() {
311 let manager = create_test_manager().await;
312
313 let pattern = ChannelPattern::exact("test:channel".to_string());
315 let mut receiver = manager.subscribe(pattern).await.unwrap();
316
317 let message = Value::String("Hello, World!".to_string());
319 let delivered = manager.publish("test:channel", message.clone()).await.unwrap();
320 assert_eq!(delivered, 1);
321
322 let received = timeout(Duration::from_millis(100), receiver.recv()).await.unwrap().unwrap();
324 assert_eq!(received.channel, "test:channel");
325 assert_eq!(received.message, message);
326 }
327
328 #[tokio::test]
329 async fn test_wildcard_channel_subscription() {
330 let manager = create_test_manager().await;
331
332 let pattern = ChannelPattern::wildcard("cache:*".to_string());
334 let mut receiver = manager.subscribe(pattern).await.unwrap();
335
336 let message = Value::String("Invalidate user123".to_string());
338 let delivered = manager.publish("cache:invalidate:user123", message.clone()).await.unwrap();
339 assert_eq!(delivered, 1);
340
341 let received = timeout(Duration::from_millis(100), receiver.recv()).await.unwrap().unwrap();
343 assert_eq!(received.channel, "cache:invalidate:user123");
344 assert_eq!(received.message, message);
345
346 let delivered = manager.publish("other:channel", Value::String("test".to_string())).await.unwrap();
348 assert_eq!(delivered, 0);
349 }
350
351 #[tokio::test]
352 async fn test_multiple_subscribers() {
353 let manager = create_test_manager().await;
354
355 let pattern = ChannelPattern::exact("broadcast".to_string());
357 let mut receiver1 = manager.subscribe(pattern.clone()).await.unwrap();
358 let mut receiver2 = manager.subscribe(pattern).await.unwrap();
359
360 let message = Value::String("Broadcast message".to_string());
362 let delivered = manager.publish("broadcast", message.clone()).await.unwrap();
363 assert_eq!(delivered, 2);
364
365 let received1 = timeout(Duration::from_millis(100), receiver1.recv()).await.unwrap().unwrap();
367 let received2 = timeout(Duration::from_millis(100), receiver2.recv()).await.unwrap().unwrap();
368
369 assert_eq!(received1.message, message);
370 assert_eq!(received2.message, message);
371 }
372
373 #[tokio::test]
374 async fn test_unsubscribe() {
375 let manager = create_test_manager().await;
376
377 let pattern = ChannelPattern::exact("test:unsub".to_string());
379 let _receiver = manager.subscribe(pattern.clone()).await.unwrap();
380
381 let delivered = manager.publish("test:unsub", Value::String("test".to_string())).await.unwrap();
383 assert_eq!(delivered, 1);
384
385 let unsub_count = manager.unsubscribe(&pattern).await.unwrap();
387 assert_eq!(unsub_count, 1);
388
389 let delivered = manager.publish("test:unsub", Value::String("test2".to_string())).await.unwrap();
391 assert_eq!(delivered, 0);
392 }
393
394 #[tokio::test]
395 async fn test_stats() {
396 let manager = create_test_manager().await;
397
398 let stats = manager.get_stats().await.unwrap();
400 assert_eq!(stats.total_subscriptions, 0);
401 assert_eq!(stats.pattern_count, 0);
402
403 let _receiver1 = manager.subscribe(ChannelPattern::exact("exact1".to_string())).await.unwrap();
405 let _receiver2 = manager.subscribe(ChannelPattern::exact("exact2".to_string())).await.unwrap();
406 let _receiver3 = manager.subscribe(ChannelPattern::wildcard("wild:*".to_string())).await.unwrap();
407
408 let stats = manager.get_stats().await.unwrap();
409 assert_eq!(stats.total_subscriptions, 3);
410 assert_eq!(stats.pattern_count, 3);
411 assert_eq!(stats.exact_patterns, 2);
412 assert_eq!(stats.wildcard_patterns, 1);
413 }
414
415 #[tokio::test]
416 async fn test_pattern_matching() {
417 let exact = ChannelPattern::exact("test:channel".to_string());
419 assert!(exact.matches("test:channel"));
420 assert!(!exact.matches("test:other"));
421
422 let wildcard = ChannelPattern::wildcard("cache:*".to_string());
424 assert!(wildcard.matches("cache:invalidate"));
425 assert!(wildcard.matches("cache:invalidate:user123"));
426 assert!(!wildcard.matches("other:invalidate"));
427
428 let prefix = ChannelPattern::wildcard("auth:*".to_string());
430 assert!(prefix.matches("auth:login"));
431 assert!(prefix.matches("auth:logout"));
432 assert!(!prefix.matches("cache:login"));
433 }
434}