1use crate::{
7 codec::{EncodeSize, Read, Write},
8 error::Error,
9 varint, Config, RangeConfig,
10};
11use bytes::{Buf, BufMut};
12use std::{collections::HashMap, hash::Hash};
13
14impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
16 fn write(&self, buf: &mut impl BufMut) {
17 let len = u32::try_from(self.len()).expect("HashMap length exceeds u32::MAX");
18 varint::write(len, buf);
19
20 let mut keys: Vec<_> = self.keys().collect();
22 keys.sort();
23 for key in keys {
24 key.write(buf);
25 self.get(key).unwrap().write(buf);
26 }
27 }
28}
29
30impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
32 fn encode_size(&self) -> usize {
33 let len = u32::try_from(self.len()).expect("HashMap length exceeds u32::MAX");
35 let mut size = varint::size(len);
36
37 for (key, value) in self {
40 size += key.encode_size();
41 size += value.encode_size();
42 }
43 size
44 }
45}
46
47impl<
49 R: RangeConfig,
50 KCfg: Config,
51 VCfg: Config,
52 K: Read<KCfg> + Clone + Ord + Hash + Eq,
53 V: Read<VCfg> + Clone,
54 > Read<(R, (KCfg, VCfg))> for HashMap<K, V>
55{
56 fn read_cfg(
57 buf: &mut impl Buf,
58 (range, (k_cfg, v_cfg)): &(R, (KCfg, VCfg)),
59 ) -> Result<Self, Error> {
60 let len32 = varint::read::<u32>(buf)?;
62 let len = usize::try_from(len32).map_err(|_| Error::InvalidVarint)?;
63 if !range.contains(&len) {
64 return Err(Error::InvalidLength(len));
65 }
66 let mut map = HashMap::with_capacity(len);
67
68 let mut last_key: Option<K> = None;
70
71 for _ in 0..len {
73 let key = K::read_cfg(buf, k_cfg)?;
74
75 if let Some(ref last) = last_key {
77 use std::cmp::Ordering;
78 match key.cmp(last) {
79 Ordering::Equal => return Err(Error::Invalid("HashMap", "Duplicate key")),
80 Ordering::Less => return Err(Error::Invalid("HashMap", "Keys must ascend")),
81 _ => {}
82 }
83 }
84 last_key = Some(key.clone());
85
86 let value = V::read_cfg(buf, v_cfg)?;
87 map.insert(key, value);
88 }
89
90 Ok(map)
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use crate::{
97 codec::{Decode, Encode, EncodeSize, FixedSize, Read, Write},
98 error::Error,
99 varint, Config, RangeConfig,
100 };
101 use bytes::{BufMut, Bytes, BytesMut};
102 use std::collections::HashMap;
103 use std::fmt::Debug;
104 use std::hash::Hash;
105 use std::ops::RangeInclusive;
106
107 fn round_trip<K, V, R, KCfg, VCfg>(map: &HashMap<K, V>, range_cfg: R, k_cfg: KCfg, v_cfg: VCfg)
109 where
110 K: Write + EncodeSize + Read<KCfg> + Clone + Ord + Hash + Eq + Debug + PartialEq,
111 V: Write + EncodeSize + Read<VCfg> + Clone + Debug + PartialEq,
112 R: RangeConfig + Clone,
113 KCfg: Config + Clone,
114 VCfg: Config + Clone,
115 HashMap<K, V>: Read<(R, (KCfg, VCfg))>
116 + Decode<(R, (KCfg, VCfg))>
117 + Debug
118 + PartialEq
119 + Write
120 + EncodeSize,
121 {
122 let encoded = map.encode();
123 let config_tuple = (range_cfg, (k_cfg, v_cfg));
124 let decoded =
125 HashMap::<K, V>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
126 assert_eq!(map, &decoded);
127 }
128
129 fn allow_any_len() -> RangeInclusive<usize> {
130 0..=usize::MAX
131 }
132
133 #[test]
134 fn test_empty_map() {
135 let map = HashMap::<u32, u64>::new();
136 round_trip(&map, allow_any_len(), (), ());
137 assert_eq!(map.encode_size(), 1);
138 let encoded = map.encode();
139 assert_eq!(encoded, Bytes::from_static(&[0])); }
141
142 #[test]
143 fn test_simple_map_u32_u64() {
144 let mut map = HashMap::new();
145 map.insert(1u32, 100u64);
146 map.insert(5u32, 500u64);
147 map.insert(2u32, 200u64);
148 round_trip(&map, allow_any_len(), (), ());
149 assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
150 }
151
152 #[test]
153 fn test_large_map() {
154 let mut map = HashMap::new();
155 for i in 0..1000 {
156 map.insert(i, i as u64 * 2);
157 }
158 round_trip(&map, 0..=1000, (), ());
159 }
160
161 #[test]
162 fn test_map_with_variable_values() {
163 let mut map = HashMap::new();
164 map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
165 map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
166 map.insert(Bytes::from_static(b"cherry"), vec![]);
167
168 let map_range = 0..=10;
169 let key_range = ..=10;
170 let val_range = 0..=100;
171
172 round_trip(&map, map_range, key_range, (val_range, ()));
173 }
174
175 #[test]
176 fn test_decode_length_limit_exceeded() {
177 let mut map = HashMap::new();
178 map.insert(1u32, 100u64);
179 map.insert(5u32, 500u64);
180
181 let encoded = map.encode();
182 let restrictive_range = 0..=1;
183 let config_tuple = (restrictive_range, ((), ()));
184
185 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
186 assert!(matches!(result, Err(Error::InvalidLength(2))));
187 }
188
189 #[test]
190 fn test_decode_value_length_limit_exceeded() {
191 let mut map = HashMap::new();
192 map.insert(Bytes::from_static(b"key1"), vec![1, 2, 3, 4, 5]);
193
194 let key_range = ..=10;
195 let map_range = 0..=10;
196 let restrictive_val_range = 0..=3;
197
198 let encoded = map.encode();
199 let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
200 let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
201
202 assert!(matches!(result, Err(Error::InvalidLength(5))));
203 }
204
205 #[test]
206 fn test_decode_invalid_key_order() {
207 let mut encoded = BytesMut::new();
208 varint::write(2u32, &mut encoded); 5u32.write(&mut encoded); 500u64.write(&mut encoded); 2u32.write(&mut encoded); 200u64.write(&mut encoded); let range = allow_any_len();
215 let config_tuple = (range, ((), ()));
216
217 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
218 assert!(matches!(
219 result,
220 Err(Error::Invalid("HashMap", "Keys must ascend"))
221 ));
222 }
223
224 #[test]
225 fn test_decode_duplicate_key() {
226 let mut encoded = BytesMut::new();
227 varint::write(2u32, &mut encoded); 1u32.write(&mut encoded); 100u64.write(&mut encoded); 1u32.write(&mut encoded); 200u64.write(&mut encoded); let range = allow_any_len();
234 let config_tuple = (range, ((), ()));
235
236 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
237 assert!(matches!(
238 result,
239 Err(Error::Invalid("HashMap", "Duplicate key"))
240 ));
241 }
242
243 #[test]
244 fn test_decode_end_of_buffer_key() {
245 let mut map = HashMap::new();
246 map.insert(1u32, 100u64);
247 map.insert(5u32, 500u64);
248
249 let mut encoded = map.encode();
250 encoded.truncate(map.encode_size() - 10); let range = allow_any_len();
253 let config_tuple = (range, ((), ()));
254 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
255 assert!(matches!(result, Err(Error::EndOfBuffer)));
256 }
257
258 #[test]
259 fn test_decode_end_of_buffer_value() {
260 let mut map = HashMap::new();
261 map.insert(1u32, 100u64);
262 map.insert(5u32, 500u64);
263
264 let mut encoded = map.encode();
265 encoded.truncate(map.encode_size() - 4); let range = allow_any_len();
268 let config_tuple = (range, ((), ()));
269 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
270 assert!(matches!(result, Err(Error::EndOfBuffer)));
271 }
272
273 #[test]
274 fn test_decode_extra_data() {
275 let mut map = HashMap::new();
276 map.insert(1u32, 100u64);
277
278 let mut encoded = map.encode();
279 encoded.put_u8(0xFF); let range = allow_any_len();
282 let config_tuple = (range.clone(), ((), ())); let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
286 assert!(matches!(result, Err(Error::ExtraData(1))));
287
288 let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
290 assert!(read_result.is_ok());
291 let decoded_map = read_result.unwrap();
292 assert_eq!(decoded_map.len(), 1);
293 assert_eq!(decoded_map.get(&1u32), Some(&100u64));
294 }
295}