commonware_codec/types/
hash_set.rs

1//! Codec implementations for HashSet (requires std).
2//!
3//! For portability and consistency between architectures,
4//! the size of the set must fit within a [u32].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{cmp::Ordering, collections::HashSet, hash::Hash};
13
14const HASHSET_TYPE: &str = "HashSet";
15
16/// Read items from [Buf] in ascending order.
17fn read_ordered_set<K, F>(
18    buf: &mut impl Buf,
19    len: usize,
20    cfg: &K::Cfg,
21    mut insert: F,
22    set_type: &'static str,
23) -> Result<(), Error>
24where
25    K: Read + Ord,
26    F: FnMut(K) -> bool,
27{
28    let mut last: Option<K> = None;
29    for _ in 0..len {
30        // Read item
31        let item = K::read_cfg(buf, cfg)?;
32
33        // Check if items are in ascending order
34        if let Some(ref last) = last {
35            match item.cmp(last) {
36                Ordering::Equal => return Err(Error::Invalid(set_type, "Duplicate item")),
37                Ordering::Less => return Err(Error::Invalid(set_type, "Items must ascend")),
38                _ => {}
39            }
40        }
41
42        // Add previous item, if exists
43        if let Some(last) = last.take() {
44            insert(last);
45        }
46        last = Some(item);
47    }
48
49    // Add last item, if exists
50    if let Some(last) = last {
51        insert(last);
52    }
53
54    Ok(())
55}
56
57impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
58    fn write(&self, buf: &mut impl BufMut) {
59        self.len().write(buf);
60
61        // Sort the items to ensure deterministic encoding
62        let mut items: Vec<_> = self.iter().collect();
63        items.sort();
64        for item in items {
65            item.write(buf);
66        }
67    }
68}
69
70impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
71    fn encode_size(&self) -> usize {
72        let mut size = self.len().encode_size();
73
74        // Note: Iteration order doesn't matter for size calculation.
75        for item in self {
76            size += item.encode_size();
77        }
78        size
79    }
80}
81
82impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
83    type Cfg = (RangeCfg, K::Cfg);
84
85    fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
86        // Read and validate the length prefix
87        let len = usize::read_cfg(buf, range)?;
88        let mut set = HashSet::with_capacity(len);
89
90        // Read items in ascending order
91        read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
92
93        Ok(set)
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::{
101        codec::{Decode, Encode},
102        FixedSize,
103    };
104    use bytes::{Bytes, BytesMut};
105    use std::fmt::Debug;
106
107    // Generic round trip test function for HashSet
108    fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
109    where
110        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
111        HashSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
112            + Decode<Cfg = (RangeCfg, K::Cfg)>
113            + Debug
114            + PartialEq
115            + Write
116            + EncodeSize,
117    {
118        let encoded = set.encode();
119        assert_eq!(set.encode_size(), encoded.len());
120        let config_tuple = (range_cfg, item_cfg);
121        let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
122        assert_eq!(set, &decoded);
123    }
124
125    // --- HashSet Tests ---
126
127    #[test]
128    fn test_empty_hashset() {
129        let set = HashSet::<u32>::new();
130        round_trip_hash(&set, (..).into(), ());
131        assert_eq!(set.encode_size(), 1); // varint 0
132        let encoded = set.encode();
133        assert_eq!(encoded, Bytes::from_static(&[0]));
134    }
135
136    #[test]
137    fn test_simple_hashset_u32() {
138        let mut set = HashSet::new();
139        set.insert(1u32);
140        set.insert(5u32);
141        set.insert(2u32);
142        round_trip_hash(&set, (..).into(), ());
143        // Size calculation: varint len + size of each item (order doesn't matter for size)
144        assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
145        // Encoding check: items must be sorted (1, 2, 5)
146        let mut expected = BytesMut::new();
147        3usize.write(&mut expected); // Set length = 3
148        1u32.write(&mut expected);
149        2u32.write(&mut expected);
150        5u32.write(&mut expected);
151        assert_eq!(set.encode(), expected.freeze());
152    }
153
154    #[test]
155    fn test_large_hashset() {
156        // Fixed-size items
157        let set: HashSet<_> = (0..1000u16).collect();
158        round_trip_hash(&set, (1000..=1000).into(), ());
159
160        // Variable-size items
161        let set: HashSet<_> = (0..1000usize).collect();
162        round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
163    }
164
165    #[test]
166    fn test_hashset_with_variable_items() {
167        let mut set = HashSet::new();
168        set.insert(Bytes::from_static(b"apple"));
169        set.insert(Bytes::from_static(b"banana"));
170        set.insert(Bytes::from_static(b"cherry"));
171
172        let set_range = 0..=10;
173        let item_range = ..=10; // Range for Bytes length
174
175        round_trip_hash(&set, set_range.into(), item_range.into());
176    }
177
178    #[test]
179    fn test_hashset_decode_length_limit_exceeded() {
180        let mut set = HashSet::new();
181        set.insert(1u32);
182        set.insert(5u32);
183
184        let encoded = set.encode();
185        let config_tuple = ((0..=1).into(), ());
186
187        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
188        assert!(matches!(result, Err(Error::InvalidLength(2))));
189    }
190
191    #[test]
192    fn test_hashset_decode_item_length_limit_exceeded() {
193        let mut set = HashSet::new();
194        set.insert(Bytes::from_static(b"longitem")); // 8 bytes
195
196        let set_range = 0..=10;
197        let restrictive_item_range = ..=5; // Limit item length
198
199        let encoded = set.encode();
200        let config_tuple = (set_range.into(), restrictive_item_range.into());
201        let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
202
203        assert!(matches!(result, Err(Error::InvalidLength(8))));
204    }
205
206    #[test]
207    fn test_hashset_decode_invalid_item_order() {
208        let mut encoded = BytesMut::new();
209        2usize.write(&mut encoded); // Set length = 2
210        5u32.write(&mut encoded); // Item 5
211        2u32.write(&mut encoded); // Item 2 (out of order)
212
213        let config_tuple = ((..).into(), ());
214
215        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
216        assert!(matches!(
217            result,
218            Err(Error::Invalid("HashSet", "Items must ascend"))
219        ));
220    }
221
222    #[test]
223    fn test_hashset_decode_duplicate_item() {
224        let mut encoded = BytesMut::new();
225        2usize.write(&mut encoded); // Set length = 2
226        1u32.write(&mut encoded); // Item 1
227        1u32.write(&mut encoded); // Duplicate Item 1
228
229        let config_tuple = ((..).into(), ());
230        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
231        assert!(matches!(
232            result,
233            Err(Error::Invalid("HashSet", "Duplicate item"))
234        ));
235    }
236
237    #[test]
238    fn test_hashset_decode_end_of_buffer() {
239        let mut set = HashSet::new();
240        set.insert(1u32);
241        set.insert(5u32);
242
243        let mut encoded = set.encode(); // Will be sorted: [1, 5]
244        encoded.truncate(set.encode_size() - 2); // Truncate during last item (5)
245
246        let config_tuple = ((..).into(), ());
247        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
248        assert!(matches!(result, Err(Error::EndOfBuffer)));
249    }
250
251    #[test]
252    fn test_hashset_decode_extra_data() {
253        let mut set = HashSet::new();
254        set.insert(1u32);
255
256        let mut encoded = set.encode();
257        encoded.put_u8(0xFF); // Add extra byte
258
259        // Use decode_cfg which enforces buffer is fully consumed
260        let config_tuple = ((..).into(), ()); // Clone range for read_cfg later
261        let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
262        assert!(matches!(result, Err(Error::ExtraData(1))));
263
264        // Verify that read_cfg would succeed (doesn't check for extra data)
265        let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
266        assert!(read_result.is_ok());
267        let decoded_set = read_result.unwrap();
268        assert_eq!(decoded_set.len(), 1);
269        assert!(decoded_set.contains(&1u32));
270    }
271
272    #[test]
273    fn test_hashset_deterministic_encoding() {
274        let mut set1 = HashSet::new();
275        (0..1000u32).for_each(|i| {
276            set1.insert(i);
277        });
278
279        let mut set2 = HashSet::new();
280        (0..1000u32).rev().for_each(|i| {
281            set2.insert(i);
282        });
283
284        assert_eq!(set1.encode(), set2.encode());
285    }
286
287    #[test]
288    fn test_hashset_conformity() {
289        // Case 1: Empty HashSet<u8>
290        let set1 = HashSet::<u8>::new();
291        let mut expected1 = BytesMut::new();
292        0usize.write(&mut expected1); // Length 0
293        assert_eq!(set1.encode(), expected1.freeze());
294        assert_eq!(set1.encode_size(), 1);
295
296        // Case 2: Simple HashSet<u8>
297        // HashSet will sort items for encoding: 1, 2, 5
298        let mut set2 = HashSet::<u8>::new();
299        set2.insert(5u8);
300        set2.insert(1u8);
301        set2.insert(2u8);
302
303        let mut expected2 = BytesMut::new();
304        3usize.write(&mut expected2); // Length 3
305        1u8.write(&mut expected2); // Item 1
306        2u8.write(&mut expected2); // Item 2
307        5u8.write(&mut expected2); // Item 5
308        assert_eq!(set2.encode(), expected2.freeze());
309        assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
310
311        // Case 3: HashSet<Bytes>
312        // HashSet sorts items for encoding: "apple", "banana", "cherry"
313        let mut set3 = HashSet::<Bytes>::new();
314        set3.insert(Bytes::from_static(b"cherry"));
315        set3.insert(Bytes::from_static(b"apple"));
316        set3.insert(Bytes::from_static(b"banana"));
317
318        let mut expected3 = BytesMut::new();
319        3usize.write(&mut expected3); // Length 3
320        Bytes::from_static(b"apple").write(&mut expected3);
321        Bytes::from_static(b"banana").write(&mut expected3);
322        Bytes::from_static(b"cherry").write(&mut expected3);
323        assert_eq!(set3.encode(), expected3.freeze());
324        let expected_size = 1usize.encode_size()
325            + Bytes::from_static(b"apple").encode_size()
326            + Bytes::from_static(b"banana").encode_size()
327            + Bytes::from_static(b"cherry").encode_size();
328        assert_eq!(set3.encode_size(), expected_size);
329    }
330}