commonware_codec/types/
map.rs

1//! Codec implementations for HashMap.
2//!
3//! For portability and consistency between architectures,
4//! the size of the map must fit within a [`u32`].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    varint, Config, RangeConfig,
10};
11use bytes::{Buf, BufMut};
12use std::{collections::HashMap, hash::Hash};
13
14// Write implementation for HashMap
15impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
16    fn write(&self, buf: &mut impl BufMut) {
17        let len = u32::try_from(self.len()).expect("HashMap length exceeds u32::MAX");
18        varint::write(len, buf);
19
20        // Sort the keys to ensure deterministic encoding
21        let mut keys: Vec<_> = self.keys().collect();
22        keys.sort();
23        for key in keys {
24            key.write(buf);
25            self.get(key).unwrap().write(buf);
26        }
27    }
28}
29
30// EncodeSize implementation for HashMap
31impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
32    fn encode_size(&self) -> usize {
33        // Start with the varint size of the length
34        let len = u32::try_from(self.len()).expect("HashMap length exceeds u32::MAX");
35        let mut size = varint::size(len);
36
37        // Add the encoded size of each key and value
38        // Note: Iteration order doesn't matter for size calculation.
39        for (key, value) in self {
40            size += key.encode_size();
41            size += value.encode_size();
42        }
43        size
44    }
45}
46
47// Read implementation for HashMap
48impl<
49        R: RangeConfig,
50        KCfg: Config,
51        VCfg: Config,
52        K: Read<KCfg> + Clone + Ord + Hash + Eq,
53        V: Read<VCfg> + Clone,
54    > Read<(R, (KCfg, VCfg))> for HashMap<K, V>
55{
56    fn read_cfg(
57        buf: &mut impl Buf,
58        (range, (k_cfg, v_cfg)): &(R, (KCfg, VCfg)),
59    ) -> Result<Self, Error> {
60        // Read and validate the length prefix
61        let len32 = varint::read::<u32>(buf)?;
62        let len = usize::try_from(len32).map_err(|_| Error::InvalidVarint)?;
63        if !range.contains(&len) {
64            return Err(Error::InvalidLength(len));
65        }
66        let mut map = HashMap::with_capacity(len);
67
68        // Keep track of the last key read
69        let mut last_key: Option<K> = None;
70
71        // Read each key-value pair
72        for _ in 0..len {
73            let key = K::read_cfg(buf, k_cfg)?;
74
75            // Check if keys are in ascending order relative to the previous key
76            if let Some(ref last) = last_key {
77                use std::cmp::Ordering;
78                match key.cmp(last) {
79                    Ordering::Equal => return Err(Error::Invalid("HashMap", "Duplicate key")),
80                    Ordering::Less => return Err(Error::Invalid("HashMap", "Keys must ascend")),
81                    _ => {}
82                }
83            }
84            last_key = Some(key.clone());
85
86            let value = V::read_cfg(buf, v_cfg)?;
87            map.insert(key, value);
88        }
89
90        Ok(map)
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use crate::{
97        codec::{Decode, Encode, EncodeSize, FixedSize, Read, Write},
98        error::Error,
99        varint, Config, RangeConfig,
100    };
101    use bytes::{BufMut, Bytes, BytesMut};
102    use std::collections::HashMap;
103    use std::fmt::Debug;
104    use std::hash::Hash;
105    use std::ops::RangeInclusive;
106
107    // Manual round trip test function for non-default configs
108    fn round_trip<K, V, R, KCfg, VCfg>(map: &HashMap<K, V>, range_cfg: R, k_cfg: KCfg, v_cfg: VCfg)
109    where
110        K: Write + EncodeSize + Read<KCfg> + Clone + Ord + Hash + Eq + Debug + PartialEq,
111        V: Write + EncodeSize + Read<VCfg> + Clone + Debug + PartialEq,
112        R: RangeConfig + Clone,
113        KCfg: Config + Clone,
114        VCfg: Config + Clone,
115        HashMap<K, V>: Read<(R, (KCfg, VCfg))>
116            + Decode<(R, (KCfg, VCfg))>
117            + Debug
118            + PartialEq
119            + Write
120            + EncodeSize,
121    {
122        let encoded = map.encode();
123        let config_tuple = (range_cfg, (k_cfg, v_cfg));
124        let decoded =
125            HashMap::<K, V>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
126        assert_eq!(map, &decoded);
127    }
128
129    fn allow_any_len() -> RangeInclusive<usize> {
130        0..=usize::MAX
131    }
132
133    #[test]
134    fn test_empty_map() {
135        let map = HashMap::<u32, u64>::new();
136        round_trip(&map, allow_any_len(), (), ());
137        assert_eq!(map.encode_size(), 1);
138        let encoded = map.encode();
139        assert_eq!(encoded, Bytes::from_static(&[0])); // varint 0
140    }
141
142    #[test]
143    fn test_simple_map_u32_u64() {
144        let mut map = HashMap::new();
145        map.insert(1u32, 100u64);
146        map.insert(5u32, 500u64);
147        map.insert(2u32, 200u64);
148        round_trip(&map, allow_any_len(), (), ());
149        assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
150    }
151
152    #[test]
153    fn test_large_map() {
154        let mut map = HashMap::new();
155        for i in 0..1000 {
156            map.insert(i, i as u64 * 2);
157        }
158        round_trip(&map, 0..=1000, (), ());
159    }
160
161    #[test]
162    fn test_map_with_variable_values() {
163        let mut map = HashMap::new();
164        map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
165        map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
166        map.insert(Bytes::from_static(b"cherry"), vec![]);
167
168        let map_range = 0..=10;
169        let key_range = ..=10;
170        let val_range = 0..=100;
171
172        round_trip(&map, map_range, key_range, (val_range, ()));
173    }
174
175    #[test]
176    fn test_decode_length_limit_exceeded() {
177        let mut map = HashMap::new();
178        map.insert(1u32, 100u64);
179        map.insert(5u32, 500u64);
180
181        let encoded = map.encode();
182        let restrictive_range = 0..=1;
183        let config_tuple = (restrictive_range, ((), ()));
184
185        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
186        assert!(matches!(result, Err(Error::InvalidLength(2))));
187    }
188
189    #[test]
190    fn test_decode_value_length_limit_exceeded() {
191        let mut map = HashMap::new();
192        map.insert(Bytes::from_static(b"key1"), vec![1, 2, 3, 4, 5]);
193
194        let key_range = ..=10;
195        let map_range = 0..=10;
196        let restrictive_val_range = 0..=3;
197
198        let encoded = map.encode();
199        let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
200        let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
201
202        assert!(matches!(result, Err(Error::InvalidLength(5))));
203    }
204
205    #[test]
206    fn test_decode_invalid_key_order() {
207        let mut encoded = BytesMut::new();
208        varint::write(2u32, &mut encoded); // Map length = 2
209        5u32.write(&mut encoded); // Key 5
210        500u64.write(&mut encoded); // Value 500
211        2u32.write(&mut encoded); // Key 2 (out of order)
212        200u64.write(&mut encoded); // Value 200
213
214        let range = allow_any_len();
215        let config_tuple = (range, ((), ()));
216
217        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
218        assert!(matches!(
219            result,
220            Err(Error::Invalid("HashMap", "Keys must ascend"))
221        ));
222    }
223
224    #[test]
225    fn test_decode_duplicate_key() {
226        let mut encoded = BytesMut::new();
227        varint::write(2u32, &mut encoded); // Map length = 2
228        1u32.write(&mut encoded); // Key 1
229        100u64.write(&mut encoded); // Value 100
230        1u32.write(&mut encoded); // Duplicate Key 1
231        200u64.write(&mut encoded); // Value 200
232
233        let range = allow_any_len();
234        let config_tuple = (range, ((), ()));
235
236        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
237        assert!(matches!(
238            result,
239            Err(Error::Invalid("HashMap", "Duplicate key"))
240        ));
241    }
242
243    #[test]
244    fn test_decode_end_of_buffer_key() {
245        let mut map = HashMap::new();
246        map.insert(1u32, 100u64);
247        map.insert(5u32, 500u64);
248
249        let mut encoded = map.encode();
250        encoded.truncate(map.encode_size() - 10); // Truncate during last key/value pair
251
252        let range = allow_any_len();
253        let config_tuple = (range, ((), ()));
254        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
255        assert!(matches!(result, Err(Error::EndOfBuffer)));
256    }
257
258    #[test]
259    fn test_decode_end_of_buffer_value() {
260        let mut map = HashMap::new();
261        map.insert(1u32, 100u64);
262        map.insert(5u32, 500u64);
263
264        let mut encoded = map.encode();
265        encoded.truncate(map.encode_size() - 4); // Truncate during last value
266
267        let range = allow_any_len();
268        let config_tuple = (range, ((), ()));
269        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
270        assert!(matches!(result, Err(Error::EndOfBuffer)));
271    }
272
273    #[test]
274    fn test_decode_extra_data() {
275        let mut map = HashMap::new();
276        map.insert(1u32, 100u64);
277
278        let mut encoded = map.encode();
279        encoded.put_u8(0xFF); // Add extra byte
280
281        let range = allow_any_len();
282        let config_tuple = (range.clone(), ((), ())); // Clone range for read_cfg later
283
284        // Use decode_cfg which enforces buffer is fully consumed
285        let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
286        assert!(matches!(result, Err(Error::ExtraData(1))));
287
288        // Verify that read_cfg would succeed (doesn't check for extra data)
289        let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
290        assert!(read_result.is_ok());
291        let decoded_map = read_result.unwrap();
292        assert_eq!(decoded_map.len(), 1);
293        assert_eq!(decoded_map.get(&1u32), Some(&100u64));
294    }
295}