1use crate::{Get, TryCollect};
21use alloc::collections::BTreeSet;
22use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
23#[cfg(feature = "serde")]
24use serde::{
25 de::{Error, SeqAccess, Visitor},
26 Deserialize, Deserializer, Serialize,
27};
28
29#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
37#[cfg_attr(feature = "scale-codec", derive(scale_codec::Encode, scale_info::TypeInfo))]
38#[cfg_attr(feature = "scale-codec", scale_info(skip_type_params(S)))]
39#[cfg_attr(feature = "jam-codec", derive(jam_codec::Encode))]
40pub struct BoundedBTreeSet<T, S>(BTreeSet<T>, #[cfg_attr(feature = "serde", serde(skip_serializing))] PhantomData<S>);
41
42#[cfg(feature = "serde")]
43impl<'de, T, S: Get<u32>> Deserialize<'de> for BoundedBTreeSet<T, S>
44where
45 T: Ord + Deserialize<'de>,
46 S: Clone,
47{
48 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49 where
50 D: Deserializer<'de>,
51 {
52 struct BTreeSetVisitor<T, S>(PhantomData<(T, S)>);
54
55 impl<'de, T, S> Visitor<'de> for BTreeSetVisitor<T, S>
56 where
57 T: Ord + Deserialize<'de>,
58 S: Get<u32> + Clone,
59 {
60 type Value = BTreeSet<T>;
61
62 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
63 formatter.write_str("a sequence")
64 }
65
66 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
67 where
68 A: SeqAccess<'de>,
69 {
70 let size = seq.size_hint().unwrap_or(0);
71 let max = match usize::try_from(S::get()) {
72 Ok(n) => n,
73 Err(_) => return Err(A::Error::custom("can't convert to usize")),
74 };
75 if size > max {
76 Err(A::Error::custom("out of bounds"))
77 } else {
78 let mut values = BTreeSet::new();
79
80 while let Some(value) = seq.next_element()? {
81 if values.len() >= max {
82 return Err(A::Error::custom("out of bounds"));
83 }
84 values.insert(value);
85 }
86
87 Ok(values)
88 }
89 }
90 }
91
92 let visitor: BTreeSetVisitor<T, S> = BTreeSetVisitor(PhantomData);
93 deserializer
94 .deserialize_seq(visitor)
95 .map(|v| BoundedBTreeSet::<T, S>::try_from(v).map_err(|_| Error::custom("out of bounds")))?
96 }
97}
98
99impl<T, S> BoundedBTreeSet<T, S>
100where
101 S: Get<u32>,
102{
103 pub fn bound() -> usize {
105 S::get() as usize
106 }
107}
108
109impl<T, S> BoundedBTreeSet<T, S>
110where
111 T: Ord,
112 S: Get<u32>,
113{
114 fn unchecked_from(t: BTreeSet<T>) -> Self {
116 Self(t, Default::default())
117 }
118
119 pub fn new() -> Self {
123 BoundedBTreeSet(BTreeSet::new(), PhantomData)
124 }
125
126 pub fn into_inner(self) -> BTreeSet<T> {
131 debug_assert!(self.0.len() <= Self::bound());
132 self.0
133 }
134
135 pub fn try_mutate(mut self, mut mutate: impl FnMut(&mut BTreeSet<T>)) -> Option<Self> {
143 mutate(&mut self.0);
144 (self.0.len() <= Self::bound()).then(move || self)
145 }
146
147 pub fn clear(&mut self) {
149 self.0.clear()
150 }
151
152 pub fn try_insert(&mut self, item: T) -> Result<bool, T> {
157 if self.len() < Self::bound() || self.0.contains(&item) {
158 Ok(self.0.insert(item))
159 } else {
160 Err(item)
161 }
162 }
163
164 pub fn remove<Q>(&mut self, item: &Q) -> bool
169 where
170 T: Borrow<Q>,
171 Q: Ord + ?Sized,
172 {
173 self.0.remove(item)
174 }
175
176 pub fn take<Q>(&mut self, value: &Q) -> Option<T>
181 where
182 T: Borrow<Q> + Ord,
183 Q: Ord + ?Sized,
184 {
185 self.0.take(value)
186 }
187
188 pub fn is_full(&self) -> bool {
190 self.len() >= Self::bound()
191 }
192}
193
194impl<T, S> Default for BoundedBTreeSet<T, S>
195where
196 T: Ord,
197 S: Get<u32>,
198{
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl<T, S> Clone for BoundedBTreeSet<T, S>
205where
206 BTreeSet<T>: Clone,
207{
208 fn clone(&self) -> Self {
209 BoundedBTreeSet(self.0.clone(), PhantomData)
210 }
211}
212
213impl<T, S> core::fmt::Debug for BoundedBTreeSet<T, S>
214where
215 BTreeSet<T>: core::fmt::Debug,
216 S: Get<u32>,
217{
218 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
219 f.debug_tuple("BoundedBTreeSet").field(&self.0).field(&Self::bound()).finish()
220 }
221}
222
223#[cfg(feature = "std")]
226impl<T: std::hash::Hash, S> std::hash::Hash for BoundedBTreeSet<T, S> {
227 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
228 self.0.hash(state);
229 }
230}
231
232impl<T, S1, S2> PartialEq<BoundedBTreeSet<T, S1>> for BoundedBTreeSet<T, S2>
233where
234 BTreeSet<T>: PartialEq,
235 S1: Get<u32>,
236 S2: Get<u32>,
237{
238 fn eq(&self, other: &BoundedBTreeSet<T, S1>) -> bool {
239 S1::get() == S2::get() && self.0 == other.0
240 }
241}
242
243impl<T, S> Eq for BoundedBTreeSet<T, S>
244where
245 BTreeSet<T>: Eq,
246 S: Get<u32>,
247{
248}
249
250impl<T, S> PartialEq<BTreeSet<T>> for BoundedBTreeSet<T, S>
251where
252 BTreeSet<T>: PartialEq,
253 S: Get<u32>,
254{
255 fn eq(&self, other: &BTreeSet<T>) -> bool {
256 self.0 == *other
257 }
258}
259
260impl<T, S> PartialOrd for BoundedBTreeSet<T, S>
261where
262 BTreeSet<T>: PartialOrd,
263 S: Get<u32>,
264{
265 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
266 self.0.partial_cmp(&other.0)
267 }
268}
269
270impl<T, S> Ord for BoundedBTreeSet<T, S>
271where
272 BTreeSet<T>: Ord,
273 S: Get<u32>,
274{
275 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
276 self.0.cmp(&other.0)
277 }
278}
279
280impl<T, S> IntoIterator for BoundedBTreeSet<T, S> {
281 type Item = T;
282 type IntoIter = alloc::collections::btree_set::IntoIter<T>;
283
284 fn into_iter(self) -> Self::IntoIter {
285 self.0.into_iter()
286 }
287}
288
289impl<'a, T, S> IntoIterator for &'a BoundedBTreeSet<T, S> {
290 type Item = &'a T;
291 type IntoIter = alloc::collections::btree_set::Iter<'a, T>;
292
293 fn into_iter(self) -> Self::IntoIter {
294 self.0.iter()
295 }
296}
297
298impl<T, S> Deref for BoundedBTreeSet<T, S>
299where
300 T: Ord,
301{
302 type Target = BTreeSet<T>;
303
304 fn deref(&self) -> &Self::Target {
305 &self.0
306 }
307}
308
309impl<T, S> AsRef<BTreeSet<T>> for BoundedBTreeSet<T, S>
310where
311 T: Ord,
312{
313 fn as_ref(&self) -> &BTreeSet<T> {
314 &self.0
315 }
316}
317
318impl<T, S> From<BoundedBTreeSet<T, S>> for BTreeSet<T>
319where
320 T: Ord,
321{
322 fn from(set: BoundedBTreeSet<T, S>) -> Self {
323 set.0
324 }
325}
326
327impl<T, S> TryFrom<BTreeSet<T>> for BoundedBTreeSet<T, S>
328where
329 T: Ord,
330 S: Get<u32>,
331{
332 type Error = ();
333
334 fn try_from(value: BTreeSet<T>) -> Result<Self, Self::Error> {
335 (value.len() <= Self::bound())
336 .then(move || BoundedBTreeSet(value, PhantomData))
337 .ok_or(())
338 }
339}
340
341impl<I, T, Bound> TryCollect<BoundedBTreeSet<T, Bound>> for I
342where
343 T: Ord,
344 I: ExactSizeIterator + Iterator<Item = T>,
345 Bound: Get<u32>,
346{
347 type Error = &'static str;
348
349 fn try_collect(self) -> Result<BoundedBTreeSet<T, Bound>, Self::Error> {
350 if self.len() > Bound::get() as usize {
351 Err("iterator length too big")
352 } else {
353 Ok(BoundedBTreeSet::<T, Bound>::unchecked_from(self.collect::<BTreeSet<T>>()))
354 }
355 }
356}
357
358#[cfg(any(feature = "scale-codec", feature = "jam-codec"))]
359macro_rules! codec_impl {
360 ($codec:ident) => {
361 use super::*;
362 use crate::codec_utils::PrependCompactInput;
363 use $codec::{
364 Compact, Decode, DecodeLength, DecodeWithMemTracking, Encode, EncodeLike, Error, Input, MaxEncodedLen,
365 };
366
367 impl<T, S> Decode for BoundedBTreeSet<T, S>
368 where
369 T: Decode + Ord,
370 S: Get<u32>,
371 {
372 fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
373 let len = <Compact<u32>>::decode(input)?;
375 if len.0 > S::get() {
376 return Err("BoundedBTreeSet exceeds its limit".into());
377 }
378 let inner = BTreeSet::decode(&mut PrependCompactInput {
380 encoded_len: len.encode().as_ref(),
381 read: 0,
382 inner: input,
383 })?;
384 Ok(Self(inner, PhantomData))
385 }
386
387 fn skip<I: Input>(input: &mut I) -> Result<(), Error> {
388 BTreeSet::<T>::skip(input)
389 }
390 }
391
392 impl<T, S> MaxEncodedLen for BoundedBTreeSet<T, S>
393 where
394 T: MaxEncodedLen,
395 S: Get<u32>,
396 {
397 fn max_encoded_len() -> usize {
398 Self::bound()
399 .saturating_mul(T::max_encoded_len())
400 .saturating_add(Compact(S::get()).encoded_size())
401 }
402 }
403
404 impl<T, S> DecodeLength for BoundedBTreeSet<T, S> {
405 fn len(self_encoded: &[u8]) -> Result<usize, Error> {
406 <BTreeSet<T> as DecodeLength>::len(self_encoded)
410 }
411 }
412
413 impl<T, S> EncodeLike<BTreeSet<T>> for BoundedBTreeSet<T, S> where BTreeSet<T>: Encode {}
414
415 impl<T, S> DecodeWithMemTracking for BoundedBTreeSet<T, S>
416 where
417 T: Decode + Ord,
418 S: Get<u32>,
419 {
420 }
421 };
422}
423
424#[cfg(feature = "scale-codec")]
425mod scale_codec_impl {
426 codec_impl!(scale_codec);
427}
428
429#[cfg(feature = "jam-codec")]
430mod jam_codec_impl {
431 codec_impl!(jam_codec);
432}
433
434#[cfg(test)]
435mod test {
436 use super::*;
437 use crate::ConstU32;
438 use alloc::{vec, vec::Vec};
439 #[cfg(feature = "scale-codec")]
440 use scale_codec::{Compact, CompactLen, Decode, Encode};
441
442 fn set_from_keys<T>(keys: &[T]) -> BTreeSet<T>
443 where
444 T: Ord + Copy,
445 {
446 keys.iter().copied().collect()
447 }
448
449 fn boundedset_from_keys<T, S>(keys: &[T]) -> BoundedBTreeSet<T, S>
450 where
451 T: Ord + Copy,
452 S: Get<u32>,
453 {
454 set_from_keys(keys).try_into().unwrap()
455 }
456
457 #[test]
458 #[cfg(feature = "scale-codec")]
459 fn encoding_same_as_unbounded_set() {
460 let b = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
461 let m = set_from_keys(&[1, 2, 3, 4, 5, 6]);
462
463 assert_eq!(b.encode(), m.encode());
464 }
465
466 #[test]
467 fn try_insert_works() {
468 let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
469 bounded.try_insert(0).unwrap();
470 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
471
472 assert!(bounded.try_insert(9).is_err());
473 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
474 }
475
476 #[test]
477 fn deref_coercion_works() {
478 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3]);
479 assert_eq!(bounded.len(), 3);
481 assert!(bounded.iter().next().is_some());
482 assert!(!bounded.is_empty());
483 }
484
485 #[test]
486 fn try_mutate_works() {
487 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
488 let bounded = bounded
489 .try_mutate(|v| {
490 v.insert(7);
491 })
492 .unwrap();
493 assert_eq!(bounded.len(), 7);
494 assert!(bounded
495 .try_mutate(|v| {
496 v.insert(8);
497 })
498 .is_none());
499 }
500
501 #[test]
502 fn btree_map_eq_works() {
503 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
504 assert_eq!(bounded, set_from_keys(&[1, 2, 3, 4, 5, 6]));
505 }
506
507 #[test]
508 #[cfg(feature = "scale-codec")]
509 fn too_big_fail_to_decode() {
510 let v: Vec<u32> = vec![1, 2, 3, 4, 5];
511 assert_eq!(
512 BoundedBTreeSet::<u32, ConstU32<4>>::decode(&mut &v.encode()[..]),
513 Err("BoundedBTreeSet exceeds its limit".into()),
514 );
515 }
516
517 #[test]
518 #[cfg(feature = "scale-codec")]
519 fn dont_consume_more_data_than_bounded_len() {
520 let s = set_from_keys(&[1, 2, 3, 4, 5, 6]);
521 let data = s.encode();
522 let data_input = &mut &data[..];
523
524 BoundedBTreeSet::<u32, ConstU32<4>>::decode(data_input).unwrap_err();
525 assert_eq!(data_input.len(), data.len() - Compact::<u32>::compact_len(&(data.len() as u32)));
526 }
527
528 #[test]
529 fn unequal_eq_impl_insert_works() {
530 #[derive(Debug)]
532 struct Unequal(u32, bool);
533
534 impl PartialEq for Unequal {
535 fn eq(&self, other: &Self) -> bool {
536 self.0 == other.0
537 }
538 }
539 impl Eq for Unequal {}
540
541 impl Ord for Unequal {
542 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
543 self.0.cmp(&other.0)
544 }
545 }
546
547 impl PartialOrd for Unequal {
548 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
549 Some(self.cmp(other))
550 }
551 }
552
553 let mut set = BoundedBTreeSet::<Unequal, ConstU32<4>>::new();
554
555 for i in 0..4 {
558 set.try_insert(Unequal(i, false)).unwrap();
559 }
560
561 set.try_insert(Unequal(5, false)).unwrap_err();
563
564 set.try_insert(Unequal(0, true)).unwrap();
567 assert_eq!(set.len(), 4);
568 let zero_item = set.get(&Unequal(0, true)).unwrap();
569 assert_eq!(zero_item.0, 0);
570 assert_eq!(zero_item.1, false);
571 }
572
573 #[test]
574 fn eq_works() {
575 let b1 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
577 let b2 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
578 assert_eq!(b1, b2);
579
580 crate::parameter_types! {
582 B1: u32 = 7;
583 B2: u32 = 7;
584 }
585 let b1 = boundedset_from_keys::<u32, B1>(&[1, 2]);
586 let b2 = boundedset_from_keys::<u32, B2>(&[1, 2]);
587 assert_eq!(b1, b2);
588 }
589
590 #[test]
591 fn can_be_collected() {
592 let b1 = boundedset_from_keys::<u32, ConstU32<5>>(&[1, 2, 3, 4]);
593 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
594 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
595
596 let b2: BoundedBTreeSet<u32, ConstU32<4>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
598 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
599
600 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).rev().skip(2).try_collect().unwrap();
602 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
604
605 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).take(2).try_collect().unwrap();
606 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
607
608 let b2: Result<BoundedBTreeSet<u32, ConstU32<3>>, _> = b1.iter().map(|k| k + 1).try_collect();
610 assert!(b2.is_err());
611
612 let b2: Result<BoundedBTreeSet<u32, ConstU32<1>>, _> = b1.iter().map(|k| k + 1).skip(2).try_collect();
613 assert!(b2.is_err());
614 }
615
616 #[test]
619 #[cfg(feature = "std")]
620 fn container_can_derive_hash() {
621 #[derive(Hash, Default)]
622 struct Foo {
623 bar: u8,
624 set: BoundedBTreeSet<String, ConstU32<16>>,
625 }
626 let _foo = Foo::default();
627 }
628
629 #[test]
630 fn is_full_works() {
631 let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
632 assert!(!bounded.is_full());
633 bounded.try_insert(0).unwrap();
634 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
635
636 assert!(bounded.is_full());
637 assert!(bounded.try_insert(9).is_err());
638 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
639 }
640
641 #[cfg(feature = "serde")]
642 mod serde {
643 use super::*;
644 use crate::alloc::string::ToString as _;
645
646 #[test]
647 fn test_serializer() {
648 let mut c = BoundedBTreeSet::<u32, ConstU32<6>>::new();
649 c.try_insert(0).unwrap();
650 c.try_insert(1).unwrap();
651 c.try_insert(2).unwrap();
652
653 assert_eq!(serde_json::json!(&c).to_string(), r#"[0,1,2]"#);
654 }
655
656 #[test]
657 fn test_deserializer() {
658 let c: Result<BoundedBTreeSet<u32, ConstU32<6>>, serde_json::error::Error> =
659 serde_json::from_str(r#"[0,1,2]"#);
660 assert!(c.is_ok());
661 let c = c.unwrap();
662
663 assert_eq!(c.len(), 3);
664 assert!(c.contains(&0));
665 assert!(c.contains(&1));
666 assert!(c.contains(&2));
667 }
668
669 #[test]
670 fn test_deserializer_bound() {
671 let c: Result<BoundedBTreeSet<u32, ConstU32<3>>, serde_json::error::Error> =
672 serde_json::from_str(r#"[0,1,2]"#);
673 assert!(c.is_ok());
674 let c = c.unwrap();
675
676 assert_eq!(c.len(), 3);
677 assert!(c.contains(&0));
678 assert!(c.contains(&1));
679 assert!(c.contains(&2));
680 }
681
682 #[test]
683 fn test_deserializer_failed() {
684 let c: Result<BoundedBTreeSet<u32, ConstU32<4>>, serde_json::error::Error> =
685 serde_json::from_str(r#"[0,1,2,3,4]"#);
686
687 match c {
688 Err(msg) => assert_eq!(msg.to_string(), "out of bounds at line 1 column 11"),
689 _ => unreachable!("deserializer must raise error"),
690 }
691 }
692 }
693}