1use crate::{
7 codec::{EncodeSize, Read, Write},
8 error::Error,
9 Config, RangeConfig,
10};
11use bytes::{Buf, BufMut};
12use std::{collections::HashMap, hash::Hash};
13
14impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
16 fn write(&self, buf: &mut impl BufMut) {
17 self.len().write(buf);
18
19 let mut keys: Vec<_> = self.keys().collect();
21 keys.sort();
22 for key in keys {
23 key.write(buf);
24 self.get(key).unwrap().write(buf);
25 }
26 }
27}
28
29impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
31 fn encode_size(&self) -> usize {
32 let mut size = self.len().encode_size();
34
35 for (key, value) in self {
38 size += key.encode_size();
39 size += value.encode_size();
40 }
41 size
42 }
43}
44
45impl<
47 R: RangeConfig,
48 KCfg: Config,
49 VCfg: Config,
50 K: Read<KCfg> + Clone + Ord + Hash + Eq,
51 V: Read<VCfg> + Clone,
52 > Read<(R, (KCfg, VCfg))> for HashMap<K, V>
53{
54 fn read_cfg(
55 buf: &mut impl Buf,
56 (range, (k_cfg, v_cfg)): &(R, (KCfg, VCfg)),
57 ) -> Result<Self, Error> {
58 let len = usize::read_cfg(buf, range)?;
60 let mut map = HashMap::with_capacity(len);
61
62 let mut last_key: Option<K> = None;
64
65 for _ in 0..len {
67 let key = K::read_cfg(buf, k_cfg)?;
68
69 if let Some(ref last) = last_key {
71 use std::cmp::Ordering;
72 match key.cmp(last) {
73 Ordering::Equal => return Err(Error::Invalid("HashMap", "Duplicate key")),
74 Ordering::Less => return Err(Error::Invalid("HashMap", "Keys must ascend")),
75 _ => {}
76 }
77 }
78 last_key = Some(key.clone());
79
80 let value = V::read_cfg(buf, v_cfg)?;
81 map.insert(key, value);
82 }
83
84 Ok(map)
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use crate::{
91 codec::{Decode, Encode, EncodeSize, FixedSize, Read, Write},
92 error::Error,
93 Config, RangeConfig,
94 };
95 use bytes::{BufMut, Bytes, BytesMut};
96 use std::collections::HashMap;
97 use std::fmt::Debug;
98 use std::hash::Hash;
99 use std::ops::RangeInclusive;
100
101 fn round_trip<K, V, R, KCfg, VCfg>(map: &HashMap<K, V>, range_cfg: R, k_cfg: KCfg, v_cfg: VCfg)
103 where
104 K: Write + EncodeSize + Read<KCfg> + Clone + Ord + Hash + Eq + Debug + PartialEq,
105 V: Write + EncodeSize + Read<VCfg> + Clone + Debug + PartialEq,
106 R: RangeConfig + Clone,
107 KCfg: Config + Clone,
108 VCfg: Config + Clone,
109 HashMap<K, V>: Read<(R, (KCfg, VCfg))>
110 + Decode<(R, (KCfg, VCfg))>
111 + Debug
112 + PartialEq
113 + Write
114 + EncodeSize,
115 {
116 let encoded = map.encode();
117 let config_tuple = (range_cfg, (k_cfg, v_cfg));
118 let decoded =
119 HashMap::<K, V>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
120 assert_eq!(map, &decoded);
121 }
122
123 fn allow_any_len() -> RangeInclusive<usize> {
124 0..=usize::MAX
125 }
126
127 #[test]
128 fn test_empty_map() {
129 let map = HashMap::<u32, u64>::new();
130 round_trip(&map, allow_any_len(), (), ());
131 assert_eq!(map.encode_size(), 1);
132 let encoded = map.encode();
133 assert_eq!(encoded, 0usize.encode());
134 }
135
136 #[test]
137 fn test_simple_map_u32_u64() {
138 let mut map = HashMap::new();
139 map.insert(1u32, 100u64);
140 map.insert(5u32, 500u64);
141 map.insert(2u32, 200u64);
142 round_trip(&map, allow_any_len(), (), ());
143 assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
144 }
145
146 #[test]
147 fn test_large_map() {
148 let mut map = HashMap::new();
149 for i in 0..1000 {
150 map.insert(i, i as u64 * 2);
151 }
152 round_trip(&map, 0..=1000, (), ());
153 }
154
155 #[test]
156 fn test_map_with_variable_values() {
157 let mut map = HashMap::new();
158 map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
159 map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
160 map.insert(Bytes::from_static(b"cherry"), vec![]);
161
162 let map_range = 0..=10;
163 let key_range = ..=10;
164 let val_range = 0..=100;
165
166 round_trip(&map, map_range, key_range, (val_range, ()));
167 }
168
169 #[test]
170 fn test_decode_length_limit_exceeded() {
171 let mut map = HashMap::new();
172 map.insert(1u32, 100u64);
173 map.insert(5u32, 500u64);
174
175 let encoded = map.encode();
176 let restrictive_range = 0..=1;
177 let config_tuple = (restrictive_range, ((), ()));
178
179 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
180 assert!(matches!(result, Err(Error::InvalidLength(2))));
181 }
182
183 #[test]
184 fn test_decode_value_length_limit_exceeded() {
185 let mut map = HashMap::new();
186 map.insert(Bytes::from_static(b"key1"), vec![1u8, 2u8, 3u8, 4u8, 5u8]);
187
188 let key_range = ..=10;
189 let map_range = 0..=10;
190 let restrictive_val_range = 0..=3;
191
192 let encoded = map.encode();
193 let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
194 let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
195
196 assert!(matches!(result, Err(Error::InvalidLength(5))));
197 }
198
199 #[test]
200 fn test_decode_invalid_key_order() {
201 let mut encoded = BytesMut::new();
202 2usize.write(&mut encoded); 5u32.write(&mut encoded); 500u64.write(&mut encoded); 2u32.write(&mut encoded); 200u64.write(&mut encoded); let range = allow_any_len();
209 let config_tuple = (range, ((), ()));
210
211 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
212 assert!(matches!(
213 result,
214 Err(Error::Invalid("HashMap", "Keys must ascend"))
215 ));
216 }
217
218 #[test]
219 fn test_decode_duplicate_key() {
220 let mut encoded = BytesMut::new();
221 2usize.write(&mut encoded); 1u32.write(&mut encoded); 100u64.write(&mut encoded); 1u32.write(&mut encoded); 200u64.write(&mut encoded); let range = allow_any_len();
228 let config_tuple = (range, ((), ()));
229
230 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
231 assert!(matches!(
232 result,
233 Err(Error::Invalid("HashMap", "Duplicate key"))
234 ));
235 }
236
237 #[test]
238 fn test_decode_end_of_buffer_key() {
239 let mut map = HashMap::new();
240 map.insert(1u32, 100u64);
241 map.insert(5u32, 500u64);
242
243 let mut encoded = map.encode();
244 encoded.truncate(map.encode_size() - 10); let range = allow_any_len();
247 let config_tuple = (range, ((), ()));
248 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
249 assert!(matches!(result, Err(Error::EndOfBuffer)));
250 }
251
252 #[test]
253 fn test_decode_end_of_buffer_value() {
254 let mut map = HashMap::new();
255 map.insert(1u32, 100u64);
256 map.insert(5u32, 500u64);
257
258 let mut encoded = map.encode();
259 encoded.truncate(map.encode_size() - 4); let range = allow_any_len();
262 let config_tuple = (range, ((), ()));
263 let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
264 assert!(matches!(result, Err(Error::EndOfBuffer)));
265 }
266
267 #[test]
268 fn test_decode_extra_data() {
269 let mut map = HashMap::new();
270 map.insert(1u32, 100u64);
271
272 let mut encoded = map.encode();
273 encoded.put_u8(0xFF); let range = allow_any_len();
276 let config_tuple = (range.clone(), ((), ())); let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
280 assert!(matches!(result, Err(Error::ExtraData(1))));
281
282 let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
284 assert!(read_result.is_ok());
285 let decoded_map = read_result.unwrap();
286 assert_eq!(decoded_map.len(), 1);
287 assert_eq!(decoded_map.get(&1u32), Some(&100u64));
288 }
289}