emissary_core/primitives/
mapping.rs1use crate::{error::parser::MappingParseError, primitives::Str};
20
21use bytes::{BufMut, Bytes, BytesMut};
22use hashbrown::{
23 hash_map::{IntoIter, Iter},
24 HashMap,
25};
26use nom::{
27 number::complete::{be_u16, be_u8},
28 Err, IResult,
29};
30
31use alloc::vec::Vec;
32use core::{
33 fmt::{self, Debug},
34 num::NonZeroUsize,
35};
36
37#[derive(Debug, PartialEq, Eq, Clone, Default)]
39pub struct Mapping(HashMap<Str, Str>);
40
41impl fmt::Display for Mapping {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 self.0.fmt(f)
44 }
45}
46
47impl Mapping {
48 pub fn serialize(&self) -> Bytes {
50 let out = BytesMut::with_capacity(2);
52 self.serialize_into(out).freeze()
53 }
54
55 pub fn serialize_into(&self, mut out: BytesMut) -> BytesMut {
56 let start_len = out.len();
57
58 if out.capacity() - out.len() < 2 {
60 out.reserve(2);
61 }
62 out.put_u16(0); let data_start = out.len();
65 let mut entries: Vec<_> = self.0.iter().collect();
66
67 entries.sort_unstable_by(|a, b| a.0.cmp(b.0));
70 for (key, value) in entries {
71 let key = key.serialize();
72 let value = value.serialize();
73 out.reserve(key.len() + value.len() + 2);
74 out.extend_from_slice(&key);
75 out.put_u8(b'=');
76 out.extend_from_slice(&value);
77 out.put_u8(b';');
78 }
79
80 let data_len = out.len() - data_start;
81 debug_assert!(data_len <= u16::MAX as usize);
82
83 let prefix_pos = start_len;
85 out[prefix_pos..prefix_pos + 2].copy_from_slice(&(data_len as u16).to_be_bytes());
86
87 out
88 }
89
90 pub fn parse_frame(input: &[u8]) -> IResult<&[u8], Self, MappingParseError> {
92 let (rest, size) = be_u16(input)?;
93 let mut mapping = Self::default();
94
95 match rest.split_at_checked(size as usize) {
96 Some((mut data, rest)) => {
97 while !data.is_empty() {
98 let (remaining, key) = Str::parse_frame(data).map_err(Err::convert)?;
99 let (remaining, _) = be_u8(remaining)?;
100 let (remaining, value) = Str::parse_frame(remaining).map_err(Err::convert)?;
101 let (remaining, _) = be_u8(remaining)?;
102 mapping.insert(key, value);
103 data = remaining;
104 }
105
106 Ok((rest, mapping))
107 }
108 None => {
109 let non_zero_size = NonZeroUsize::new(size as usize).expect("non-zero size");
111 Err(nom::Err::Incomplete(nom::Needed::Size(non_zero_size)))
112 }
113 }
114 }
115
116 pub fn parse(bytes: impl AsRef<[u8]>) -> Result<Mapping, MappingParseError> {
118 Ok(Self::parse_frame(bytes.as_ref())?.1)
119 }
120
121 pub fn insert(&mut self, key: Str, value: Str) -> Option<Str> {
123 self.0.insert(key, value)
124 }
125
126 pub fn get(&self, key: &Str) -> Option<&Str> {
128 self.0.get(key)
129 }
130
131 pub fn len(&self) -> usize {
133 self.0.len()
134 }
135
136 pub fn is_empty(&self) -> bool {
138 self.0.is_empty()
139 }
140
141 pub fn iter(&self) -> Iter<'_, Str, Str> {
143 self.0.iter()
144 }
145}
146
147impl IntoIterator for Mapping {
148 type Item = (Str, Str);
149 type IntoIter = IntoIter<Str, Str>;
150
151 fn into_iter(self) -> Self::IntoIter {
152 self.0.into_iter()
153 }
154}
155
156impl FromIterator<(Str, Str)> for Mapping {
157 fn from_iter<T: IntoIterator<Item = (Str, Str)>>(iter: T) -> Self {
158 Self(HashMap::from_iter(iter))
159 }
160}
161
162impl From<HashMap<Str, Str>> for Mapping {
163 fn from(value: HashMap<Str, Str>) -> Self {
164 Mapping(value)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn empty_mapping() {
174 assert_eq!(Mapping::parse(b"\0\0"), Ok(Mapping::default()));
175 }
176
177 #[test]
178 fn valid_mapping() {
179 let mut mapping = Mapping::default();
180 mapping.insert("hello".into(), "world".into());
181
182 let ser = mapping.serialize();
183
184 assert_eq!(Mapping::parse(ser), Ok(mapping));
185 }
186
187 #[test]
188 fn valid_string_with_extra_end_bytes() {
189 let mut mapping = Mapping::default();
190 mapping.insert("hello".into(), "world".into());
191
192 let mut ser = mapping.serialize().to_vec();
193 ser.push(1);
194 ser.push(2);
195 ser.push(3);
196 ser.push(4);
197
198 assert_eq!(Mapping::parse(ser), Ok(mapping));
199 }
200
201 #[test]
202 fn valid_string_with_extra_start_bytes() {
203 let mut mapping = Mapping::default();
204 mapping.insert("hello".into(), "world".into());
205
206 const PREFIX: &[u8] = b"prefix";
207
208 let buf = BytesMut::from(PREFIX);
209 let ser = mapping.serialize_into(buf).to_vec();
210
211 assert_eq!(&ser[..PREFIX.len()], b"prefix");
212 assert_eq!(Mapping::parse(&ser[PREFIX.len()..]), Ok(mapping));
213 }
214
215 #[test]
216 fn extra_bytes_returned() {
217 let mut mapping = Mapping::default();
218 mapping.insert("hello".into(), "world".into());
219
220 let mut ser = mapping.serialize().to_vec();
221 ser.push(1);
222 ser.push(2);
223 ser.push(3);
224 ser.push(4);
225
226 let (rest, parsed_mapping) = Mapping::parse_frame(&ser).unwrap();
227
228 assert_eq!(parsed_mapping, mapping);
229 assert_eq!(rest, [1, 2, 3, 4]);
230 }
231
232 #[test]
233 fn multiple_mappings() {
234 let expected_ser = b"\x00\x19\x01a=\x01b;\x01c=\x01d;\x01e=\x01f;\x02zz=\x01z;";
235 let mapping = Mapping::parse(expected_ser).expect("to be valid");
236
237 assert_eq!(mapping.get(&"a".into()), Some(&Str::from("b")));
238 assert_eq!(mapping.get(&"c".into()), Some(&Str::from("d")));
239 assert_eq!(mapping.get(&"e".into()), Some(&Str::from("f")));
240 assert_eq!(mapping.get(&"zz".into()), Some(&Str::from("z")));
241
242 assert_eq!(mapping.serialize().to_vec(), expected_ser);
243 }
244
245 #[test]
246 fn over_sized() {
247 let ser = b"\x01\x00\x01a=\x01b;\x01c=\x01d;\x01e=\x01f;";
248 assert_eq!(
249 Mapping::parse(ser).unwrap_err(),
250 MappingParseError::InvalidBitstream
251 );
252 }
253}