mockforge_mqtt/
topics.rs

1use regex::Regex;
2use std::collections::HashMap;
3
4/// Represents a subscription to a topic
5#[derive(Debug, Clone)]
6pub struct Subscription {
7    pub filter: String,
8    pub qos: u8,
9    pub client_id: String,
10}
11
12/// Represents a retained message
13#[derive(Debug, Clone)]
14pub struct RetainedMessage {
15    pub payload: Vec<u8>,
16    pub qos: u8,
17    pub timestamp: u64,
18}
19
20/// Topic tree for managing subscriptions and retained messages
21pub struct TopicTree {
22    subscriptions: HashMap<String, Vec<Subscription>>,
23    retained: HashMap<String, RetainedMessage>,
24}
25
26impl Default for TopicTree {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl TopicTree {
33    pub fn new() -> Self {
34        Self {
35            subscriptions: HashMap::new(),
36            retained: HashMap::new(),
37        }
38    }
39
40    /// Match a topic against all subscriptions
41    pub fn match_topic(&self, topic: &str) -> Vec<&Subscription> {
42        let mut matches = Vec::new();
43
44        for subscriptions in self.subscriptions.values() {
45            for subscription in subscriptions {
46                if self.matches_filter(topic, &subscription.filter) {
47                    matches.push(subscription);
48                }
49            }
50        }
51
52        matches
53    }
54
55    /// Check if a topic matches a filter (supports wildcards + and #)
56    fn matches_filter(&self, topic: &str, filter: &str) -> bool {
57        // Convert MQTT wildcard filter to regex
58        let regex_pattern = filter
59            .replace('+', "[^/]+")  // + matches any single level
60            .replace("#", ".+")      // # matches any remaining levels
61            .replace("$", "\\$"); // Escape $ for regex
62
63        let regex = match Regex::new(&format!("^{}$", regex_pattern)) {
64            Ok(r) => r,
65            Err(_) => return false,
66        };
67
68        regex.is_match(topic)
69    }
70
71    /// Add a subscription
72    pub fn subscribe(&mut self, filter: &str, qos: u8, client_id: &str) {
73        let subscription = Subscription {
74            filter: filter.to_string(),
75            qos,
76            client_id: client_id.to_string(),
77        };
78
79        self.subscriptions.entry(filter.to_string()).or_default().push(subscription);
80    }
81
82    /// Remove a subscription
83    pub fn unsubscribe(&mut self, filter: &str, client_id: &str) {
84        if let Some(subscriptions) = self.subscriptions.get_mut(filter) {
85            subscriptions.retain(|s| s.client_id != client_id);
86            if subscriptions.is_empty() {
87                self.subscriptions.remove(filter);
88            }
89        }
90    }
91
92    /// Store a retained message
93    pub fn retain_message(&mut self, topic: &str, payload: Vec<u8>, qos: u8) {
94        if payload.is_empty() {
95            // Empty payload removes retained message
96            self.retained.remove(topic);
97        } else {
98            let message = RetainedMessage {
99                payload,
100                qos,
101                timestamp: std::time::SystemTime::now()
102                    .duration_since(std::time::UNIX_EPOCH)
103                    .expect("system time before UNIX epoch")
104                    .as_secs(),
105            };
106            self.retained.insert(topic.to_string(), message);
107        }
108    }
109
110    /// Get retained message for a topic
111    pub fn get_retained(&self, topic: &str) -> Option<&RetainedMessage> {
112        self.retained.get(topic)
113    }
114
115    /// Get all retained messages that match a subscription filter
116    pub fn get_retained_for_filter(&self, filter: &str) -> Vec<(&str, &RetainedMessage)> {
117        self.retained
118            .iter()
119            .filter(|(topic, _)| self.matches_filter(topic, filter))
120            .map(|(topic, message)| (topic.as_str(), message))
121            .collect()
122    }
123
124    /// Clean up expired retained messages (basic implementation)
125    pub fn cleanup_expired_retained(&mut self, max_age_secs: u64) {
126        let now = std::time::SystemTime::now()
127            .duration_since(std::time::UNIX_EPOCH)
128            .expect("system time before UNIX epoch")
129            .as_secs();
130
131        self.retained
132            .retain(|_, message| now.saturating_sub(message.timestamp) < max_age_secs);
133    }
134
135    /// Get all topic filters (subscription patterns)
136    pub fn get_all_topic_filters(&self) -> Vec<String> {
137        self.subscriptions.keys().cloned().collect()
138    }
139
140    /// Get all retained message topics
141    pub fn get_all_retained_topics(&self) -> Vec<String> {
142        self.retained.keys().cloned().collect()
143    }
144
145    /// Get topic statistics
146    pub fn stats(&self) -> TopicStats {
147        TopicStats {
148            total_subscriptions: self.subscriptions.len(),
149            total_subscribers: self.subscriptions.values().map(|subs| subs.len()).sum(),
150            retained_messages: self.retained.len(),
151        }
152    }
153}
154
155/// Topic tree statistics
156#[derive(Debug, Clone)]
157pub struct TopicStats {
158    pub total_subscriptions: usize,
159    pub total_subscribers: usize,
160    pub retained_messages: usize,
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_subscription_clone() {
169        let sub = Subscription {
170            filter: "test/topic".to_string(),
171            qos: 1,
172            client_id: "client-1".to_string(),
173        };
174
175        let cloned = sub.clone();
176        assert_eq!(sub.filter, cloned.filter);
177        assert_eq!(sub.qos, cloned.qos);
178        assert_eq!(sub.client_id, cloned.client_id);
179    }
180
181    #[test]
182    fn test_subscription_debug() {
183        let sub = Subscription {
184            filter: "sensor/#".to_string(),
185            qos: 2,
186            client_id: "sensor-client".to_string(),
187        };
188        let debug = format!("{:?}", sub);
189        assert!(debug.contains("Subscription"));
190        assert!(debug.contains("sensor/#"));
191    }
192
193    #[test]
194    fn test_retained_message_clone() {
195        let msg = RetainedMessage {
196            payload: b"hello".to_vec(),
197            qos: 1,
198            timestamp: 1234567890,
199        };
200
201        let cloned = msg.clone();
202        assert_eq!(msg.payload, cloned.payload);
203        assert_eq!(msg.qos, cloned.qos);
204        assert_eq!(msg.timestamp, cloned.timestamp);
205    }
206
207    #[test]
208    fn test_retained_message_debug() {
209        let msg = RetainedMessage {
210            payload: b"test".to_vec(),
211            qos: 0,
212            timestamp: 0,
213        };
214        let debug = format!("{:?}", msg);
215        assert!(debug.contains("RetainedMessage"));
216    }
217
218    #[test]
219    fn test_topic_tree_new() {
220        let tree = TopicTree::new();
221        let stats = tree.stats();
222        assert_eq!(stats.total_subscriptions, 0);
223        assert_eq!(stats.total_subscribers, 0);
224        assert_eq!(stats.retained_messages, 0);
225    }
226
227    #[test]
228    fn test_topic_tree_default() {
229        let tree = TopicTree::default();
230        assert!(tree.get_all_topic_filters().is_empty());
231    }
232
233    #[test]
234    fn test_subscribe() {
235        let mut tree = TopicTree::new();
236        tree.subscribe("sensor/temp", 1, "client-1");
237
238        let stats = tree.stats();
239        assert_eq!(stats.total_subscriptions, 1);
240        assert_eq!(stats.total_subscribers, 1);
241    }
242
243    #[test]
244    fn test_subscribe_multiple_clients() {
245        let mut tree = TopicTree::new();
246        tree.subscribe("sensor/temp", 1, "client-1");
247        tree.subscribe("sensor/temp", 2, "client-2");
248
249        let stats = tree.stats();
250        assert_eq!(stats.total_subscriptions, 1);
251        assert_eq!(stats.total_subscribers, 2);
252    }
253
254    #[test]
255    fn test_unsubscribe() {
256        let mut tree = TopicTree::new();
257        tree.subscribe("sensor/temp", 1, "client-1");
258        tree.subscribe("sensor/temp", 1, "client-2");
259
260        tree.unsubscribe("sensor/temp", "client-1");
261
262        let stats = tree.stats();
263        assert_eq!(stats.total_subscribers, 1);
264    }
265
266    #[test]
267    fn test_unsubscribe_removes_filter() {
268        let mut tree = TopicTree::new();
269        tree.subscribe("sensor/temp", 1, "client-1");
270        tree.unsubscribe("sensor/temp", "client-1");
271
272        let stats = tree.stats();
273        assert_eq!(stats.total_subscriptions, 0);
274    }
275
276    #[test]
277    fn test_match_topic_exact() {
278        let mut tree = TopicTree::new();
279        tree.subscribe("sensor/temp", 1, "client-1");
280
281        let matches = tree.match_topic("sensor/temp");
282        assert_eq!(matches.len(), 1);
283        assert_eq!(matches[0].client_id, "client-1");
284    }
285
286    #[test]
287    fn test_match_topic_plus_wildcard() {
288        let mut tree = TopicTree::new();
289        tree.subscribe("sensor/+/temp", 1, "client-1");
290
291        let matches = tree.match_topic("sensor/room1/temp");
292        assert_eq!(matches.len(), 1);
293
294        // Should not match different depth
295        let no_matches = tree.match_topic("sensor/temp");
296        assert_eq!(no_matches.len(), 0);
297    }
298
299    #[test]
300    fn test_match_topic_hash_wildcard() {
301        let mut tree = TopicTree::new();
302        tree.subscribe("sensor/#", 1, "client-1");
303
304        let matches1 = tree.match_topic("sensor/temp");
305        assert_eq!(matches1.len(), 1);
306
307        let matches2 = tree.match_topic("sensor/room/temp/value");
308        assert_eq!(matches2.len(), 1);
309    }
310
311    #[test]
312    fn test_match_topic_no_match() {
313        let mut tree = TopicTree::new();
314        tree.subscribe("sensor/temp", 1, "client-1");
315
316        let matches = tree.match_topic("actuator/temp");
317        assert!(matches.is_empty());
318    }
319
320    #[test]
321    fn test_retain_message() {
322        let mut tree = TopicTree::new();
323        tree.retain_message("sensor/temp", b"25.5".to_vec(), 1);
324
325        let retained = tree.get_retained("sensor/temp");
326        assert!(retained.is_some());
327        assert_eq!(retained.unwrap().payload, b"25.5".to_vec());
328    }
329
330    #[test]
331    fn test_retain_message_empty_removes() {
332        let mut tree = TopicTree::new();
333        tree.retain_message("sensor/temp", b"25.5".to_vec(), 1);
334        tree.retain_message("sensor/temp", vec![], 0);
335
336        let retained = tree.get_retained("sensor/temp");
337        assert!(retained.is_none());
338    }
339
340    #[test]
341    fn test_get_retained_for_filter() {
342        let mut tree = TopicTree::new();
343        tree.retain_message("sensor/temp", b"25.5".to_vec(), 1);
344        tree.retain_message("sensor/humidity", b"60".to_vec(), 1);
345        tree.retain_message("actuator/fan", b"on".to_vec(), 1);
346
347        let matches = tree.get_retained_for_filter("sensor/#");
348        assert_eq!(matches.len(), 2);
349    }
350
351    #[test]
352    fn test_cleanup_expired_retained() {
353        let mut tree = TopicTree::new();
354        tree.retain_message("sensor/temp", b"25.5".to_vec(), 1);
355
356        // Cleanup with max age of 1 year - should not remove
357        tree.cleanup_expired_retained(365 * 24 * 60 * 60);
358        assert!(tree.get_retained("sensor/temp").is_some());
359    }
360
361    #[test]
362    fn test_get_all_topic_filters() {
363        let mut tree = TopicTree::new();
364        tree.subscribe("sensor/temp", 1, "client-1");
365        tree.subscribe("sensor/humidity", 1, "client-2");
366
367        let filters = tree.get_all_topic_filters();
368        assert_eq!(filters.len(), 2);
369    }
370
371    #[test]
372    fn test_get_all_retained_topics() {
373        let mut tree = TopicTree::new();
374        tree.retain_message("topic1", b"msg1".to_vec(), 1);
375        tree.retain_message("topic2", b"msg2".to_vec(), 1);
376
377        let topics = tree.get_all_retained_topics();
378        assert_eq!(topics.len(), 2);
379    }
380
381    #[test]
382    fn test_topic_stats_clone() {
383        let stats = TopicStats {
384            total_subscriptions: 5,
385            total_subscribers: 10,
386            retained_messages: 3,
387        };
388
389        let cloned = stats.clone();
390        assert_eq!(stats.total_subscriptions, cloned.total_subscriptions);
391        assert_eq!(stats.total_subscribers, cloned.total_subscribers);
392        assert_eq!(stats.retained_messages, cloned.retained_messages);
393    }
394
395    #[test]
396    fn test_topic_stats_debug() {
397        let stats = TopicStats {
398            total_subscriptions: 1,
399            total_subscribers: 2,
400            retained_messages: 3,
401        };
402        let debug = format!("{:?}", stats);
403        assert!(debug.contains("TopicStats"));
404    }
405}