Skip to main content

commonware_codec/types/
hash_set.rs

1//! Codec implementations for HashSet (requires std).
2//!
3//! For portability and consistency between architectures,
4//! the size of the set must fit within a [u32].
5
6use crate::{
7    codec::{BufsMut, EncodeSize, Read, Write},
8    error::Error,
9    types::read_ordered_set,
10    RangeCfg,
11};
12use bytes::{Buf, BufMut};
13use std::{collections::HashSet, hash::Hash};
14
15const HASHSET_TYPE: &str = "HashSet";
16
17impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
18    fn write(&self, buf: &mut impl BufMut) {
19        self.len().write(buf);
20
21        // Sort the items to ensure deterministic encoding
22        let mut items: Vec<_> = self.iter().collect();
23        items.sort();
24        for item in items {
25            item.write(buf);
26        }
27    }
28
29    fn write_bufs(&self, buf: &mut impl BufsMut) {
30        self.len().write(buf);
31
32        // Sort the items to ensure deterministic encoding
33        let mut items: Vec<_> = self.iter().collect();
34        items.sort();
35        for item in items {
36            item.write_bufs(buf);
37        }
38    }
39}
40
41impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
42    fn encode_size(&self) -> usize {
43        let mut size = self.len().encode_size();
44
45        // Note: Iteration order doesn't matter for size calculation.
46        for item in self {
47            size += item.encode_size();
48        }
49        size
50    }
51
52    fn encode_inline_size(&self) -> usize {
53        let mut size = self.len().encode_size();
54
55        // Note: Iteration order doesn't matter for size calculation.
56        for item in self {
57            size += item.encode_inline_size();
58        }
59        size
60    }
61}
62
63impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
64    type Cfg = (RangeCfg<usize>, K::Cfg);
65
66    fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
67        // Read and validate the length prefix
68        let len = usize::read_cfg(buf, range)?;
69        let mut set = Self::with_capacity(len);
70
71        // Read items in ascending order
72        read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
73
74        Ok(set)
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::{
82        codec::{Decode, Encode},
83        FixedSize,
84    };
85    use bytes::{Bytes, BytesMut};
86    use std::fmt::Debug;
87
88    // Generic round trip test function for HashSet
89    fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg<usize>, item_cfg: K::Cfg)
90    where
91        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
92        HashSet<K>: Read<Cfg = (RangeCfg<usize>, K::Cfg)>
93            + Decode<Cfg = (RangeCfg<usize>, K::Cfg)>
94            + Debug
95            + PartialEq
96            + Write
97            + EncodeSize,
98    {
99        let encoded = set.encode();
100        assert_eq!(set.encode_size(), encoded.len());
101        let config_tuple = (range_cfg, item_cfg);
102        let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
103        assert_eq!(set, &decoded);
104    }
105
106    // --- HashSet Tests ---
107
108    #[test]
109    fn test_empty_hashset() {
110        let set = HashSet::<u32>::new();
111        round_trip_hash(&set, (..).into(), ());
112        assert_eq!(set.encode_size(), 1); // varint 0
113        let encoded = set.encode();
114        assert_eq!(encoded, Bytes::from_static(&[0]));
115    }
116
117    #[test]
118    fn test_simple_hashset_u32() {
119        let mut set = HashSet::new();
120        set.insert(1u32);
121        set.insert(5u32);
122        set.insert(2u32);
123        round_trip_hash(&set, (..).into(), ());
124        // Size calculation: varint len + size of each item (order doesn't matter for size)
125        assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
126        // Encoding check: items must be sorted (1, 2, 5)
127        let mut expected = BytesMut::new();
128        3usize.write(&mut expected); // Set length = 3
129        1u32.write(&mut expected);
130        2u32.write(&mut expected);
131        5u32.write(&mut expected);
132        assert_eq!(set.encode(), expected.freeze());
133    }
134
135    #[test]
136    fn test_large_hashset() {
137        // Fixed-size items
138        let set: HashSet<_> = (0..1000u16).collect();
139        round_trip_hash(&set, (1000..=1000).into(), ());
140
141        // Variable-size items
142        let set: HashSet<_> = (0..1000usize).collect();
143        round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
144    }
145
146    #[test]
147    fn test_hashset_with_variable_items() {
148        let mut set = HashSet::new();
149        set.insert(Bytes::from_static(b"apple"));
150        set.insert(Bytes::from_static(b"banana"));
151        set.insert(Bytes::from_static(b"cherry"));
152
153        let set_range = 0..=10;
154        let item_range = ..=10; // Range for Bytes length
155
156        round_trip_hash(&set, set_range.into(), item_range.into());
157    }
158
159    #[test]
160    fn test_hashset_decode_length_limit_exceeded() {
161        let mut set = HashSet::new();
162        set.insert(1u32);
163        set.insert(5u32);
164
165        let encoded = set.encode();
166        let config_tuple = ((0..=1).into(), ());
167
168        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
169        assert!(matches!(result, Err(Error::InvalidLength(2))));
170    }
171
172    #[test]
173    fn test_hashset_decode_item_length_limit_exceeded() {
174        let mut set = HashSet::new();
175        set.insert(Bytes::from_static(b"longitem")); // 8 bytes
176
177        let set_range = 0..=10;
178        let restrictive_item_range = ..=5; // Limit item length
179
180        let encoded = set.encode();
181        let config_tuple = (set_range.into(), restrictive_item_range.into());
182        let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
183
184        assert!(matches!(result, Err(Error::InvalidLength(8))));
185    }
186
187    #[test]
188    fn test_hashset_decode_invalid_item_order() {
189        let mut encoded = BytesMut::new();
190        2usize.write(&mut encoded); // Set length = 2
191        5u32.write(&mut encoded); // Item 5
192        2u32.write(&mut encoded); // Item 2 (out of order)
193
194        let config_tuple = ((..).into(), ());
195
196        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
197        assert!(matches!(
198            result,
199            Err(Error::Invalid("HashSet", "Items must ascend"))
200        ));
201    }
202
203    #[test]
204    fn test_hashset_decode_duplicate_item() {
205        let mut encoded = BytesMut::new();
206        2usize.write(&mut encoded); // Set length = 2
207        1u32.write(&mut encoded); // Item 1
208        1u32.write(&mut encoded); // Duplicate Item 1
209
210        let config_tuple = ((..).into(), ());
211        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
212        assert!(matches!(
213            result,
214            Err(Error::Invalid("HashSet", "Duplicate item"))
215        ));
216    }
217
218    #[test]
219    fn test_hashset_decode_end_of_buffer() {
220        let mut set = HashSet::new();
221        set.insert(1u32);
222        set.insert(5u32);
223
224        let mut encoded = set.encode(); // Will be sorted: [1, 5]
225        encoded.truncate(set.encode_size() - 2); // Truncate during last item (5)
226
227        let config_tuple = ((..).into(), ());
228        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
229        assert!(matches!(result, Err(Error::EndOfBuffer)));
230    }
231
232    #[test]
233    fn test_hashset_decode_extra_data() {
234        let mut set = HashSet::new();
235        set.insert(1u32);
236
237        let mut encoded = set.encode_mut();
238        encoded.put_u8(0xFF); // Add extra byte
239
240        // Use decode_cfg which enforces buffer is fully consumed
241        let config_tuple = ((..).into(), ()); // Clone range for read_cfg later
242        let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
243        assert!(matches!(result, Err(Error::ExtraData(1))));
244
245        // Verify that read_cfg would succeed (doesn't check for extra data)
246        let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
247        assert!(read_result.is_ok());
248        let decoded_set = read_result.unwrap();
249        assert_eq!(decoded_set.len(), 1);
250        assert!(decoded_set.contains(&1u32));
251    }
252
253    #[test]
254    fn test_hashset_deterministic_encoding() {
255        let mut set1 = HashSet::new();
256        (0..1000u32).for_each(|i| {
257            set1.insert(i);
258        });
259
260        let mut set2 = HashSet::new();
261        (0..1000u32).rev().for_each(|i| {
262            set2.insert(i);
263        });
264
265        assert_eq!(set1.encode(), set2.encode());
266    }
267
268    #[test]
269    fn test_hashset_conformity() {
270        // Case 1: Empty HashSet<u8>
271        let set1 = HashSet::<u8>::new();
272        let mut expected1 = BytesMut::new();
273        0usize.write(&mut expected1); // Length 0
274        assert_eq!(set1.encode(), expected1.freeze());
275        assert_eq!(set1.encode_size(), 1);
276
277        // Case 2: Simple HashSet<u8>
278        // HashSet will sort items for encoding: 1, 2, 5
279        let mut set2 = HashSet::<u8>::new();
280        set2.insert(5u8);
281        set2.insert(1u8);
282        set2.insert(2u8);
283
284        let mut expected2 = BytesMut::new();
285        3usize.write(&mut expected2); // Length 3
286        1u8.write(&mut expected2); // Item 1
287        2u8.write(&mut expected2); // Item 2
288        5u8.write(&mut expected2); // Item 5
289        assert_eq!(set2.encode(), expected2.freeze());
290        assert_eq!(set2.encode_size(), 1 + 3);
291
292        // Case 3: HashSet<Bytes>
293        // HashSet sorts items for encoding: "apple", "banana", "cherry"
294        let mut set3 = HashSet::<Bytes>::new();
295        set3.insert(Bytes::from_static(b"cherry"));
296        set3.insert(Bytes::from_static(b"apple"));
297        set3.insert(Bytes::from_static(b"banana"));
298
299        let mut expected3 = BytesMut::new();
300        3usize.write(&mut expected3); // Length 3
301        Bytes::from_static(b"apple").write(&mut expected3);
302        Bytes::from_static(b"banana").write(&mut expected3);
303        Bytes::from_static(b"cherry").write(&mut expected3);
304        assert_eq!(set3.encode(), expected3.freeze());
305        let expected_size = 1usize.encode_size()
306            + Bytes::from_static(b"apple").encode_size()
307            + Bytes::from_static(b"banana").encode_size()
308            + Bytes::from_static(b"cherry").encode_size();
309        assert_eq!(set3.encode_size(), expected_size);
310    }
311
312    #[cfg(feature = "arbitrary")]
313    mod conformance {
314        use super::*;
315        use crate::conformance::CodecConformance;
316
317        commonware_conformance::conformance_tests! {
318            CodecConformance<HashSet<u32>>
319        }
320    }
321}