mqtt_proto/common/
utils.rs1use core::slice;
2
3use alloc::string::String;
4use alloc::vec::Vec;
5
6use simdutf8::basic::from_utf8;
7
8use crate::{from_read_exact_error, AsyncRead, Encodable, Error, SyncWrite};
9
10#[inline]
12pub async fn decode_raw_header<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(u8, u32), Error> {
13 let typ = read_u8(reader).await?;
14 let (remaining_len, _bytes) = decode_var_int(reader).await?;
15 Ok((typ, remaining_len))
16}
17
18#[inline]
19pub(crate) async fn read_string<T: AsyncRead + Unpin>(reader: &mut T) -> Result<String, Error> {
20 let data_buf = read_bytes(reader).await?;
21 let _str = from_utf8(&data_buf).map_err(|_| Error::InvalidString)?;
22 Ok(unsafe { String::from_utf8_unchecked(data_buf) })
23}
24
25#[inline]
26pub(crate) async fn read_bytes<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Vec<u8>, Error> {
27 let data_len = read_u16(reader).await?;
28 let mut data_buf = alloc::vec![0u8; data_len as usize];
29 reader
30 .read_exact(&mut data_buf)
31 .await
32 .map_err(from_read_exact_error)?;
33 Ok(data_buf)
34}
35
36#[inline]
38pub(crate) async fn read_u32<T: AsyncRead + Unpin>(reader: &mut T) -> Result<u32, Error> {
39 let mut len4_bytes = [0u8; 4];
40 reader
41 .read_exact(&mut len4_bytes)
42 .await
43 .map_err(from_read_exact_error)?;
44 Ok(u32::from_be_bytes(len4_bytes))
45}
46
47#[inline]
48pub(crate) async fn read_u16<T: AsyncRead + Unpin>(reader: &mut T) -> Result<u16, Error> {
49 let mut len2_bytes = [0u8; 2];
50 reader
51 .read_exact(&mut len2_bytes)
52 .await
53 .map_err(from_read_exact_error)?;
54 Ok(u16::from_be_bytes(len2_bytes))
55}
56
57#[inline]
58pub(crate) async fn read_u8<T: AsyncRead + Unpin>(reader: &mut T) -> Result<u8, Error> {
59 let mut byte = [0u8; 1];
60 reader
61 .read_exact(&mut byte)
62 .await
63 .map_err(from_read_exact_error)?;
64 Ok(byte[0])
65}
66
67pub(crate) fn write_string<W: SyncWrite>(writer: &mut W, value: &str) -> Result<(), Error> {
68 write_bytes(writer, value.as_bytes())?;
69 Ok(())
70}
71
72#[inline]
73pub(crate) fn write_bytes<W: SyncWrite>(writer: &mut W, data: &[u8]) -> Result<(), Error> {
74 write_u16(writer, data.len() as u16)?;
75 writer.write_all(data)?;
76 Ok(())
77}
78
79#[inline]
80pub(crate) fn write_u32<W: SyncWrite>(writer: &mut W, value: u32) -> Result<(), Error> {
81 writer.write_all(&value.to_be_bytes())?;
82 Ok(())
83}
84
85#[inline]
86pub(crate) fn write_u16<W: SyncWrite>(writer: &mut W, value: u16) -> Result<(), Error> {
87 writer.write_all(&value.to_be_bytes())?;
88 Ok(())
89}
90
91#[inline]
92pub(crate) fn write_u8<W: SyncWrite>(writer: &mut W, value: u8) -> Result<(), Error> {
93 writer.write_all(slice::from_ref(&value))?;
94 Ok(())
95}
96
97#[inline]
98pub(crate) fn write_var_int<W: SyncWrite>(writer: &mut W, mut len: usize) -> Result<(), Error> {
99 loop {
100 let mut byte = (len % 128) as u8;
101 len /= 128;
102 if len > 0 {
103 byte |= 128;
104 }
105 write_u8(writer, byte)?;
106 if len == 0 {
107 break;
108 }
109 }
110 Ok(())
111}
112
113#[inline]
115pub(crate) async fn decode_var_int<T: AsyncRead + Unpin>(
116 reader: &mut T,
117) -> Result<(u32, usize), Error> {
118 let mut var_int: u32 = 0;
119 let mut i = 0;
120 loop {
121 let mut buf = [0u8; 1];
122 reader
123 .read_exact(&mut buf)
124 .await
125 .map_err(from_read_exact_error)?;
126 let byte = buf[0];
127 var_int |= (u32::from(byte) & 0x7F) << (7 * i);
128 if byte & 0x80 == 0 {
129 break;
130 } else if i < 3 {
131 i += 1;
132 } else {
133 return Err(Error::InvalidVarByteInt);
134 }
135 }
136 Ok((var_int, i + 1))
137}
138
139#[inline]
141pub fn var_int_len(value: usize) -> Result<usize, Error> {
142 let len = if value < 128 {
143 1
144 } else if value < 16384 {
145 2
146 } else if value < 2097152 {
147 3
148 } else if value < 268435456 {
149 4
150 } else {
151 return Err(Error::InvalidVarByteInt);
152 };
153 Ok(len)
154}
155
156#[inline]
158pub fn total_len(remaining_len: usize) -> Result<usize, Error> {
159 let header_len = if remaining_len < 128 {
160 2
161 } else if remaining_len < 16384 {
162 3
163 } else if remaining_len < 2097152 {
164 4
165 } else if remaining_len < 268435456 {
166 5
167 } else {
168 return Err(Error::InvalidVarByteInt);
169 };
170 Ok(header_len + remaining_len)
171}
172
173#[inline]
176pub fn remaining_len(total_len: usize) -> usize {
177 total_len - header_len(total_len)
178}
179
180#[inline]
183pub fn header_len(total_len: usize) -> usize {
184 if total_len < 128 + 2 {
185 2
186 } else if total_len < 16384 + 3 {
187 3
188 } else if total_len < 2097152 + 4 {
189 4
190 } else {
191 5
192 }
193}
194
195#[inline]
197pub(crate) fn encode_packet<E: Encodable>(control_byte: u8, body: &E) -> Result<Vec<u8>, Error> {
198 let remaining_len = body.encode_len();
199 let total = total_len(remaining_len)?;
200 let mut buf = Vec::with_capacity(total);
201
202 buf.push(control_byte);
204 write_var_int(&mut buf, remaining_len)?;
205
206 body.encode(&mut buf)?;
207 debug_assert_eq!(buf.len(), total);
208 Ok(buf)
209}
210
211macro_rules! packet_from {
212 ($($t:ident),+) => {
213 $(
214 impl From<$t> for Packet {
215 fn from(p: $t) -> Self {
216 Packet::$t(p)
217 }
218 }
219 )+
220 }
221}
222
223pub(crate) use packet_from;
224
225#[cfg(test)]
226mod tests {
227 use crate::block_on;
228
229 use super::*;
230
231 #[test]
232 fn test_decode_var_int() {
233 for (mut data, value, size) in [
234 (&[0xff, 0xff, 0xff, 0x7f][..], 268435455, 4),
235 (&[0x80, 0x80, 0x80, 0x01][..], 2097152, 4),
236 (&[0xff, 0xff, 0x7f][..], 2097151, 3),
237 (&[0x80, 0x80, 0x01][..], 16384, 3),
238 (&[0xff, 0x7f][..], 16383, 2),
239 (&[0x80, 0x01][..], 128, 2),
240 (&[0x7f][..], 127, 1),
241 (&[0x00][..], 0, 1),
242 ] {
243 assert_eq!(block_on(decode_var_int(&mut data)).unwrap(), (value, size));
244 }
245
246 let mut err_data = &[0xff, 0xff, 0xff][..];
247 assert!(block_on(decode_var_int(&mut err_data))
248 .unwrap_err()
249 .is_eof());
250 }
251}