1use std::cell::Cell;
2
3use bytes::{Buf, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use super::{decode, encode, Packet};
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{FixedHeader, QoS};
9use crate::utils::decode_variable_length;
10
11#[derive(Debug, Clone)]
12pub struct Codec {
14 state: Cell<DecodeState>,
15 max_size: Cell<u32>,
16}
17
18#[derive(Debug, Copy, Clone, PartialEq, Eq)]
19enum DecodeState {
20 FrameHeader,
21 Frame(FixedHeader),
22}
23
24impl Codec {
25 pub fn new(max_packet_size: u32) -> Self {
27 Codec { state: Cell::new(DecodeState::FrameHeader), max_size: Cell::new(max_packet_size) }
28 }
29
30 pub fn set_max_size(&mut self, size: u32) {
35 self.max_size.set(size);
36 }
37}
38
39impl Default for Codec {
40 fn default() -> Self {
41 Self::new(0)
42 }
43}
44
45impl Decoder for Codec {
46 type Item = (Packet, u32);
47 type Error = DecodeError;
48
49 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
50 loop {
51 match self.state.get() {
52 DecodeState::FrameHeader => {
53 if src.len() < 2 {
54 return Ok(None);
55 }
56 let src_slice = src.as_ref();
57 let first_byte = src_slice[0];
58 match decode_variable_length(&src_slice[1..])? {
59 Some((remaining_length, consumed)) => {
60 let max_size = self.max_size.get();
62 if max_size != 0 && max_size < remaining_length {
63 return Err(DecodeError::MaxSizeExceeded);
64 }
65 src.advance(consumed + 1);
66 self.state.set(DecodeState::Frame(FixedHeader { first_byte, remaining_length }));
67 let remaining_length = remaining_length as usize;
69 if src.len() < remaining_length {
70 src.reserve(remaining_length); return Ok(None);
73 }
74 }
75 None => {
76 return Ok(None);
77 }
78 }
79 }
80 DecodeState::Frame(fixed) => {
81 if src.len() < fixed.remaining_length as usize {
82 return Ok(None);
83 }
84 let packet_buf = src.split_to(fixed.remaining_length as usize);
85 let packet = decode::decode_packet(packet_buf.freeze(), fixed.first_byte)?;
86 self.state.set(DecodeState::FrameHeader);
87 src.reserve(2);
88 return Ok(Some((packet, fixed.remaining_length)));
89 }
90 }
91 }
92 }
93}
94
95impl Encoder<Packet> for Codec {
96 type Error = EncodeError;
98
99 fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), EncodeError> {
100 if let Packet::Publish(ref publish) = item {
101 if (publish.qos == QoS::AtLeastOnce || publish.qos == QoS::ExactlyOnce)
102 && publish.packet_id.is_none()
103 {
104 return Err(EncodeError::PacketIdRequired);
105 }
106 }
107 let content_size = encode::get_encoded_size(&item);
108 dst.reserve(content_size + 5);
109 encode::encode(&item, dst, content_size as u32)?;
110 Ok(())
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::v3::packet::Publish;
118 use bytes::Bytes;
119 use bytestring::ByteString;
120
121 #[test]
122 fn test_max_size() {
123 let mut codec = Codec::default();
124 codec.set_max_size(5);
125
126 let mut buf = BytesMut::new();
127 buf.extend_from_slice(b"\0\x09");
128 assert_eq!(codec.decode(&mut buf).map_err(|e| matches!(e, DecodeError::MaxSizeExceeded)), Err(true));
129 }
130
131 #[test]
132 fn test_packet() {
133 let mut codec = Codec::default();
134 let mut buf = BytesMut::new();
135
136 let pkt = Box::new(Publish {
137 dup: false,
138 retain: false,
139 qos: QoS::AtMostOnce,
140 topic: ByteString::from_static("/test"),
141 packet_id: None,
142 payload: Bytes::from(Vec::from("a".repeat(260 * 1024))),
143 properties: None,
144 });
145 codec.encode(Packet::Publish(pkt.clone()), &mut buf).unwrap();
146
147 let pkt2 =
148 if let (Packet::Publish(v), _) = codec.decode(&mut buf).unwrap().unwrap() { v } else { panic!() };
149 assert_eq!(pkt, pkt2);
150 }
151}