1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use futures::{AsyncWrite, AsyncWriteExt};
use nom::{bytes::complete::take, number::complete::be_u16, IResult, Parser};
use nom_supreme::ParserExt;
use super::errors::MPacketWriteError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct MString<'message> {
pub value: &'message str,
}
impl<'message> std::ops::Deref for MString<'message> {
type Target = str;
fn deref(&self) -> &Self::Target {
self.value
}
}
impl<'message> MString<'message> {
pub fn get_len(mstr: &MString<'_>) -> usize {
2 + mstr.value.len()
}
pub(crate) async fn write_to<W: AsyncWrite>(
mstr: &MString<'_>,
writer: &mut std::pin::Pin<&mut W>,
) -> Result<(), MPacketWriteError> {
writer
.write_all(&(mstr.value.len() as u16).to_be_bytes())
.await?;
writer.write_all(mstr.value.as_bytes()).await?;
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum MStringError {
#[error("The input contained control characters, which this implementation rejects.")]
ControlCharacters,
}
fn control_characters(c: char) -> bool {
('\u{0001}'..='\u{001F}').contains(&c) || ('\u{007F}'..='\u{009F}').contains(&c)
}
pub fn mstring(input: &[u8]) -> IResult<&[u8], MString<'_>> {
let len = be_u16;
let string_data = len.flat_map(take);
string_data
.map_res(|data| std::str::from_utf8(data).map(|s| MString { value: s }))
.map_res(|s| {
if s.contains(control_characters) {
Err(MStringError::ControlCharacters)
} else {
Ok(s)
}
})
.parse(input)
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use super::{mstring, MString};
#[test]
fn check_simple_string() {
let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];
let s = mstring(&input);
assert_eq!(
s,
Ok((
&[][..],
MString {
value: "A\u{2A6D4}"
}
))
)
}
#[tokio::test]
async fn check_simple_string_roundtrip() {
let input = [0x00, 0x05, 0x41, 0xF0, 0xAA, 0x9B, 0x94];
let (_, s) = mstring(&input).unwrap();
let mut vec = vec![];
MString::write_to(&s, &mut Pin::new(&mut vec))
.await
.unwrap();
assert_eq!(input, &vec[..])
}
#[test]
fn check_forbidden_characters() {
let input = [0x00, 0x02, 0x00, 0x01];
let s = mstring(&input);
assert_eq!(
s,
Err(nom::Err::Error(nom::error::Error {
input: &input[..],
code: nom::error::ErrorKind::MapRes
}))
)
}
}