mqtt5_protocol/session/
limits.rs

1use crate::error::{MqttError, Result};
2use crate::prelude::{String, Vec};
3use crate::time::{Duration, Instant};
4use crate::QoS;
5
6#[derive(Debug, Clone)]
7pub struct LimitsConfig {
8    pub client_maximum_packet_size: u32,
9    pub server_maximum_packet_size: Option<u32>,
10    pub default_message_expiry: Option<Duration>,
11    pub max_message_expiry: Option<Duration>,
12}
13
14impl Default for LimitsConfig {
15    fn default() -> Self {
16        Self {
17            client_maximum_packet_size: crate::constants::limits::MAX_PACKET_SIZE,
18            server_maximum_packet_size: None,
19            default_message_expiry: None,
20            max_message_expiry: Some(Duration::from_secs(86400 * 7)),
21        }
22    }
23}
24
25#[derive(Debug)]
26pub struct LimitsManager {
27    config: LimitsConfig,
28}
29
30impl LimitsManager {
31    #[must_use]
32    pub fn new(config: LimitsConfig) -> Self {
33        Self { config }
34    }
35
36    #[must_use]
37    pub fn with_defaults() -> Self {
38        Self::new(LimitsConfig::default())
39    }
40
41    pub fn set_server_maximum_packet_size(&mut self, size: u32) {
42        self.config.server_maximum_packet_size = Some(size);
43    }
44
45    pub fn set_client_maximum_packet_size(&mut self, size: u32) {
46        self.config.client_maximum_packet_size = size;
47    }
48
49    #[must_use]
50    pub fn effective_maximum_packet_size(&self) -> u32 {
51        match self.config.server_maximum_packet_size {
52            Some(server_max) if server_max > 0 && self.config.client_maximum_packet_size > 0 => {
53                server_max.min(self.config.client_maximum_packet_size)
54            }
55            Some(server_max) if server_max > 0 => server_max,
56            _ => self.config.client_maximum_packet_size,
57        }
58    }
59
60    /// # Errors
61    /// Returns `PacketTooLarge` if the size exceeds the effective maximum.
62    pub fn check_packet_size(&self, size: usize) -> Result<()> {
63        let max_size = self.effective_maximum_packet_size();
64        if max_size > 0 && size > max_size as usize {
65            Err(MqttError::PacketTooLarge {
66                size,
67                max: max_size as usize,
68            })
69        } else {
70            Ok(())
71        }
72    }
73
74    #[must_use]
75    pub fn calculate_message_expiry(&self, expiry_interval: Option<u32>) -> Option<Instant> {
76        let interval = match expiry_interval {
77            Some(seconds) => Duration::from_secs(u64::from(seconds)),
78            None => self.config.default_message_expiry?,
79        };
80
81        let final_interval = match self.config.max_message_expiry {
82            Some(max) => interval.min(max),
83            None => interval,
84        };
85
86        Some(Instant::now() + final_interval)
87    }
88
89    #[must_use]
90    pub fn is_message_expired(&self, expiry_time: Option<Instant>) -> bool {
91        match expiry_time {
92            Some(expiry) => Instant::now() > expiry,
93            None => false,
94        }
95    }
96
97    #[must_use]
98    pub fn get_remaining_expiry(&self, expiry_time: Option<Instant>) -> Option<u32> {
99        match expiry_time {
100            Some(expiry) => {
101                let now = Instant::now();
102                if now < expiry {
103                    let remaining = expiry.duration_since(now);
104                    Some(u32::try_from(remaining.as_secs()).unwrap_or(u32::MAX))
105                } else {
106                    Some(0)
107                }
108            }
109            None => None,
110        }
111    }
112
113    #[must_use]
114    pub fn client_maximum_packet_size(&self) -> u32 {
115        self.config.client_maximum_packet_size
116    }
117
118    #[must_use]
119    pub fn server_maximum_packet_size(&self) -> Option<u32> {
120        self.config.server_maximum_packet_size
121    }
122}
123
124#[derive(Debug, Clone)]
125pub struct ExpiringMessage {
126    pub topic: String,
127    pub payload: Vec<u8>,
128    pub qos: QoS,
129    pub retain: bool,
130    pub packet_id: Option<u16>,
131    pub expiry_time: Option<Instant>,
132    pub expiry_interval: Option<u32>,
133}
134
135impl ExpiringMessage {
136    #[must_use]
137    pub fn new(
138        topic: String,
139        payload: Vec<u8>,
140        qos: QoS,
141        retain: bool,
142        packet_id: Option<u16>,
143        expiry_interval: Option<u32>,
144        limits: &LimitsManager,
145    ) -> Self {
146        let expiry_time = limits.calculate_message_expiry(expiry_interval);
147
148        Self {
149            topic,
150            payload,
151            qos,
152            retain,
153            packet_id,
154            expiry_time,
155            expiry_interval,
156        }
157    }
158
159    #[must_use]
160    pub fn is_expired(&self) -> bool {
161        match self.expiry_time {
162            Some(expiry) => Instant::now() > expiry,
163            None => false,
164        }
165    }
166
167    #[must_use]
168    pub fn remaining_expiry_interval(&self) -> Option<u32> {
169        match self.expiry_time {
170            Some(expiry) => {
171                let now = Instant::now();
172                if now < expiry {
173                    let remaining = expiry.duration_since(now);
174                    Some(u32::try_from(remaining.as_secs()).unwrap_or(u32::MAX))
175                } else {
176                    Some(0)
177                }
178            }
179            None => self.expiry_interval,
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_limits_manager_creation() {
190        let limits = LimitsManager::with_defaults();
191        assert_eq!(
192            limits.client_maximum_packet_size(),
193            crate::constants::limits::MAX_PACKET_SIZE
194        );
195        assert_eq!(limits.server_maximum_packet_size(), None);
196    }
197
198    #[test]
199    fn test_effective_packet_size() {
200        let mut limits = LimitsManager::with_defaults();
201
202        assert_eq!(
203            limits.effective_maximum_packet_size(),
204            crate::constants::limits::MAX_PACKET_SIZE
205        );
206
207        limits.set_server_maximum_packet_size(1_048_576);
208        assert_eq!(limits.effective_maximum_packet_size(), 1_048_576);
209
210        let config = LimitsConfig {
211            client_maximum_packet_size: 1_048_576,
212            ..Default::default()
213        };
214        let mut limits = LimitsManager::new(config);
215        limits.set_server_maximum_packet_size(10_485_760);
216        assert_eq!(limits.effective_maximum_packet_size(), 1_048_576);
217    }
218
219    #[test]
220    fn test_packet_size_checking() {
221        let mut limits = LimitsManager::with_defaults();
222        limits.set_server_maximum_packet_size(1024);
223
224        assert!(limits.check_packet_size(512).is_ok());
225        assert!(limits.check_packet_size(1024).is_ok());
226
227        let result = limits.check_packet_size(2048);
228        assert!(result.is_err());
229        if let Err(MqttError::PacketTooLarge { size, max }) = result {
230            assert_eq!(size, 2048);
231            assert_eq!(max, 1024);
232        }
233    }
234
235    #[test]
236    fn test_message_expiry() {
237        let config = LimitsConfig {
238            default_message_expiry: Some(Duration::from_secs(60)),
239            ..Default::default()
240        };
241        let limits = LimitsManager::new(config);
242
243        let expiry_time = limits.calculate_message_expiry(Some(30));
244        assert!(expiry_time.is_some());
245
246        let expiry_time = limits.calculate_message_expiry(None);
247        assert!(expiry_time.is_some());
248
249        let past_time = Some(Instant::now().checked_sub(Duration::from_secs(10)).unwrap());
250        assert!(limits.is_message_expired(past_time));
251
252        let future_time = Some(Instant::now() + Duration::from_secs(10));
253        assert!(!limits.is_message_expired(future_time));
254    }
255
256    #[test]
257    fn test_remaining_expiry() {
258        let limits = LimitsManager::with_defaults();
259
260        let future_time = Some(Instant::now() + Duration::from_secs(100));
261        let remaining = limits.get_remaining_expiry(future_time);
262        assert!(remaining.is_some());
263        assert!(remaining.unwrap() > 95 && remaining.unwrap() <= 100);
264
265        let past_time = Some(Instant::now().checked_sub(Duration::from_secs(10)).unwrap());
266        let remaining = limits.get_remaining_expiry(past_time);
267        assert_eq!(remaining, Some(0));
268    }
269
270    #[test]
271    fn test_expiring_message() {
272        let limits = LimitsManager::with_defaults();
273
274        let msg = ExpiringMessage::new(
275            "test/topic".into(),
276            vec![1, 2, 3],
277            QoS::AtLeastOnce,
278            false,
279            Some(123),
280            Some(60),
281            &limits,
282        );
283
284        assert!(!msg.is_expired());
285        assert!(msg.remaining_expiry_interval().is_some());
286
287        let mut msg = ExpiringMessage::new(
288            "test/topic".into(),
289            vec![1, 2, 3],
290            QoS::AtLeastOnce,
291            false,
292            Some(123),
293            Some(0),
294            &limits,
295        );
296
297        msg.expiry_time = Some(Instant::now().checked_sub(Duration::from_secs(10)).unwrap());
298        assert!(msg.is_expired());
299        assert_eq!(msg.remaining_expiry_interval(), Some(0));
300    }
301
302    #[test]
303    fn test_max_expiry_limit() {
304        let config = LimitsConfig {
305            max_message_expiry: Some(crate::constants::time::DEFAULT_SESSION_EXPIRY),
306            ..Default::default()
307        };
308        let limits = LimitsManager::new(config);
309
310        let expiry_time = limits.calculate_message_expiry(Some(7200));
311        assert!(expiry_time.is_some());
312
313        let remaining = limits.get_remaining_expiry(expiry_time);
314        assert!(remaining.is_some());
315        assert!(
316            remaining.unwrap()
317                <= u32::try_from(crate::constants::time::DEFAULT_SESSION_EXPIRY.as_secs())
318                    .unwrap_or(u32::MAX)
319        );
320    }
321}