mqtt_format/v3/
strings.rs1use futures::{AsyncWrite, AsyncWriteExt};
8use nom::{bytes::complete::take, number::complete::be_u16, IResult, Parser};
9use nom_supreme::ParserExt;
10
11use super::errors::MPacketWriteError;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
15pub struct MString<'message> {
16 pub value: &'message str,
17}
18
19impl<'message> std::ops::Deref for MString<'message> {
20 type Target = str;
21
22 fn deref(&self) -> &Self::Target {
23 self.value
24 }
25}
26
27impl<'message> MString<'message> {
28 pub fn get_len(mstr: &MString<'_>) -> usize {
29 2 + mstr.value.len()
30 }
31
32 pub(crate) async fn write_to<W: AsyncWrite>(
33 mstr: &MString<'_>,
34 writer: &mut std::pin::Pin<&mut W>,
35 ) -> Result<(), MPacketWriteError> {
36 writer
37 .write_all(&(mstr.value.len() as u16).to_be_bytes())
38 .await?;
39 writer.write_all(mstr.value.as_bytes()).await?;
40
41 Ok(())
42 }
43}
44
45#[derive(Debug, thiserror::Error)]
46pub enum MStringError {
47 #[error("The input contained control characters, which this implementation rejects.")]
48 ControlCharacters,
49}
50
51fn control_characters(c: char) -> bool {
52 ('\u{0001}'..='\u{001F}').contains(&c) || ('\u{007F}'..='\u{009F}').contains(&c)
53}
54
55pub fn mstring(input: &[u8]) -> IResult<&[u8], MString<'_>> {
56 let len = be_u16;
57 let string_data = len.flat_map(take);
58
59 string_data
60 .map_res(|data| std::str::from_utf8(data).map(|s| MString { value: s }))
61 .map_res(|s| {
62 if s.contains(control_characters) {
63 Err(MStringError::ControlCharacters)
64 } else {
65 Ok(s)
66 }
67 })
68 .parse(input)
69}
70
71#[cfg(test)]
72mod tests {
73 use std::pin::Pin;
74
75 use super::{mstring, MString};
76
77 #[test]
82 fn check_simple_string() {
83 let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];
84
85 let s = mstring(&input);
86
87 assert_eq!(
88 s,
89 Ok((
90 &[][..],
91 MString {
92 value: "A\u{2A6D4}"
93 }
94 ))
95 )
96 }
97
98 #[tokio::test]
99 async fn check_simple_string_roundtrip() {
100 let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];
101
102 let (_, s) = mstring(&input).unwrap();
103
104 let mut vec = vec![];
105
106 MString::write_to(&s, &mut Pin::new(&mut vec))
107 .await
108 .unwrap();
109
110 assert_eq!(input, &vec[..])
111 }
112
113 #[test]
115 fn check_forbidden_characters() {
116 let input = [0x00, 0x02, 0x00, 0x01];
117
118 let s = mstring(&input);
119
120 assert_eq!(
121 s,
122 Err(nom::Err::Error(nom::error::Error {
123 input: &input[..],
124 code: nom::error::ErrorKind::MapRes
125 }))
126 )
127 }
128}