1use std::cell::Cell;
2
3use bytes::{Buf, BytesMut};
4use tokio_util::codec::{Decoder, Encoder};
5
6use super::{decode, encode, Packet};
7use crate::error::{DecodeError, EncodeError};
8use crate::types::{FixedHeader, QoS};
9use crate::utils::decode_variable_length;
10
11#[derive(Debug, Clone)]
12pub struct Codec {
14 state: Cell<DecodeState>,
15 max_size: Cell<u32>,
16}
17
18#[derive(Debug, Copy, Clone, PartialEq, Eq)]
19enum DecodeState {
20 FrameHeader,
21 Frame(FixedHeader),
22}
23
24impl Codec {
25 pub fn new(max_packet_size: u32) -> Self {
27 Codec { state: Cell::new(DecodeState::FrameHeader), max_size: Cell::new(max_packet_size) }
28 }
29
30 pub fn set_max_size(&mut self, size: u32) {
35 self.max_size.set(size);
36 }
37}
38
39impl Default for Codec {
40 fn default() -> Self {
41 Self::new(0)
42 }
43}
44
45impl Decoder for Codec {
46 type Item = (Packet, u32);
47 type Error = DecodeError;
48
49 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, DecodeError> {
50 loop {
51 match self.state.get() {
52 DecodeState::FrameHeader => {
53 if src.len() < 2 {
54 return Ok(None);
55 }
56 let src_slice = src.as_ref();
57 let first_byte = src_slice[0];
58 match decode_variable_length(&src_slice[1..])? {
59 Some((remaining_length, consumed)) => {
60 let max_size = self.max_size.get();
62 if max_size != 0 && max_size < remaining_length {
63 return Err(DecodeError::MaxSizeExceeded);
64 }
65 src.advance(consumed + 1);
66 self.state.set(DecodeState::Frame(FixedHeader { first_byte, remaining_length }));
67 let remaining_length = remaining_length as usize;
69 if src.len() < remaining_length {
70 src.reserve(remaining_length); return Ok(None);
73 }
74 }
75 None => {
76 return Ok(None);
77 }
78 }
79 }
80 DecodeState::Frame(fixed) => {
81 if src.len() < fixed.remaining_length as usize {
82 return Ok(None);
83 }
84 let packet_buf = src.split_to(fixed.remaining_length as usize);
85 let packet = decode::decode_packet(packet_buf.freeze(), fixed.first_byte)?;
86 self.state.set(DecodeState::FrameHeader);
87 src.reserve(2);
88 return Ok(Some((packet, fixed.remaining_length)));
89 }
90 }
91 }
92 }
93}
94
95impl Encoder<Packet> for Codec {
96 type Error = EncodeError;
98
99 fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), EncodeError> {
100 if let Packet::Publish(ref publish) = item {
101 if (publish.qos == QoS::AtLeastOnce || publish.qos == QoS::ExactlyOnce)
102 && publish.packet_id.is_none()
103 {
104 return Err(EncodeError::PacketIdRequired);
105 }
106 }
107 let content_size = encode::get_encoded_size(&item);
108 dst.reserve(content_size + 5);
109 encode::encode(&item, dst, content_size as u32)?;
110 Ok(())
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::v3::packet::Publish;
118 use bytes::Bytes;
119 use bytestring::ByteString;
120 use rmqtt_utils::timestamp_millis;
121
122 #[test]
123 fn test_max_size() {
124 let mut codec = Codec::default();
125 codec.set_max_size(5);
126
127 let mut buf = BytesMut::new();
128 buf.extend_from_slice(b"\0\x09");
129 assert_eq!(codec.decode(&mut buf).map_err(|e| matches!(e, DecodeError::MaxSizeExceeded)), Err(true));
130 }
131
132 #[test]
133 fn test_packet() {
134 let mut codec = Codec::default();
135 let mut buf = BytesMut::new();
136
137 let mut pkt = Box::new(Publish {
138 dup: false,
139 retain: false,
140 qos: QoS::AtMostOnce,
141 topic: ByteString::from_static("/test"),
142 packet_id: None,
143 payload: Bytes::from(Vec::from("a".repeat(260 * 1024))),
144 properties: None,
145 delay_interval: None,
146 create_time: None,
147 });
148 codec.encode(Packet::Publish(pkt.clone()), &mut buf).unwrap();
149
150 let pkt2 =
151 if let (Packet::Publish(v), _) = codec.decode(&mut buf).unwrap().unwrap() { v } else { panic!() };
152 pkt.create_time = Some(timestamp_millis());
153 assert_eq!(pkt, pkt2);
154 }
155}