1use crate::{Get, TryCollect};
21use alloc::collections::BTreeSet;
22use codec::{Compact, Decode, Encode, MaxEncodedLen};
23use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
24#[cfg(feature = "serde")]
25use serde::{
26 de::{Error, SeqAccess, Visitor},
27 Deserialize, Deserializer, Serialize,
28};
29
30#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
38#[derive(Encode, scale_info::TypeInfo)]
39#[scale_info(skip_type_params(S))]
40pub struct BoundedBTreeSet<T, S>(BTreeSet<T>, #[cfg_attr(feature = "serde", serde(skip_serializing))] PhantomData<S>);
41
42#[cfg(feature = "serde")]
43impl<'de, T, S: Get<u32>> Deserialize<'de> for BoundedBTreeSet<T, S>
44where
45 T: Ord + Deserialize<'de>,
46 S: Clone,
47{
48 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49 where
50 D: Deserializer<'de>,
51 {
52 struct BTreeSetVisitor<T, S>(PhantomData<(T, S)>);
54
55 impl<'de, T, S> Visitor<'de> for BTreeSetVisitor<T, S>
56 where
57 T: Ord + Deserialize<'de>,
58 S: Get<u32> + Clone,
59 {
60 type Value = BTreeSet<T>;
61
62 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
63 formatter.write_str("a sequence")
64 }
65
66 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
67 where
68 A: SeqAccess<'de>,
69 {
70 let size = seq.size_hint().unwrap_or(0);
71 let max = match usize::try_from(S::get()) {
72 Ok(n) => n,
73 Err(_) => return Err(A::Error::custom("can't convert to usize")),
74 };
75 if size > max {
76 Err(A::Error::custom("out of bounds"))
77 } else {
78 let mut values = BTreeSet::new();
79
80 while let Some(value) = seq.next_element()? {
81 if values.len() >= max {
82 return Err(A::Error::custom("out of bounds"))
83 }
84 values.insert(value);
85 }
86
87 Ok(values)
88 }
89 }
90 }
91
92 let visitor: BTreeSetVisitor<T, S> = BTreeSetVisitor(PhantomData);
93 deserializer
94 .deserialize_seq(visitor)
95 .map(|v| BoundedBTreeSet::<T, S>::try_from(v).map_err(|_| Error::custom("out of bounds")))?
96 }
97}
98
99impl<T, S> Decode for BoundedBTreeSet<T, S>
100where
101 T: Decode + Ord,
102 S: Get<u32>,
103{
104 fn decode<I: codec::Input>(input: &mut I) -> Result<Self, codec::Error> {
105 let len: u32 = <Compact<u32>>::decode(input)?.into();
108 if len > S::get() {
109 return Err("BoundedBTreeSet exceeds its limit".into())
110 }
111 input.descend_ref()?;
112 let inner = Result::from_iter((0..len).map(|_| Decode::decode(input)))?;
113 input.ascend_ref();
114 Ok(Self(inner, PhantomData))
115 }
116
117 fn skip<I: codec::Input>(input: &mut I) -> Result<(), codec::Error> {
118 BTreeSet::<T>::skip(input)
119 }
120}
121
122impl<T, S> BoundedBTreeSet<T, S>
123where
124 S: Get<u32>,
125{
126 pub fn bound() -> usize {
128 S::get() as usize
129 }
130}
131
132impl<T, S> BoundedBTreeSet<T, S>
133where
134 T: Ord,
135 S: Get<u32>,
136{
137 fn unchecked_from(t: BTreeSet<T>) -> Self {
139 Self(t, Default::default())
140 }
141
142 pub fn new() -> Self {
146 BoundedBTreeSet(BTreeSet::new(), PhantomData)
147 }
148
149 pub fn into_inner(self) -> BTreeSet<T> {
154 debug_assert!(self.0.len() <= Self::bound());
155 self.0
156 }
157
158 pub fn try_mutate(mut self, mut mutate: impl FnMut(&mut BTreeSet<T>)) -> Option<Self> {
166 mutate(&mut self.0);
167 (self.0.len() <= Self::bound()).then(move || self)
168 }
169
170 pub fn clear(&mut self) {
172 self.0.clear()
173 }
174
175 pub fn try_insert(&mut self, item: T) -> Result<bool, T> {
180 if self.len() < Self::bound() || self.0.contains(&item) {
181 Ok(self.0.insert(item))
182 } else {
183 Err(item)
184 }
185 }
186
187 pub fn remove<Q>(&mut self, item: &Q) -> bool
192 where
193 T: Borrow<Q>,
194 Q: Ord + ?Sized,
195 {
196 self.0.remove(item)
197 }
198
199 pub fn take<Q>(&mut self, value: &Q) -> Option<T>
204 where
205 T: Borrow<Q> + Ord,
206 Q: Ord + ?Sized,
207 {
208 self.0.take(value)
209 }
210
211 pub fn is_full(&self) -> bool {
213 self.len() >= Self::bound()
214 }
215}
216
217impl<T, S> Default for BoundedBTreeSet<T, S>
218where
219 T: Ord,
220 S: Get<u32>,
221{
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227impl<T, S> Clone for BoundedBTreeSet<T, S>
228where
229 BTreeSet<T>: Clone,
230{
231 fn clone(&self) -> Self {
232 BoundedBTreeSet(self.0.clone(), PhantomData)
233 }
234}
235
236impl<T, S> core::fmt::Debug for BoundedBTreeSet<T, S>
237where
238 BTreeSet<T>: core::fmt::Debug,
239 S: Get<u32>,
240{
241 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
242 f.debug_tuple("BoundedBTreeSet").field(&self.0).field(&Self::bound()).finish()
243 }
244}
245
246#[cfg(feature = "std")]
249impl<T: std::hash::Hash, S> std::hash::Hash for BoundedBTreeSet<T, S> {
250 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
251 self.0.hash(state);
252 }
253}
254
255impl<T, S1, S2> PartialEq<BoundedBTreeSet<T, S1>> for BoundedBTreeSet<T, S2>
256where
257 BTreeSet<T>: PartialEq,
258 S1: Get<u32>,
259 S2: Get<u32>,
260{
261 fn eq(&self, other: &BoundedBTreeSet<T, S1>) -> bool {
262 S1::get() == S2::get() && self.0 == other.0
263 }
264}
265
266impl<T, S> Eq for BoundedBTreeSet<T, S>
267where
268 BTreeSet<T>: Eq,
269 S: Get<u32>,
270{
271}
272
273impl<T, S> PartialEq<BTreeSet<T>> for BoundedBTreeSet<T, S>
274where
275 BTreeSet<T>: PartialEq,
276 S: Get<u32>,
277{
278 fn eq(&self, other: &BTreeSet<T>) -> bool {
279 self.0 == *other
280 }
281}
282
283impl<T, S> PartialOrd for BoundedBTreeSet<T, S>
284where
285 BTreeSet<T>: PartialOrd,
286 S: Get<u32>,
287{
288 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
289 self.0.partial_cmp(&other.0)
290 }
291}
292
293impl<T, S> Ord for BoundedBTreeSet<T, S>
294where
295 BTreeSet<T>: Ord,
296 S: Get<u32>,
297{
298 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
299 self.0.cmp(&other.0)
300 }
301}
302
303impl<T, S> IntoIterator for BoundedBTreeSet<T, S> {
304 type Item = T;
305 type IntoIter = alloc::collections::btree_set::IntoIter<T>;
306
307 fn into_iter(self) -> Self::IntoIter {
308 self.0.into_iter()
309 }
310}
311
312impl<'a, T, S> IntoIterator for &'a BoundedBTreeSet<T, S> {
313 type Item = &'a T;
314 type IntoIter = alloc::collections::btree_set::Iter<'a, T>;
315
316 fn into_iter(self) -> Self::IntoIter {
317 self.0.iter()
318 }
319}
320
321impl<T, S> MaxEncodedLen for BoundedBTreeSet<T, S>
322where
323 T: MaxEncodedLen,
324 S: Get<u32>,
325{
326 fn max_encoded_len() -> usize {
327 Self::bound()
328 .saturating_mul(T::max_encoded_len())
329 .saturating_add(codec::Compact(S::get()).encoded_size())
330 }
331}
332
333impl<T, S> Deref for BoundedBTreeSet<T, S>
334where
335 T: Ord,
336{
337 type Target = BTreeSet<T>;
338
339 fn deref(&self) -> &Self::Target {
340 &self.0
341 }
342}
343
344impl<T, S> AsRef<BTreeSet<T>> for BoundedBTreeSet<T, S>
345where
346 T: Ord,
347{
348 fn as_ref(&self) -> &BTreeSet<T> {
349 &self.0
350 }
351}
352
353impl<T, S> From<BoundedBTreeSet<T, S>> for BTreeSet<T>
354where
355 T: Ord,
356{
357 fn from(set: BoundedBTreeSet<T, S>) -> Self {
358 set.0
359 }
360}
361
362impl<T, S> TryFrom<BTreeSet<T>> for BoundedBTreeSet<T, S>
363where
364 T: Ord,
365 S: Get<u32>,
366{
367 type Error = ();
368
369 fn try_from(value: BTreeSet<T>) -> Result<Self, Self::Error> {
370 (value.len() <= Self::bound())
371 .then(move || BoundedBTreeSet(value, PhantomData))
372 .ok_or(())
373 }
374}
375
376impl<T, S> codec::DecodeLength for BoundedBTreeSet<T, S> {
377 fn len(self_encoded: &[u8]) -> Result<usize, codec::Error> {
378 <BTreeSet<T> as codec::DecodeLength>::len(self_encoded)
382 }
383}
384
385impl<T, S> codec::EncodeLike<BTreeSet<T>> for BoundedBTreeSet<T, S> where BTreeSet<T>: Encode {}
386
387impl<I, T, Bound> TryCollect<BoundedBTreeSet<T, Bound>> for I
388where
389 T: Ord,
390 I: ExactSizeIterator + Iterator<Item = T>,
391 Bound: Get<u32>,
392{
393 type Error = &'static str;
394
395 fn try_collect(self) -> Result<BoundedBTreeSet<T, Bound>, Self::Error> {
396 if self.len() > Bound::get() as usize {
397 Err("iterator length too big")
398 } else {
399 Ok(BoundedBTreeSet::<T, Bound>::unchecked_from(self.collect::<BTreeSet<T>>()))
400 }
401 }
402}
403
404#[cfg(test)]
405mod test {
406 use super::*;
407 use crate::ConstU32;
408 use alloc::{vec, vec::Vec};
409 use codec::CompactLen;
410
411 fn set_from_keys<T>(keys: &[T]) -> BTreeSet<T>
412 where
413 T: Ord + Copy,
414 {
415 keys.iter().copied().collect()
416 }
417
418 fn boundedset_from_keys<T, S>(keys: &[T]) -> BoundedBTreeSet<T, S>
419 where
420 T: Ord + Copy,
421 S: Get<u32>,
422 {
423 set_from_keys(keys).try_into().unwrap()
424 }
425
426 #[test]
427 fn encoding_same_as_unbounded_set() {
428 let b = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
429 let m = set_from_keys(&[1, 2, 3, 4, 5, 6]);
430
431 assert_eq!(b.encode(), m.encode());
432 }
433
434 #[test]
435 fn try_insert_works() {
436 let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
437 bounded.try_insert(0).unwrap();
438 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
439
440 assert!(bounded.try_insert(9).is_err());
441 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
442 }
443
444 #[test]
445 fn deref_coercion_works() {
446 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3]);
447 assert_eq!(bounded.len(), 3);
449 assert!(bounded.iter().next().is_some());
450 assert!(!bounded.is_empty());
451 }
452
453 #[test]
454 fn try_mutate_works() {
455 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
456 let bounded = bounded
457 .try_mutate(|v| {
458 v.insert(7);
459 })
460 .unwrap();
461 assert_eq!(bounded.len(), 7);
462 assert!(bounded
463 .try_mutate(|v| {
464 v.insert(8);
465 })
466 .is_none());
467 }
468
469 #[test]
470 fn btree_map_eq_works() {
471 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
472 assert_eq!(bounded, set_from_keys(&[1, 2, 3, 4, 5, 6]));
473 }
474
475 #[test]
476 fn too_big_fail_to_decode() {
477 let v: Vec<u32> = vec![1, 2, 3, 4, 5];
478 assert_eq!(
479 BoundedBTreeSet::<u32, ConstU32<4>>::decode(&mut &v.encode()[..]),
480 Err("BoundedBTreeSet exceeds its limit".into()),
481 );
482 }
483
484 #[test]
485 fn dont_consume_more_data_than_bounded_len() {
486 let s = set_from_keys(&[1, 2, 3, 4, 5, 6]);
487 let data = s.encode();
488 let data_input = &mut &data[..];
489
490 BoundedBTreeSet::<u32, ConstU32<4>>::decode(data_input).unwrap_err();
491 assert_eq!(data_input.len(), data.len() - Compact::<u32>::compact_len(&(data.len() as u32)));
492 }
493
494 #[test]
495 fn unequal_eq_impl_insert_works() {
496 #[derive(Debug)]
498 struct Unequal(u32, bool);
499
500 impl PartialEq for Unequal {
501 fn eq(&self, other: &Self) -> bool {
502 self.0 == other.0
503 }
504 }
505 impl Eq for Unequal {}
506
507 impl Ord for Unequal {
508 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
509 self.0.cmp(&other.0)
510 }
511 }
512
513 impl PartialOrd for Unequal {
514 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
515 Some(self.cmp(other))
516 }
517 }
518
519 let mut set = BoundedBTreeSet::<Unequal, ConstU32<4>>::new();
520
521 for i in 0..4 {
524 set.try_insert(Unequal(i, false)).unwrap();
525 }
526
527 set.try_insert(Unequal(5, false)).unwrap_err();
529
530 set.try_insert(Unequal(0, true)).unwrap();
533 assert_eq!(set.len(), 4);
534 let zero_item = set.get(&Unequal(0, true)).unwrap();
535 assert_eq!(zero_item.0, 0);
536 assert_eq!(zero_item.1, false);
537 }
538
539 #[test]
540 fn eq_works() {
541 let b1 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
543 let b2 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
544 assert_eq!(b1, b2);
545
546 crate::parameter_types! {
548 B1: u32 = 7;
549 B2: u32 = 7;
550 }
551 let b1 = boundedset_from_keys::<u32, B1>(&[1, 2]);
552 let b2 = boundedset_from_keys::<u32, B2>(&[1, 2]);
553 assert_eq!(b1, b2);
554 }
555
556 #[test]
557 fn can_be_collected() {
558 let b1 = boundedset_from_keys::<u32, ConstU32<5>>(&[1, 2, 3, 4]);
559 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
560 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
561
562 let b2: BoundedBTreeSet<u32, ConstU32<4>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
564 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
565
566 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).rev().skip(2).try_collect().unwrap();
568 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
570
571 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).take(2).try_collect().unwrap();
572 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
573
574 let b2: Result<BoundedBTreeSet<u32, ConstU32<3>>, _> = b1.iter().map(|k| k + 1).try_collect();
576 assert!(b2.is_err());
577
578 let b2: Result<BoundedBTreeSet<u32, ConstU32<1>>, _> = b1.iter().map(|k| k + 1).skip(2).try_collect();
579 assert!(b2.is_err());
580 }
581
582 #[test]
585 #[cfg(feature = "std")]
586 fn container_can_derive_hash() {
587 #[derive(Hash, Default)]
588 struct Foo {
589 bar: u8,
590 set: BoundedBTreeSet<String, ConstU32<16>>,
591 }
592 let _foo = Foo::default();
593 }
594
595 #[test]
596 fn is_full_works() {
597 let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
598 assert!(!bounded.is_full());
599 bounded.try_insert(0).unwrap();
600 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
601
602 assert!(bounded.is_full());
603 assert!(bounded.try_insert(9).is_err());
604 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
605 }
606
607 #[cfg(feature = "serde")]
608 mod serde {
609 use super::*;
610 use crate::alloc::string::ToString as _;
611
612 #[test]
613 fn test_serializer() {
614 let mut c = BoundedBTreeSet::<u32, ConstU32<6>>::new();
615 c.try_insert(0).unwrap();
616 c.try_insert(1).unwrap();
617 c.try_insert(2).unwrap();
618
619 assert_eq!(serde_json::json!(&c).to_string(), r#"[0,1,2]"#);
620 }
621
622 #[test]
623 fn test_deserializer() {
624 let c: Result<BoundedBTreeSet<u32, ConstU32<6>>, serde_json::error::Error> =
625 serde_json::from_str(r#"[0,1,2]"#);
626 assert!(c.is_ok());
627 let c = c.unwrap();
628
629 assert_eq!(c.len(), 3);
630 assert!(c.contains(&0));
631 assert!(c.contains(&1));
632 assert!(c.contains(&2));
633 }
634
635 #[test]
636 fn test_deserializer_bound() {
637 let c: Result<BoundedBTreeSet<u32, ConstU32<3>>, serde_json::error::Error> =
638 serde_json::from_str(r#"[0,1,2]"#);
639 assert!(c.is_ok());
640 let c = c.unwrap();
641
642 assert_eq!(c.len(), 3);
643 assert!(c.contains(&0));
644 assert!(c.contains(&1));
645 assert!(c.contains(&2));
646 }
647
648 #[test]
649 fn test_deserializer_failed() {
650 let c: Result<BoundedBTreeSet<u32, ConstU32<4>>, serde_json::error::Error> =
651 serde_json::from_str(r#"[0,1,2,3,4]"#);
652
653 match c {
654 Err(msg) => assert_eq!(msg.to_string(), "out of bounds at line 1 column 11"),
655 _ => unreachable!("deserializer must raise error"),
656 }
657 }
658 }
659}