mqtt_proto/common/
utils.rs

1use core::slice;
2
3use alloc::string::String;
4use alloc::vec::Vec;
5
6use simdutf8::basic::from_utf8;
7
8use crate::{from_read_exact_error, AsyncRead, Encodable, Error, SyncWrite};
9
10/// Read first byte(packet type and flags) and decode remaining length
11#[inline]
12pub async fn decode_raw_header<T: AsyncRead + Unpin>(reader: &mut T) -> Result<(u8, u32), Error> {
13    let typ = read_u8(reader).await?;
14    let (remaining_len, _bytes) = decode_var_int(reader).await?;
15    Ok((typ, remaining_len))
16}
17
18#[inline]
19pub(crate) async fn read_string<T: AsyncRead + Unpin>(reader: &mut T) -> Result<String, Error> {
20    let data_buf = read_bytes(reader).await?;
21    let _str = from_utf8(&data_buf).map_err(|_| Error::InvalidString)?;
22    Ok(unsafe { String::from_utf8_unchecked(data_buf) })
23}
24
25#[inline]
26pub(crate) async fn read_bytes<T: AsyncRead + Unpin>(reader: &mut T) -> Result<Vec<u8>, Error> {
27    let data_len = read_u16(reader).await?;
28    let mut data_buf = alloc::vec![0u8; data_len as usize];
29    reader
30        .read_exact(&mut data_buf)
31        .await
32        .map_err(from_read_exact_error)?;
33    Ok(data_buf)
34}
35
36// Only for v5.0
37#[inline]
38pub(crate) async fn read_u32<T: AsyncRead + Unpin>(reader: &mut T) -> Result<u32, Error> {
39    let mut len4_bytes = [0u8; 4];
40    reader
41        .read_exact(&mut len4_bytes)
42        .await
43        .map_err(from_read_exact_error)?;
44    Ok(u32::from_be_bytes(len4_bytes))
45}
46
47#[inline]
48pub(crate) async fn read_u16<T: AsyncRead + Unpin>(reader: &mut T) -> Result<u16, Error> {
49    let mut len2_bytes = [0u8; 2];
50    reader
51        .read_exact(&mut len2_bytes)
52        .await
53        .map_err(from_read_exact_error)?;
54    Ok(u16::from_be_bytes(len2_bytes))
55}
56
57#[inline]
58pub(crate) async fn read_u8<T: AsyncRead + Unpin>(reader: &mut T) -> Result<u8, Error> {
59    let mut byte = [0u8; 1];
60    reader
61        .read_exact(&mut byte)
62        .await
63        .map_err(from_read_exact_error)?;
64    Ok(byte[0])
65}
66
67pub(crate) fn write_string<W: SyncWrite>(writer: &mut W, value: &str) -> Result<(), Error> {
68    write_bytes(writer, value.as_bytes())?;
69    Ok(())
70}
71
72#[inline]
73pub(crate) fn write_bytes<W: SyncWrite>(writer: &mut W, data: &[u8]) -> Result<(), Error> {
74    write_u16(writer, data.len() as u16)?;
75    writer.write_all(data)?;
76    Ok(())
77}
78
79#[inline]
80pub(crate) fn write_u32<W: SyncWrite>(writer: &mut W, value: u32) -> Result<(), Error> {
81    writer.write_all(&value.to_be_bytes())?;
82    Ok(())
83}
84
85#[inline]
86pub(crate) fn write_u16<W: SyncWrite>(writer: &mut W, value: u16) -> Result<(), Error> {
87    writer.write_all(&value.to_be_bytes())?;
88    Ok(())
89}
90
91#[inline]
92pub(crate) fn write_u8<W: SyncWrite>(writer: &mut W, value: u8) -> Result<(), Error> {
93    writer.write_all(slice::from_ref(&value))?;
94    Ok(())
95}
96
97#[inline]
98pub(crate) fn write_var_int<W: SyncWrite>(writer: &mut W, mut len: usize) -> Result<(), Error> {
99    loop {
100        let mut byte = (len % 128) as u8;
101        len /= 128;
102        if len > 0 {
103            byte |= 128;
104        }
105        write_u8(writer, byte)?;
106        if len == 0 {
107            break;
108        }
109    }
110    Ok(())
111}
112
113/// Decode a variable byte integer (4 bytes max)
114#[inline]
115pub(crate) async fn decode_var_int<T: AsyncRead + Unpin>(
116    reader: &mut T,
117) -> Result<(u32, usize), Error> {
118    let mut var_int: u32 = 0;
119    let mut i = 0;
120    loop {
121        let mut buf = [0u8; 1];
122        reader
123            .read_exact(&mut buf)
124            .await
125            .map_err(from_read_exact_error)?;
126        let byte = buf[0];
127        var_int |= (u32::from(byte) & 0x7F) << (7 * i);
128        if byte & 0x80 == 0 {
129            break;
130        } else if i < 3 {
131            i += 1;
132        } else {
133            return Err(Error::InvalidVarByteInt);
134        }
135    }
136    Ok((var_int, i + 1))
137}
138
139/// Return the encoded size of the variable byte integer.
140#[inline]
141pub fn var_int_len(value: usize) -> Result<usize, Error> {
142    let len = if value < 128 {
143        1
144    } else if value < 16384 {
145        2
146    } else if value < 2097152 {
147        3
148    } else if value < 268435456 {
149        4
150    } else {
151        return Err(Error::InvalidVarByteInt);
152    };
153    Ok(len)
154}
155
156/// Return the packet total encoded length by a given remaining length.
157#[inline]
158pub fn total_len(remaining_len: usize) -> Result<usize, Error> {
159    let header_len = if remaining_len < 128 {
160        2
161    } else if remaining_len < 16384 {
162        3
163    } else if remaining_len < 2097152 {
164        4
165    } else if remaining_len < 268435456 {
166        5
167    } else {
168        return Err(Error::InvalidVarByteInt);
169    };
170    Ok(header_len + remaining_len)
171}
172
173/// Calculate remaining length by given total length (the total length MUST be
174/// valid value).
175#[inline]
176pub fn remaining_len(total_len: usize) -> usize {
177    total_len - header_len(total_len)
178}
179
180/// Calculate header length by given total length (the total length MUST be
181/// valid value).
182#[inline]
183pub fn header_len(total_len: usize) -> usize {
184    if total_len < 128 + 2 {
185        2
186    } else if total_len < 16384 + 3 {
187        3
188    } else if total_len < 2097152 + 4 {
189        4
190    } else {
191        5
192    }
193}
194
195/// Encode packet use control byte and body type
196#[inline]
197pub(crate) fn encode_packet<E: Encodable>(control_byte: u8, body: &E) -> Result<Vec<u8>, Error> {
198    let remaining_len = body.encode_len();
199    let total = total_len(remaining_len)?;
200    let mut buf = Vec::with_capacity(total);
201
202    // encode header
203    buf.push(control_byte);
204    write_var_int(&mut buf, remaining_len)?;
205
206    body.encode(&mut buf)?;
207    debug_assert_eq!(buf.len(), total);
208    Ok(buf)
209}
210
211macro_rules! packet_from {
212    ($($t:ident),+) => {
213        $(
214            impl From<$t> for Packet {
215                fn from(p: $t) -> Self {
216                    Packet::$t(p)
217                }
218            }
219        )+
220    }
221}
222
223pub(crate) use packet_from;
224
225#[cfg(test)]
226mod tests {
227    use crate::block_on;
228
229    use super::*;
230
231    #[test]
232    fn test_decode_var_int() {
233        for (mut data, value, size) in [
234            (&[0xff, 0xff, 0xff, 0x7f][..], 268435455, 4),
235            (&[0x80, 0x80, 0x80, 0x01][..], 2097152, 4),
236            (&[0xff, 0xff, 0x7f][..], 2097151, 3),
237            (&[0x80, 0x80, 0x01][..], 16384, 3),
238            (&[0xff, 0x7f][..], 16383, 2),
239            (&[0x80, 0x01][..], 128, 2),
240            (&[0x7f][..], 127, 1),
241            (&[0x00][..], 0, 1),
242        ] {
243            assert_eq!(block_on(decode_var_int(&mut data)).unwrap(), (value, size));
244        }
245
246        let mut err_data = &[0xff, 0xff, 0xff][..];
247        assert!(block_on(decode_var_int(&mut err_data))
248            .unwrap_err()
249            .is_eof());
250    }
251}