commonware_codec/types/
map.rs

1//! Codec implementations for HashMap.
2//!
3//! For portability and consistency between architectures,
4//! the size of the map must fit within a [`u32`].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    Config, RangeConfig,
10};
11use bytes::{Buf, BufMut};
12use std::{collections::HashMap, hash::Hash};
13
14// Write implementation for HashMap
15impl<K: Ord + Hash + Eq + Write, V: Write> Write for HashMap<K, V> {
16    fn write(&self, buf: &mut impl BufMut) {
17        self.len().write(buf);
18
19        // Sort the keys to ensure deterministic encoding
20        let mut keys: Vec<_> = self.keys().collect();
21        keys.sort();
22        for key in keys {
23            key.write(buf);
24            self.get(key).unwrap().write(buf);
25        }
26    }
27}
28
29// EncodeSize implementation for HashMap
30impl<K: Ord + Hash + Eq + EncodeSize, V: EncodeSize> EncodeSize for HashMap<K, V> {
31    fn encode_size(&self) -> usize {
32        // Start with the size of the length prefix
33        let mut size = self.len().encode_size();
34
35        // Add the encoded size of each key and value
36        // Note: Iteration order doesn't matter for size calculation.
37        for (key, value) in self {
38            size += key.encode_size();
39            size += value.encode_size();
40        }
41        size
42    }
43}
44
45// Read implementation for HashMap
46impl<
47        R: RangeConfig,
48        KCfg: Config,
49        VCfg: Config,
50        K: Read<KCfg> + Clone + Ord + Hash + Eq,
51        V: Read<VCfg> + Clone,
52    > Read<(R, (KCfg, VCfg))> for HashMap<K, V>
53{
54    fn read_cfg(
55        buf: &mut impl Buf,
56        (range, (k_cfg, v_cfg)): &(R, (KCfg, VCfg)),
57    ) -> Result<Self, Error> {
58        // Read and validate the length prefix
59        let len = usize::read_cfg(buf, range)?;
60        let mut map = HashMap::with_capacity(len);
61
62        // Keep track of the last key read
63        let mut last_key: Option<K> = None;
64
65        // Read each key-value pair
66        for _ in 0..len {
67            let key = K::read_cfg(buf, k_cfg)?;
68
69            // Check if keys are in ascending order relative to the previous key
70            if let Some(ref last) = last_key {
71                use std::cmp::Ordering;
72                match key.cmp(last) {
73                    Ordering::Equal => return Err(Error::Invalid("HashMap", "Duplicate key")),
74                    Ordering::Less => return Err(Error::Invalid("HashMap", "Keys must ascend")),
75                    _ => {}
76                }
77            }
78            last_key = Some(key.clone());
79
80            let value = V::read_cfg(buf, v_cfg)?;
81            map.insert(key, value);
82        }
83
84        Ok(map)
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use crate::{
91        codec::{Decode, Encode, EncodeSize, FixedSize, Read, Write},
92        error::Error,
93        Config, RangeConfig,
94    };
95    use bytes::{BufMut, Bytes, BytesMut};
96    use std::collections::HashMap;
97    use std::fmt::Debug;
98    use std::hash::Hash;
99    use std::ops::RangeInclusive;
100
101    // Manual round trip test function for non-default configs
102    fn round_trip<K, V, R, KCfg, VCfg>(map: &HashMap<K, V>, range_cfg: R, k_cfg: KCfg, v_cfg: VCfg)
103    where
104        K: Write + EncodeSize + Read<KCfg> + Clone + Ord + Hash + Eq + Debug + PartialEq,
105        V: Write + EncodeSize + Read<VCfg> + Clone + Debug + PartialEq,
106        R: RangeConfig + Clone,
107        KCfg: Config + Clone,
108        VCfg: Config + Clone,
109        HashMap<K, V>: Read<(R, (KCfg, VCfg))>
110            + Decode<(R, (KCfg, VCfg))>
111            + Debug
112            + PartialEq
113            + Write
114            + EncodeSize,
115    {
116        let encoded = map.encode();
117        let config_tuple = (range_cfg, (k_cfg, v_cfg));
118        let decoded =
119            HashMap::<K, V>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
120        assert_eq!(map, &decoded);
121    }
122
123    fn allow_any_len() -> RangeInclusive<usize> {
124        0..=usize::MAX
125    }
126
127    #[test]
128    fn test_empty_map() {
129        let map = HashMap::<u32, u64>::new();
130        round_trip(&map, allow_any_len(), (), ());
131        assert_eq!(map.encode_size(), 1);
132        let encoded = map.encode();
133        assert_eq!(encoded, 0usize.encode());
134    }
135
136    #[test]
137    fn test_simple_map_u32_u64() {
138        let mut map = HashMap::new();
139        map.insert(1u32, 100u64);
140        map.insert(5u32, 500u64);
141        map.insert(2u32, 200u64);
142        round_trip(&map, allow_any_len(), (), ());
143        assert_eq!(map.encode_size(), 1 + 3 * (u32::SIZE + u64::SIZE));
144    }
145
146    #[test]
147    fn test_large_map() {
148        let mut map = HashMap::new();
149        for i in 0..1000 {
150            map.insert(i, i as u64 * 2);
151        }
152        round_trip(&map, 0..=1000, (), ());
153    }
154
155    #[test]
156    fn test_map_with_variable_values() {
157        let mut map = HashMap::new();
158        map.insert(Bytes::from_static(b"apple"), vec![1, 2]);
159        map.insert(Bytes::from_static(b"banana"), vec![3, 4, 5]);
160        map.insert(Bytes::from_static(b"cherry"), vec![]);
161
162        let map_range = 0..=10;
163        let key_range = ..=10;
164        let val_range = 0..=100;
165
166        round_trip(&map, map_range, key_range, (val_range, ()));
167    }
168
169    #[test]
170    fn test_decode_length_limit_exceeded() {
171        let mut map = HashMap::new();
172        map.insert(1u32, 100u64);
173        map.insert(5u32, 500u64);
174
175        let encoded = map.encode();
176        let restrictive_range = 0..=1;
177        let config_tuple = (restrictive_range, ((), ()));
178
179        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
180        assert!(matches!(result, Err(Error::InvalidLength(2))));
181    }
182
183    #[test]
184    fn test_decode_value_length_limit_exceeded() {
185        let mut map = HashMap::new();
186        map.insert(Bytes::from_static(b"key1"), vec![1u8, 2u8, 3u8, 4u8, 5u8]);
187
188        let key_range = ..=10;
189        let map_range = 0..=10;
190        let restrictive_val_range = 0..=3;
191
192        let encoded = map.encode();
193        let config_tuple = (map_range, (key_range, (restrictive_val_range, ())));
194        let result = HashMap::<Bytes, Vec<u8>>::decode_cfg(encoded, &config_tuple);
195
196        assert!(matches!(result, Err(Error::InvalidLength(5))));
197    }
198
199    #[test]
200    fn test_decode_invalid_key_order() {
201        let mut encoded = BytesMut::new();
202        2usize.write(&mut encoded); // Map length = 2
203        5u32.write(&mut encoded); // Key 5
204        500u64.write(&mut encoded); // Value 500
205        2u32.write(&mut encoded); // Key 2 (out of order)
206        200u64.write(&mut encoded); // Value 200
207
208        let range = allow_any_len();
209        let config_tuple = (range, ((), ()));
210
211        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
212        assert!(matches!(
213            result,
214            Err(Error::Invalid("HashMap", "Keys must ascend"))
215        ));
216    }
217
218    #[test]
219    fn test_decode_duplicate_key() {
220        let mut encoded = BytesMut::new();
221        2usize.write(&mut encoded); // Map length = 2
222        1u32.write(&mut encoded); // Key 1
223        100u64.write(&mut encoded); // Value 100
224        1u32.write(&mut encoded); // Duplicate Key 1
225        200u64.write(&mut encoded); // Value 200
226
227        let range = allow_any_len();
228        let config_tuple = (range, ((), ()));
229
230        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
231        assert!(matches!(
232            result,
233            Err(Error::Invalid("HashMap", "Duplicate key"))
234        ));
235    }
236
237    #[test]
238    fn test_decode_end_of_buffer_key() {
239        let mut map = HashMap::new();
240        map.insert(1u32, 100u64);
241        map.insert(5u32, 500u64);
242
243        let mut encoded = map.encode();
244        encoded.truncate(map.encode_size() - 10); // Truncate during last key/value pair
245
246        let range = allow_any_len();
247        let config_tuple = (range, ((), ()));
248        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
249        assert!(matches!(result, Err(Error::EndOfBuffer)));
250    }
251
252    #[test]
253    fn test_decode_end_of_buffer_value() {
254        let mut map = HashMap::new();
255        map.insert(1u32, 100u64);
256        map.insert(5u32, 500u64);
257
258        let mut encoded = map.encode();
259        encoded.truncate(map.encode_size() - 4); // Truncate during last value
260
261        let range = allow_any_len();
262        let config_tuple = (range, ((), ()));
263        let result = HashMap::<u32, u64>::decode_cfg(encoded, &config_tuple);
264        assert!(matches!(result, Err(Error::EndOfBuffer)));
265    }
266
267    #[test]
268    fn test_decode_extra_data() {
269        let mut map = HashMap::new();
270        map.insert(1u32, 100u64);
271
272        let mut encoded = map.encode();
273        encoded.put_u8(0xFF); // Add extra byte
274
275        let range = allow_any_len();
276        let config_tuple = (range.clone(), ((), ())); // Clone range for read_cfg later
277
278        // Use decode_cfg which enforces buffer is fully consumed
279        let result = HashMap::<u32, u64>::decode_cfg(encoded.clone(), &config_tuple);
280        assert!(matches!(result, Err(Error::ExtraData(1))));
281
282        // Verify that read_cfg would succeed (doesn't check for extra data)
283        let read_result = HashMap::<u32, u64>::read_cfg(&mut encoded, &config_tuple);
284        assert!(read_result.is_ok());
285        let decoded_map = read_result.unwrap();
286        assert_eq!(decoded_map.len(), 1);
287        assert_eq!(decoded_map.get(&1u32), Some(&100u64));
288    }
289}