commonware_codec/types/
map.rs

1//! Codec implementations for HashMap.
2//!
3//! For portability and consistency between architectures,
4//! the size of the map must fit within a [`u32`].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{collections::HashMap, hash::Hash};
13
14// Write implementation for HashMap
15impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
16    fn write(&self, buf: &mut impl BufMut) {
17        self.len().write(buf);
18
19        // Sort the keys to ensure deterministic encoding
20        let mut keys: Vec<_> = self.keys().collect();
21        keys.sort();
22        for key in keys {
23            key.write(buf);
24            self.get(key).unwrap().write(buf);
25        }
26    }
27}
28
29// EncodeSize implementation for HashMap
30impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
31    fn encode_size(&self) -> usize {
32        // Start with the size of the length prefix
33        let mut size = self.len().encode_size();
34
35        // Add the encoded size of each key and value
36        // Note: Iteration order doesn't matter for size calculation.
37        for (key, value) in self {
38            size += key.encode_size();
39            size += value.encode_size();
40        }
41        size
42    }
43}
44
45// Read implementation for HashMap
46impl<K: Read + Clone + Ord + Hash + Eq, V: Read + Clone> Read for HashMap<K, V> {
47    type Cfg = (RangeCfg, (K::Cfg, V::Cfg));
48
49    fn read_cfg(buf: &mut impl Buf, (range, (k_cfg, v_cfg)): &Self::Cfg) -> Result<Self, Error> {
50        // Read and validate the length prefix
51        let len = usize::read_cfg(buf, range)?;
52        let mut map = HashMap::with_capacity(len);
53
54        // Keep track of the last key read
55        let mut last_key: Option<K> = None;
56
57        // Read each key-value pair
58        for _ in 0..len {
59            let key = K::read_cfg(buf, k_cfg)?;
60
61            // Check if keys are in ascending order relative to the previous key
62            if let Some(ref last) = last_key {
63                use std::cmp::Ordering;
64                match key.cmp(last) {
65                    Ordering::Equal => return Err(Error::Invalid("HashMap", "Duplicate key")),
66                    Ordering::Less => return Err(Error::Invalid("HashMap", "Keys must ascend")),
67                    _ => {}
68                }
69            }
70            last_key = Some(key.clone());
71
72            let value = V::read_cfg(buf, v_cfg)?;
73            map.insert(key, value);
74        }
75
76        Ok(map)
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use crate::{
83        codec::{Decode, Encode, EncodeSize, FixedSize, Read, Write},
84        error::Error,
85        RangeCfg,
86    };
87    use bytes::{BufMut, Bytes, BytesMut};
88    use std::{collections::HashMap, fmt::Debug, hash::Hash};
89
90    // Manual round trip test function for non-default configs
91    fn round_trip<K, V>(map: &HashMap<K, V>, range_cfg: RangeCfg, k_cfg: K::Cfg, v_cfg: V::Cfg)
92    where
93        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
94        V: Write + EncodeSize + Read + Clone + Debug + PartialEq,
95        HashMap<K, V>:
96            Read<Cfg = (RangeCfg, (K::Cfg, V::Cfg))> + Debug + PartialEq + Write + EncodeSize,
97    {
98        let encoded = map.encode();
99        let config_tuple = (range_cfg, (k_cfg, v_cfg));
100        let decoded =
101            HashMap::<K, V>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
102        assert_eq!(map, &decoded);
103    }
104
105    #[test]
106    fn test_empty_map() {
107        let map = HashMap::<u32, u64>::new();
108        round_trip(&map, (..).into(), (), ());
109        assert_eq!(map.encode_size(), 1);
110        let encoded = map.encode();
111        assert_eq!(encoded, 0usize.encode());
112    }
113
114    #[test]
115    fn test_simple_map_u32_u64() {
116        let mut map = HashMap::new();
117        map.insert(1u32, 100u64);
118        map.insert(5u32, 500u64);
119        map.insert(2u32, 200u64);
120        round_trip(&map, (..).into(), (), ());
121        assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
122    }
123
124    #[test]
125    fn test_large_map() {
126        let mut map = HashMap::new();
127        for i in 0..1000 {
128            map.insert(i, i as u64 * 2);
129        }
130        round_trip(&map, (..=1000).into(), (), ());
131    }
132
133    #[test]
134    fn test_map_with_variable_values() {
135        let mut map = HashMap::new();
136        map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
137        map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
138        map.insert(Bytes::from_static(b"cherry"), vec![]);
139
140        let map_range = RangeCfg::from(0..=10);
141        let key_range = RangeCfg::from(..=10);
142        let val_range = RangeCfg::from(0..=100);
143
144        round_trip(&map, map_range, key_range, (val_range, ()));
145    }
146
147    #[test]
148    fn test_decode_length_limit_exceeded() {
149        let mut map = HashMap::new();
150        map.insert(1u32, 100u64);
151        map.insert(5u32, 500u64);
152
153        let encoded = map.encode();
154        let restrictive_range = (0..=1).into();
155        let config_tuple = (restrictive_range, ((), ()));
156
157        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
158        assert!(matches!(result, Err(Error::InvalidLength(2))));
159    }
160
161    #[test]
162    fn test_decode_value_length_limit_exceeded() {
163        let mut map = HashMap::new();
164        map.insert(Bytes::from_static(b"key1"), vec![1u8, 2u8, 3u8, 4u8, 5u8]);
165
166        let key_range = RangeCfg::from(..=10);
167        let map_range = RangeCfg::from(0..=10);
168        let restrictive_val_range = RangeCfg::from(0..=3);
169
170        let encoded = map.encode();
171        let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
172        let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
173
174        assert!(matches!(result, Err(Error::InvalidLength(5))));
175    }
176
177    #[test]
178    fn test_decode_invalid_key_order() {
179        let mut encoded = BytesMut::new();
180        2usize.write(&mut encoded); // Map length = 2
181        5u32.write(&mut encoded); // Key 5
182        500u64.write(&mut encoded); // Value 500
183        2u32.write(&mut encoded); // Key 2 (out of order)
184        200u64.write(&mut encoded); // Value 200
185
186        let range = (..).into();
187        let config_tuple = (range, ((), ()));
188
189        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
190        assert!(matches!(
191            result,
192            Err(Error::Invalid("HashMap", "Keys must ascend"))
193        ));
194    }
195
196    #[test]
197    fn test_decode_duplicate_key() {
198        let mut encoded = BytesMut::new();
199        2usize.write(&mut encoded); // Map length = 2
200        1u32.write(&mut encoded); // Key 1
201        100u64.write(&mut encoded); // Value 100
202        1u32.write(&mut encoded); // Duplicate Key 1
203        200u64.write(&mut encoded); // Value 200
204
205        let range = (..).into();
206        let config_tuple = (range, ((), ()));
207
208        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
209        assert!(matches!(
210            result,
211            Err(Error::Invalid("HashMap", "Duplicate key"))
212        ));
213    }
214
215    #[test]
216    fn test_decode_end_of_buffer_key() {
217        let mut map = HashMap::new();
218        map.insert(1u32, 100u64);
219        map.insert(5u32, 500u64);
220
221        let mut encoded = map.encode();
222        encoded.truncate(map.encode_size() - 10); // Truncate during last key/value pair
223
224        let range = (..).into();
225        let config_tuple = (range, ((), ()));
226        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
227        assert!(matches!(result, Err(Error::EndOfBuffer)));
228    }
229
230    #[test]
231    fn test_decode_end_of_buffer_value() {
232        let mut map = HashMap::new();
233        map.insert(1u32, 100u64);
234        map.insert(5u32, 500u64);
235
236        let mut encoded = map.encode();
237        encoded.truncate(map.encode_size() - 4); // Truncate during last value
238
239        let range = RangeCfg::from(..);
240        let config_tuple = (range, ((), ()));
241        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
242        assert!(matches!(result, Err(Error::EndOfBuffer)));
243    }
244
245    #[test]
246    fn test_decode_extra_data() {
247        let mut map = HashMap::new();
248        map.insert(1u32, 100u64);
249
250        let mut encoded = map.encode();
251        encoded.put_u8(0xFF); // Add extra byte
252
253        let range = RangeCfg::from(..);
254        let config_tuple = (range, ((), ()));
255
256        // Use decode_cfg which enforces buffer is fully consumed
257        let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
258        assert!(matches!(result, Err(Error::ExtraData(1))));
259
260        // Verify that read_cfg would succeed (doesn't check for extra data)
261        let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
262        assert!(read_result.is_ok());
263        let decoded_map = read_result.unwrap();
264        assert_eq!(decoded_map.len(), 1);
265        assert_eq!(decoded_map.get(&1u32), Some(&100u64));
266    }
267
268    #[test]
269    fn test_conformity() {
270        let mut map1 = HashMap::<u8, u16>::new();
271        assert_eq!(map1.encode(), &[0x00][..]); // Empty map
272
273        map1.insert(1u8, 0xAAAAu16);
274        map1.insert(2u8, 0xBBBBu16);
275        // Expected: len=2 (0x02)
276        // Key 1 (0x01), Value 0xAAAA (0xAA, 0xAA)
277        // Key 2 (0x02), Value 0xBBBB (0xBB, 0xBB)
278        // Keys are sorted for encoding.
279        assert_eq!(
280            map1.encode(),
281            &[0x02, 0x01, 0xAA, 0xAA, 0x02, 0xBB, 0xBB][..]
282        );
283
284        let mut map2 = HashMap::<u16, bool>::new();
285        map2.insert(0x0303u16, true);
286        map2.insert(0x0101u16, false);
287        map2.insert(0x0202u16, true);
288        // Expected: len=3 (0x03)
289        // Key 0x0101, Value false (0x00)
290        // Key 0x0202, Value true (0x01)
291        // Key 0x0303, Value true (0x01)
292        assert_eq!(
293            map2.encode(),
294            &[0x03, 0x01, 0x01, 0x00, 0x02, 0x02, 0x01, 0x03, 0x03, 0x01][..]
295        );
296
297        // Map with Bytes as key and Vec<u8> as value
298        let mut map3 = HashMap::<Bytes, Vec<u8>>::new();
299        map3.insert(Bytes::from_static(b"b"), vec![20u8, 21u8]);
300        map3.insert(Bytes::from_static(b"a"), vec![10u8]);
301        // Expected: len=2 (0x02)
302        // Key "a": len=1 (0x01), 'a' (0x61)
303        // Value vec![10u8]: len=1 (0x01), 10u8 (0x0A)
304        // Key "b": len=1 (0x01), 'b' (0x62)
305        // Value vec![20u8, 21u8]: len=2 (0x02), 20u8 (0x14), 21u8 (0x15)
306        let mut expected_map3 = vec![0x02]; // Map length
307        expected_map3.extend_from_slice(&[0x01, 0x61]); // Key "a"
308        expected_map3.extend_from_slice(&[0x01, 0x0A]); // Value vec![10u8]
309        expected_map3.extend_from_slice(&[0x01, 0x62]); // Key "b"
310        expected_map3.extend_from_slice(&[0x02, 0x14, 0x15]); // Value vec![20u8, 21u8]
311        assert_eq!(map3.encode(), expected_map3.as_slice());
312    }
313}