1use crate::{
7 codec::{EncodeSize, Read, Write},
8 error::Error,
9 RangeCfg,
10};
11use bytes::{Buf, BufMut};
12use std::{
13 cmp::Ordering,
14 collections::{BTreeSet, HashSet},
15 hash::Hash,
16};
17
18const BTREESET_TYPE: &str = "BTreeSet";
19const HASHSET_TYPE: &str = "HashSet";
20
21fn read_ordered_set<K, F>(
23 buf: &mut impl Buf,
24 len: usize,
25 cfg: &K::Cfg,
26 mut insert: F,
27 set_type: &'static str,
28) -> Result<(), Error>
29where
30 K: Read + Ord,
31 F: FnMut(K) -> bool,
32{
33 let mut last: Option<K> = None;
34 for _ in 0..len {
35 let item = K::read_cfg(buf, cfg)?;
37
38 if let Some(ref last) = last {
40 match item.cmp(last) {
41 Ordering::Equal => return Err(Error::Invalid(set_type, "Duplicate item")),
42 Ordering::Less => return Err(Error::Invalid(set_type, "Items must ascend")),
43 _ => {}
44 }
45 }
46
47 if let Some(last) = last.take() {
49 insert(last);
50 }
51 last = Some(item);
52 }
53
54 if let Some(last) = last {
56 insert(last);
57 }
58
59 Ok(())
60}
61
62impl<K: Ord + Hash + Eq + Write> Write for BTreeSet<K> {
65 fn write(&self, buf: &mut impl BufMut) {
66 self.len().write(buf);
67
68 for item in self {
70 item.write(buf);
71 }
72 }
73}
74
75impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for BTreeSet<K> {
76 fn encode_size(&self) -> usize {
77 let mut size = self.len().encode_size();
78 for item in self {
79 size += item.encode_size();
80 }
81 size
82 }
83}
84
85impl<K: Read + Clone + Ord + Hash + Eq> Read for BTreeSet<K> {
86 type Cfg = (RangeCfg, K::Cfg);
87
88 fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
89 let len = usize::read_cfg(buf, range)?;
91 let mut set = BTreeSet::new();
92
93 read_ordered_set(buf, len, cfg, |item| set.insert(item), BTREESET_TYPE)?;
95
96 Ok(set)
97 }
98}
99
100impl<K: Ord + Hash + Eq + Write> Write for HashSet<K> {
103 fn write(&self, buf: &mut impl BufMut) {
104 self.len().write(buf);
105
106 let mut items: Vec<_> = self.iter().collect();
108 items.sort();
109 for item in items {
110 item.write(buf);
111 }
112 }
113}
114
115impl<K: Ord + Hash + Eq + EncodeSize> EncodeSize for HashSet<K> {
116 fn encode_size(&self) -> usize {
117 let mut size = self.len().encode_size();
118
119 for item in self {
121 size += item.encode_size();
122 }
123 size
124 }
125}
126
127impl<K: Read + Clone + Ord + Hash + Eq> Read for HashSet<K> {
128 type Cfg = (RangeCfg, K::Cfg);
129
130 fn read_cfg(buf: &mut impl Buf, (range, cfg): &Self::Cfg) -> Result<Self, Error> {
131 let len = usize::read_cfg(buf, range)?;
133 let mut set = HashSet::with_capacity(len);
134
135 read_ordered_set(buf, len, cfg, |item| set.insert(item), HASHSET_TYPE)?;
137
138 Ok(set)
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::{
146 codec::{Decode, Encode},
147 FixedSize,
148 };
149 use bytes::{Bytes, BytesMut};
150 use std::{
151 collections::{BTreeSet, HashSet},
152 fmt::Debug,
153 };
154
155 fn round_trip_btree<K>(set: &BTreeSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
157 where
158 K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
159 BTreeSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
160 + Decode<Cfg = (RangeCfg, K::Cfg)>
161 + Debug
162 + PartialEq
163 + Write
164 + EncodeSize,
165 {
166 let encoded = set.encode();
167 assert_eq!(set.encode_size(), encoded.len());
168 let config_tuple = (range_cfg, item_cfg);
169 let decoded = BTreeSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
170 assert_eq!(set, &decoded);
171 }
172
173 fn round_trip_hash<K>(set: &HashSet<K>, range_cfg: RangeCfg, item_cfg: K::Cfg)
175 where
176 K: Write + EncodeSize + Read + Clone + Ord + Hash + Eq + Debug + PartialEq,
177 HashSet<K>: Read<Cfg = (RangeCfg, K::Cfg)>
178 + Decode<Cfg = (RangeCfg, K::Cfg)>
179 + Debug
180 + PartialEq
181 + Write
182 + EncodeSize,
183 {
184 let encoded = set.encode();
185 assert_eq!(set.encode_size(), encoded.len());
186 let config_tuple = (range_cfg, item_cfg);
187 let decoded = HashSet::<K>::decode_cfg(encoded, &config_tuple).expect("decode_cfg failed");
188 assert_eq!(set, &decoded);
189 }
190
191 #[test]
194 fn test_empty_btreeset() {
195 let set = BTreeSet::<u32>::new();
196 round_trip_btree(&set, (..).into(), ());
197 assert_eq!(set.encode_size(), 1); let encoded = set.encode();
199 assert_eq!(encoded, Bytes::from_static(&[0]));
200 }
201
202 #[test]
203 fn test_simple_btreeset_u32() {
204 let mut set = BTreeSet::new();
205 set.insert(1u32);
206 set.insert(5u32);
207 set.insert(2u32);
208 round_trip_btree(&set, (..).into(), ());
209 assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
210 }
211
212 #[test]
213 fn test_large_btreeset() {
214 let set: BTreeSet<_> = (0..1000u16).collect();
216 round_trip_btree(&set, (1000..=1000).into(), ());
217
218 let set: BTreeSet<_> = (0..1000usize).collect();
220 round_trip_btree(&set, (1000..=1000).into(), (..=1000).into());
221 }
222
223 #[test]
224 fn test_btreeset_with_variable_items() {
225 let mut set = BTreeSet::new();
226 set.insert(Bytes::from_static(b"apple"));
227 set.insert(Bytes::from_static(b"banana"));
228 set.insert(Bytes::from_static(b"cherry"));
229
230 let set_range = 0..=10;
231 let item_range = ..=10; round_trip_btree(&set, set_range.into(), item_range.into());
234 }
235
236 #[test]
237 fn test_btreeset_decode_length_limit_exceeded() {
238 let mut set = BTreeSet::new();
239 set.insert(1u32);
240 set.insert(5u32);
241 let encoded = set.encode();
242
243 let config_tuple = ((0..=1).into(), ());
244 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
245 assert!(matches!(result, Err(Error::InvalidLength(2))));
246 }
247
248 #[test]
249 fn test_btreeset_decode_item_length_limit_exceeded() {
250 let mut set = BTreeSet::new();
251 set.insert(Bytes::from_static(b"longitem")); let encoded = set.encode();
253
254 let set_range = 0..=10;
255 let restrictive_item_range = ..=5; let config_tuple = (set_range.into(), restrictive_item_range.into());
257 let result = BTreeSet::<Bytes>::decode_cfg(encoded, &config_tuple);
258
259 assert!(matches!(result, Err(Error::InvalidLength(8))));
260 }
261
262 #[test]
263 fn test_btreeset_decode_invalid_item_order() {
264 let mut encoded = BytesMut::new();
265 2usize.write(&mut encoded); 5u32.write(&mut encoded); 2u32.write(&mut encoded); let config_tuple = ((..).into(), ());
270 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
271 assert!(matches!(
272 result,
273 Err(Error::Invalid("BTreeSet", "Items must ascend")) ));
275 }
276
277 #[test]
278 fn test_btreeset_decode_duplicate_item() {
279 let mut encoded = BytesMut::new();
280 2usize.write(&mut encoded); 1u32.write(&mut encoded); 1u32.write(&mut encoded); let config_tuple = ((..).into(), ());
285 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
286 assert!(matches!(
287 result,
288 Err(Error::Invalid("BTreeSet", "Duplicate item")) ));
290 }
291
292 #[test]
293 fn test_btreeset_decode_end_of_buffer() {
294 let mut set = BTreeSet::new();
295 set.insert(1u32);
296 set.insert(5u32);
297
298 let mut encoded = set.encode();
299 encoded.truncate(set.encode_size() - 2); let config_tuple = ((..).into(), ());
302 let result = BTreeSet::<u32>::decode_cfg(encoded, &config_tuple);
303 assert!(matches!(result, Err(Error::EndOfBuffer)));
304 }
305
306 #[test]
307 fn test_btreeset_decode_extra_data() {
308 let mut set = BTreeSet::new();
309 set.insert(1u32);
310
311 let mut encoded = set.encode();
312 encoded.put_u8(0xFF); let config_tuple = ((..).into(), ());
316 let result = BTreeSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
317 assert!(matches!(result, Err(Error::ExtraData(1))));
318
319 let read_result = BTreeSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
321 assert!(read_result.is_ok());
322 let decoded_set = read_result.unwrap();
323 assert_eq!(decoded_set.len(), 1);
324 assert!(decoded_set.contains(&1u32));
325 }
326
327 #[test]
328 fn test_btreeset_deterministic_encoding() {
329 let mut set1 = BTreeSet::new();
330 (0..1000u32).for_each(|i| {
331 set1.insert(i);
332 });
333
334 let mut set2 = BTreeSet::new();
335 (0..1000u32).rev().for_each(|i| {
336 set2.insert(i);
337 });
338
339 assert_eq!(set1.encode(), set2.encode());
340 }
341
342 #[test]
343 fn test_btreeset_conformity() {
344 let set1 = BTreeSet::<u8>::new();
346 let mut expected1 = BytesMut::new();
347 0usize.write(&mut expected1); assert_eq!(set1.encode(), expected1.freeze());
349 assert_eq!(set1.encode_size(), 1);
350
351 let mut set2 = BTreeSet::<u8>::new();
354 set2.insert(5u8);
355 set2.insert(1u8);
356 set2.insert(2u8);
357
358 let mut expected2 = BytesMut::new();
359 3usize.write(&mut expected2); 1u8.write(&mut expected2); 2u8.write(&mut expected2); 5u8.write(&mut expected2); assert_eq!(set2.encode(), expected2.freeze());
364 assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
365
366 let mut set3 = BTreeSet::<Bytes>::new();
369 set3.insert(Bytes::from_static(b"cherry"));
370 set3.insert(Bytes::from_static(b"apple"));
371 set3.insert(Bytes::from_static(b"banana"));
372
373 let mut expected3 = BytesMut::new();
374 3usize.write(&mut expected3); Bytes::from_static(b"apple").write(&mut expected3);
376 Bytes::from_static(b"banana").write(&mut expected3);
377 Bytes::from_static(b"cherry").write(&mut expected3);
378 assert_eq!(set3.encode(), expected3.freeze());
379 let expected_size = 1usize.encode_size()
380 + Bytes::from_static(b"apple").encode_size()
381 + Bytes::from_static(b"banana").encode_size()
382 + Bytes::from_static(b"cherry").encode_size();
383 assert_eq!(set3.encode_size(), expected_size);
384 }
385
386 #[test]
389 fn test_empty_hashset() {
390 let set = HashSet::<u32>::new();
391 round_trip_hash(&set, (..).into(), ());
392 assert_eq!(set.encode_size(), 1); let encoded = set.encode();
394 assert_eq!(encoded, Bytes::from_static(&[0]));
395 }
396
397 #[test]
398 fn test_simple_hashset_u32() {
399 let mut set = HashSet::new();
400 set.insert(1u32);
401 set.insert(5u32);
402 set.insert(2u32);
403 round_trip_hash(&set, (..).into(), ());
404 assert_eq!(set.encode_size(), 1 + 3 * u32::SIZE);
406 let mut expected = BytesMut::new();
408 3usize.write(&mut expected); 1u32.write(&mut expected);
410 2u32.write(&mut expected);
411 5u32.write(&mut expected);
412 assert_eq!(set.encode(), expected.freeze());
413 }
414
415 #[test]
416 fn test_large_hashset() {
417 let set: HashSet<_> = (0..1000u16).collect();
419 round_trip_hash(&set, (1000..=1000).into(), ());
420
421 let set: HashSet<_> = (0..1000usize).collect();
423 round_trip_hash(&set, (1000..=1000).into(), (..=1000).into());
424 }
425
426 #[test]
427 fn test_hashset_with_variable_items() {
428 let mut set = HashSet::new();
429 set.insert(Bytes::from_static(b"apple"));
430 set.insert(Bytes::from_static(b"banana"));
431 set.insert(Bytes::from_static(b"cherry"));
432
433 let set_range = 0..=10;
434 let item_range = ..=10; round_trip_hash(&set, set_range.into(), item_range.into());
437 }
438
439 #[test]
440 fn test_hashset_decode_length_limit_exceeded() {
441 let mut set = HashSet::new();
442 set.insert(1u32);
443 set.insert(5u32);
444
445 let encoded = set.encode();
446 let config_tuple = ((0..=1).into(), ());
447
448 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
449 assert!(matches!(result, Err(Error::InvalidLength(2))));
450 }
451
452 #[test]
453 fn test_hashset_decode_item_length_limit_exceeded() {
454 let mut set = HashSet::new();
455 set.insert(Bytes::from_static(b"longitem")); let set_range = 0..=10;
458 let restrictive_item_range = ..=5; let encoded = set.encode();
461 let config_tuple = (set_range.into(), restrictive_item_range.into());
462 let result = HashSet::<Bytes>::decode_cfg(encoded, &config_tuple);
463
464 assert!(matches!(result, Err(Error::InvalidLength(8))));
465 }
466
467 #[test]
468 fn test_hashset_decode_invalid_item_order() {
469 let mut encoded = BytesMut::new();
470 2usize.write(&mut encoded); 5u32.write(&mut encoded); 2u32.write(&mut encoded); let config_tuple = ((..).into(), ());
475
476 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
477 assert!(matches!(
478 result,
479 Err(Error::Invalid("HashSet", "Items must ascend"))
480 ));
481 }
482
483 #[test]
484 fn test_hashset_decode_duplicate_item() {
485 let mut encoded = BytesMut::new();
486 2usize.write(&mut encoded); 1u32.write(&mut encoded); 1u32.write(&mut encoded); let config_tuple = ((..).into(), ());
491 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
492 assert!(matches!(
493 result,
494 Err(Error::Invalid("HashSet", "Duplicate item"))
495 ));
496 }
497
498 #[test]
499 fn test_hashset_decode_end_of_buffer() {
500 let mut set = HashSet::new();
501 set.insert(1u32);
502 set.insert(5u32);
503
504 let mut encoded = set.encode(); encoded.truncate(set.encode_size() - 2); let config_tuple = ((..).into(), ());
508 let result = HashSet::<u32>::decode_cfg(encoded, &config_tuple);
509 assert!(matches!(result, Err(Error::EndOfBuffer)));
510 }
511
512 #[test]
513 fn test_hashset_decode_extra_data() {
514 let mut set = HashSet::new();
515 set.insert(1u32);
516
517 let mut encoded = set.encode();
518 encoded.put_u8(0xFF); let config_tuple = ((..).into(), ()); let result = HashSet::<u32>::decode_cfg(encoded.clone(), &config_tuple);
523 assert!(matches!(result, Err(Error::ExtraData(1))));
524
525 let read_result = HashSet::<u32>::read_cfg(&mut encoded.clone(), &config_tuple);
527 assert!(read_result.is_ok());
528 let decoded_set = read_result.unwrap();
529 assert_eq!(decoded_set.len(), 1);
530 assert!(decoded_set.contains(&1u32));
531 }
532
533 #[test]
534 fn test_hashset_deterministic_encoding() {
535 let mut set1 = HashSet::new();
536 (0..1000u32).for_each(|i| {
537 set1.insert(i);
538 });
539
540 let mut set2 = HashSet::new();
541 (0..1000u32).rev().for_each(|i| {
542 set2.insert(i);
543 });
544
545 assert_eq!(set1.encode(), set2.encode());
546 }
547
548 #[test]
549 fn test_hashset_conformity() {
550 let set1 = HashSet::<u8>::new();
552 let mut expected1 = BytesMut::new();
553 0usize.write(&mut expected1); assert_eq!(set1.encode(), expected1.freeze());
555 assert_eq!(set1.encode_size(), 1);
556
557 let mut set2 = HashSet::<u8>::new();
560 set2.insert(5u8);
561 set2.insert(1u8);
562 set2.insert(2u8);
563
564 let mut expected2 = BytesMut::new();
565 3usize.write(&mut expected2); 1u8.write(&mut expected2); 2u8.write(&mut expected2); 5u8.write(&mut expected2); assert_eq!(set2.encode(), expected2.freeze());
570 assert_eq!(set2.encode_size(), 1 + 3 * u8::SIZE);
571
572 let mut set3 = HashSet::<Bytes>::new();
575 set3.insert(Bytes::from_static(b"cherry"));
576 set3.insert(Bytes::from_static(b"apple"));
577 set3.insert(Bytes::from_static(b"banana"));
578
579 let mut expected3 = BytesMut::new();
580 3usize.write(&mut expected3); Bytes::from_static(b"apple").write(&mut expected3);
582 Bytes::from_static(b"banana").write(&mut expected3);
583 Bytes::from_static(b"cherry").write(&mut expected3);
584 assert_eq!(set3.encode(), expected3.freeze());
585 let expected_size = 1usize.encode_size()
586 + Bytes::from_static(b"apple").encode_size()
587 + Bytes::from_static(b"banana").encode_size()
588 + Bytes::from_static(b"cherry").encode_size();
589 assert_eq!(set3.encode_size(), expected_size);
590 }
591}