commonware_codec/types/
set.rs

1//! Codec implementations for various set types.
2//!
3//! For portability and consistency between architectures,
4//! the size of the set must fit within a [u32].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{
13    cmp::Ordering,
14    collections::{BTreeSet, HashSet},
15    hash::Hash,
16};
17
18const BTREESET_TYPE: &str = "BTreeSet";
19const HASHSET_TYPE: &str = "HashSet";
20
21/// Read items from [Buf] in ascending order.
22fn read_ordered_set<K, F>(
23    buf: &mut impl Buf,
24    len: usize,
25    cfg: &K::Cfg,
26    mut insert: F,
27    set_type: &'static str,
28) -> Result<(), Error>
29where
30    K: Read + Ord,
31    F: FnMut(K) -> bool,
32{
33    let mut last: Option<K> = None;
34    for _ in 0..len {
35        // Read item
36        let item = K::read_cfg(buf, cfg)?;
37
38        // Check if items are in ascending order
39        if let Some(ref last) = last {
40            match item.cmp(last) {
41                Ordering::Equal => return Err(Error::Invalid(set_type, "Duplicate item")),
42                Ordering::Less => return Err(Error::Invalid(set_type, "Items must ascend")),
43                _ => {}
44            }
45        }
46
47        // Add previous item, if exists
48        if let Some(last) = last.take() {
49            insert(last);
50        }
51        last = Some(item);
52    }
53
54    // Add last item, if exists
55    if let Some(last) = last {
56        insert(last);
57    }
58
59    Ok(())
60}
61
62// ---------- BTreeSet ----------
63
64impl<K: Ord + Hash + Eq + Write> Write for BTreeSet<K> {
65    fn write(&self, buf: &mut impl BufMut) {
66        self.len().write(buf);
67
68        // Items are already sorted in BTreeSet, so we can iterate directly
69        for item in self {
70            item.write(buf);
71        }
72    }
73}
74
75impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for BTreeSet<K> {
76    fn encode_size(&self) -> usize {
77        let mut size = self.len().encode_size();
78        for item in self {
79            size += item.encode_size();
80        }
81        size
82    }
83}
84
85impl<K: Read + Clone + Ord + Hash + Eq> Read for BTreeSet<K> {
86    type Cfg = (RangeCfg, K::Cfg);
87
88    fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
89        // Read and validate the length prefix
90        let len = usize::read_cfg(buf, range)?;
91        let mut set = BTreeSet::new();
92
93        // Read items in ascending order
94        read_ordered_set(buf, len, cfg, |item| set.insert(item), BTREESET_TYPE)?;
95
96        Ok(set)
97    }
98}
99
100// ---------- HashSet ----------
101
102impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
103    fn write(&self, buf: &mut impl BufMut) {
104        self.len().write(buf);
105
106        // Sort the items to ensure deterministic encoding
107        let mut items: Vec<_> = self.iter().collect();
108        items.sort();
109        for item in items {
110            item.write(buf);
111        }
112    }
113}
114
115impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
116    fn encode_size(&self) -> usize {
117        let mut size = self.len().encode_size();
118
119        // Note: Iteration order doesn't matter for size calculation.
120        for item in self {
121            size += item.encode_size();
122        }
123        size
124    }
125}
126
127impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
128    type Cfg = (RangeCfg, K::Cfg);
129
130    fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
131        // Read and validate the length prefix
132        let len = usize::read_cfg(buf, range)?;
133        let mut set = HashSet::with_capacity(len);
134
135        // Read items in ascending order
136        read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
137
138        Ok(set)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::{
146        codec::{Decode, Encode},
147        FixedSize,
148    };
149    use bytes::{Bytes, BytesMut};
150    use std::{
151        collections::{BTreeSet, HashSet},
152        fmt::Debug,
153    };
154
155    // Generic round trip test function for BTreeSet
156    fn round_trip_btree<K>(set: &BTreeSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
157    where
158        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
159        BTreeSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
160            + Decode<Cfg = (RangeCfg, K::Cfg)>
161            + Debug
162            + PartialEq
163            + Write
164            + EncodeSize,
165    {
166        let encoded = set.encode();
167        assert_eq!(set.encode_size(), encoded.len());
168        let config_tuple = (range_cfg, item_cfg);
169        let decoded = BTreeSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
170        assert_eq!(set, &decoded);
171    }
172
173    // Generic round trip test function for HashSet
174    fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
175    where
176        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
177        HashSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
178            + Decode<Cfg = (RangeCfg, K::Cfg)>
179            + Debug
180            + PartialEq
181            + Write
182            + EncodeSize,
183    {
184        let encoded = set.encode();
185        assert_eq!(set.encode_size(), encoded.len());
186        let config_tuple = (range_cfg, item_cfg);
187        let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
188        assert_eq!(set, &decoded);
189    }
190
191    // --- BTreeSet Tests ---
192
193    #[test]
194    fn test_empty_btreeset() {
195        let set = BTreeSet::<u32>::new();
196        round_trip_btree(&set, (..).into(), ());
197        assert_eq!(set.encode_size(), 1); // varint 0
198        let encoded = set.encode();
199        assert_eq!(encoded, Bytes::from_static(&[0]));
200    }
201
202    #[test]
203    fn test_simple_btreeset_u32() {
204        let mut set = BTreeSet::new();
205        set.insert(1u32);
206        set.insert(5u32);
207        set.insert(2u32);
208        round_trip_btree(&set, (..).into(), ());
209        assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
210    }
211
212    #[test]
213    fn test_large_btreeset() {
214        // Fixed-size items
215        let set: BTreeSet<_> = (0..1000u16).collect();
216        round_trip_btree(&set, (1000..=1000).into(), ());
217
218        // Variable-size items
219        let set: BTreeSet<_> = (0..1000usize).collect();
220        round_trip_btree(&set, (1000..=1000).into(), (..=1000).into());
221    }
222
223    #[test]
224    fn test_btreeset_with_variable_items() {
225        let mut set = BTreeSet::new();
226        set.insert(Bytes::from_static(b"apple"));
227        set.insert(Bytes::from_static(b"banana"));
228        set.insert(Bytes::from_static(b"cherry"));
229
230        let set_range = 0..=10;
231        let item_range = ..=10; // Range for Bytes length
232
233        round_trip_btree(&set, set_range.into(), item_range.into());
234    }
235
236    #[test]
237    fn test_btreeset_decode_length_limit_exceeded() {
238        let mut set = BTreeSet::new();
239        set.insert(1u32);
240        set.insert(5u32);
241        let encoded = set.encode();
242
243        let config_tuple = ((0..=1).into(), ());
244        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
245        assert!(matches!(result, Err(Error::InvalidLength(2))));
246    }
247
248    #[test]
249    fn test_btreeset_decode_item_length_limit_exceeded() {
250        let mut set = BTreeSet::new();
251        set.insert(Bytes::from_static(b"longitem")); // 8 bytes
252        let encoded = set.encode();
253
254        let set_range = 0..=10;
255        let restrictive_item_range = ..=5; // Limit item length
256        let config_tuple = (set_range.into(), restrictive_item_range.into());
257        let result = BTreeSet::<Bytes>::decode_cfg(encoded, &config_tuple);
258
259        assert!(matches!(result, Err(Error::InvalidLength(8))));
260    }
261
262    #[test]
263    fn test_btreeset_decode_invalid_item_order() {
264        let mut encoded = BytesMut::new();
265        2usize.write(&mut encoded); // Set length = 2
266        5u32.write(&mut encoded); // Item 5
267        2u32.write(&mut encoded); // Item 2 (out of order)
268
269        let config_tuple = ((..).into(), ());
270        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
271        assert!(matches!(
272            result,
273            Err(Error::Invalid("BTreeSet", "Items must ascend")) // Note: Error message uses HashSet currently
274        ));
275    }
276
277    #[test]
278    fn test_btreeset_decode_duplicate_item() {
279        let mut encoded = BytesMut::new();
280        2usize.write(&mut encoded); // Set length = 2
281        1u32.write(&mut encoded); // Item 1
282        1u32.write(&mut encoded); // Duplicate Item 1
283
284        let config_tuple = ((..).into(), ());
285        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
286        assert!(matches!(
287            result,
288            Err(Error::Invalid("BTreeSet", "Duplicate item")) // Note: Error message uses HashSet currently
289        ));
290    }
291
292    #[test]
293    fn test_btreeset_decode_end_of_buffer() {
294        let mut set = BTreeSet::new();
295        set.insert(1u32);
296        set.insert(5u32);
297
298        let mut encoded = set.encode();
299        encoded.truncate(set.encode_size() - 2); // Truncate during last item
300
301        let config_tuple = ((..).into(), ());
302        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
303        assert!(matches!(result, Err(Error::EndOfBuffer)));
304    }
305
306    #[test]
307    fn test_btreeset_decode_extra_data() {
308        let mut set = BTreeSet::new();
309        set.insert(1u32);
310
311        let mut encoded = set.encode();
312        encoded.put_u8(0xFF); // Add extra byte
313
314        // Use decode_cfg which enforces buffer is fully consumed
315        let config_tuple = ((..).into(), ());
316        let result = BTreeSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
317        assert!(matches!(result, Err(Error::ExtraData(1))));
318
319        // Verify that read_cfg would succeed (doesn't check for extra data)
320        let read_result = BTreeSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
321        assert!(read_result.is_ok());
322        let decoded_set = read_result.unwrap();
323        assert_eq!(decoded_set.len(), 1);
324        assert!(decoded_set.contains(&1u32));
325    }
326
327    #[test]
328    fn test_btreeset_deterministic_encoding() {
329        let mut set1 = BTreeSet::new();
330        (0..1000u32).for_each(|i| {
331            set1.insert(i);
332        });
333
334        let mut set2 = BTreeSet::new();
335        (0..1000u32).rev().for_each(|i| {
336            set2.insert(i);
337        });
338
339        assert_eq!(set1.encode(), set2.encode());
340    }
341
342    #[test]
343    fn test_btreeset_conformity() {
344        // Case 1: Empty BTreeSet<u8>
345        let set1 = BTreeSet::<u8>::new();
346        let mut expected1 = BytesMut::new();
347        0usize.write(&mut expected1); // Length 0
348        assert_eq!(set1.encode(), expected1.freeze());
349        assert_eq!(set1.encode_size(), 1);
350
351        // Case 2: Simple BTreeSet<u8>
352        // BTreeSet will store and encode items in sorted order: 1, 2, 5
353        let mut set2 = BTreeSet::<u8>::new();
354        set2.insert(5u8);
355        set2.insert(1u8);
356        set2.insert(2u8);
357
358        let mut expected2 = BytesMut::new();
359        3usize.write(&mut expected2); // Length 3
360        1u8.write(&mut expected2); // Item 1
361        2u8.write(&mut expected2); // Item 2
362        5u8.write(&mut expected2); // Item 5
363        assert_eq!(set2.encode(), expected2.freeze());
364        assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
365
366        // Case 3: BTreeSet<Bytes>
367        // BTreeSet sorts items: "apple", "banana", "cherry"
368        let mut set3 = BTreeSet::<Bytes>::new();
369        set3.insert(Bytes::from_static(b"cherry"));
370        set3.insert(Bytes::from_static(b"apple"));
371        set3.insert(Bytes::from_static(b"banana"));
372
373        let mut expected3 = BytesMut::new();
374        3usize.write(&mut expected3); // Length 3
375        Bytes::from_static(b"apple").write(&mut expected3);
376        Bytes::from_static(b"banana").write(&mut expected3);
377        Bytes::from_static(b"cherry").write(&mut expected3);
378        assert_eq!(set3.encode(), expected3.freeze());
379        let expected_size = 1usize.encode_size()
380            + Bytes::from_static(b"apple").encode_size()
381            + Bytes::from_static(b"banana").encode_size()
382            + Bytes::from_static(b"cherry").encode_size();
383        assert_eq!(set3.encode_size(), expected_size);
384    }
385
386    // --- HashSet Tests ---
387
388    #[test]
389    fn test_empty_hashset() {
390        let set = HashSet::<u32>::new();
391        round_trip_hash(&set, (..).into(), ());
392        assert_eq!(set.encode_size(), 1); // varint 0
393        let encoded = set.encode();
394        assert_eq!(encoded, Bytes::from_static(&[0]));
395    }
396
397    #[test]
398    fn test_simple_hashset_u32() {
399        let mut set = HashSet::new();
400        set.insert(1u32);
401        set.insert(5u32);
402        set.insert(2u32);
403        round_trip_hash(&set, (..).into(), ());
404        // Size calculation: varint len + size of each item (order doesn't matter for size)
405        assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
406        // Encoding check: items must be sorted (1, 2, 5)
407        let mut expected = BytesMut::new();
408        3usize.write(&mut expected); // Set length = 3
409        1u32.write(&mut expected);
410        2u32.write(&mut expected);
411        5u32.write(&mut expected);
412        assert_eq!(set.encode(), expected.freeze());
413    }
414
415    #[test]
416    fn test_large_hashset() {
417        // Fixed-size items
418        let set: HashSet<_> = (0..1000u16).collect();
419        round_trip_hash(&set, (1000..=1000).into(), ());
420
421        // Variable-size items
422        let set: HashSet<_> = (0..1000usize).collect();
423        round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
424    }
425
426    #[test]
427    fn test_hashset_with_variable_items() {
428        let mut set = HashSet::new();
429        set.insert(Bytes::from_static(b"apple"));
430        set.insert(Bytes::from_static(b"banana"));
431        set.insert(Bytes::from_static(b"cherry"));
432
433        let set_range = 0..=10;
434        let item_range = ..=10; // Range for Bytes length
435
436        round_trip_hash(&set, set_range.into(), item_range.into());
437    }
438
439    #[test]
440    fn test_hashset_decode_length_limit_exceeded() {
441        let mut set = HashSet::new();
442        set.insert(1u32);
443        set.insert(5u32);
444
445        let encoded = set.encode();
446        let config_tuple = ((0..=1).into(), ());
447
448        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
449        assert!(matches!(result, Err(Error::InvalidLength(2))));
450    }
451
452    #[test]
453    fn test_hashset_decode_item_length_limit_exceeded() {
454        let mut set = HashSet::new();
455        set.insert(Bytes::from_static(b"longitem")); // 8 bytes
456
457        let set_range = 0..=10;
458        let restrictive_item_range = ..=5; // Limit item length
459
460        let encoded = set.encode();
461        let config_tuple = (set_range.into(), restrictive_item_range.into());
462        let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
463
464        assert!(matches!(result, Err(Error::InvalidLength(8))));
465    }
466
467    #[test]
468    fn test_hashset_decode_invalid_item_order() {
469        let mut encoded = BytesMut::new();
470        2usize.write(&mut encoded); // Set length = 2
471        5u32.write(&mut encoded); // Item 5
472        2u32.write(&mut encoded); // Item 2 (out of order)
473
474        let config_tuple = ((..).into(), ());
475
476        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
477        assert!(matches!(
478            result,
479            Err(Error::Invalid("HashSet", "Items must ascend"))
480        ));
481    }
482
483    #[test]
484    fn test_hashset_decode_duplicate_item() {
485        let mut encoded = BytesMut::new();
486        2usize.write(&mut encoded); // Set length = 2
487        1u32.write(&mut encoded); // Item 1
488        1u32.write(&mut encoded); // Duplicate Item 1
489
490        let config_tuple = ((..).into(), ());
491        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
492        assert!(matches!(
493            result,
494            Err(Error::Invalid("HashSet", "Duplicate item"))
495        ));
496    }
497
498    #[test]
499    fn test_hashset_decode_end_of_buffer() {
500        let mut set = HashSet::new();
501        set.insert(1u32);
502        set.insert(5u32);
503
504        let mut encoded = set.encode(); // Will be sorted: [1, 5]
505        encoded.truncate(set.encode_size() - 2); // Truncate during last item (5)
506
507        let config_tuple = ((..).into(), ());
508        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
509        assert!(matches!(result, Err(Error::EndOfBuffer)));
510    }
511
512    #[test]
513    fn test_hashset_decode_extra_data() {
514        let mut set = HashSet::new();
515        set.insert(1u32);
516
517        let mut encoded = set.encode();
518        encoded.put_u8(0xFF); // Add extra byte
519
520        // Use decode_cfg which enforces buffer is fully consumed
521        let config_tuple = ((..).into(), ()); // Clone range for read_cfg later
522        let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
523        assert!(matches!(result, Err(Error::ExtraData(1))));
524
525        // Verify that read_cfg would succeed (doesn't check for extra data)
526        let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
527        assert!(read_result.is_ok());
528        let decoded_set = read_result.unwrap();
529        assert_eq!(decoded_set.len(), 1);
530        assert!(decoded_set.contains(&1u32));
531    }
532
533    #[test]
534    fn test_hashset_deterministic_encoding() {
535        let mut set1 = HashSet::new();
536        (0..1000u32).for_each(|i| {
537            set1.insert(i);
538        });
539
540        let mut set2 = HashSet::new();
541        (0..1000u32).rev().for_each(|i| {
542            set2.insert(i);
543        });
544
545        assert_eq!(set1.encode(), set2.encode());
546    }
547
548    #[test]
549    fn test_hashset_conformity() {
550        // Case 1: Empty HashSet<u8>
551        let set1 = HashSet::<u8>::new();
552        let mut expected1 = BytesMut::new();
553        0usize.write(&mut expected1); // Length 0
554        assert_eq!(set1.encode(), expected1.freeze());
555        assert_eq!(set1.encode_size(), 1);
556
557        // Case 2: Simple HashSet<u8>
558        // HashSet will sort items for encoding: 1, 2, 5
559        let mut set2 = HashSet::<u8>::new();
560        set2.insert(5u8);
561        set2.insert(1u8);
562        set2.insert(2u8);
563
564        let mut expected2 = BytesMut::new();
565        3usize.write(&mut expected2); // Length 3
566        1u8.write(&mut expected2); // Item 1
567        2u8.write(&mut expected2); // Item 2
568        5u8.write(&mut expected2); // Item 5
569        assert_eq!(set2.encode(), expected2.freeze());
570        assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
571
572        // Case 3: HashSet<Bytes>
573        // HashSet sorts items for encoding: "apple", "banana", "cherry"
574        let mut set3 = HashSet::<Bytes>::new();
575        set3.insert(Bytes::from_static(b"cherry"));
576        set3.insert(Bytes::from_static(b"apple"));
577        set3.insert(Bytes::from_static(b"banana"));
578
579        let mut expected3 = BytesMut::new();
580        3usize.write(&mut expected3); // Length 3
581        Bytes::from_static(b"apple").write(&mut expected3);
582        Bytes::from_static(b"banana").write(&mut expected3);
583        Bytes::from_static(b"cherry").write(&mut expected3);
584        assert_eq!(set3.encode(), expected3.freeze());
585        let expected_size = 1usize.encode_size()
586            + Bytes::from_static(b"apple").encode_size()
587            + Bytes::from_static(b"banana").encode_size()
588            + Bytes::from_static(b"cherry").encode_size();
589        assert_eq!(set3.encode_size(), expected_size);
590    }
591}