1use crate::{
7 codec::{EncodeSize, Read, Write},
8 error::Error,
9 RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{collections::HashMap, hash::Hash};
13
14impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
16 fn write(&self, buf: &mut impl BufMut) {
17 self.len().write(buf);
18
19 let mut keys: Vec<_> = self.keys().collect();
21 keys.sort();
22 for key in keys {
23 key.write(buf);
24 self.get(key).unwrap().write(buf);
25 }
26 }
27}
28
29impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
31 fn encode_size(&self) -> usize {
32 let mut size = self.len().encode_size();
34
35 for (key, value) in self {
38 size += key.encode_size();
39 size += value.encode_size();
40 }
41 size
42 }
43}
44
45impl<K: Read + Clone + Ord + Hash + Eq, V: Read + Clone> Read for HashMap<K, V> {
47 type Cfg = (RangeCfg, (K::Cfg, V::Cfg));
48
49 fn read_cfg(buf: &mut impl Buf, (range, (k_cfg, v_cfg)): &Self::Cfg) -> Result<Self, Error> {
50 let len = usize::read_cfg(buf, range)?;
52 let mut map = HashMap::with_capacity(len);
53
54 let mut last_key: Option<K> = None;
56
57 for _ in 0..len {
59 let key = K::read_cfg(buf, k_cfg)?;
60
61 if let Some(ref last) = last_key {
63 use std::cmp::Ordering;
64 match key.cmp(last) {
65 Ordering::Equal => return Err(Error::Invalid("HashMap", "Duplicate key")),
66 Ordering::Less => return Err(Error::Invalid("HashMap", "Keys must ascend")),
67 _ => {}
68 }
69 }
70 last_key = Some(key.clone());
71
72 let value = V::read_cfg(buf, v_cfg)?;
73 map.insert(key, value);
74 }
75
76 Ok(map)
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use crate::{
83 codec::{Decode, Encode, EncodeSize, FixedSize, Read, Write},
84 error::Error,
85 RangeCfg,
86 };
87 use bytes::{BufMut, Bytes, BytesMut};
88 use std::{collections::HashMap, fmt::Debug, hash::Hash};
89
90 fn round_trip<K, V>(map: &HashMap<K, V>, range_cfg: RangeCfg, k_cfg: K::Cfg, v_cfg: V::Cfg)
92 where
93 K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
94 V: Write + EncodeSize + Read + Clone + Debug + PartialEq,
95 HashMap<K, V>:
96 Read<Cfg = (RangeCfg, (K::Cfg, V::Cfg))> + Debug + PartialEq + Write + EncodeSize,
97 {
98 let encoded = map.encode();
99 let config_tuple = (range_cfg, (k_cfg, v_cfg));
100 let decoded =
101 HashMap::<K, V>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
102 assert_eq!(map, &decoded);
103 }
104
105 #[test]
106 fn test_empty_map() {
107 let map = HashMap::<u32, u64>::new();
108 round_trip(&map, (..).into(), (), ());
109 assert_eq!(map.encode_size(), 1);
110 let encoded = map.encode();
111 assert_eq!(encoded, 0usize.encode());
112 }
113
114 #[test]
115 fn test_simple_map_u32_u64() {
116 let mut map = HashMap::new();
117 map.insert(1u32, 100u64);
118 map.insert(5u32, 500u64);
119 map.insert(2u32, 200u64);
120 round_trip(&map, (..).into(), (), ());
121 assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
122 }
123
124 #[test]
125 fn test_large_map() {
126 let mut map = HashMap::new();
127 for i in 0..1000 {
128 map.insert(i, i as u64 * 2);
129 }
130 round_trip(&map, (..=1000).into(), (), ());
131 }
132
133 #[test]
134 fn test_map_with_variable_values() {
135 let mut map = HashMap::new();
136 map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
137 map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
138 map.insert(Bytes::from_static(b"cherry"), vec![]);
139
140 let map_range = RangeCfg::from(0..=10);
141 let key_range = RangeCfg::from(..=10);
142 let val_range = RangeCfg::from(0..=100);
143
144 round_trip(&map, map_range, key_range, (val_range, ()));
145 }
146
147 #[test]
148 fn test_decode_length_limit_exceeded() {
149 let mut map = HashMap::new();
150 map.insert(1u32, 100u64);
151 map.insert(5u32, 500u64);
152
153 let encoded = map.encode();
154 let restrictive_range = (0..=1).into();
155 let config_tuple = (restrictive_range, ((), ()));
156
157 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
158 assert!(matches!(result, Err(Error::InvalidLength(2))));
159 }
160
161 #[test]
162 fn test_decode_value_length_limit_exceeded() {
163 let mut map = HashMap::new();
164 map.insert(Bytes::from_static(b"key1"), vec![1u8, 2u8, 3u8, 4u8, 5u8]);
165
166 let key_range = RangeCfg::from(..=10);
167 let map_range = RangeCfg::from(0..=10);
168 let restrictive_val_range = RangeCfg::from(0..=3);
169
170 let encoded = map.encode();
171 let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
172 let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
173
174 assert!(matches!(result, Err(Error::InvalidLength(5))));
175 }
176
177 #[test]
178 fn test_decode_invalid_key_order() {
179 let mut encoded = BytesMut::new();
180 2usize.write(&mut encoded); 5u32.write(&mut encoded); 500u64.write(&mut encoded); 2u32.write(&mut encoded); 200u64.write(&mut encoded); let range = (..).into();
187 let config_tuple = (range, ((), ()));
188
189 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
190 assert!(matches!(
191 result,
192 Err(Error::Invalid("HashMap", "Keys must ascend"))
193 ));
194 }
195
196 #[test]
197 fn test_decode_duplicate_key() {
198 let mut encoded = BytesMut::new();
199 2usize.write(&mut encoded); 1u32.write(&mut encoded); 100u64.write(&mut encoded); 1u32.write(&mut encoded); 200u64.write(&mut encoded); let range = (..).into();
206 let config_tuple = (range, ((), ()));
207
208 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
209 assert!(matches!(
210 result,
211 Err(Error::Invalid("HashMap", "Duplicate key"))
212 ));
213 }
214
215 #[test]
216 fn test_decode_end_of_buffer_key() {
217 let mut map = HashMap::new();
218 map.insert(1u32, 100u64);
219 map.insert(5u32, 500u64);
220
221 let mut encoded = map.encode();
222 encoded.truncate(map.encode_size() - 10); let range = (..).into();
225 let config_tuple = (range, ((), ()));
226 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
227 assert!(matches!(result, Err(Error::EndOfBuffer)));
228 }
229
230 #[test]
231 fn test_decode_end_of_buffer_value() {
232 let mut map = HashMap::new();
233 map.insert(1u32, 100u64);
234 map.insert(5u32, 500u64);
235
236 let mut encoded = map.encode();
237 encoded.truncate(map.encode_size() - 4); let range = RangeCfg::from(..);
240 let config_tuple = (range, ((), ()));
241 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
242 assert!(matches!(result, Err(Error::EndOfBuffer)));
243 }
244
245 #[test]
246 fn test_decode_extra_data() {
247 let mut map = HashMap::new();
248 map.insert(1u32, 100u64);
249
250 let mut encoded = map.encode();
251 encoded.put_u8(0xFF); let range = RangeCfg::from(..);
254 let config_tuple = (range, ((), ()));
255
256 let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
258 assert!(matches!(result, Err(Error::ExtraData(1))));
259
260 let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
262 assert!(read_result.is_ok());
263 let decoded_map = read_result.unwrap();
264 assert_eq!(decoded_map.len(), 1);
265 assert_eq!(decoded_map.get(&1u32), Some(&100u64));
266 }
267
268 #[test]
269 fn test_conformity() {
270 let mut map1 = HashMap::<u8, u16>::new();
271 assert_eq!(map1.encode(), &[0x00][..]); map1.insert(1u8, 0xAAAAu16);
274 map1.insert(2u8, 0xBBBBu16);
275 assert_eq!(
280 map1.encode(),
281 &[0x02, 0x01, 0xAA, 0xAA, 0x02, 0xBB, 0xBB][..]
282 );
283
284 let mut map2 = HashMap::<u16, bool>::new();
285 map2.insert(0x0303u16, true);
286 map2.insert(0x0101u16, false);
287 map2.insert(0x0202u16, true);
288 assert_eq!(
293 map2.encode(),
294 &[0x03, 0x01, 0x01, 0x00, 0x02, 0x02, 0x01, 0x03, 0x03, 0x01][..]
295 );
296
297 let mut map3 = HashMap::<Bytes, Vec<u8>>::new();
299 map3.insert(Bytes::from_static(b"b"), vec![20u8, 21u8]);
300 map3.insert(Bytes::from_static(b"a"), vec![10u8]);
301 let mut expected_map3 = vec![0x02]; expected_map3.extend_from_slice(&[0x01, 0x61]); expected_map3.extend_from_slice(&[0x01, 0x0A]); expected_map3.extend_from_slice(&[0x01, 0x62]); expected_map3.extend_from_slice(&[0x02, 0x14, 0x15]); assert_eq!(map3.encode(), expected_map3.as_slice());
312 }
313}