commonware_codec/types/
hash_map.rs

1//! Codec implementations for HashMap (requires std).
2//!
3//! For portability and consistency between architectures,
4//! the size of the map must fit within a [u32].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{cmp::Ordering, collections::HashMap, hash::Hash};
13
14const HASHMAP_TYPE: &str = "HashMap";
15
16/// Read keyed items from [Buf] in ascending order.
17fn read_ordered_map<K, V, F>(
18    buf: &mut impl Buf,
19    len: usize,
20    k_cfg: &K::Cfg,
21    v_cfg: &V::Cfg,
22    mut insert: F,
23    map_type: &'static str,
24) -> Result<(), Error>
25where
26    K: Read + Ord,
27    V: Read,
28    F: FnMut(K, V) -> Option<V>,
29{
30    let mut last: Option<(K, V)> = None;
31    for _ in 0..len {
32        // Read key
33        let key = K::read_cfg(buf, k_cfg)?;
34
35        // Check if keys are in ascending order relative to the previous key
36        if let Some((ref last_key, _)) = last {
37            match key.cmp(last_key) {
38                Ordering::Equal => return Err(Error::Invalid(map_type, "Duplicate key")),
39                Ordering::Less => return Err(Error::Invalid(map_type, "Keys must ascend")),
40                _ => {}
41            }
42        }
43
44        // Read value
45        let value = V::read_cfg(buf, v_cfg)?;
46
47        // Add previous item, if exists
48        if let Some((last_key, last_value)) = last.take() {
49            insert(last_key, last_value);
50        }
51        last = Some((key, value));
52    }
53
54    // Add last item, if exists
55    if let Some((last_key, last_value)) = last {
56        insert(last_key, last_value);
57    }
58
59    Ok(())
60}
61
62// ---------- HashMap ----------
63
64impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
65    fn write(&self, buf: &mut impl BufMut) {
66        self.len().write(buf);
67
68        // Sort the keys to ensure deterministic encoding
69        let mut entries: Vec<_> = self.iter().collect();
70        entries.sort_by(|a, b| a.0.cmp(b.0));
71        for (k, v) in entries {
72            k.write(buf);
73            v.write(buf);
74        }
75    }
76}
77
78impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
79    fn encode_size(&self) -> usize {
80        // Start with the size of the length prefix
81        let mut size = self.len().encode_size();
82
83        // Add the encoded size of each key and value
84        // Note: Iteration order doesn't matter for size calculation.
85        for (k, v) in self {
86            size += k.encode_size();
87            size += v.encode_size();
88        }
89        size
90    }
91}
92
93// Read implementation for HashMap
94impl<K: Read + Clone + Ord + Hash + Eq, V: Read + Clone> Read for HashMap<K, V> {
95    type Cfg = (RangeCfg<usize>, (K::Cfg, V::Cfg));
96
97    fn read_cfg(buf: &mut impl Buf, (range, (k_cfg, v_cfg)): &Self::Cfg) -> Result<Self, Error> {
98        // Read and validate the length prefix
99        let len = usize::read_cfg(buf, range)?;
100        let mut map = HashMap::with_capacity(len);
101
102        // Read items in ascending order
103        read_ordered_map(
104            buf,
105            len,
106            k_cfg,
107            v_cfg,
108            |k, v| map.insert(k, v),
109            HASHMAP_TYPE,
110        )?;
111
112        Ok(map)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use crate::{Decode, Encode, FixedSize};
120    use bytes::{Bytes, BytesMut};
121    use std::fmt::Debug;
122
123    // Manual round trip test function for HashMap with non-default configs
124    fn round_trip_hash<K, V, KCfg, VCfg>(
125        map: &HashMap<K, V>,
126        range_cfg: RangeCfg<usize>,
127        k_cfg: KCfg,
128        v_cfg: VCfg,
129    ) where
130        K: Write + EncodeSize + Read<Cfg = KCfg> + Clone + Ord + Hash + Eq + PartialEq + Debug,
131        V: Write + EncodeSize + Read<Cfg = VCfg> + Clone + PartialEq + Debug,
132        HashMap<K, V>: Read<Cfg = (RangeCfg<usize>, (K::Cfg, V::Cfg))>
133            + Decode<Cfg = (RangeCfg<usize>, (K::Cfg, V::Cfg))>
134            + PartialEq
135            + Write
136            + EncodeSize,
137    {
138        let encoded = map.encode();
139        assert_eq!(encoded.len(), map.encode_size());
140        let config_tuple = (range_cfg, (k_cfg, v_cfg));
141        let decoded = HashMap::<K, V>::decode_cfg(encoded, &config_tuple)
142            .expect("decode_cfg failed for HashMap");
143        assert_eq!(map, &decoded);
144    }
145
146    // --- HashMap Tests ---
147
148    #[test]
149    fn test_empty_hashmap() {
150        let map = HashMap::<u32, u64>::new();
151        round_trip_hash(&map, (..).into(), (), ());
152        assert_eq!(map.encode_size(), 1);
153        let encoded = map.encode();
154        assert_eq!(encoded, 0usize.encode());
155    }
156
157    #[test]
158    fn test_simple_hashmap_u32_u64() {
159        let mut map = HashMap::new();
160        map.insert(1u32, 100u64);
161        map.insert(5u32, 500u64);
162        map.insert(2u32, 200u64);
163        round_trip_hash(&map, (..).into(), (), ());
164        assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
165    }
166
167    #[test]
168    fn test_large_hashmap() {
169        // Fixed-size items
170        let mut map = HashMap::new();
171        for i in 0..1000 {
172            map.insert(i as u16, i as u64 * 2);
173        }
174        round_trip_hash(&map, (0..=1000).into(), (), ());
175
176        // Variable-size items
177        let mut map = HashMap::new();
178        for i in 0..1000usize {
179            map.insert(i, 1000usize + i);
180        }
181        round_trip_hash(
182            &map,
183            (0..=1000).into(),
184            (..=1000).into(),
185            (1000..=2000).into(),
186        );
187    }
188
189    #[test]
190    fn test_hashmap_with_variable_values() {
191        let mut map = HashMap::new();
192        map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
193        map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
194        map.insert(Bytes::from_static(b"cherry"), vec![]);
195
196        let map_range = RangeCfg::from(0..=10);
197        let key_range = RangeCfg::from(..=10);
198        let val_range = RangeCfg::from(0..=100);
199
200        round_trip_hash(&map, map_range, key_range, (val_range, ()));
201    }
202
203    #[test]
204    fn test_hashmap_decode_length_limit_exceeded() {
205        let mut map = HashMap::new();
206        map.insert(1u32, 100u64);
207        map.insert(5u32, 500u64);
208
209        let encoded = map.encode();
210        let config_tuple = ((0..=1).into(), ((), ()));
211
212        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
213        assert!(matches!(result, Err(Error::InvalidLength(2))));
214    }
215
216    #[test]
217    fn test_hashmap_decode_value_length_limit_exceeded() {
218        let mut map = HashMap::new();
219        map.insert(Bytes::from_static(b"key1"), vec![1u8, 2u8, 3u8, 4u8, 5u8]);
220
221        let key_range = RangeCfg::from(..=10);
222        let map_range = RangeCfg::from(0..=10);
223        let restrictive_val_range = RangeCfg::from(0..=3);
224
225        let encoded = map.encode();
226        let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
227        let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
228
229        assert!(matches!(result, Err(Error::InvalidLength(5))));
230    }
231
232    #[test]
233    fn test_hashmap_decode_invalid_key_order() {
234        let mut encoded = BytesMut::new();
235        2usize.write(&mut encoded); // Map length = 2
236        5u32.write(&mut encoded); // Key 5
237        500u64.write(&mut encoded); // Value 500
238        2u32.write(&mut encoded); // Key 2 (out of order)
239        200u64.write(&mut encoded); // Value 200
240
241        let range = (..).into();
242        let config_tuple = (range, ((), ()));
243
244        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
245        assert!(matches!(
246            result,
247            Err(Error::Invalid("HashMap", "Keys must ascend"))
248        ));
249    }
250
251    #[test]
252    fn test_hashmap_decode_duplicate_key() {
253        let mut encoded = BytesMut::new();
254        2usize.write(&mut encoded); // Map length = 2
255        1u32.write(&mut encoded); // Key 1
256        100u64.write(&mut encoded); // Value 100
257        1u32.write(&mut encoded); // Duplicate Key 1
258        200u64.write(&mut encoded); // Value 200
259
260        let range = (..).into();
261        let config_tuple = (range, ((), ()));
262
263        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
264        assert!(matches!(
265            result,
266            Err(Error::Invalid("HashMap", "Duplicate key"))
267        ));
268    }
269
270    #[test]
271    fn test_hashmap_decode_end_of_buffer_key() {
272        let mut map = HashMap::new();
273        map.insert(1u32, 100u64);
274        map.insert(5u32, 500u64);
275
276        let mut encoded = map.encode();
277        encoded.truncate(map.encode_size() - 10); // Truncate during last key/value pair
278
279        let range = (..).into();
280        let config_tuple = (range, ((), ()));
281        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
282        assert!(matches!(result, Err(Error::EndOfBuffer)));
283    }
284
285    #[test]
286    fn test_hashmap_decode_end_of_buffer_value() {
287        let mut map = HashMap::new();
288        map.insert(1u32, 100u64);
289        map.insert(5u32, 500u64);
290
291        let mut encoded = map.encode();
292        encoded.truncate(map.encode_size() - 4); // Truncate during last value
293
294        let range = RangeCfg::from(..);
295        let config_tuple = (range, ((), ()));
296        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
297        assert!(matches!(result, Err(Error::EndOfBuffer)));
298    }
299
300    #[test]
301    fn test_hashmap_decode_extra_data() {
302        let mut map = HashMap::new();
303        map.insert(1u32, 100u64);
304
305        let mut encoded = map.encode();
306        encoded.put_u8(0xFF); // Add extra byte
307
308        // Use decode_cfg which enforces buffer is fully consumed
309        let config_tuple = ((..).into(), ((), ()));
310        let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
311        assert!(matches!(result, Err(Error::ExtraData(1))));
312
313        // Verify that read_cfg would succeed (doesn't check for extra data)
314        let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
315        assert!(read_result.is_ok());
316        let decoded_map = read_result.unwrap();
317        assert_eq!(decoded_map.len(), 1);
318        assert_eq!(decoded_map.get(&1u32), Some(&100u64));
319    }
320
321    #[test]
322    fn test_hashmap_deterministic_encoding() {
323        // In-order
324        let mut map2 = HashMap::new();
325        (0..=1000u32).for_each(|i| {
326            map2.insert(i, i * 2);
327        });
328
329        // Reverse order
330        let mut map1 = HashMap::new();
331        (0..=1000u32).rev().for_each(|i| {
332            map1.insert(i, i * 2);
333        });
334
335        assert_eq!(map1.encode(), map2.encode());
336    }
337
338    #[test]
339    fn test_hashmap_conformity() {
340        // Case 1: Empty HashMap<u8, u16>
341        let map1 = HashMap::<u8, u16>::new();
342        let mut expected1 = BytesMut::new();
343        0usize.write(&mut expected1); // Length 0
344        assert_eq!(map1.encode(), expected1.freeze());
345
346        // Case 2: Simple HashMap<u8, u16>
347        // Keys are sorted for encoding: 1, 2
348        let mut map2 = HashMap::<u8, u16>::new();
349        map2.insert(2u8, 0xBBBBu16); // Inserted out of order
350        map2.insert(1u8, 0xAAAAu16);
351
352        let mut expected2 = BytesMut::new();
353        2usize.write(&mut expected2); // Length 2
354        1u8.write(&mut expected2); // Key 1
355        0xAAAAu16.write(&mut expected2); // Value for key 1
356        2u8.write(&mut expected2); // Key 2
357        0xBBBBu16.write(&mut expected2); // Value for key 2
358        assert_eq!(map2.encode(), expected2.freeze());
359
360        // Case 3: HashMap<u16, bool>
361        // Keys are sorted for encoding: 0x0101, 0x0202, 0x0303
362        let mut map3 = HashMap::<u16, bool>::new();
363        map3.insert(0x0303u16, true);
364        map3.insert(0x0101u16, false);
365        map3.insert(0x0202u16, true);
366
367        let mut expected3 = BytesMut::new();
368        3usize.write(&mut expected3); // Length 3
369        0x0101u16.write(&mut expected3); // Key 0x0101
370        false.write(&mut expected3); // Value false (0x00)
371        0x0202u16.write(&mut expected3); // Key 0x0202
372        true.write(&mut expected3); // Value true (0x01)
373        0x0303u16.write(&mut expected3); // Key 0x0303
374        true.write(&mut expected3); // Value true (0x01)
375        assert_eq!(map3.encode(), expected3.freeze());
376
377        // Case 4: HashMap with Bytes as key and Vec<u8> as value
378        // Keys are sorted for encoding: "a", "b"
379        let mut map4 = HashMap::<Bytes, Vec<u8>>::new();
380        map4.insert(Bytes::from_static(b"b"), vec![20u8, 21u8]);
381        map4.insert(Bytes::from_static(b"a"), vec![10u8]);
382
383        let mut expected4 = BytesMut::new();
384        2usize.write(&mut expected4); // Map length = 2
385
386        // Key "a" (length 1, 'a')
387        Bytes::from_static(b"a").write(&mut expected4);
388        // Value vec![10u8] (length 1, 10u8)
389        vec![10u8].write(&mut expected4);
390
391        // Key "b" (length 1, 'b')
392        Bytes::from_static(b"b").write(&mut expected4);
393        // Value vec![20u8, 21u8] (length 2, 20u8, 21u8)
394        vec![20u8, 21u8].write(&mut expected4);
395
396        assert_eq!(map4.encode(), expected4.freeze());
397    }
398}