1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//
//   This Source Code Form is subject to the terms of the Mozilla Public
//   License, v. 2.0. If a copy of the MPL was not distributed with this
//   file, You can obtain one at http://mozilla.org/MPL/2.0/.
//

use futures::{AsyncWrite, AsyncWriteExt};
use nom::{bytes::complete::take, number::complete::be_u16, IResult, Parser};
use nom_supreme::ParserExt;

use super::errors::MPacketWriteError;

/// A v3 MQTT string as defined in section 1.5.3
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct MString<'message> {
    pub value: &'message str,
}

impl<'message> std::ops::Deref for MString<'message> {
    type Target = str;

    fn deref(&self) -> &Self::Target {
        self.value
    }
}

impl<'message> MString<'message> {
    pub fn get_len(mstr: &MString<'_>) -> usize {
        2 + mstr.value.len()
    }

    pub(crate) async fn write_to<W: AsyncWrite>(
        mstr: &MString<'_>,
        writer: &mut std::pin::Pin<&mut W>,
    ) -> Result<(), MPacketWriteError> {
        writer
            .write_all(&(mstr.value.len() as u16).to_be_bytes())
            .await?;
        writer.write_all(mstr.value.as_bytes()).await?;

        Ok(())
    }
}

#[derive(Debug, thiserror::Error)]
pub enum MStringError {
    #[error("The input contained control characters, which this implementation rejects.")]
    ControlCharacters,
}

fn control_characters(c: char) -> bool {
    ('\u{0001}'..='\u{001F}').contains(&c) || ('\u{007F}'..='\u{009F}').contains(&c)
}

pub fn mstring(input: &[u8]) -> IResult<&[u8], MString<'_>> {
    let len = be_u16;
    let string_data = len.flat_map(take);

    string_data
        .map_res(|data| std::str::from_utf8(data).map(|s| MString { value: s }))
        .map_res(|s| {
            if s.contains(control_characters) {
                Err(MStringError::ControlCharacters)
            } else {
                Ok(s)
            }
        })
        .parse(input)
}

#[cfg(test)]
mod tests {
    use std::pin::Pin;

    use super::{mstring, MString};

    // TODO(neikos): Unclear how MQTT-1.5.3-3 is to be tested. Since we don't touch the stream, I
    // think we are fulfilling that requirement

    // MQTT-1.5.3-1
    #[test]
    fn check_simple_string() {
        let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];

        let s = mstring(&input);

        assert_eq!(
            s,
            Ok((
                &[][..],
                MString {
                    value: "A\u{2A6D4}"
                }
            ))
        )
    }

    #[tokio::test]
    async fn check_simple_string_roundtrip() {
        let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];

        let (_, s) = mstring(&input).unwrap();

        let mut vec = vec![];

        MString::write_to(&s, &mut Pin::new(&mut vec))
            .await
            .unwrap();

        assert_eq!(input, &vec[..])
    }

    // MQTT-1.5.3-2
    #[test]
    fn check_forbidden_characters() {
        let input = [0x00, 0x02, 0x00, 0x01];

        let s = mstring(&input);

        assert_eq!(
            s,
            Err(nom::Err::Error(nom::error::Error {
                input: &input[..],
                code: nom::error::ErrorKind::MapRes
            }))
        )
    }
}