kv_core/
pubsub.rs

1//! Pub/Sub system for the KV service
2//! 
3//! Provides channel-based publish/subscribe functionality with pattern matching
4//! and thread-safe message broadcasting for cache invalidation events.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{RwLock, mpsc};
9use tokio::time::Duration;
10use chrono::{DateTime, Utc};
11use tracing::{debug, warn};
12
13use crate::{KVResult, Value, PubSubMessage};
14
15/// Pattern matching for channel subscriptions
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum ChannelPattern {
18    /// Exact channel match
19    Exact(String),
20    /// Wildcard pattern (e.g., "cache:*" matches "cache:invalidate:user123")
21    Wildcard(String),
22}
23
24impl ChannelPattern {
25    /// Create a new exact channel pattern
26    #[must_use]
27    pub fn exact(channel: String) -> Self {
28        Self::Exact(channel)
29    }
30
31    /// Create a new wildcard pattern
32    #[must_use]
33    pub fn wildcard(pattern: String) -> Self {
34        Self::Wildcard(pattern)
35    }
36
37    /// Check if a channel matches this pattern
38    #[must_use]
39    pub fn matches(&self, channel: &str) -> bool {
40        match self {
41            Self::Exact(exact) => exact == channel,
42            Self::Wildcard(pattern) => {
43                // Simple wildcard matching: * matches any characters
44                if pattern.contains('*') {
45                    let parts: Vec<&str> = pattern.split('*').collect();
46                    if parts.len() == 2 {
47                        // Single wildcard: prefix*suffix
48                        channel.starts_with(parts[0]) && channel.ends_with(parts[1])
49                    } else {
50                        // Multiple wildcards - for now, just check if it's a prefix
51                        channel.starts_with(parts[0])
52                    }
53                } else {
54                    pattern == channel
55                }
56            }
57        }
58    }
59}
60
61/// Subscription information
62#[derive(Debug)]
63struct Subscription {
64    pattern: ChannelPattern,
65    sender: mpsc::UnboundedSender<PubSubMessage>,
66    created_at: DateTime<Utc>,
67    last_activity: DateTime<Utc>,
68}
69
70/// Pub/Sub manager for handling channel subscriptions and message broadcasting
71pub struct PubSubManager {
72    /// Map of channel patterns to subscribers
73    subscriptions: Arc<RwLock<HashMap<ChannelPattern, Vec<Subscription>>>>,
74    /// Cleanup interval for inactive subscriptions
75    cleanup_interval: Duration,
76    /// Subscription timeout
77    subscription_timeout: Duration,
78    /// Background cleanup task handle
79    cleanup_handle: Option<tokio::task::JoinHandle<()>>,
80}
81
82impl PubSubManager {
83    /// Create a new pub/sub manager
84    #[must_use]
85    pub fn new(cleanup_interval: Duration, subscription_timeout: Duration) -> Self {
86        Self {
87            subscriptions: Arc::new(RwLock::new(HashMap::new())),
88            cleanup_interval,
89            subscription_timeout,
90            cleanup_handle: None,
91        }
92    }
93
94    /// Start the background cleanup task
95    pub fn start_cleanup(&mut self) {
96        let subscriptions = Arc::clone(&self.subscriptions);
97        let cleanup_interval = self.cleanup_interval;
98        let subscription_timeout = self.subscription_timeout;
99
100        let handle = tokio::spawn(async move {
101            let mut interval = tokio::time::interval(cleanup_interval);
102            
103            loop {
104                interval.tick().await;
105                
106                let now = Utc::now();
107                let mut subs = subscriptions.write().await;
108                
109                // Remove inactive subscriptions
110                for (pattern, subs_list) in subs.iter_mut() {
111                    subs_list.retain(|sub| {
112                        let is_active = (now - sub.last_activity).to_std()
113                            .map(|d| d < subscription_timeout)
114                            .unwrap_or(false);
115                        
116                        if !is_active {
117                            debug!("Removing inactive subscription for pattern: {:?}", pattern);
118                        }
119                        
120                        is_active
121                    });
122                }
123                
124                // Remove empty pattern entries
125                subs.retain(|_, subs_list| !subs_list.is_empty());
126            }
127        });
128
129        self.cleanup_handle = Some(handle);
130    }
131
132    /// Stop the background cleanup task
133    pub fn stop_cleanup(&mut self) {
134        if let Some(handle) = self.cleanup_handle.take() {
135            handle.abort();
136        }
137    }
138
139    /// Subscribe to a channel pattern
140    /// 
141    /// # Errors
142    /// Returns error if subscription setup fails
143    pub async fn subscribe(&self, pattern: ChannelPattern) -> KVResult<mpsc::UnboundedReceiver<PubSubMessage>> {
144        let (sender, receiver) = mpsc::unbounded_channel();
145        let pattern_clone = pattern.clone();
146        
147        let subscription = Subscription {
148            pattern: pattern.clone(),
149            sender,
150            created_at: Utc::now(),
151            last_activity: Utc::now(),
152        };
153
154        let mut subscriptions = self.subscriptions.write().await;
155        subscriptions.entry(pattern).or_insert_with(Vec::new).push(subscription);
156        
157        debug!("New subscription created for pattern: {:?}", pattern_clone);
158        Ok(receiver)
159    }
160
161    /// Unsubscribe from a channel pattern
162    /// 
163    /// # Errors
164    /// Returns error if unsubscription fails
165    pub async fn unsubscribe(&self, pattern: &ChannelPattern) -> KVResult<usize> {
166        let mut subscriptions = self.subscriptions.write().await;
167        
168        if let Some(subs_list) = subscriptions.get_mut(pattern) {
169            let count = subs_list.len();
170            subs_list.clear();
171            
172            if subs_list.is_empty() {
173                subscriptions.remove(pattern);
174            }
175            
176            debug!("Unsubscribed {} subscribers from pattern: {:?}", count, pattern);
177            Ok(count)
178        } else {
179            Ok(0)
180        }
181    }
182
183    /// Publish a message to a channel
184    /// 
185    /// # Errors
186    /// Returns error if publishing fails
187    pub async fn publish(&self, channel: &str, message: Value) -> KVResult<usize> {
188        let pubsub_message = PubSubMessage {
189            channel: channel.to_string(),
190            message,
191            timestamp: Utc::now(),
192        };
193
194        let subscriptions = self.subscriptions.read().await;
195        let mut delivered_count = 0;
196        let mut failed_deliveries = Vec::new();
197
198        // Find all matching subscriptions
199        for (pattern, subs_list) in subscriptions.iter() {
200            if pattern.matches(channel) {
201                for (index, subscription) in subs_list.iter().enumerate() {
202                    if let Err(e) = subscription.sender.send(pubsub_message.clone()) {
203                        warn!("Failed to deliver message to subscriber: {}", e);
204                        failed_deliveries.push((pattern.clone(), index));
205                    } else {
206                        delivered_count += 1;
207                    }
208                }
209            }
210        }
211
212        // Clean up failed deliveries
213        if !failed_deliveries.is_empty() {
214            drop(subscriptions);
215            let mut subs = self.subscriptions.write().await;
216            
217            for (pattern, index) in failed_deliveries {
218                if let Some(subs_list) = subs.get_mut(&pattern) {
219                    if index < subs_list.len() {
220                        subs_list.remove(index);
221                    }
222                    if subs_list.is_empty() {
223                        subs.remove(&pattern);
224                    }
225                }
226            }
227        }
228
229        debug!("Published message to channel '{}', delivered to {} subscribers", channel, delivered_count);
230        Ok(delivered_count)
231    }
232
233    /// Get statistics about subscriptions
234    /// 
235    /// # Errors
236    /// Returns error if stats calculation fails
237    pub async fn get_stats(&self) -> KVResult<PubSubStats> {
238        let subscriptions = self.subscriptions.read().await;
239        
240        let mut total_subscriptions = 0;
241        let mut pattern_count = 0;
242        let mut exact_patterns = 0;
243        let mut wildcard_patterns = 0;
244
245        for (pattern, subs_list) in subscriptions.iter() {
246            pattern_count += 1;
247            total_subscriptions += subs_list.len();
248            
249            match pattern {
250                ChannelPattern::Exact(_) => exact_patterns += 1,
251                ChannelPattern::Wildcard(_) => wildcard_patterns += 1,
252            }
253        }
254
255        Ok(PubSubStats {
256            total_subscriptions,
257            pattern_count,
258            exact_patterns,
259            wildcard_patterns,
260        })
261    }
262
263    /// Get all active channel patterns
264    /// 
265    /// # Errors
266    /// Returns error if pattern retrieval fails
267    pub async fn get_active_patterns(&self) -> KVResult<Vec<ChannelPattern>> {
268        let subscriptions = self.subscriptions.read().await;
269        Ok(subscriptions.keys().cloned().collect())
270    }
271}
272
273/// Statistics for pub/sub system
274#[derive(Debug, Clone)]
275pub struct PubSubStats {
276    pub total_subscriptions: usize,
277    pub pattern_count: usize,
278    pub exact_patterns: usize,
279    pub wildcard_patterns: usize,
280}
281
282impl Default for PubSubManager {
283    fn default() -> Self {
284        Self::new(
285            Duration::from_secs(300), // 5 minutes cleanup interval
286            Duration::from_secs(3600), // 1 hour subscription timeout
287        )
288    }
289}
290
291impl Drop for PubSubManager {
292    fn drop(&mut self) {
293        self.stop_cleanup();
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use tokio::time::timeout;
301
302    async fn create_test_manager() -> PubSubManager {
303        PubSubManager::new(
304            Duration::from_millis(100),
305            Duration::from_secs(1),
306        )
307    }
308
309    #[tokio::test]
310    async fn test_exact_channel_subscription() {
311        let manager = create_test_manager().await;
312        
313        // Subscribe to exact channel
314        let pattern = ChannelPattern::exact("test:channel".to_string());
315        let mut receiver = manager.subscribe(pattern).await.unwrap();
316        
317        // Publish message
318        let message = Value::String("Hello, World!".to_string());
319        let delivered = manager.publish("test:channel", message.clone()).await.unwrap();
320        assert_eq!(delivered, 1);
321        
322        // Receive message
323        let received = timeout(Duration::from_millis(100), receiver.recv()).await.unwrap().unwrap();
324        assert_eq!(received.channel, "test:channel");
325        assert_eq!(received.message, message);
326    }
327
328    #[tokio::test]
329    async fn test_wildcard_channel_subscription() {
330        let manager = create_test_manager().await;
331        
332        // Subscribe to wildcard pattern
333        let pattern = ChannelPattern::wildcard("cache:*".to_string());
334        let mut receiver = manager.subscribe(pattern).await.unwrap();
335        
336        // Publish message to matching channel
337        let message = Value::String("Invalidate user123".to_string());
338        let delivered = manager.publish("cache:invalidate:user123", message.clone()).await.unwrap();
339        assert_eq!(delivered, 1);
340        
341        // Receive message
342        let received = timeout(Duration::from_millis(100), receiver.recv()).await.unwrap().unwrap();
343        assert_eq!(received.channel, "cache:invalidate:user123");
344        assert_eq!(received.message, message);
345        
346        // Publish to non-matching channel
347        let delivered = manager.publish("other:channel", Value::String("test".to_string())).await.unwrap();
348        assert_eq!(delivered, 0);
349    }
350
351    #[tokio::test]
352    async fn test_multiple_subscribers() {
353        let manager = create_test_manager().await;
354        
355        // Create multiple subscribers
356        let pattern = ChannelPattern::exact("broadcast".to_string());
357        let mut receiver1 = manager.subscribe(pattern.clone()).await.unwrap();
358        let mut receiver2 = manager.subscribe(pattern).await.unwrap();
359        
360        // Publish message
361        let message = Value::String("Broadcast message".to_string());
362        let delivered = manager.publish("broadcast", message.clone()).await.unwrap();
363        assert_eq!(delivered, 2);
364        
365        // Both should receive the message
366        let received1 = timeout(Duration::from_millis(100), receiver1.recv()).await.unwrap().unwrap();
367        let received2 = timeout(Duration::from_millis(100), receiver2.recv()).await.unwrap().unwrap();
368        
369        assert_eq!(received1.message, message);
370        assert_eq!(received2.message, message);
371    }
372
373    #[tokio::test]
374    async fn test_unsubscribe() {
375        let manager = create_test_manager().await;
376        
377        // Subscribe
378        let pattern = ChannelPattern::exact("test:unsub".to_string());
379        let _receiver = manager.subscribe(pattern.clone()).await.unwrap();
380        
381        // Publish before unsubscribe
382        let delivered = manager.publish("test:unsub", Value::String("test".to_string())).await.unwrap();
383        assert_eq!(delivered, 1);
384        
385        // Unsubscribe
386        let unsub_count = manager.unsubscribe(&pattern).await.unwrap();
387        assert_eq!(unsub_count, 1);
388        
389        // Publish after unsubscribe
390        let delivered = manager.publish("test:unsub", Value::String("test2".to_string())).await.unwrap();
391        assert_eq!(delivered, 0);
392    }
393
394    #[tokio::test]
395    async fn test_stats() {
396        let manager = create_test_manager().await;
397        
398        // Initial stats
399        let stats = manager.get_stats().await.unwrap();
400        assert_eq!(stats.total_subscriptions, 0);
401        assert_eq!(stats.pattern_count, 0);
402        
403        // Add subscriptions
404        let _receiver1 = manager.subscribe(ChannelPattern::exact("exact1".to_string())).await.unwrap();
405        let _receiver2 = manager.subscribe(ChannelPattern::exact("exact2".to_string())).await.unwrap();
406        let _receiver3 = manager.subscribe(ChannelPattern::wildcard("wild:*".to_string())).await.unwrap();
407        
408        let stats = manager.get_stats().await.unwrap();
409        assert_eq!(stats.total_subscriptions, 3);
410        assert_eq!(stats.pattern_count, 3);
411        assert_eq!(stats.exact_patterns, 2);
412        assert_eq!(stats.wildcard_patterns, 1);
413    }
414
415    #[tokio::test]
416    async fn test_pattern_matching() {
417        // Test exact matching
418        let exact = ChannelPattern::exact("test:channel".to_string());
419        assert!(exact.matches("test:channel"));
420        assert!(!exact.matches("test:other"));
421        
422        // Test wildcard matching
423        let wildcard = ChannelPattern::wildcard("cache:*".to_string());
424        assert!(wildcard.matches("cache:invalidate"));
425        assert!(wildcard.matches("cache:invalidate:user123"));
426        assert!(!wildcard.matches("other:invalidate"));
427        
428        // Test prefix wildcard
429        let prefix = ChannelPattern::wildcard("auth:*".to_string());
430        assert!(prefix.matches("auth:login"));
431        assert!(prefix.matches("auth:logout"));
432        assert!(!prefix.matches("cache:login"));
433    }
434}