mockforge_mqtt/
topics.rs

1use regex::Regex;
2use std::collections::HashMap;
3
4/// Represents a subscription to a topic
5#[derive(Debug, Clone)]
6pub struct Subscription {
7    pub filter: String,
8    pub qos: u8,
9    pub client_id: String,
10}
11
12/// Represents a retained message
13#[derive(Debug, Clone)]
14pub struct RetainedMessage {
15    pub payload: Vec<u8>,
16    pub qos: u8,
17    pub timestamp: u64,
18}
19
20/// Topic tree for managing subscriptions and retained messages
21pub struct TopicTree {
22    subscriptions: HashMap<String, Vec<Subscription>>,
23    retained: HashMap<String, RetainedMessage>,
24}
25
26impl Default for TopicTree {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl TopicTree {
33    pub fn new() -> Self {
34        Self {
35            subscriptions: HashMap::new(),
36            retained: HashMap::new(),
37        }
38    }
39
40    /// Match a topic against all subscriptions
41    pub fn match_topic(&self, topic: &str) -> Vec<&Subscription> {
42        let mut matches = Vec::new();
43
44        for subscriptions in self.subscriptions.values() {
45            for subscription in subscriptions {
46                if self.matches_filter(topic, &subscription.filter) {
47                    matches.push(subscription);
48                }
49            }
50        }
51
52        matches
53    }
54
55    /// Check if a topic matches a filter (supports wildcards + and #)
56    fn matches_filter(&self, topic: &str, filter: &str) -> bool {
57        // Convert MQTT wildcard filter to regex
58        let regex_pattern = filter
59            .replace('+', "[^/]+")  // + matches any single level
60            .replace("#", ".+")      // # matches any remaining levels
61            .replace("$", "\\$"); // Escape $ for regex
62
63        let regex = match Regex::new(&format!("^{}$", regex_pattern)) {
64            Ok(r) => r,
65            Err(_) => return false,
66        };
67
68        regex.is_match(topic)
69    }
70
71    /// Add a subscription
72    pub fn subscribe(&mut self, filter: &str, qos: u8, client_id: &str) {
73        let subscription = Subscription {
74            filter: filter.to_string(),
75            qos,
76            client_id: client_id.to_string(),
77        };
78
79        self.subscriptions.entry(filter.to_string()).or_default().push(subscription);
80    }
81
82    /// Remove a subscription
83    pub fn unsubscribe(&mut self, filter: &str, client_id: &str) {
84        if let Some(subscriptions) = self.subscriptions.get_mut(filter) {
85            subscriptions.retain(|s| s.client_id != client_id);
86            if subscriptions.is_empty() {
87                self.subscriptions.remove(filter);
88            }
89        }
90    }
91
92    /// Store a retained message
93    pub fn retain_message(&mut self, topic: &str, payload: Vec<u8>, qos: u8) {
94        if payload.is_empty() {
95            // Empty payload removes retained message
96            self.retained.remove(topic);
97        } else {
98            let message = RetainedMessage {
99                payload,
100                qos,
101                timestamp: std::time::SystemTime::now()
102                    .duration_since(std::time::UNIX_EPOCH)
103                    .unwrap()
104                    .as_secs(),
105            };
106            self.retained.insert(topic.to_string(), message);
107        }
108    }
109
110    /// Get retained message for a topic
111    pub fn get_retained(&self, topic: &str) -> Option<&RetainedMessage> {
112        self.retained.get(topic)
113    }
114
115    /// Get all retained messages that match a subscription filter
116    pub fn get_retained_for_filter(&self, filter: &str) -> Vec<(&str, &RetainedMessage)> {
117        self.retained
118            .iter()
119            .filter(|(topic, _)| self.matches_filter(topic, filter))
120            .map(|(topic, message)| (topic.as_str(), message))
121            .collect()
122    }
123
124    /// Clean up expired retained messages (basic implementation)
125    pub fn cleanup_expired_retained(&mut self, max_age_secs: u64) {
126        let now = std::time::SystemTime::now()
127            .duration_since(std::time::UNIX_EPOCH)
128            .unwrap()
129            .as_secs();
130
131        self.retained
132            .retain(|_, message| now.saturating_sub(message.timestamp) < max_age_secs);
133    }
134
135    /// Get all topic filters (subscription patterns)
136    pub fn get_all_topic_filters(&self) -> Vec<String> {
137        self.subscriptions.keys().cloned().collect()
138    }
139
140    /// Get all retained message topics
141    pub fn get_all_retained_topics(&self) -> Vec<String> {
142        self.retained.keys().cloned().collect()
143    }
144
145    /// Get topic statistics
146    pub fn stats(&self) -> TopicStats {
147        TopicStats {
148            total_subscriptions: self.subscriptions.len(),
149            total_subscribers: self.subscriptions.values().map(|subs| subs.len()).sum(),
150            retained_messages: self.retained.len(),
151        }
152    }
153}
154
155/// Topic tree statistics
156#[derive(Debug, Clone)]
157pub struct TopicStats {
158    pub total_subscriptions: usize,
159    pub total_subscribers: usize,
160    pub retained_messages: usize,
161}