1use std::io::Cursor;
2
3use actix_codec::{Decoder, Encoder};
4use bytes::BytesMut;
5
6use crate::error::ParseError;
7use crate::proto::QoS;
8use crate::{Packet, Publish};
9
10mod decode;
11mod encode;
12
13use self::decode::*;
14use self::encode::*;
15
16bitflags! {
17 pub struct ConnectFlags: u8 {
18 const USERNAME = 0b1000_0000;
19 const PASSWORD = 0b0100_0000;
20 const WILL_RETAIN = 0b0010_0000;
21 const WILL_QOS = 0b0001_1000;
22 const WILL = 0b0000_0100;
23 const CLEAN_SESSION = 0b0000_0010;
24 }
25}
26
27pub const WILL_QOS_SHIFT: u8 = 3;
28
29bitflags! {
30 pub struct ConnectAckFlags: u8 {
31 const SESSION_PRESENT = 0b0000_0001;
32 }
33}
34
35#[derive(Debug)]
36pub struct Codec {
37 state: DecodeState,
38 max_size: usize,
39}
40
41#[derive(Debug, Clone, Copy)]
42enum DecodeState {
43 FrameHeader,
44 Frame(FixedHeader),
45}
46
47impl Codec {
48 pub fn new() -> Self {
50 Codec {
51 state: DecodeState::FrameHeader,
52 max_size: 0,
53 }
54 }
55
56 pub fn max_size(mut self, size: usize) -> Self {
61 self.max_size = size;
62 self
63 }
64}
65
66impl Default for Codec {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl Decoder for Codec {
73 type Item = Packet;
74 type Error = ParseError;
75
76 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, ParseError> {
77 loop {
78 match self.state {
79 DecodeState::FrameHeader => {
80 if src.len() < 2 {
81 return Ok(None);
82 }
83 let fixed = src.as_ref()[0];
84 match decode_variable_length(&src.as_ref()[1..])? {
85 Some((remaining_length, consumed)) => {
86 if self.max_size != 0 && self.max_size < remaining_length {
88 return Err(ParseError::MaxSizeExceeded);
89 }
90 src.split_to(consumed + 1);
91 self.state = DecodeState::Frame(FixedHeader {
92 packet_type: fixed >> 4,
93 packet_flags: fixed & 0xF,
94 remaining_length,
95 });
96 if src.len() < remaining_length {
98 src.reserve(remaining_length); return Ok(None);
101 }
102 }
103 None => {
104 return Ok(None);
105 }
106 }
107 }
108 DecodeState::Frame(fixed) => {
109 if src.len() < fixed.remaining_length {
110 return Ok(None);
111 }
112 let packet_buf = src.split_to(fixed.remaining_length);
113 let mut packet_cur = Cursor::new(packet_buf.freeze());
114 let packet = read_packet(&mut packet_cur, fixed)?;
115 self.state = DecodeState::FrameHeader;
116 src.reserve(2);
117 return Ok(Some(packet));
118 }
119 }
120 }
121 }
122}
123
124impl Encoder for Codec {
125 type Item = Packet;
126 type Error = ParseError;
127
128 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), ParseError> {
129 if let Packet::Publish(Publish { qos, packet_id, .. }) = item {
130 if (qos == QoS::AtLeastOnce || qos == QoS::ExactlyOnce) && packet_id.is_none() {
131 return Err(ParseError::PacketIdRequired);
132 }
133 }
134 let content_size = get_encoded_size(&item);
135 dst.reserve(content_size + 5);
136 write_packet(&item, dst, content_size);
137 Ok(())
138 }
139}
140
141#[derive(Debug, PartialEq, Clone, Copy)]
142pub(crate) struct FixedHeader {
143 pub packet_type: u8,
145 pub packet_flags: u8,
147 pub remaining_length: usize,
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_max_size() {
158 let mut codec = Codec::new().max_size(5);
159
160 let mut buf = BytesMut::new();
161 buf.extend_from_slice(b"\0\x09");
162 assert_eq!(codec.decode(&mut buf), Err(ParseError::MaxSizeExceeded));
163 }
164}