bounded_collections/
bounded_btree_set.rs

1// This file is part of Substrate.
2
3// Copyright (C) 2023 Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Traits, types and structs to support a bounded `BTreeSet`.
19
20use crate::{Get, TryCollect};
21use alloc::collections::BTreeSet;
22use codec::{Compact, Decode, Encode, MaxEncodedLen};
23use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
24#[cfg(feature = "serde")]
25use serde::{
26	de::{Error, SeqAccess, Visitor},
27	Deserialize, Deserializer, Serialize,
28};
29
30/// A bounded set based on a B-Tree.
31///
32/// B-Trees represent a fundamental compromise between cache-efficiency and actually minimizing
33/// the amount of work performed in a search. See [`BTreeSet`] for more details.
34///
35/// Unlike a standard `BTreeSet`, there is an enforced upper limit to the number of items in the
36/// set. All internal operations ensure this bound is respected.
37#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
38#[derive(Encode, scale_info::TypeInfo)]
39#[scale_info(skip_type_params(S))]
40pub struct BoundedBTreeSet<T, S>(BTreeSet<T>, #[cfg_attr(feature = "serde", serde(skip_serializing))] PhantomData<S>);
41
42#[cfg(feature = "serde")]
43impl<'de, T, S: Get<u32>> Deserialize<'de> for BoundedBTreeSet<T, S>
44where
45	T: Ord + Deserialize<'de>,
46	S: Clone,
47{
48	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49	where
50		D: Deserializer<'de>,
51	{
52		// Create a visitor to visit each element in the sequence
53		struct BTreeSetVisitor<T, S>(PhantomData<(T, S)>);
54
55		impl<'de, T, S> Visitor<'de> for BTreeSetVisitor<T, S>
56		where
57			T: Ord + Deserialize<'de>,
58			S: Get<u32> + Clone,
59		{
60			type Value = BTreeSet<T>;
61
62			fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
63				formatter.write_str("a sequence")
64			}
65
66			fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
67			where
68				A: SeqAccess<'de>,
69			{
70				let size = seq.size_hint().unwrap_or(0);
71				let max = match usize::try_from(S::get()) {
72					Ok(n) => n,
73					Err(_) => return Err(A::Error::custom("can't convert to usize")),
74				};
75				if size > max {
76					Err(A::Error::custom("out of bounds"))
77				} else {
78					let mut values = BTreeSet::new();
79
80					while let Some(value) = seq.next_element()? {
81						if values.len() >= max {
82							return Err(A::Error::custom("out of bounds"))
83						}
84						values.insert(value);
85					}
86
87					Ok(values)
88				}
89			}
90		}
91
92		let visitor: BTreeSetVisitor<T, S> = BTreeSetVisitor(PhantomData);
93		deserializer
94			.deserialize_seq(visitor)
95			.map(|v| BoundedBTreeSet::<T, S>::try_from(v).map_err(|_| Error::custom("out of bounds")))?
96	}
97}
98
99impl<T, S> Decode for BoundedBTreeSet<T, S>
100where
101	T: Decode + Ord,
102	S: Get<u32>,
103{
104	fn decode<I: codec::Input>(input: &mut I) -> Result<Self, codec::Error> {
105		// Same as the underlying implementation for `Decode` on `BTreeSet`, except we fail early if
106		// the len is too big.
107		let len: u32 = <Compact<u32>>::decode(input)?.into();
108		if len > S::get() {
109			return Err("BoundedBTreeSet exceeds its limit".into())
110		}
111		input.descend_ref()?;
112		let inner = Result::from_iter((0..len).map(|_| Decode::decode(input)))?;
113		input.ascend_ref();
114		Ok(Self(inner, PhantomData))
115	}
116
117	fn skip<I: codec::Input>(input: &mut I) -> Result<(), codec::Error> {
118		BTreeSet::<T>::skip(input)
119	}
120}
121
122impl<T, S> BoundedBTreeSet<T, S>
123where
124	S: Get<u32>,
125{
126	/// Get the bound of the type in `usize`.
127	pub fn bound() -> usize {
128		S::get() as usize
129	}
130}
131
132impl<T, S> BoundedBTreeSet<T, S>
133where
134	T: Ord,
135	S: Get<u32>,
136{
137	/// Create `Self` from `t` without any checks.
138	fn unchecked_from(t: BTreeSet<T>) -> Self {
139		Self(t, Default::default())
140	}
141
142	/// Create a new `BoundedBTreeSet`.
143	///
144	/// Does not allocate.
145	pub fn new() -> Self {
146		BoundedBTreeSet(BTreeSet::new(), PhantomData)
147	}
148
149	/// Consume self, and return the inner `BTreeSet`.
150	///
151	/// This is useful when a mutating API of the inner type is desired, and closure-based mutation
152	/// such as provided by [`try_mutate`][Self::try_mutate] is inconvenient.
153	pub fn into_inner(self) -> BTreeSet<T> {
154		debug_assert!(self.0.len() <= Self::bound());
155		self.0
156	}
157
158	/// Consumes self and mutates self via the given `mutate` function.
159	///
160	/// If the outcome of mutation is within bounds, `Some(Self)` is returned. Else, `None` is
161	/// returned.
162	///
163	/// This is essentially a *consuming* shorthand [`Self::into_inner`] -> `...` ->
164	/// [`Self::try_from`].
165	pub fn try_mutate(mut self, mut mutate: impl FnMut(&mut BTreeSet<T>)) -> Option<Self> {
166		mutate(&mut self.0);
167		(self.0.len() <= Self::bound()).then(move || self)
168	}
169
170	/// Clears the set, removing all elements.
171	pub fn clear(&mut self) {
172		self.0.clear()
173	}
174
175	/// Exactly the same semantics as [`BTreeSet::insert`], but returns an `Err` (and is a noop) if
176	/// the new length of the set exceeds `S`.
177	///
178	/// In the `Err` case, returns the inserted item so it can be further used without cloning.
179	pub fn try_insert(&mut self, item: T) -> Result<bool, T> {
180		if self.len() < Self::bound() || self.0.contains(&item) {
181			Ok(self.0.insert(item))
182		} else {
183			Err(item)
184		}
185	}
186
187	/// Remove an item from the set, returning whether it was previously in the set.
188	///
189	/// The item may be any borrowed form of the set's item type, but the ordering on the borrowed
190	/// form _must_ match the ordering on the item type.
191	pub fn remove<Q>(&mut self, item: &Q) -> bool
192	where
193		T: Borrow<Q>,
194		Q: Ord + ?Sized,
195	{
196		self.0.remove(item)
197	}
198
199	/// Removes and returns the value in the set, if any, that is equal to the given one.
200	///
201	/// The value may be any borrowed form of the set's value type, but the ordering on the borrowed
202	/// form _must_ match the ordering on the value type.
203	pub fn take<Q>(&mut self, value: &Q) -> Option<T>
204	where
205		T: Borrow<Q> + Ord,
206		Q: Ord + ?Sized,
207	{
208		self.0.take(value)
209	}
210
211	/// Returns true if this set is full.
212	pub fn is_full(&self) -> bool {
213		self.len() >= Self::bound()
214	}
215}
216
217impl<T, S> Default for BoundedBTreeSet<T, S>
218where
219	T: Ord,
220	S: Get<u32>,
221{
222	fn default() -> Self {
223		Self::new()
224	}
225}
226
227impl<T, S> Clone for BoundedBTreeSet<T, S>
228where
229	BTreeSet<T>: Clone,
230{
231	fn clone(&self) -> Self {
232		BoundedBTreeSet(self.0.clone(), PhantomData)
233	}
234}
235
236impl<T, S> core::fmt::Debug for BoundedBTreeSet<T, S>
237where
238	BTreeSet<T>: core::fmt::Debug,
239	S: Get<u32>,
240{
241	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
242		f.debug_tuple("BoundedBTreeSet").field(&self.0).field(&Self::bound()).finish()
243	}
244}
245
246// Custom implementation of `Hash` since deriving it would require all generic bounds to also
247// implement it.
248#[cfg(feature = "std")]
249impl<T: std::hash::Hash, S> std::hash::Hash for BoundedBTreeSet<T, S> {
250	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
251		self.0.hash(state);
252	}
253}
254
255impl<T, S1, S2> PartialEq<BoundedBTreeSet<T, S1>> for BoundedBTreeSet<T, S2>
256where
257	BTreeSet<T>: PartialEq,
258	S1: Get<u32>,
259	S2: Get<u32>,
260{
261	fn eq(&self, other: &BoundedBTreeSet<T, S1>) -> bool {
262		S1::get() == S2::get() && self.0 == other.0
263	}
264}
265
266impl<T, S> Eq for BoundedBTreeSet<T, S>
267where
268	BTreeSet<T>: Eq,
269	S: Get<u32>,
270{
271}
272
273impl<T, S> PartialEq<BTreeSet<T>> for BoundedBTreeSet<T, S>
274where
275	BTreeSet<T>: PartialEq,
276	S: Get<u32>,
277{
278	fn eq(&self, other: &BTreeSet<T>) -> bool {
279		self.0 == *other
280	}
281}
282
283impl<T, S> PartialOrd for BoundedBTreeSet<T, S>
284where
285	BTreeSet<T>: PartialOrd,
286	S: Get<u32>,
287{
288	fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
289		self.0.partial_cmp(&other.0)
290	}
291}
292
293impl<T, S> Ord for BoundedBTreeSet<T, S>
294where
295	BTreeSet<T>: Ord,
296	S: Get<u32>,
297{
298	fn cmp(&self, other: &Self) -> core::cmp::Ordering {
299		self.0.cmp(&other.0)
300	}
301}
302
303impl<T, S> IntoIterator for BoundedBTreeSet<T, S> {
304	type Item = T;
305	type IntoIter = alloc::collections::btree_set::IntoIter<T>;
306
307	fn into_iter(self) -> Self::IntoIter {
308		self.0.into_iter()
309	}
310}
311
312impl<'a, T, S> IntoIterator for &'a BoundedBTreeSet<T, S> {
313	type Item = &'a T;
314	type IntoIter = alloc::collections::btree_set::Iter<'a, T>;
315
316	fn into_iter(self) -> Self::IntoIter {
317		self.0.iter()
318	}
319}
320
321impl<T, S> MaxEncodedLen for BoundedBTreeSet<T, S>
322where
323	T: MaxEncodedLen,
324	S: Get<u32>,
325{
326	fn max_encoded_len() -> usize {
327		Self::bound()
328			.saturating_mul(T::max_encoded_len())
329			.saturating_add(codec::Compact(S::get()).encoded_size())
330	}
331}
332
333impl<T, S> Deref for BoundedBTreeSet<T, S>
334where
335	T: Ord,
336{
337	type Target = BTreeSet<T>;
338
339	fn deref(&self) -> &Self::Target {
340		&self.0
341	}
342}
343
344impl<T, S> AsRef<BTreeSet<T>> for BoundedBTreeSet<T, S>
345where
346	T: Ord,
347{
348	fn as_ref(&self) -> &BTreeSet<T> {
349		&self.0
350	}
351}
352
353impl<T, S> From<BoundedBTreeSet<T, S>> for BTreeSet<T>
354where
355	T: Ord,
356{
357	fn from(set: BoundedBTreeSet<T, S>) -> Self {
358		set.0
359	}
360}
361
362impl<T, S> TryFrom<BTreeSet<T>> for BoundedBTreeSet<T, S>
363where
364	T: Ord,
365	S: Get<u32>,
366{
367	type Error = ();
368
369	fn try_from(value: BTreeSet<T>) -> Result<Self, Self::Error> {
370		(value.len() <= Self::bound())
371			.then(move || BoundedBTreeSet(value, PhantomData))
372			.ok_or(())
373	}
374}
375
376impl<T, S> codec::DecodeLength for BoundedBTreeSet<T, S> {
377	fn len(self_encoded: &[u8]) -> Result<usize, codec::Error> {
378		// `BoundedBTreeSet<T, S>` is stored just a `BTreeSet<T>`, which is stored as a
379		// `Compact<u32>` with its length followed by an iteration of its items. We can just use
380		// the underlying implementation.
381		<BTreeSet<T> as codec::DecodeLength>::len(self_encoded)
382	}
383}
384
385impl<T, S> codec::EncodeLike<BTreeSet<T>> for BoundedBTreeSet<T, S> where BTreeSet<T>: Encode {}
386
387impl<I, T, Bound> TryCollect<BoundedBTreeSet<T, Bound>> for I
388where
389	T: Ord,
390	I: ExactSizeIterator + Iterator<Item = T>,
391	Bound: Get<u32>,
392{
393	type Error = &'static str;
394
395	fn try_collect(self) -> Result<BoundedBTreeSet<T, Bound>, Self::Error> {
396		if self.len() > Bound::get() as usize {
397			Err("iterator length too big")
398		} else {
399			Ok(BoundedBTreeSet::<T, Bound>::unchecked_from(self.collect::<BTreeSet<T>>()))
400		}
401	}
402}
403
404#[cfg(test)]
405mod test {
406	use super::*;
407	use crate::ConstU32;
408	use alloc::{vec, vec::Vec};
409	use codec::CompactLen;
410
411	fn set_from_keys<T>(keys: &[T]) -> BTreeSet<T>
412	where
413		T: Ord + Copy,
414	{
415		keys.iter().copied().collect()
416	}
417
418	fn boundedset_from_keys<T, S>(keys: &[T]) -> BoundedBTreeSet<T, S>
419	where
420		T: Ord + Copy,
421		S: Get<u32>,
422	{
423		set_from_keys(keys).try_into().unwrap()
424	}
425
426	#[test]
427	fn encoding_same_as_unbounded_set() {
428		let b = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
429		let m = set_from_keys(&[1, 2, 3, 4, 5, 6]);
430
431		assert_eq!(b.encode(), m.encode());
432	}
433
434	#[test]
435	fn try_insert_works() {
436		let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
437		bounded.try_insert(0).unwrap();
438		assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
439
440		assert!(bounded.try_insert(9).is_err());
441		assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
442	}
443
444	#[test]
445	fn deref_coercion_works() {
446		let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3]);
447		// these methods come from deref-ed vec.
448		assert_eq!(bounded.len(), 3);
449		assert!(bounded.iter().next().is_some());
450		assert!(!bounded.is_empty());
451	}
452
453	#[test]
454	fn try_mutate_works() {
455		let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
456		let bounded = bounded
457			.try_mutate(|v| {
458				v.insert(7);
459			})
460			.unwrap();
461		assert_eq!(bounded.len(), 7);
462		assert!(bounded
463			.try_mutate(|v| {
464				v.insert(8);
465			})
466			.is_none());
467	}
468
469	#[test]
470	fn btree_map_eq_works() {
471		let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
472		assert_eq!(bounded, set_from_keys(&[1, 2, 3, 4, 5, 6]));
473	}
474
475	#[test]
476	fn too_big_fail_to_decode() {
477		let v: Vec<u32> = vec![1, 2, 3, 4, 5];
478		assert_eq!(
479			BoundedBTreeSet::<u32, ConstU32<4>>::decode(&mut &v.encode()[..]),
480			Err("BoundedBTreeSet exceeds its limit".into()),
481		);
482	}
483
484	#[test]
485	fn dont_consume_more_data_than_bounded_len() {
486		let s = set_from_keys(&[1, 2, 3, 4, 5, 6]);
487		let data = s.encode();
488		let data_input = &mut &data[..];
489
490		BoundedBTreeSet::<u32, ConstU32<4>>::decode(data_input).unwrap_err();
491		assert_eq!(data_input.len(), data.len() - Compact::<u32>::compact_len(&(data.len() as u32)));
492	}
493
494	#[test]
495	fn unequal_eq_impl_insert_works() {
496		// given a struct with a strange notion of equality
497		#[derive(Debug)]
498		struct Unequal(u32, bool);
499
500		impl PartialEq for Unequal {
501			fn eq(&self, other: &Self) -> bool {
502				self.0 == other.0
503			}
504		}
505		impl Eq for Unequal {}
506
507		impl Ord for Unequal {
508			fn cmp(&self, other: &Self) -> core::cmp::Ordering {
509				self.0.cmp(&other.0)
510			}
511		}
512
513		impl PartialOrd for Unequal {
514			fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
515				Some(self.cmp(other))
516			}
517		}
518
519		let mut set = BoundedBTreeSet::<Unequal, ConstU32<4>>::new();
520
521		// when the set is full
522
523		for i in 0..4 {
524			set.try_insert(Unequal(i, false)).unwrap();
525		}
526
527		// can't insert a new distinct member
528		set.try_insert(Unequal(5, false)).unwrap_err();
529
530		// but _can_ insert a distinct member which compares equal, though per the documentation,
531		// neither the set length nor the actual member are changed
532		set.try_insert(Unequal(0, true)).unwrap();
533		assert_eq!(set.len(), 4);
534		let zero_item = set.get(&Unequal(0, true)).unwrap();
535		assert_eq!(zero_item.0, 0);
536		assert_eq!(zero_item.1, false);
537	}
538
539	#[test]
540	fn eq_works() {
541		// of same type
542		let b1 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
543		let b2 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
544		assert_eq!(b1, b2);
545
546		// of different type, but same value and bound.
547		crate::parameter_types! {
548			B1: u32 = 7;
549			B2: u32 = 7;
550		}
551		let b1 = boundedset_from_keys::<u32, B1>(&[1, 2]);
552		let b2 = boundedset_from_keys::<u32, B2>(&[1, 2]);
553		assert_eq!(b1, b2);
554	}
555
556	#[test]
557	fn can_be_collected() {
558		let b1 = boundedset_from_keys::<u32, ConstU32<5>>(&[1, 2, 3, 4]);
559		let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
560		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
561
562		// can also be collected into a collection of length 4.
563		let b2: BoundedBTreeSet<u32, ConstU32<4>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
564		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
565
566		// can be mutated further into iterators that are `ExactSizedIterator`.
567		let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).rev().skip(2).try_collect().unwrap();
568		// note that the binary tree will re-sort this, so rev() is not really seen
569		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
570
571		let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).take(2).try_collect().unwrap();
572		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
573
574		// but these worn't work
575		let b2: Result<BoundedBTreeSet<u32, ConstU32<3>>, _> = b1.iter().map(|k| k + 1).try_collect();
576		assert!(b2.is_err());
577
578		let b2: Result<BoundedBTreeSet<u32, ConstU32<1>>, _> = b1.iter().map(|k| k + 1).skip(2).try_collect();
579		assert!(b2.is_err());
580	}
581
582	// Just a test that structs containing `BoundedBTreeSet` can derive `Hash`. (This was broken
583	// when it was deriving `Hash`).
584	#[test]
585	#[cfg(feature = "std")]
586	fn container_can_derive_hash() {
587		#[derive(Hash, Default)]
588		struct Foo {
589			bar: u8,
590			set: BoundedBTreeSet<String, ConstU32<16>>,
591		}
592		let _foo = Foo::default();
593	}
594
595	#[test]
596	fn is_full_works() {
597		let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
598		assert!(!bounded.is_full());
599		bounded.try_insert(0).unwrap();
600		assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
601
602		assert!(bounded.is_full());
603		assert!(bounded.try_insert(9).is_err());
604		assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
605	}
606
607	#[cfg(feature = "serde")]
608	mod serde {
609		use super::*;
610		use crate::alloc::string::ToString as _;
611
612		#[test]
613		fn test_serializer() {
614			let mut c = BoundedBTreeSet::<u32, ConstU32<6>>::new();
615			c.try_insert(0).unwrap();
616			c.try_insert(1).unwrap();
617			c.try_insert(2).unwrap();
618
619			assert_eq!(serde_json::json!(&c).to_string(), r#"[0,1,2]"#);
620		}
621
622		#[test]
623		fn test_deserializer() {
624			let c: Result<BoundedBTreeSet<u32, ConstU32<6>>, serde_json::error::Error> =
625				serde_json::from_str(r#"[0,1,2]"#);
626			assert!(c.is_ok());
627			let c = c.unwrap();
628
629			assert_eq!(c.len(), 3);
630			assert!(c.contains(&0));
631			assert!(c.contains(&1));
632			assert!(c.contains(&2));
633		}
634
635		#[test]
636		fn test_deserializer_bound() {
637			let c: Result<BoundedBTreeSet<u32, ConstU32<3>>, serde_json::error::Error> =
638				serde_json::from_str(r#"[0,1,2]"#);
639			assert!(c.is_ok());
640			let c = c.unwrap();
641
642			assert_eq!(c.len(), 3);
643			assert!(c.contains(&0));
644			assert!(c.contains(&1));
645			assert!(c.contains(&2));
646		}
647
648		#[test]
649		fn test_deserializer_failed() {
650			let c: Result<BoundedBTreeSet<u32, ConstU32<4>>, serde_json::error::Error> =
651				serde_json::from_str(r#"[0,1,2,3,4]"#);
652
653			match c {
654				Err(msg) => assert_eq!(msg.to_string(), "out of bounds at line 1 column 11"),
655				_ => unreachable!("deserializer must raise error"),
656			}
657		}
658	}
659}