mqtt5_protocol/session/
subscription.rs

1use crate::error::{MqttError, Result};
2use crate::packet::subscribe::SubscriptionOptions;
3use crate::prelude::{HashMap, String, Vec};
4use crate::topic_matching::matches as topic_matches;
5use crate::validation::is_valid_topic_filter;
6
7#[derive(Debug, Clone, PartialEq)]
8pub struct Subscription {
9    pub topic_filter: String,
10    pub options: SubscriptionOptions,
11}
12
13#[derive(Debug)]
14pub struct SubscriptionManager {
15    subscriptions: HashMap<String, Subscription>,
16}
17
18impl SubscriptionManager {
19    #[must_use]
20    pub fn new() -> Self {
21        Self {
22            subscriptions: HashMap::new(),
23        }
24    }
25
26    /// # Errors
27    /// Returns `InvalidTopicFilter` if the topic filter is invalid.
28    pub fn add(&mut self, topic_filter: String, subscription: Subscription) -> Result<()> {
29        if !is_valid_topic_filter(&topic_filter) {
30            return Err(MqttError::InvalidTopicFilter(topic_filter));
31        }
32
33        self.subscriptions.insert(topic_filter, subscription);
34        Ok(())
35    }
36
37    /// # Errors
38    /// This function currently cannot fail but returns Result for API consistency.
39    pub fn remove(&mut self, topic_filter: &str) -> Result<bool> {
40        Ok(self.subscriptions.remove(topic_filter).is_some())
41    }
42
43    #[must_use]
44    pub fn matching_subscriptions(&self, topic: &str) -> Vec<(String, Subscription)> {
45        self.subscriptions
46            .iter()
47            .filter(|(filter, _)| topic_matches(topic, filter))
48            .map(|(filter, sub)| (filter.clone(), sub.clone()))
49            .collect()
50    }
51
52    #[must_use]
53    pub fn get(&self, topic_filter: &str) -> Option<&Subscription> {
54        self.subscriptions.get(topic_filter)
55    }
56
57    #[must_use]
58    pub fn all(&self) -> HashMap<String, Subscription> {
59        self.subscriptions.clone()
60    }
61
62    #[must_use]
63    pub fn count(&self) -> usize {
64        self.subscriptions.len()
65    }
66
67    pub fn clear(&mut self) {
68        self.subscriptions.clear();
69    }
70
71    #[must_use]
72    pub fn contains(&self, topic_filter: &str) -> bool {
73        self.subscriptions.contains_key(topic_filter)
74    }
75}
76
77impl Default for SubscriptionManager {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use crate::prelude::ToString;
87    use crate::QoS;
88
89    #[test]
90    fn test_topic_matching_exact() {
91        assert!(topic_matches(
92            "sport/tennis/player1",
93            "sport/tennis/player1"
94        ));
95        assert!(!topic_matches(
96            "sport/tennis/player1",
97            "sport/tennis/player2"
98        ));
99        assert!(!topic_matches("sport/tennis", "sport/tennis/player1"));
100    }
101
102    #[test]
103    fn test_topic_matching_single_level_wildcard() {
104        assert!(topic_matches("sport/tennis/player1", "sport/tennis/+"));
105        assert!(topic_matches("sport/tennis/player2", "sport/tennis/+"));
106        assert!(!topic_matches(
107            "sport/tennis/player1/ranking",
108            "sport/tennis/+"
109        ));
110
111        assert!(topic_matches("sport/tennis/player1", "sport/+/player1"));
112        assert!(topic_matches("sport/basketball/player1", "sport/+/player1"));
113
114        assert!(topic_matches(
115            "sensors/temperature/room1",
116            "+/temperature/+"
117        ));
118        assert!(topic_matches(
119            "devices/temperature/kitchen",
120            "+/temperature/+"
121        ));
122    }
123
124    #[test]
125    fn test_topic_matching_multi_level_wildcard() {
126        assert!(topic_matches("sport/tennis/player1", "sport/#"));
127        assert!(topic_matches("sport/tennis/player1/ranking", "sport/#"));
128        assert!(topic_matches("sport", "sport/#"));
129        assert!(topic_matches(
130            "sport/tennis/player1/score/final",
131            "sport/tennis/#"
132        ));
133
134        assert!(!topic_matches("sports/tennis/player1", "sport/#"));
135
136        assert!(topic_matches("sport/tennis/player1", "#"));
137        assert!(topic_matches("anything/at/all", "#"));
138        assert!(topic_matches("single", "#"));
139    }
140
141    #[test]
142    fn test_topic_matching_combined_wildcards() {
143        assert!(topic_matches(
144            "sport/tennis/player1/score",
145            "sport/+/+/score"
146        ));
147        assert!(topic_matches(
148            "sport/tennis/player1/score/final",
149            "sport/+/player1/#"
150        ));
151        assert!(topic_matches("sensors/temperature/room1", "+/+/+"));
152        assert!(!topic_matches("sensors/temperature", "+/+/+"));
153    }
154
155    #[test]
156    fn test_valid_topic_filter() {
157        assert!(is_valid_topic_filter("sport/tennis/player1"));
158        assert!(is_valid_topic_filter("sport/tennis/+"));
159        assert!(is_valid_topic_filter("sport/#"));
160        assert!(is_valid_topic_filter("#"));
161        assert!(is_valid_topic_filter("+/tennis/+"));
162        assert!(is_valid_topic_filter("sport/+/player1/#"));
163
164        assert!(!is_valid_topic_filter(""));
165        assert!(!is_valid_topic_filter("sport/tennis#"));
166        assert!(!is_valid_topic_filter("sport/#/player"));
167        assert!(!is_valid_topic_filter("sport/ten+nis"));
168        assert!(!is_valid_topic_filter("sport/tennis/\0"));
169    }
170
171    #[test]
172    fn test_subscription_manager() {
173        let mut manager = SubscriptionManager::new();
174
175        let sub1 = Subscription {
176            topic_filter: "sport/tennis/+".to_string(),
177            options: SubscriptionOptions::default(),
178        };
179
180        let sub2 = Subscription {
181            topic_filter: "sport/#".to_string(),
182            options: SubscriptionOptions::default().with_qos(QoS::ExactlyOnce),
183        };
184
185        manager.add("sport/tennis/+".to_string(), sub1).unwrap();
186        manager.add("sport/#".to_string(), sub2).unwrap();
187
188        assert_eq!(manager.count(), 2);
189        assert!(manager.contains("sport/tennis/+"));
190
191        let matches = manager.matching_subscriptions("sport/tennis/player1");
192        assert_eq!(matches.len(), 2);
193
194        let matches = manager.matching_subscriptions("sport/basketball/team1");
195        assert_eq!(matches.len(), 1);
196        assert_eq!(matches[0].0, "sport/#");
197
198        manager.remove("sport/tennis/+").unwrap();
199        assert_eq!(manager.count(), 1);
200        assert!(!manager.contains("sport/tennis/+"));
201    }
202
203    #[test]
204    fn test_subscription_manager_edge_cases() {
205        let mut manager = SubscriptionManager::new();
206
207        let sub = Subscription {
208            topic_filter: "sport/#/invalid".to_string(),
209            options: SubscriptionOptions::default(),
210        };
211
212        assert!(manager.add("sport/#/invalid".to_string(), sub).is_err());
213
214        let sub1 = Subscription {
215            topic_filter: "test/topic".to_string(),
216            options: SubscriptionOptions::default().with_qos(QoS::AtMostOnce),
217        };
218
219        let sub2 = Subscription {
220            topic_filter: "test/topic".to_string(),
221            options: SubscriptionOptions::default().with_qos(QoS::AtLeastOnce),
222        };
223
224        manager.add("test/topic".to_string(), sub1).unwrap();
225        manager.add("test/topic".to_string(), sub2).unwrap();
226
227        assert_eq!(manager.count(), 1);
228        assert_eq!(
229            manager.get("test/topic").unwrap().options.qos,
230            QoS::AtLeastOnce
231        );
232    }
233}