Skip to main content

emissary_core/primitives/
mapping.rs

1// Permission is hereby granted, free of charge, to any person obtaining a
2// copy of this software and associated documentation files (the "Software"),
3// to deal in the Software without restriction, including without limitation
4// the rights to use, copy, modify, merge, publish, distribute, sublicense,
5// and/or sell copies of the Software, and to permit persons to whom the
6// Software is furnished to do so, subject to the following conditions:
7//
8// The above copyright notice and this permission notice shall be included in
9// all copies or substantial portions of the Software.
10//
11// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
12// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
13// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
14// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
15// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
16// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
17// DEALINGS IN THE SOFTWARE.
18
19use crate::{error::parser::MappingParseError, primitives::Str};
20
21use bytes::{BufMut, Bytes, BytesMut};
22use hashbrown::{
23    hash_map::{IntoIter, Iter},
24    HashMap,
25};
26use nom::{
27    number::complete::{be_u16, be_u8},
28    Err, IResult,
29};
30
31use alloc::vec::Vec;
32use core::{
33    fmt::{self, Debug},
34    num::NonZeroUsize,
35};
36
37/// Key-value mapping
38#[derive(Debug, PartialEq, Eq, Clone, Default)]
39pub struct Mapping(HashMap<Str, Str>);
40
41impl fmt::Display for Mapping {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        self.0.fmt(f)
44    }
45}
46
47impl Mapping {
48    /// Serialize [`Mapping`] into a byte vector.
49    pub fn serialize(&self) -> Bytes {
50        // Allocate at least two bytes for the size prefix
51        let out = BytesMut::with_capacity(2);
52        self.serialize_into(out).freeze()
53    }
54
55    pub fn serialize_into(&self, mut out: BytesMut) -> BytesMut {
56        let start_len = out.len();
57
58        // Reserve 2 bytes for size prefix
59        if out.capacity() - out.len() < 2 {
60            out.reserve(2);
61        }
62        out.put_u16(0); // temporary placeholder
63
64        let data_start = out.len();
65        let mut entries: Vec<_> = self.0.iter().collect();
66
67        // Our mapping implementation does not support duplicate keys, so we do not need to preserve
68        // order
69        entries.sort_unstable_by(|a, b| a.0.cmp(b.0));
70        for (key, value) in entries {
71            let key = key.serialize();
72            let value = value.serialize();
73            out.reserve(key.len() + value.len() + 2);
74            out.extend_from_slice(&key);
75            out.put_u8(b'=');
76            out.extend_from_slice(&value);
77            out.put_u8(b';');
78        }
79
80        let data_len = out.len() - data_start;
81        debug_assert!(data_len <= u16::MAX as usize);
82
83        // Write actual length into the 2-byte prefix we reserved
84        let prefix_pos = start_len;
85        out[prefix_pos..prefix_pos + 2].copy_from_slice(&(data_len as u16).to_be_bytes());
86
87        out
88    }
89
90    /// Parse [`Mapping`] from `input`, returning rest of `input` and parsed mapping.
91    pub fn parse_frame(input: &[u8]) -> IResult<&[u8], Self, MappingParseError> {
92        let (rest, size) = be_u16(input)?;
93        let mut mapping = Self::default();
94
95        match rest.split_at_checked(size as usize) {
96            Some((mut data, rest)) => {
97                while !data.is_empty() {
98                    let (remaining, key) = Str::parse_frame(data).map_err(Err::convert)?;
99                    let (remaining, _) = be_u8(remaining)?;
100                    let (remaining, value) = Str::parse_frame(remaining).map_err(Err::convert)?;
101                    let (remaining, _) = be_u8(remaining)?;
102                    mapping.insert(key, value);
103                    data = remaining;
104                }
105
106                Ok((rest, mapping))
107            }
108            None => {
109                // This is safe as the zero case will always pass `split_at_checked`
110                let non_zero_size = NonZeroUsize::new(size as usize).expect("non-zero size");
111                Err(nom::Err::Incomplete(nom::Needed::Size(non_zero_size)))
112            }
113        }
114    }
115
116    /// Try to convert `bytes` into a [`Mapping`].
117    pub fn parse(bytes: impl AsRef<[u8]>) -> Result<Mapping, MappingParseError> {
118        Ok(Self::parse_frame(bytes.as_ref())?.1)
119    }
120
121    /// Equivalent to `HashMap::insert`
122    pub fn insert(&mut self, key: Str, value: Str) -> Option<Str> {
123        self.0.insert(key, value)
124    }
125
126    /// Equivalent to `HashMap::get`
127    pub fn get(&self, key: &Str) -> Option<&Str> {
128        self.0.get(key)
129    }
130
131    /// Equivalent to `HashMap::len`
132    pub fn len(&self) -> usize {
133        self.0.len()
134    }
135
136    /// Equivalent to `HashMap::is_empty`
137    pub fn is_empty(&self) -> bool {
138        self.0.is_empty()
139    }
140
141    /// Equivalent to `HashMap::iter`
142    pub fn iter(&self) -> Iter<'_, Str, Str> {
143        self.0.iter()
144    }
145}
146
147impl IntoIterator for Mapping {
148    type Item = (Str, Str);
149    type IntoIter = IntoIter<Str, Str>;
150
151    fn into_iter(self) -> Self::IntoIter {
152        self.0.into_iter()
153    }
154}
155
156impl FromIterator<(Str, Str)> for Mapping {
157    fn from_iter<T: IntoIterator<Item = (Str, Str)>>(iter: T) -> Self {
158        Self(HashMap::from_iter(iter))
159    }
160}
161
162impl From<HashMap<Str, Str>> for Mapping {
163    fn from(value: HashMap<Str, Str>) -> Self {
164        Mapping(value)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn empty_mapping() {
174        assert_eq!(Mapping::parse(b"\0\0"), Ok(Mapping::default()));
175    }
176
177    #[test]
178    fn valid_mapping() {
179        let mut mapping = Mapping::default();
180        mapping.insert("hello".into(), "world".into());
181
182        let ser = mapping.serialize();
183
184        assert_eq!(Mapping::parse(ser), Ok(mapping));
185    }
186
187    #[test]
188    fn valid_string_with_extra_end_bytes() {
189        let mut mapping = Mapping::default();
190        mapping.insert("hello".into(), "world".into());
191
192        let mut ser = mapping.serialize().to_vec();
193        ser.push(1);
194        ser.push(2);
195        ser.push(3);
196        ser.push(4);
197
198        assert_eq!(Mapping::parse(ser), Ok(mapping));
199    }
200
201    #[test]
202    fn valid_string_with_extra_start_bytes() {
203        let mut mapping = Mapping::default();
204        mapping.insert("hello".into(), "world".into());
205
206        const PREFIX: &[u8] = b"prefix";
207
208        let buf = BytesMut::from(PREFIX);
209        let ser = mapping.serialize_into(buf).to_vec();
210
211        assert_eq!(&ser[..PREFIX.len()], b"prefix");
212        assert_eq!(Mapping::parse(&ser[PREFIX.len()..]), Ok(mapping));
213    }
214
215    #[test]
216    fn extra_bytes_returned() {
217        let mut mapping = Mapping::default();
218        mapping.insert("hello".into(), "world".into());
219
220        let mut ser = mapping.serialize().to_vec();
221        ser.push(1);
222        ser.push(2);
223        ser.push(3);
224        ser.push(4);
225
226        let (rest, parsed_mapping) = Mapping::parse_frame(&ser).unwrap();
227
228        assert_eq!(parsed_mapping, mapping);
229        assert_eq!(rest, [1, 2, 3, 4]);
230    }
231
232    #[test]
233    fn multiple_mappings() {
234        let expected_ser = b"\x00\x19\x01a=\x01b;\x01c=\x01d;\x01e=\x01f;\x02zz=\x01z;";
235        let mapping = Mapping::parse(expected_ser).expect("to be valid");
236
237        assert_eq!(mapping.get(&"a".into()), Some(&Str::from("b")));
238        assert_eq!(mapping.get(&"c".into()), Some(&Str::from("d")));
239        assert_eq!(mapping.get(&"e".into()), Some(&Str::from("f")));
240        assert_eq!(mapping.get(&"zz".into()), Some(&Str::from("z")));
241
242        assert_eq!(mapping.serialize().to_vec(), expected_ser);
243    }
244
245    #[test]
246    fn over_sized() {
247        let ser = b"\x01\x00\x01a=\x01b;\x01c=\x01d;\x01e=\x01f;";
248        assert_eq!(
249            Mapping::parse(ser).unwrap_err(),
250            MappingParseError::InvalidBitstream
251        );
252    }
253}