1use crate::error::{MqttError, Result};
2use crate::prelude::{String, Vec};
3use crate::time::{Duration, Instant};
4use crate::QoS;
5
6#[derive(Debug, Clone)]
7pub struct LimitsConfig {
8 pub client_maximum_packet_size: u32,
9 pub server_maximum_packet_size: Option<u32>,
10 pub default_message_expiry: Option<Duration>,
11 pub max_message_expiry: Option<Duration>,
12}
13
14impl Default for LimitsConfig {
15 fn default() -> Self {
16 Self {
17 client_maximum_packet_size: crate::constants::limits::MAX_PACKET_SIZE,
18 server_maximum_packet_size: None,
19 default_message_expiry: None,
20 max_message_expiry: Some(Duration::from_secs(86400 * 7)),
21 }
22 }
23}
24
25#[derive(Debug)]
26pub struct LimitsManager {
27 config: LimitsConfig,
28}
29
30impl LimitsManager {
31 #[must_use]
32 pub fn new(config: LimitsConfig) -> Self {
33 Self { config }
34 }
35
36 #[must_use]
37 pub fn with_defaults() -> Self {
38 Self::new(LimitsConfig::default())
39 }
40
41 pub fn set_server_maximum_packet_size(&mut self, size: u32) {
42 self.config.server_maximum_packet_size = Some(size);
43 }
44
45 pub fn set_client_maximum_packet_size(&mut self, size: u32) {
46 self.config.client_maximum_packet_size = size;
47 }
48
49 #[must_use]
50 pub fn effective_maximum_packet_size(&self) -> u32 {
51 match self.config.server_maximum_packet_size {
52 Some(server_max) if server_max > 0 && self.config.client_maximum_packet_size > 0 => {
53 server_max.min(self.config.client_maximum_packet_size)
54 }
55 Some(server_max) if server_max > 0 => server_max,
56 _ => self.config.client_maximum_packet_size,
57 }
58 }
59
60 pub fn check_packet_size(&self, size: usize) -> Result<()> {
63 let max_size = self.effective_maximum_packet_size();
64 if max_size > 0 && size > max_size as usize {
65 Err(MqttError::PacketTooLarge {
66 size,
67 max: max_size as usize,
68 })
69 } else {
70 Ok(())
71 }
72 }
73
74 #[must_use]
75 pub fn calculate_message_expiry(&self, expiry_interval: Option<u32>) -> Option<Instant> {
76 let interval = match expiry_interval {
77 Some(seconds) => Duration::from_secs(u64::from(seconds)),
78 None => self.config.default_message_expiry?,
79 };
80
81 let final_interval = match self.config.max_message_expiry {
82 Some(max) => interval.min(max),
83 None => interval,
84 };
85
86 Some(Instant::now() + final_interval)
87 }
88
89 #[must_use]
90 pub fn is_message_expired(&self, expiry_time: Option<Instant>) -> bool {
91 match expiry_time {
92 Some(expiry) => Instant::now() > expiry,
93 None => false,
94 }
95 }
96
97 #[must_use]
98 pub fn get_remaining_expiry(&self, expiry_time: Option<Instant>) -> Option<u32> {
99 match expiry_time {
100 Some(expiry) => {
101 let now = Instant::now();
102 if now < expiry {
103 let remaining = expiry.duration_since(now);
104 Some(u32::try_from(remaining.as_secs()).unwrap_or(u32::MAX))
105 } else {
106 Some(0)
107 }
108 }
109 None => None,
110 }
111 }
112
113 #[must_use]
114 pub fn client_maximum_packet_size(&self) -> u32 {
115 self.config.client_maximum_packet_size
116 }
117
118 #[must_use]
119 pub fn server_maximum_packet_size(&self) -> Option<u32> {
120 self.config.server_maximum_packet_size
121 }
122}
123
124#[derive(Debug, Clone)]
125pub struct ExpiringMessage {
126 pub topic: String,
127 pub payload: Vec<u8>,
128 pub qos: QoS,
129 pub retain: bool,
130 pub packet_id: Option<u16>,
131 pub expiry_time: Option<Instant>,
132 pub expiry_interval: Option<u32>,
133}
134
135impl ExpiringMessage {
136 #[must_use]
137 pub fn new(
138 topic: String,
139 payload: Vec<u8>,
140 qos: QoS,
141 retain: bool,
142 packet_id: Option<u16>,
143 expiry_interval: Option<u32>,
144 limits: &LimitsManager,
145 ) -> Self {
146 let expiry_time = limits.calculate_message_expiry(expiry_interval);
147
148 Self {
149 topic,
150 payload,
151 qos,
152 retain,
153 packet_id,
154 expiry_time,
155 expiry_interval,
156 }
157 }
158
159 #[must_use]
160 pub fn is_expired(&self) -> bool {
161 match self.expiry_time {
162 Some(expiry) => Instant::now() > expiry,
163 None => false,
164 }
165 }
166
167 #[must_use]
168 pub fn remaining_expiry_interval(&self) -> Option<u32> {
169 match self.expiry_time {
170 Some(expiry) => {
171 let now = Instant::now();
172 if now < expiry {
173 let remaining = expiry.duration_since(now);
174 Some(u32::try_from(remaining.as_secs()).unwrap_or(u32::MAX))
175 } else {
176 Some(0)
177 }
178 }
179 None => self.expiry_interval,
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_limits_manager_creation() {
190 let limits = LimitsManager::with_defaults();
191 assert_eq!(
192 limits.client_maximum_packet_size(),
193 crate::constants::limits::MAX_PACKET_SIZE
194 );
195 assert_eq!(limits.server_maximum_packet_size(), None);
196 }
197
198 #[test]
199 fn test_effective_packet_size() {
200 let mut limits = LimitsManager::with_defaults();
201
202 assert_eq!(
203 limits.effective_maximum_packet_size(),
204 crate::constants::limits::MAX_PACKET_SIZE
205 );
206
207 limits.set_server_maximum_packet_size(1_048_576);
208 assert_eq!(limits.effective_maximum_packet_size(), 1_048_576);
209
210 let config = LimitsConfig {
211 client_maximum_packet_size: 1_048_576,
212 ..Default::default()
213 };
214 let mut limits = LimitsManager::new(config);
215 limits.set_server_maximum_packet_size(10_485_760);
216 assert_eq!(limits.effective_maximum_packet_size(), 1_048_576);
217 }
218
219 #[test]
220 fn test_packet_size_checking() {
221 let mut limits = LimitsManager::with_defaults();
222 limits.set_server_maximum_packet_size(1024);
223
224 assert!(limits.check_packet_size(512).is_ok());
225 assert!(limits.check_packet_size(1024).is_ok());
226
227 let result = limits.check_packet_size(2048);
228 assert!(result.is_err());
229 if let Err(MqttError::PacketTooLarge { size, max }) = result {
230 assert_eq!(size, 2048);
231 assert_eq!(max, 1024);
232 }
233 }
234
235 #[test]
236 fn test_message_expiry() {
237 let config = LimitsConfig {
238 default_message_expiry: Some(Duration::from_secs(60)),
239 ..Default::default()
240 };
241 let limits = LimitsManager::new(config);
242
243 let expiry_time = limits.calculate_message_expiry(Some(30));
244 assert!(expiry_time.is_some());
245
246 let expiry_time = limits.calculate_message_expiry(None);
247 assert!(expiry_time.is_some());
248
249 let past_time = Some(Instant::now().checked_sub(Duration::from_secs(10)).unwrap());
250 assert!(limits.is_message_expired(past_time));
251
252 let future_time = Some(Instant::now() + Duration::from_secs(10));
253 assert!(!limits.is_message_expired(future_time));
254 }
255
256 #[test]
257 fn test_remaining_expiry() {
258 let limits = LimitsManager::with_defaults();
259
260 let future_time = Some(Instant::now() + Duration::from_secs(100));
261 let remaining = limits.get_remaining_expiry(future_time);
262 assert!(remaining.is_some());
263 assert!(remaining.unwrap() > 95 && remaining.unwrap() <= 100);
264
265 let past_time = Some(Instant::now().checked_sub(Duration::from_secs(10)).unwrap());
266 let remaining = limits.get_remaining_expiry(past_time);
267 assert_eq!(remaining, Some(0));
268 }
269
270 #[test]
271 fn test_expiring_message() {
272 let limits = LimitsManager::with_defaults();
273
274 let msg = ExpiringMessage::new(
275 "test/topic".into(),
276 vec![1, 2, 3],
277 QoS::AtLeastOnce,
278 false,
279 Some(123),
280 Some(60),
281 &limits,
282 );
283
284 assert!(!msg.is_expired());
285 assert!(msg.remaining_expiry_interval().is_some());
286
287 let mut msg = ExpiringMessage::new(
288 "test/topic".into(),
289 vec![1, 2, 3],
290 QoS::AtLeastOnce,
291 false,
292 Some(123),
293 Some(0),
294 &limits,
295 );
296
297 msg.expiry_time = Some(Instant::now().checked_sub(Duration::from_secs(10)).unwrap());
298 assert!(msg.is_expired());
299 assert_eq!(msg.remaining_expiry_interval(), Some(0));
300 }
301
302 #[test]
303 fn test_max_expiry_limit() {
304 let config = LimitsConfig {
305 max_message_expiry: Some(crate::constants::time::DEFAULT_SESSION_EXPIRY),
306 ..Default::default()
307 };
308 let limits = LimitsManager::new(config);
309
310 let expiry_time = limits.calculate_message_expiry(Some(7200));
311 assert!(expiry_time.is_some());
312
313 let remaining = limits.get_remaining_expiry(expiry_time);
314 assert!(remaining.is_some());
315 assert!(
316 remaining.unwrap()
317 <= u32::try_from(crate::constants::time::DEFAULT_SESSION_EXPIRY.as_secs())
318 .unwrap_or(u32::MAX)
319 );
320 }
321}