1use crate::{Error, QoS, Result, MAGIC_BYTE};
22use bytes::{Buf, BufMut, Bytes, BytesMut};
23
24pub const HEADER_SIZE: usize = 4;
26
27pub const HEADER_SIZE_WITH_TS: usize = 12;
29
30pub const MAX_PAYLOAD_SIZE: usize = 65535;
32
33#[derive(Debug, Clone, Copy, Default)]
35pub struct FrameFlags {
36 pub qos: QoS,
37 pub has_timestamp: bool,
38 pub encrypted: bool,
39 pub compressed: bool,
40 pub version: u8,
42}
43
44impl FrameFlags {
45 pub fn to_byte(&self) -> u8 {
46 let mut flags = 0u8;
47 flags |= (self.qos as u8) << 6;
48 if self.has_timestamp {
49 flags |= 0x20;
50 }
51 if self.encrypted {
52 flags |= 0x10;
53 }
54 if self.compressed {
55 flags |= 0x08;
56 }
57 flags |= self.version & 0x07;
59 flags
60 }
61
62 pub fn from_byte(byte: u8) -> Self {
63 Self {
64 qos: QoS::from_u8((byte >> 6) & 0x03).unwrap_or(QoS::Fire),
65 has_timestamp: (byte & 0x20) != 0,
66 encrypted: (byte & 0x10) != 0,
67 compressed: (byte & 0x08) != 0,
68 version: byte & 0x07,
69 }
70 }
71
72 pub fn is_binary_encoding(&self) -> bool {
74 self.version >= 1
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct Frame {
81 pub flags: FrameFlags,
82 pub timestamp: Option<u64>,
83 pub payload: Bytes,
84}
85
86impl Frame {
87 pub fn new(payload: impl Into<Bytes>) -> Self {
89 Self {
90 flags: FrameFlags::default(),
91 timestamp: None,
92 payload: payload.into(),
93 }
94 }
95
96 pub fn with_qos(mut self, qos: QoS) -> Self {
98 self.flags.qos = qos;
99 self
100 }
101
102 pub fn with_timestamp(mut self, timestamp: u64) -> Self {
104 self.timestamp = Some(timestamp);
105 self.flags.has_timestamp = true;
106 self
107 }
108
109 pub fn with_encrypted(mut self, encrypted: bool) -> Self {
111 self.flags.encrypted = encrypted;
112 self
113 }
114
115 pub fn with_compressed(mut self, compressed: bool) -> Self {
117 self.flags.compressed = compressed;
118 self
119 }
120
121 pub fn size(&self) -> usize {
123 let header = if self.flags.has_timestamp {
124 HEADER_SIZE_WITH_TS
125 } else {
126 HEADER_SIZE
127 };
128 header + self.payload.len()
129 }
130
131 pub fn encode(&self) -> Result<Bytes> {
133 if self.payload.len() > MAX_PAYLOAD_SIZE {
134 return Err(Error::PayloadTooLarge(self.payload.len()));
135 }
136
137 let mut buf = BytesMut::with_capacity(self.size());
138
139 buf.put_u8(MAGIC_BYTE);
141
142 buf.put_u8(self.flags.to_byte());
144
145 buf.put_u16(self.payload.len() as u16);
147
148 if let Some(ts) = self.timestamp {
150 buf.put_u64(ts);
151 }
152
153 buf.extend_from_slice(&self.payload);
155
156 Ok(buf.freeze())
157 }
158
159 pub fn decode(mut buf: impl Buf) -> Result<Self> {
161 if buf.remaining() < HEADER_SIZE {
162 return Err(Error::BufferTooSmall {
163 needed: HEADER_SIZE,
164 have: buf.remaining(),
165 });
166 }
167
168 let magic = buf.get_u8();
170 if magic != MAGIC_BYTE {
171 return Err(Error::InvalidMagic(magic));
172 }
173
174 let flags = FrameFlags::from_byte(buf.get_u8());
176
177 let payload_len = buf.get_u16() as usize;
179
180 let header_size = if flags.has_timestamp {
182 HEADER_SIZE_WITH_TS
183 } else {
184 HEADER_SIZE
185 };
186 let total_remaining = if flags.has_timestamp { 8 } else { 0 } + payload_len;
187
188 if buf.remaining() < total_remaining {
189 return Err(Error::BufferTooSmall {
190 needed: header_size + payload_len,
191 have: HEADER_SIZE + buf.remaining(),
192 });
193 }
194
195 let timestamp = if flags.has_timestamp {
197 Some(buf.get_u64())
198 } else {
199 None
200 };
201
202 let payload = buf.copy_to_bytes(payload_len);
204
205 Ok(Self {
206 flags,
207 timestamp,
208 payload,
209 })
210 }
211
212 pub fn check_complete(buf: &[u8]) -> Option<usize> {
214 if buf.len() < HEADER_SIZE {
215 return None;
216 }
217
218 if buf[0] != MAGIC_BYTE {
219 return None;
220 }
221
222 let flags = FrameFlags::from_byte(buf[1]);
223 let payload_len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
224
225 let header_size = if flags.has_timestamp {
226 HEADER_SIZE_WITH_TS
227 } else {
228 HEADER_SIZE
229 };
230
231 let total_size = header_size + payload_len;
232
233 if buf.len() >= total_size {
234 Some(total_size)
235 } else {
236 None
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_frame_encode_decode() {
247 let payload = b"hello world";
248 let frame = Frame::new(payload.as_slice())
249 .with_qos(QoS::Confirm)
250 .with_timestamp(1234567890);
251
252 let encoded = frame.encode().unwrap();
253 let decoded = Frame::decode(&encoded[..]).unwrap();
254
255 assert_eq!(decoded.flags.qos, QoS::Confirm);
256 assert_eq!(decoded.timestamp, Some(1234567890));
257 assert_eq!(decoded.payload.as_ref(), payload);
258 }
259
260 #[test]
261 fn test_flags_roundtrip() {
262 let flags = FrameFlags {
263 qos: QoS::Commit,
264 has_timestamp: true,
265 encrypted: true,
266 compressed: false,
267 version: 1, };
269
270 let byte = flags.to_byte();
271 let decoded = FrameFlags::from_byte(byte);
272
273 assert_eq!(decoded.qos, QoS::Commit);
274 assert!(decoded.has_timestamp);
275 assert!(decoded.encrypted);
276 assert!(!decoded.compressed);
277 assert_eq!(decoded.version, 1);
278 assert!(decoded.is_binary_encoding());
279 }
280
281 #[test]
282 fn test_flags_version_bits() {
283 let v2_flags = FrameFlags {
285 version: 0,
286 ..Default::default()
287 };
288 assert!(!v2_flags.is_binary_encoding());
289
290 let v3_flags = FrameFlags {
292 version: 1,
293 ..Default::default()
294 };
295 assert!(v3_flags.is_binary_encoding());
296 }
297
298 #[test]
299 fn test_check_complete() {
300 let frame = Frame::new(b"test".as_slice());
301 let encoded = frame.encode().unwrap();
302
303 assert_eq!(Frame::check_complete(&encoded), Some(encoded.len()));
305
306 assert_eq!(Frame::check_complete(&encoded[..2]), None);
308
309 assert_eq!(Frame::check_complete(&encoded[..5]), None);
311 }
312
313 #[test]
314 fn test_frame_max_payload_size() {
315 let payload = vec![0u8; MAX_PAYLOAD_SIZE];
317 let frame = Frame::new(payload.clone())
318 .with_qos(QoS::Fire)
319 .with_encrypted(true);
320
321 let encoded = frame.encode().expect("encode max payload");
322 let decoded = Frame::decode(&encoded[..]).expect("decode max payload");
323
324 assert_eq!(decoded.payload.len(), MAX_PAYLOAD_SIZE);
325 assert_eq!(decoded.flags.qos, QoS::Fire);
326 assert!(decoded.flags.encrypted);
327 assert!(!decoded.flags.has_timestamp);
328 }
329
330 #[test]
331 fn test_frame_payload_too_large() {
332 let payload = vec![0u8; MAX_PAYLOAD_SIZE + 1];
334 let frame = Frame::new(payload);
335
336 let err = frame.encode().expect_err("expected PayloadTooLarge error");
337 match err {
338 Error::PayloadTooLarge(len) => assert_eq!(len, MAX_PAYLOAD_SIZE + 1),
339 other => panic!("unexpected error: {:?}", other),
340 }
341 }
342
343 #[test]
344 fn test_decode_invalid_magic() {
345 let frame = Frame::new(b"magic".as_slice());
346 let mut encoded_vec = frame.encode().unwrap().to_vec();
347
348 encoded_vec[0] = 0x00;
350
351 let err = Frame::decode(&encoded_vec[..]).expect_err("expected InvalidMagic error");
352 match err {
353 Error::InvalidMagic(byte) => assert_eq!(byte, 0x00),
354 other => panic!("unexpected error: {:?}", other),
355 }
356 }
357
358 #[test]
359 fn test_check_complete_with_timestamp() {
360 let frame = Frame::new(b"ts".as_slice()).with_timestamp(42);
361 let encoded = frame.encode().unwrap();
362
363 assert_eq!(Frame::check_complete(&encoded), Some(encoded.len()));
365
366 let truncated = &encoded[..encoded.len() - 1];
368 assert_eq!(Frame::check_complete(truncated), None);
369 }
370}