1use crate::{
7 codec::{EncodeSize, Read, Write},
8 error::Error,
9 RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{
13 cmp::Ordering,
14 collections::{BTreeSet, HashSet},
15 hash::Hash,
16};
17
18const BTREESET_TYPE: &str = "BTreeSet";
19const HASHSET_TYPE: &str = "HashSet";
20
21fn read_ordered_set<K, F>(
23 buf: &mut impl Buf,
24 len: usize,
25 cfg: &K::Cfg,
26 mut insert: F,
27 set_type: &'static str,
28) -> Result<(), Error>
29where
30 K: Read + Ord,
31 F: FnMut(K) -> bool,
32{
33 let mut last: Option<K> = None;
34 for _ in 0..len {
35 let item = K::read_cfg(buf, cfg)?;
37
38 if let Some(ref last) = last {
40 match item.cmp(last) {
41 Ordering::Equal => return Err(Error::Invalid(set_type, "Duplicate item")),
42 Ordering::Less => return Err(Error::Invalid(set_type, "Items must ascend")),
43 _ => {}
44 }
45 }
46
47 if let Some(last) = last.take() {
49 insert(last);
50 }
51 last = Some(item);
52 }
53
54 if let Some(last) = last {
56 insert(last);
57 }
58
59 Ok(())
60}
61
62impl<K: Ord + Hash + Eq + Write> Write for BTreeSet<K> {
65 fn write(&self, buf: &mut impl BufMut) {
66 self.len().write(buf);
67
68 for item in self {
70 item.write(buf);
71 }
72 }
73}
74
75impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for BTreeSet<K> {
76 fn encode_size(&self) -> usize {
77 let mut size = self.len().encode_size();
78 for item in self {
79 size += item.encode_size();
80 }
81 size
82 }
83}
84
85impl<K: Read + Clone + Ord + Hash + Eq> Read for BTreeSet<K> {
86 type Cfg = (RangeCfg, K::Cfg);
87
88 fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
89 let len = usize::read_cfg(buf, range)?;
91 let mut set = BTreeSet::new();
92
93 read_ordered_set(buf, len, cfg, |item| set.insert(item), BTREESET_TYPE)?;
95
96 Ok(set)
97 }
98}
99
100impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
103 fn write(&self, buf: &mut impl BufMut) {
104 self.len().write(buf);
105
106 let mut items: Vec<_> = self.iter().collect();
108 items.sort();
109 for item in items {
110 item.write(buf);
111 }
112 }
113}
114
115impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
116 fn encode_size(&self) -> usize {
117 let mut size = self.len().encode_size();
118
119 for item in self {
121 size += item.encode_size();
122 }
123 size
124 }
125}
126
127impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
128 type Cfg = (RangeCfg, K::Cfg);
129
130 fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
131 let len = usize::read_cfg(buf, range)?;
133 let mut set = HashSet::with_capacity(len);
134
135 read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
137
138 Ok(set)
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::{
146 codec::{Decode, Encode},
147 FixedSize,
148 };
149 use bytes::{Bytes, BytesMut};
150 use std::collections::{BTreeSet, HashSet};
151 use std::fmt::Debug;
152
153 fn round_trip_btree<K>(set: &BTreeSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
155 where
156 K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
157 BTreeSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
158 + Decode<Cfg = (RangeCfg, K::Cfg)>
159 + Debug
160 + PartialEq
161 + Write
162 + EncodeSize,
163 {
164 let encoded = set.encode();
165 assert_eq!(set.encode_size(), encoded.len());
166 let config_tuple = (range_cfg, item_cfg);
167 let decoded = BTreeSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
168 assert_eq!(set, &decoded);
169 }
170
171 fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
173 where
174 K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
175 HashSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
176 + Decode<Cfg = (RangeCfg, K::Cfg)>
177 + Debug
178 + PartialEq
179 + Write
180 + EncodeSize,
181 {
182 let encoded = set.encode();
183 assert_eq!(set.encode_size(), encoded.len());
184 let config_tuple = (range_cfg, item_cfg);
185 let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
186 assert_eq!(set, &decoded);
187 }
188
189 #[test]
192 fn test_empty_btreeset() {
193 let set = BTreeSet::<u32>::new();
194 round_trip_btree(&set, (..).into(), ());
195 assert_eq!(set.encode_size(), 1); let encoded = set.encode();
197 assert_eq!(encoded, Bytes::from_static(&[0]));
198 }
199
200 #[test]
201 fn test_simple_btreeset_u32() {
202 let mut set = BTreeSet::new();
203 set.insert(1u32);
204 set.insert(5u32);
205 set.insert(2u32);
206 round_trip_btree(&set, (..).into(), ());
207 assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
208 }
209
210 #[test]
211 fn test_large_btreeset() {
212 let set: BTreeSet<_> = (0..1000u16).collect();
214 round_trip_btree(&set, (1000..=1000).into(), ());
215
216 let set: BTreeSet<_> = (0..1000usize).collect();
218 round_trip_btree(&set, (1000..=1000).into(), (..=1000).into());
219 }
220
221 #[test]
222 fn test_btreeset_with_variable_items() {
223 let mut set = BTreeSet::new();
224 set.insert(Bytes::from_static(b"apple"));
225 set.insert(Bytes::from_static(b"banana"));
226 set.insert(Bytes::from_static(b"cherry"));
227
228 let set_range = 0..=10;
229 let item_range = ..=10; round_trip_btree(&set, set_range.into(), item_range.into());
232 }
233
234 #[test]
235 fn test_btreeset_decode_length_limit_exceeded() {
236 let mut set = BTreeSet::new();
237 set.insert(1u32);
238 set.insert(5u32);
239 let encoded = set.encode();
240
241 let config_tuple = ((0..=1).into(), ());
242 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
243 assert!(matches!(result, Err(Error::InvalidLength(2))));
244 }
245
246 #[test]
247 fn test_btreeset_decode_item_length_limit_exceeded() {
248 let mut set = BTreeSet::new();
249 set.insert(Bytes::from_static(b"longitem")); let encoded = set.encode();
251
252 let set_range = 0..=10;
253 let restrictive_item_range = ..=5; let config_tuple = (set_range.into(), restrictive_item_range.into());
255 let result = BTreeSet::<Bytes>::decode_cfg(encoded, &config_tuple);
256
257 assert!(matches!(result, Err(Error::InvalidLength(8))));
258 }
259
260 #[test]
261 fn test_btreeset_decode_invalid_item_order() {
262 let mut encoded = BytesMut::new();
263 2usize.write(&mut encoded); 5u32.write(&mut encoded); 2u32.write(&mut encoded); let config_tuple = ((..).into(), ());
268 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
269 assert!(matches!(
270 result,
271 Err(Error::Invalid("BTreeSet", "Items must ascend")) ));
273 }
274
275 #[test]
276 fn test_btreeset_decode_duplicate_item() {
277 let mut encoded = BytesMut::new();
278 2usize.write(&mut encoded); 1u32.write(&mut encoded); 1u32.write(&mut encoded); let config_tuple = ((..).into(), ());
283 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
284 assert!(matches!(
285 result,
286 Err(Error::Invalid("BTreeSet", "Duplicate item")) ));
288 }
289
290 #[test]
291 fn test_btreeset_decode_end_of_buffer() {
292 let mut set = BTreeSet::new();
293 set.insert(1u32);
294 set.insert(5u32);
295
296 let mut encoded = set.encode();
297 encoded.truncate(set.encode_size() - 2); let config_tuple = ((..).into(), ());
300 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
301 assert!(matches!(result, Err(Error::EndOfBuffer)));
302 }
303
304 #[test]
305 fn test_btreeset_decode_extra_data() {
306 let mut set = BTreeSet::new();
307 set.insert(1u32);
308
309 let mut encoded = set.encode();
310 encoded.put_u8(0xFF); let config_tuple = ((..).into(), ());
314 let result = BTreeSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
315 assert!(matches!(result, Err(Error::ExtraData(1))));
316
317 let read_result = BTreeSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
319 assert!(read_result.is_ok());
320 let decoded_set = read_result.unwrap();
321 assert_eq!(decoded_set.len(), 1);
322 assert!(decoded_set.contains(&1u32));
323 }
324
325 #[test]
326 fn test_btreeset_deterministic_encoding() {
327 let mut set1 = BTreeSet::new();
328 (0..1000u32).for_each(|i| {
329 set1.insert(i);
330 });
331
332 let mut set2 = BTreeSet::new();
333 (0..1000u32).rev().for_each(|i| {
334 set2.insert(i);
335 });
336
337 assert_eq!(set1.encode(), set2.encode());
338 }
339
340 #[test]
341 fn test_btreeset_conformity() {
342 let set1 = BTreeSet::<u8>::new();
344 let mut expected1 = BytesMut::new();
345 0usize.write(&mut expected1); assert_eq!(set1.encode(), expected1.freeze());
347 assert_eq!(set1.encode_size(), 1);
348
349 let mut set2 = BTreeSet::<u8>::new();
352 set2.insert(5u8);
353 set2.insert(1u8);
354 set2.insert(2u8);
355
356 let mut expected2 = BytesMut::new();
357 3usize.write(&mut expected2); 1u8.write(&mut expected2); 2u8.write(&mut expected2); 5u8.write(&mut expected2); assert_eq!(set2.encode(), expected2.freeze());
362 assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
363
364 let mut set3 = BTreeSet::<Bytes>::new();
367 set3.insert(Bytes::from_static(b"cherry"));
368 set3.insert(Bytes::from_static(b"apple"));
369 set3.insert(Bytes::from_static(b"banana"));
370
371 let mut expected3 = BytesMut::new();
372 3usize.write(&mut expected3); Bytes::from_static(b"apple").write(&mut expected3);
374 Bytes::from_static(b"banana").write(&mut expected3);
375 Bytes::from_static(b"cherry").write(&mut expected3);
376 assert_eq!(set3.encode(), expected3.freeze());
377 let expected_size = 1usize.encode_size()
378 + Bytes::from_static(b"apple").encode_size()
379 + Bytes::from_static(b"banana").encode_size()
380 + Bytes::from_static(b"cherry").encode_size();
381 assert_eq!(set3.encode_size(), expected_size);
382 }
383
384 #[test]
387 fn test_empty_hashset() {
388 let set = HashSet::<u32>::new();
389 round_trip_hash(&set, (..).into(), ());
390 assert_eq!(set.encode_size(), 1); let encoded = set.encode();
392 assert_eq!(encoded, Bytes::from_static(&[0]));
393 }
394
395 #[test]
396 fn test_simple_hashset_u32() {
397 let mut set = HashSet::new();
398 set.insert(1u32);
399 set.insert(5u32);
400 set.insert(2u32);
401 round_trip_hash(&set, (..).into(), ());
402 assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
404 let mut expected = BytesMut::new();
406 3usize.write(&mut expected); 1u32.write(&mut expected);
408 2u32.write(&mut expected);
409 5u32.write(&mut expected);
410 assert_eq!(set.encode(), expected.freeze());
411 }
412
413 #[test]
414 fn test_large_hashset() {
415 let set: HashSet<_> = (0..1000u16).collect();
417 round_trip_hash(&set, (1000..=1000).into(), ());
418
419 let set: HashSet<_> = (0..1000usize).collect();
421 round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
422 }
423
424 #[test]
425 fn test_hashset_with_variable_items() {
426 let mut set = HashSet::new();
427 set.insert(Bytes::from_static(b"apple"));
428 set.insert(Bytes::from_static(b"banana"));
429 set.insert(Bytes::from_static(b"cherry"));
430
431 let set_range = 0..=10;
432 let item_range = ..=10; round_trip_hash(&set, set_range.into(), item_range.into());
435 }
436
437 #[test]
438 fn test_hashset_decode_length_limit_exceeded() {
439 let mut set = HashSet::new();
440 set.insert(1u32);
441 set.insert(5u32);
442
443 let encoded = set.encode();
444 let config_tuple = ((0..=1).into(), ());
445
446 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
447 assert!(matches!(result, Err(Error::InvalidLength(2))));
448 }
449
450 #[test]
451 fn test_hashset_decode_item_length_limit_exceeded() {
452 let mut set = HashSet::new();
453 set.insert(Bytes::from_static(b"longitem")); let set_range = 0..=10;
456 let restrictive_item_range = ..=5; let encoded = set.encode();
459 let config_tuple = (set_range.into(), restrictive_item_range.into());
460 let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
461
462 assert!(matches!(result, Err(Error::InvalidLength(8))));
463 }
464
465 #[test]
466 fn test_hashset_decode_invalid_item_order() {
467 let mut encoded = BytesMut::new();
468 2usize.write(&mut encoded); 5u32.write(&mut encoded); 2u32.write(&mut encoded); let config_tuple = ((..).into(), ());
473
474 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
475 assert!(matches!(
476 result,
477 Err(Error::Invalid("HashSet", "Items must ascend"))
478 ));
479 }
480
481 #[test]
482 fn test_hashset_decode_duplicate_item() {
483 let mut encoded = BytesMut::new();
484 2usize.write(&mut encoded); 1u32.write(&mut encoded); 1u32.write(&mut encoded); let config_tuple = ((..).into(), ());
489 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
490 assert!(matches!(
491 result,
492 Err(Error::Invalid("HashSet", "Duplicate item"))
493 ));
494 }
495
496 #[test]
497 fn test_hashset_decode_end_of_buffer() {
498 let mut set = HashSet::new();
499 set.insert(1u32);
500 set.insert(5u32);
501
502 let mut encoded = set.encode(); encoded.truncate(set.encode_size() - 2); let config_tuple = ((..).into(), ());
506 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
507 assert!(matches!(result, Err(Error::EndOfBuffer)));
508 }
509
510 #[test]
511 fn test_hashset_decode_extra_data() {
512 let mut set = HashSet::new();
513 set.insert(1u32);
514
515 let mut encoded = set.encode();
516 encoded.put_u8(0xFF); let config_tuple = ((..).into(), ()); let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
521 assert!(matches!(result, Err(Error::ExtraData(1))));
522
523 let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
525 assert!(read_result.is_ok());
526 let decoded_set = read_result.unwrap();
527 assert_eq!(decoded_set.len(), 1);
528 assert!(decoded_set.contains(&1u32));
529 }
530
531 #[test]
532 fn test_hashset_deterministic_encoding() {
533 let mut set1 = HashSet::new();
534 (0..1000u32).for_each(|i| {
535 set1.insert(i);
536 });
537
538 let mut set2 = HashSet::new();
539 (0..1000u32).rev().for_each(|i| {
540 set2.insert(i);
541 });
542
543 assert_eq!(set1.encode(), set2.encode());
544 }
545
546 #[test]
547 fn test_hashset_conformity() {
548 let set1 = HashSet::<u8>::new();
550 let mut expected1 = BytesMut::new();
551 0usize.write(&mut expected1); assert_eq!(set1.encode(), expected1.freeze());
553 assert_eq!(set1.encode_size(), 1);
554
555 let mut set2 = HashSet::<u8>::new();
558 set2.insert(5u8);
559 set2.insert(1u8);
560 set2.insert(2u8);
561
562 let mut expected2 = BytesMut::new();
563 3usize.write(&mut expected2); 1u8.write(&mut expected2); 2u8.write(&mut expected2); 5u8.write(&mut expected2); assert_eq!(set2.encode(), expected2.freeze());
568 assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
569
570 let mut set3 = HashSet::<Bytes>::new();
573 set3.insert(Bytes::from_static(b"cherry"));
574 set3.insert(Bytes::from_static(b"apple"));
575 set3.insert(Bytes::from_static(b"banana"));
576
577 let mut expected3 = BytesMut::new();
578 3usize.write(&mut expected3); Bytes::from_static(b"apple").write(&mut expected3);
580 Bytes::from_static(b"banana").write(&mut expected3);
581 Bytes::from_static(b"cherry").write(&mut expected3);
582 assert_eq!(set3.encode(), expected3.freeze());
583 let expected_size = 1usize.encode_size()
584 + Bytes::from_static(b"apple").encode_size()
585 + Bytes::from_static(b"banana").encode_size()
586 + Bytes::from_static(b"cherry").encode_size();
587 assert_eq!(set3.encode_size(), expected_size);
588 }
589}