commonware_codec/types/
hash_set.rs

1//! Codec implementations for HashSet (requires std).
2//!
3//! For portability and consistency between architectures,
4//! the size of the set must fit within a [u32].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    types::read_ordered_set,
10    RangeCfg,
11};
12use bytes::{Buf, BufMut};
13use std::{collections::HashSet, hash::Hash};
14
15const HASHSET_TYPE: &str = "HashSet";
16
17impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
18    fn write(&self, buf: &mut impl BufMut) {
19        self.len().write(buf);
20
21        // Sort the items to ensure deterministic encoding
22        let mut items: Vec<_> = self.iter().collect();
23        items.sort();
24        for item in items {
25            item.write(buf);
26        }
27    }
28}
29
30impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
31    fn encode_size(&self) -> usize {
32        let mut size = self.len().encode_size();
33
34        // Note: Iteration order doesn't matter for size calculation.
35        for item in self {
36            size += item.encode_size();
37        }
38        size
39    }
40}
41
42impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
43    type Cfg = (RangeCfg<usize>, K::Cfg);
44
45    fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
46        // Read and validate the length prefix
47        let len = usize::read_cfg(buf, range)?;
48        let mut set = Self::with_capacity(len);
49
50        // Read items in ascending order
51        read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
52
53        Ok(set)
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use crate::{
61        codec::{Decode, Encode},
62        FixedSize,
63    };
64    use bytes::{Bytes, BytesMut};
65    use std::fmt::Debug;
66
67    // Generic round trip test function for HashSet
68    fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg<usize>, item_cfg: K::Cfg)
69    where
70        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
71        HashSet<K>: Read<Cfg = (RangeCfg<usize>, K::Cfg)>
72            + Decode<Cfg = (RangeCfg<usize>, K::Cfg)>
73            + Debug
74            + PartialEq
75            + Write
76            + EncodeSize,
77    {
78        let encoded = set.encode();
79        assert_eq!(set.encode_size(), encoded.len());
80        let config_tuple = (range_cfg, item_cfg);
81        let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
82        assert_eq!(set, &decoded);
83    }
84
85    // --- HashSet Tests ---
86
87    #[test]
88    fn test_empty_hashset() {
89        let set = HashSet::<u32>::new();
90        round_trip_hash(&set, (..).into(), ());
91        assert_eq!(set.encode_size(), 1); // varint 0
92        let encoded = set.encode();
93        assert_eq!(encoded, Bytes::from_static(&[0]));
94    }
95
96    #[test]
97    fn test_simple_hashset_u32() {
98        let mut set = HashSet::new();
99        set.insert(1u32);
100        set.insert(5u32);
101        set.insert(2u32);
102        round_trip_hash(&set, (..).into(), ());
103        // Size calculation: varint len + size of each item (order doesn't matter for size)
104        assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
105        // Encoding check: items must be sorted (1, 2, 5)
106        let mut expected = BytesMut::new();
107        3usize.write(&mut expected); // Set length = 3
108        1u32.write(&mut expected);
109        2u32.write(&mut expected);
110        5u32.write(&mut expected);
111        assert_eq!(set.encode(), expected.freeze());
112    }
113
114    #[test]
115    fn test_large_hashset() {
116        // Fixed-size items
117        let set: HashSet<_> = (0..1000u16).collect();
118        round_trip_hash(&set, (1000..=1000).into(), ());
119
120        // Variable-size items
121        let set: HashSet<_> = (0..1000usize).collect();
122        round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
123    }
124
125    #[test]
126    fn test_hashset_with_variable_items() {
127        let mut set = HashSet::new();
128        set.insert(Bytes::from_static(b"apple"));
129        set.insert(Bytes::from_static(b"banana"));
130        set.insert(Bytes::from_static(b"cherry"));
131
132        let set_range = 0..=10;
133        let item_range = ..=10; // Range for Bytes length
134
135        round_trip_hash(&set, set_range.into(), item_range.into());
136    }
137
138    #[test]
139    fn test_hashset_decode_length_limit_exceeded() {
140        let mut set = HashSet::new();
141        set.insert(1u32);
142        set.insert(5u32);
143
144        let encoded = set.encode();
145        let config_tuple = ((0..=1).into(), ());
146
147        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
148        assert!(matches!(result, Err(Error::InvalidLength(2))));
149    }
150
151    #[test]
152    fn test_hashset_decode_item_length_limit_exceeded() {
153        let mut set = HashSet::new();
154        set.insert(Bytes::from_static(b"longitem")); // 8 bytes
155
156        let set_range = 0..=10;
157        let restrictive_item_range = ..=5; // Limit item length
158
159        let encoded = set.encode();
160        let config_tuple = (set_range.into(), restrictive_item_range.into());
161        let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
162
163        assert!(matches!(result, Err(Error::InvalidLength(8))));
164    }
165
166    #[test]
167    fn test_hashset_decode_invalid_item_order() {
168        let mut encoded = BytesMut::new();
169        2usize.write(&mut encoded); // Set length = 2
170        5u32.write(&mut encoded); // Item 5
171        2u32.write(&mut encoded); // Item 2 (out of order)
172
173        let config_tuple = ((..).into(), ());
174
175        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
176        assert!(matches!(
177            result,
178            Err(Error::Invalid("HashSet", "Items must ascend"))
179        ));
180    }
181
182    #[test]
183    fn test_hashset_decode_duplicate_item() {
184        let mut encoded = BytesMut::new();
185        2usize.write(&mut encoded); // Set length = 2
186        1u32.write(&mut encoded); // Item 1
187        1u32.write(&mut encoded); // Duplicate Item 1
188
189        let config_tuple = ((..).into(), ());
190        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
191        assert!(matches!(
192            result,
193            Err(Error::Invalid("HashSet", "Duplicate item"))
194        ));
195    }
196
197    #[test]
198    fn test_hashset_decode_end_of_buffer() {
199        let mut set = HashSet::new();
200        set.insert(1u32);
201        set.insert(5u32);
202
203        let mut encoded = set.encode(); // Will be sorted: [1, 5]
204        encoded.truncate(set.encode_size() - 2); // Truncate during last item (5)
205
206        let config_tuple = ((..).into(), ());
207        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
208        assert!(matches!(result, Err(Error::EndOfBuffer)));
209    }
210
211    #[test]
212    fn test_hashset_decode_extra_data() {
213        let mut set = HashSet::new();
214        set.insert(1u32);
215
216        let mut encoded = set.encode_mut();
217        encoded.put_u8(0xFF); // Add extra byte
218
219        // Use decode_cfg which enforces buffer is fully consumed
220        let config_tuple = ((..).into(), ()); // Clone range for read_cfg later
221        let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
222        assert!(matches!(result, Err(Error::ExtraData(1))));
223
224        // Verify that read_cfg would succeed (doesn't check for extra data)
225        let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
226        assert!(read_result.is_ok());
227        let decoded_set = read_result.unwrap();
228        assert_eq!(decoded_set.len(), 1);
229        assert!(decoded_set.contains(&1u32));
230    }
231
232    #[test]
233    fn test_hashset_deterministic_encoding() {
234        let mut set1 = HashSet::new();
235        (0..1000u32).for_each(|i| {
236            set1.insert(i);
237        });
238
239        let mut set2 = HashSet::new();
240        (0..1000u32).rev().for_each(|i| {
241            set2.insert(i);
242        });
243
244        assert_eq!(set1.encode(), set2.encode());
245    }
246
247    #[test]
248    fn test_hashset_conformity() {
249        // Case 1: Empty HashSet<u8>
250        let set1 = HashSet::<u8>::new();
251        let mut expected1 = BytesMut::new();
252        0usize.write(&mut expected1); // Length 0
253        assert_eq!(set1.encode(), expected1.freeze());
254        assert_eq!(set1.encode_size(), 1);
255
256        // Case 2: Simple HashSet<u8>
257        // HashSet will sort items for encoding: 1, 2, 5
258        let mut set2 = HashSet::<u8>::new();
259        set2.insert(5u8);
260        set2.insert(1u8);
261        set2.insert(2u8);
262
263        let mut expected2 = BytesMut::new();
264        3usize.write(&mut expected2); // Length 3
265        1u8.write(&mut expected2); // Item 1
266        2u8.write(&mut expected2); // Item 2
267        5u8.write(&mut expected2); // Item 5
268        assert_eq!(set2.encode(), expected2.freeze());
269        assert_eq!(set2.encode_size(), 1 + 3);
270
271        // Case 3: HashSet<Bytes>
272        // HashSet sorts items for encoding: "apple", "banana", "cherry"
273        let mut set3 = HashSet::<Bytes>::new();
274        set3.insert(Bytes::from_static(b"cherry"));
275        set3.insert(Bytes::from_static(b"apple"));
276        set3.insert(Bytes::from_static(b"banana"));
277
278        let mut expected3 = BytesMut::new();
279        3usize.write(&mut expected3); // Length 3
280        Bytes::from_static(b"apple").write(&mut expected3);
281        Bytes::from_static(b"banana").write(&mut expected3);
282        Bytes::from_static(b"cherry").write(&mut expected3);
283        assert_eq!(set3.encode(), expected3.freeze());
284        let expected_size = 1usize.encode_size()
285            + Bytes::from_static(b"apple").encode_size()
286            + Bytes::from_static(b"banana").encode_size()
287            + Bytes::from_static(b"cherry").encode_size();
288        assert_eq!(set3.encode_size(), expected_size);
289    }
290
291    #[cfg(feature = "arbitrary")]
292    mod conformance {
293        use super::*;
294        use crate::conformance::CodecConformance;
295
296        commonware_conformance::conformance_tests! {
297            CodecConformance<HashSet<u32>>
298        }
299    }
300}