commonware_codec/types/
set.rs

1//! Codec implementations for various set types.
2//!
3//! For portability and consistency between architectures,
4//! the size of the set must fit within a [`u32`].
5
6use crate::{
7    codec::{EncodeSize, Read, Write},
8    error::Error,
9    RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{
13    cmp::Ordering,
14    collections::{BTreeSet, HashSet},
15    hash::Hash,
16};
17
18const BTREESET_TYPE: &str = "BTreeSet";
19const HASHSET_TYPE: &str = "HashSet";
20
21/// Read items from [Buf] in ascending order.
22fn read_ordered_set<K, F>(
23    buf: &mut impl Buf,
24    len: usize,
25    cfg: &K::Cfg,
26    mut insert: F,
27    set_type: &'static str,
28) -> Result<(), Error>
29where
30    K: Read + Ord,
31    F: FnMut(K) -> bool,
32{
33    let mut last: Option<K> = None;
34    for _ in 0..len {
35        // Read item
36        let item = K::read_cfg(buf, cfg)?;
37
38        // Check if items are in ascending order
39        if let Some(ref last) = last {
40            match item.cmp(last) {
41                Ordering::Equal => return Err(Error::Invalid(set_type, "Duplicate item")),
42                Ordering::Less => return Err(Error::Invalid(set_type, "Items must ascend")),
43                _ => {}
44            }
45        }
46
47        // Add previous item, if exists
48        if let Some(last) = last.take() {
49            insert(last);
50        }
51        last = Some(item);
52    }
53
54    // Add last item, if exists
55    if let Some(last) = last {
56        insert(last);
57    }
58
59    Ok(())
60}
61
62// ---------- BTreeSet ----------
63
64impl<K: Ord + Hash + Eq + Write> Write for BTreeSet<K> {
65    fn write(&self, buf: &mut impl BufMut) {
66        self.len().write(buf);
67
68        // Items are already sorted in BTreeSet, so we can iterate directly
69        for item in self {
70            item.write(buf);
71        }
72    }
73}
74
75impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for BTreeSet<K> {
76    fn encode_size(&self) -> usize {
77        let mut size = self.len().encode_size();
78        for item in self {
79            size += item.encode_size();
80        }
81        size
82    }
83}
84
85impl<K: Read + Clone + Ord + Hash + Eq> Read for BTreeSet<K> {
86    type Cfg = (RangeCfg, K::Cfg);
87
88    fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
89        // Read and validate the length prefix
90        let len = usize::read_cfg(buf, range)?;
91        let mut set = BTreeSet::new();
92
93        // Read items in ascending order
94        read_ordered_set(buf, len, cfg, |item| set.insert(item), BTREESET_TYPE)?;
95
96        Ok(set)
97    }
98}
99
100// ---------- HashSet ----------
101
102impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
103    fn write(&self, buf: &mut impl BufMut) {
104        self.len().write(buf);
105
106        // Sort the items to ensure deterministic encoding
107        let mut items: Vec<_> = self.iter().collect();
108        items.sort();
109        for item in items {
110            item.write(buf);
111        }
112    }
113}
114
115impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
116    fn encode_size(&self) -> usize {
117        let mut size = self.len().encode_size();
118
119        // Note: Iteration order doesn't matter for size calculation.
120        for item in self {
121            size += item.encode_size();
122        }
123        size
124    }
125}
126
127impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
128    type Cfg = (RangeCfg, K::Cfg);
129
130    fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
131        // Read and validate the length prefix
132        let len = usize::read_cfg(buf, range)?;
133        let mut set = HashSet::with_capacity(len);
134
135        // Read items in ascending order
136        read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
137
138        Ok(set)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::{
146        codec::{Decode, Encode},
147        FixedSize,
148    };
149    use bytes::{Bytes, BytesMut};
150    use std::collections::{BTreeSet, HashSet};
151    use std::fmt::Debug;
152
153    // Generic round trip test function for BTreeSet
154    fn round_trip_btree<K>(set: &BTreeSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
155    where
156        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
157        BTreeSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
158            + Decode<Cfg = (RangeCfg, K::Cfg)>
159            + Debug
160            + PartialEq
161            + Write
162            + EncodeSize,
163    {
164        let encoded = set.encode();
165        assert_eq!(set.encode_size(), encoded.len());
166        let config_tuple = (range_cfg, item_cfg);
167        let decoded = BTreeSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
168        assert_eq!(set, &decoded);
169    }
170
171    // Generic round trip test function for HashSet
172    fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
173    where
174        K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
175        HashSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
176            + Decode<Cfg = (RangeCfg, K::Cfg)>
177            + Debug
178            + PartialEq
179            + Write
180            + EncodeSize,
181    {
182        let encoded = set.encode();
183        assert_eq!(set.encode_size(), encoded.len());
184        let config_tuple = (range_cfg, item_cfg);
185        let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
186        assert_eq!(set, &decoded);
187    }
188
189    // --- BTreeSet Tests ---
190
191    #[test]
192    fn test_empty_btreeset() {
193        let set = BTreeSet::<u32>::new();
194        round_trip_btree(&set, (..).into(), ());
195        assert_eq!(set.encode_size(), 1); // varint 0
196        let encoded = set.encode();
197        assert_eq!(encoded, Bytes::from_static(&[0]));
198    }
199
200    #[test]
201    fn test_simple_btreeset_u32() {
202        let mut set = BTreeSet::new();
203        set.insert(1u32);
204        set.insert(5u32);
205        set.insert(2u32);
206        round_trip_btree(&set, (..).into(), ());
207        assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
208    }
209
210    #[test]
211    fn test_large_btreeset() {
212        // Fixed-size items
213        let set: BTreeSet<_> = (0..1000u16).collect();
214        round_trip_btree(&set, (1000..=1000).into(), ());
215
216        // Variable-size items
217        let set: BTreeSet<_> = (0..1000usize).collect();
218        round_trip_btree(&set, (1000..=1000).into(), (..=1000).into());
219    }
220
221    #[test]
222    fn test_btreeset_with_variable_items() {
223        let mut set = BTreeSet::new();
224        set.insert(Bytes::from_static(b"apple"));
225        set.insert(Bytes::from_static(b"banana"));
226        set.insert(Bytes::from_static(b"cherry"));
227
228        let set_range = 0..=10;
229        let item_range = ..=10; // Range for Bytes length
230
231        round_trip_btree(&set, set_range.into(), item_range.into());
232    }
233
234    #[test]
235    fn test_btreeset_decode_length_limit_exceeded() {
236        let mut set = BTreeSet::new();
237        set.insert(1u32);
238        set.insert(5u32);
239        let encoded = set.encode();
240
241        let config_tuple = ((0..=1).into(), ());
242        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
243        assert!(matches!(result, Err(Error::InvalidLength(2))));
244    }
245
246    #[test]
247    fn test_btreeset_decode_item_length_limit_exceeded() {
248        let mut set = BTreeSet::new();
249        set.insert(Bytes::from_static(b"longitem")); // 8 bytes
250        let encoded = set.encode();
251
252        let set_range = 0..=10;
253        let restrictive_item_range = ..=5; // Limit item length
254        let config_tuple = (set_range.into(), restrictive_item_range.into());
255        let result = BTreeSet::<Bytes>::decode_cfg(encoded, &config_tuple);
256
257        assert!(matches!(result, Err(Error::InvalidLength(8))));
258    }
259
260    #[test]
261    fn test_btreeset_decode_invalid_item_order() {
262        let mut encoded = BytesMut::new();
263        2usize.write(&mut encoded); // Set length = 2
264        5u32.write(&mut encoded); // Item 5
265        2u32.write(&mut encoded); // Item 2 (out of order)
266
267        let config_tuple = ((..).into(), ());
268        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
269        assert!(matches!(
270            result,
271            Err(Error::Invalid("BTreeSet", "Items must ascend")) // Note: Error message uses HashSet currently
272        ));
273    }
274
275    #[test]
276    fn test_btreeset_decode_duplicate_item() {
277        let mut encoded = BytesMut::new();
278        2usize.write(&mut encoded); // Set length = 2
279        1u32.write(&mut encoded); // Item 1
280        1u32.write(&mut encoded); // Duplicate Item 1
281
282        let config_tuple = ((..).into(), ());
283        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
284        assert!(matches!(
285            result,
286            Err(Error::Invalid("BTreeSet", "Duplicate item")) // Note: Error message uses HashSet currently
287        ));
288    }
289
290    #[test]
291    fn test_btreeset_decode_end_of_buffer() {
292        let mut set = BTreeSet::new();
293        set.insert(1u32);
294        set.insert(5u32);
295
296        let mut encoded = set.encode();
297        encoded.truncate(set.encode_size() - 2); // Truncate during last item
298
299        let config_tuple = ((..).into(), ());
300        let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
301        assert!(matches!(result, Err(Error::EndOfBuffer)));
302    }
303
304    #[test]
305    fn test_btreeset_decode_extra_data() {
306        let mut set = BTreeSet::new();
307        set.insert(1u32);
308
309        let mut encoded = set.encode();
310        encoded.put_u8(0xFF); // Add extra byte
311
312        // Use decode_cfg which enforces buffer is fully consumed
313        let config_tuple = ((..).into(), ());
314        let result = BTreeSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
315        assert!(matches!(result, Err(Error::ExtraData(1))));
316
317        // Verify that read_cfg would succeed (doesn't check for extra data)
318        let read_result = BTreeSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
319        assert!(read_result.is_ok());
320        let decoded_set = read_result.unwrap();
321        assert_eq!(decoded_set.len(), 1);
322        assert!(decoded_set.contains(&1u32));
323    }
324
325    #[test]
326    fn test_btreeset_deterministic_encoding() {
327        let mut set1 = BTreeSet::new();
328        (0..1000u32).for_each(|i| {
329            set1.insert(i);
330        });
331
332        let mut set2 = BTreeSet::new();
333        (0..1000u32).rev().for_each(|i| {
334            set2.insert(i);
335        });
336
337        assert_eq!(set1.encode(), set2.encode());
338    }
339
340    #[test]
341    fn test_btreeset_conformity() {
342        // Case 1: Empty BTreeSet<u8>
343        let set1 = BTreeSet::<u8>::new();
344        let mut expected1 = BytesMut::new();
345        0usize.write(&mut expected1); // Length 0
346        assert_eq!(set1.encode(), expected1.freeze());
347        assert_eq!(set1.encode_size(), 1);
348
349        // Case 2: Simple BTreeSet<u8>
350        // BTreeSet will store and encode items in sorted order: 1, 2, 5
351        let mut set2 = BTreeSet::<u8>::new();
352        set2.insert(5u8);
353        set2.insert(1u8);
354        set2.insert(2u8);
355
356        let mut expected2 = BytesMut::new();
357        3usize.write(&mut expected2); // Length 3
358        1u8.write(&mut expected2); // Item 1
359        2u8.write(&mut expected2); // Item 2
360        5u8.write(&mut expected2); // Item 5
361        assert_eq!(set2.encode(), expected2.freeze());
362        assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
363
364        // Case 3: BTreeSet<Bytes>
365        // BTreeSet sorts items: "apple", "banana", "cherry"
366        let mut set3 = BTreeSet::<Bytes>::new();
367        set3.insert(Bytes::from_static(b"cherry"));
368        set3.insert(Bytes::from_static(b"apple"));
369        set3.insert(Bytes::from_static(b"banana"));
370
371        let mut expected3 = BytesMut::new();
372        3usize.write(&mut expected3); // Length 3
373        Bytes::from_static(b"apple").write(&mut expected3);
374        Bytes::from_static(b"banana").write(&mut expected3);
375        Bytes::from_static(b"cherry").write(&mut expected3);
376        assert_eq!(set3.encode(), expected3.freeze());
377        let expected_size = 1usize.encode_size()
378            + Bytes::from_static(b"apple").encode_size()
379            + Bytes::from_static(b"banana").encode_size()
380            + Bytes::from_static(b"cherry").encode_size();
381        assert_eq!(set3.encode_size(), expected_size);
382    }
383
384    // --- HashSet Tests ---
385
386    #[test]
387    fn test_empty_hashset() {
388        let set = HashSet::<u32>::new();
389        round_trip_hash(&set, (..).into(), ());
390        assert_eq!(set.encode_size(), 1); // varint 0
391        let encoded = set.encode();
392        assert_eq!(encoded, Bytes::from_static(&[0]));
393    }
394
395    #[test]
396    fn test_simple_hashset_u32() {
397        let mut set = HashSet::new();
398        set.insert(1u32);
399        set.insert(5u32);
400        set.insert(2u32);
401        round_trip_hash(&set, (..).into(), ());
402        // Size calculation: varint len + size of each item (order doesn't matter for size)
403        assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
404        // Encoding check: items must be sorted (1, 2, 5)
405        let mut expected = BytesMut::new();
406        3usize.write(&mut expected); // Set length = 3
407        1u32.write(&mut expected);
408        2u32.write(&mut expected);
409        5u32.write(&mut expected);
410        assert_eq!(set.encode(), expected.freeze());
411    }
412
413    #[test]
414    fn test_large_hashset() {
415        // Fixed-size items
416        let set: HashSet<_> = (0..1000u16).collect();
417        round_trip_hash(&set, (1000..=1000).into(), ());
418
419        // Variable-size items
420        let set: HashSet<_> = (0..1000usize).collect();
421        round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
422    }
423
424    #[test]
425    fn test_hashset_with_variable_items() {
426        let mut set = HashSet::new();
427        set.insert(Bytes::from_static(b"apple"));
428        set.insert(Bytes::from_static(b"banana"));
429        set.insert(Bytes::from_static(b"cherry"));
430
431        let set_range = 0..=10;
432        let item_range = ..=10; // Range for Bytes length
433
434        round_trip_hash(&set, set_range.into(), item_range.into());
435    }
436
437    #[test]
438    fn test_hashset_decode_length_limit_exceeded() {
439        let mut set = HashSet::new();
440        set.insert(1u32);
441        set.insert(5u32);
442
443        let encoded = set.encode();
444        let config_tuple = ((0..=1).into(), ());
445
446        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
447        assert!(matches!(result, Err(Error::InvalidLength(2))));
448    }
449
450    #[test]
451    fn test_hashset_decode_item_length_limit_exceeded() {
452        let mut set = HashSet::new();
453        set.insert(Bytes::from_static(b"longitem")); // 8 bytes
454
455        let set_range = 0..=10;
456        let restrictive_item_range = ..=5; // Limit item length
457
458        let encoded = set.encode();
459        let config_tuple = (set_range.into(), restrictive_item_range.into());
460        let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
461
462        assert!(matches!(result, Err(Error::InvalidLength(8))));
463    }
464
465    #[test]
466    fn test_hashset_decode_invalid_item_order() {
467        let mut encoded = BytesMut::new();
468        2usize.write(&mut encoded); // Set length = 2
469        5u32.write(&mut encoded); // Item 5
470        2u32.write(&mut encoded); // Item 2 (out of order)
471
472        let config_tuple = ((..).into(), ());
473
474        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
475        assert!(matches!(
476            result,
477            Err(Error::Invalid("HashSet", "Items must ascend"))
478        ));
479    }
480
481    #[test]
482    fn test_hashset_decode_duplicate_item() {
483        let mut encoded = BytesMut::new();
484        2usize.write(&mut encoded); // Set length = 2
485        1u32.write(&mut encoded); // Item 1
486        1u32.write(&mut encoded); // Duplicate Item 1
487
488        let config_tuple = ((..).into(), ());
489        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
490        assert!(matches!(
491            result,
492            Err(Error::Invalid("HashSet", "Duplicate item"))
493        ));
494    }
495
496    #[test]
497    fn test_hashset_decode_end_of_buffer() {
498        let mut set = HashSet::new();
499        set.insert(1u32);
500        set.insert(5u32);
501
502        let mut encoded = set.encode(); // Will be sorted: [1, 5]
503        encoded.truncate(set.encode_size() - 2); // Truncate during last item (5)
504
505        let config_tuple = ((..).into(), ());
506        let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
507        assert!(matches!(result, Err(Error::EndOfBuffer)));
508    }
509
510    #[test]
511    fn test_hashset_decode_extra_data() {
512        let mut set = HashSet::new();
513        set.insert(1u32);
514
515        let mut encoded = set.encode();
516        encoded.put_u8(0xFF); // Add extra byte
517
518        // Use decode_cfg which enforces buffer is fully consumed
519        let config_tuple = ((..).into(), ()); // Clone range for read_cfg later
520        let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
521        assert!(matches!(result, Err(Error::ExtraData(1))));
522
523        // Verify that read_cfg would succeed (doesn't check for extra data)
524        let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
525        assert!(read_result.is_ok());
526        let decoded_set = read_result.unwrap();
527        assert_eq!(decoded_set.len(), 1);
528        assert!(decoded_set.contains(&1u32));
529    }
530
531    #[test]
532    fn test_hashset_deterministic_encoding() {
533        let mut set1 = HashSet::new();
534        (0..1000u32).for_each(|i| {
535            set1.insert(i);
536        });
537
538        let mut set2 = HashSet::new();
539        (0..1000u32).rev().for_each(|i| {
540            set2.insert(i);
541        });
542
543        assert_eq!(set1.encode(), set2.encode());
544    }
545
546    #[test]
547    fn test_hashset_conformity() {
548        // Case 1: Empty HashSet<u8>
549        let set1 = HashSet::<u8>::new();
550        let mut expected1 = BytesMut::new();
551        0usize.write(&mut expected1); // Length 0
552        assert_eq!(set1.encode(), expected1.freeze());
553        assert_eq!(set1.encode_size(), 1);
554
555        // Case 2: Simple HashSet<u8>
556        // HashSet will sort items for encoding: 1, 2, 5
557        let mut set2 = HashSet::<u8>::new();
558        set2.insert(5u8);
559        set2.insert(1u8);
560        set2.insert(2u8);
561
562        let mut expected2 = BytesMut::new();
563        3usize.write(&mut expected2); // Length 3
564        1u8.write(&mut expected2); // Item 1
565        2u8.write(&mut expected2); // Item 2
566        5u8.write(&mut expected2); // Item 5
567        assert_eq!(set2.encode(), expected2.freeze());
568        assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
569
570        // Case 3: HashSet<Bytes>
571        // HashSet sorts items for encoding: "apple", "banana", "cherry"
572        let mut set3 = HashSet::<Bytes>::new();
573        set3.insert(Bytes::from_static(b"cherry"));
574        set3.insert(Bytes::from_static(b"apple"));
575        set3.insert(Bytes::from_static(b"banana"));
576
577        let mut expected3 = BytesMut::new();
578        3usize.write(&mut expected3); // Length 3
579        Bytes::from_static(b"apple").write(&mut expected3);
580        Bytes::from_static(b"banana").write(&mut expected3);
581        Bytes::from_static(b"cherry").write(&mut expected3);
582        assert_eq!(set3.encode(), expected3.freeze());
583        let expected_size = 1usize.encode_size()
584            + Bytes::from_static(b"apple").encode_size()
585            + Bytes::from_static(b"banana").encode_size()
586            + Bytes::from_static(b"cherry").encode_size();
587        assert_eq!(set3.encode_size(), expected_size);
588    }
589}