mqtt_format/v3/
strings.rs

1//
2//   This Source Code Form is subject to the terms of the Mozilla Public
3//   License, v. 2.0. If a copy of the MPL was not distributed with this
4//   file, You can obtain one at http://mozilla.org/MPL/2.0/.
5//
6
7use futures::{AsyncWrite, AsyncWriteExt};
8use nom::{bytes::complete::take, number::complete::be_u16, IResult, Parser};
9use nom_supreme::ParserExt;
10
11use super::errors::MPacketWriteError;
12
13/// A v3 MQTT string as defined in section 1.5.3
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
15pub struct MString<'message> {
16    pub value: &'message str,
17}
18
19impl<'message> std::ops::Deref for MString<'message> {
20    type Target = str;
21
22    fn deref(&self) -> &Self::Target {
23        self.value
24    }
25}
26
27impl<'message> MString<'message> {
28    pub fn get_len(mstr: &MString<'_>) -> usize {
29        2 + mstr.value.len()
30    }
31
32    pub(crate) async fn write_to<W: AsyncWrite>(
33        mstr: &MString<'_>,
34        writer: &mut std::pin::Pin<&mut W>,
35    ) -> Result<(), MPacketWriteError> {
36        writer
37            .write_all(&(mstr.value.len() as u16).to_be_bytes())
38            .await?;
39        writer.write_all(mstr.value.as_bytes()).await?;
40
41        Ok(())
42    }
43}
44
45#[derive(Debug, thiserror::Error)]
46pub enum MStringError {
47    #[error("The input contained control characters, which this implementation rejects.")]
48    ControlCharacters,
49}
50
51fn control_characters(c: char) -> bool {
52    ('\u{0001}'..='\u{001F}').contains(&c) || ('\u{007F}'..='\u{009F}').contains(&c)
53}
54
55pub fn mstring(input: &[u8]) -> IResult<&[u8], MString<'_>> {
56    let len = be_u16;
57    let string_data = len.flat_map(take);
58
59    string_data
60        .map_res(|data| std::str::from_utf8(data).map(|s| MString { value: s }))
61        .map_res(|s| {
62            if s.contains(control_characters) {
63                Err(MStringError::ControlCharacters)
64            } else {
65                Ok(s)
66            }
67        })
68        .parse(input)
69}
70
71#[cfg(test)]
72mod tests {
73    use std::pin::Pin;
74
75    use super::{mstring, MString};
76
77    // TODO(neikos): Unclear how MQTT-1.5.3-3 is to be tested. Since we don't touch the stream, I
78    // think we are fulfilling that requirement
79
80    // MQTT-1.5.3-1
81    #[test]
82    fn check_simple_string() {
83        let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];
84
85        let s = mstring(&input);
86
87        assert_eq!(
88            s,
89            Ok((
90                &[][..],
91                MString {
92                    value: "A\u{2A6D4}"
93                }
94            ))
95        )
96    }
97
98    #[tokio::test]
99    async fn check_simple_string_roundtrip() {
100        let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];
101
102        let (_, s) = mstring(&input).unwrap();
103
104        let mut vec = vec![];
105
106        MString::write_to(&s, &mut Pin::new(&mut vec))
107            .await
108            .unwrap();
109
110        assert_eq!(input, &vec[..])
111    }
112
113    // MQTT-1.5.3-2
114    #[test]
115    fn check_forbidden_characters() {
116        let input = [0x00, 0x02, 0x00, 0x01];
117
118        let s = mstring(&input);
119
120        assert_eq!(
121            s,
122            Err(nom::Err::Error(nom::error::Error {
123                input: &input[..],
124                code: nom::error::ErrorKind::MapRes
125            }))
126        )
127    }
128}