Skip to main content

mask_tracked_array/
lib.rs

1#![no_std]
2#![warn(missing_docs)]
3//! A no std and no alloc abstraction for some data structures on
4//! microcontrollers. A [`MaskTrackedArray`] is a
5//! [`UnsafeCell<MaybeUninit<T>>`] with a number mask for tracking which slots
6//! are filled and which aren't. The arrays are allocated at compile time using
7//! generics and as such come in different sizes based on the number types.
8//!
9//! You can think of a [`MaskTrackedArray`] as an array of numbered slots which
10//! can be accessed independently (unsafe code may be required).
11//!
12//! The current implementations supplied by this crate are
13//! - [`MaskTrackedArrayU8`]
14//! - [`MaskTrackedArrayU16`]
15//! - [`MaskTrackedArrayU32`]
16//! - [`MaskTrackedArrayU64`]
17//! - [`MaskTrackedArrayU128`]
18//!
19//! See the documentation on [`MaskTrackedArray`] to see what methods are
20//! available.
21use core::{cell::UnsafeCell, mem::MaybeUninit};
22
23use bit_iter::BitIter;
24
25mod mask_trait;
26#[cfg(feature = "serde")]
27#[doc(hidden)]
28pub mod serde_impl;
29pub use mask_trait::Mask;
30
31/// Implemented by every variant of the mask tracked array. The
32/// [`MaskTrackedArray::MaskType`] is the number type used for the mask.
33pub trait MaskTrackedArray<T>: Default + FromIterator<T> + FromIterator<(usize, T)> {
34    /// The number type used as the mask.
35    type MaskType: Mask;
36    /// Check if there is an item at a slot. This function will also return
37    /// false if the index is out of range.
38    fn contains_item_at(&self, index: usize) -> bool;
39    /// Check how many slots are occupied.
40    fn len(&self) -> u32;
41    /// Returns true if this array is completely empty.
42    fn is_empty(&self) -> bool {
43        self.len() == 0
44    }
45    /// Construct a new empty instance of this array.
46    #[must_use]
47    fn new() -> Self {
48        Self::default()
49    }
50    /// Clear out all items in this array, returning to its empty state. Drop
51    /// is called if needed.
52    fn clear(&mut self);
53    /// Get a reference to an item inside a slot if available.
54    fn get_ref(&self, index: usize) -> Option<&T> {
55        if self.contains_item_at(index) {
56            Some(unsafe { self.get_unchecked_ref(index) })
57        } else {
58            None
59        }
60    }
61    /// Get a mutable reference to an item inside a slot if available.
62    fn get_mut(&mut self, index: usize) -> Option<&mut T> {
63        if self.contains_item_at(index) {
64            Some(unsafe { self.get_unchecked_mut(index) })
65        } else {
66            None
67        }
68    }
69    /// Get an immutable reference to slot without bounds or validity checking.
70    /// # Safety
71    /// The given index must be valid.
72    unsafe fn get_unchecked_ref(&self, index: usize) -> &T;
73    /// Get a mutable reference to a slot without bounds, validity or borrow
74    /// checking.
75    /// # Safety
76    /// The given index must be valid and there must be no other references
77    /// to the same slot.
78    #[allow(clippy::mut_from_ref)]
79    unsafe fn get_unchecked_mut(&self, index: usize) -> &mut T;
80    /// Insert an item at a given index without bounds, validity or borrow
81    /// checking.
82    /// # Safety
83    /// The given index must be valid to avoid undefined behaviour. If a value
84    /// was already present, that value will be forgotten without running its
85    /// drop implementation.
86    unsafe fn insert_unchecked(&self, index: usize, value: T);
87    /// Try to insert an item at a given index. If insertion fails, return
88    /// the value in an option. If the option is none then the insertion
89    /// succeeded.
90    #[must_use]
91    fn insert(&self, index: usize, value: T) -> Option<T> {
92        if self.contains_item_at(index) || index >= Self::MaskType::MAX_SELECTIONS as usize {
93            Some(value)
94        } else {
95            unsafe { self.insert_unchecked(index, value) };
96            None
97        }
98    }
99    /// Remove a value from a slot without checking the index or if the item is
100    /// there.
101    /// # Safety
102    /// Calling this function on currently referenced or nonexistent slots will
103    /// result in undefined behaviour.
104    unsafe fn remove_unchecked(&self, index: usize) -> T;
105    /// Remove a value at a specific index and return it if available.
106    fn remove(&mut self, index: usize) -> Option<T> {
107        if self.contains_item_at(index) {
108            Some(unsafe { self.remove_unchecked(index) })
109        } else {
110            None
111        }
112    }
113    /// Get an iterator over all filled slot indices.
114    fn iter_filled_indices(&self) -> impl Iterator<Item = usize>;
115    /// Get an iterator over all filled slot indices also present in the given
116    /// mask.
117    fn iter_filled_indices_mask(&self, mask: Self::MaskType) -> impl Iterator<Item = usize>;
118    /// Get an iterator over all unfilled slot indices.
119    fn iter_empty_indices(&self) -> impl Iterator<Item = usize>;
120    /// Iterate over references to every filled slot.
121    fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T>
122    where
123        T: 'a,
124    {
125        self.iter_filled_indices()
126            .map(|index| unsafe { self.get_unchecked_ref(index) })
127    }
128    /// Iterate over mutable references to every filled slot.
129    fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T>
130    where
131        T: 'a,
132    {
133        self.iter_filled_indices()
134            .map(|index| unsafe { self.get_unchecked_mut(index) })
135    }
136    /// Iterate over references which are only present in the given mask.
137    fn iter_mask<'a>(&'a self, mask: Self::MaskType) -> impl Iterator<Item = &'a T>
138    where
139        T: 'a,
140    {
141        self.iter_filled_indices_mask(mask)
142            .map(|index| unsafe { self.get_unchecked_ref(index) })
143    }
144    /// Iterate over mutable references which are only present in the given
145    /// mask.
146    fn iter_mut_mask<'a>(&'a mut self, mask: Self::MaskType) -> impl Iterator<Item = &'a mut T>
147    where
148        T: 'a,
149    {
150        self.iter_filled_indices_mask(mask)
151            .map(|index| unsafe { self.get_unchecked_mut(index) })
152    }
153    /// Get the internal mask used.
154    fn mask(&self) -> Self::MaskType;
155    /// Try to push a value into the lowest indexed position possible. Returns
156    /// the value if failed.
157    fn push(&mut self, value: T) -> Result<usize, T> {
158        if let Some(smallest) = self.iter_empty_indices().next() {
159            let None = self.insert(smallest, value) else {
160                unreachable!()
161            };
162            Ok(smallest)
163        } else {
164            Err(value)
165        }
166    }
167}
168
169/// An array with slots for values backed by a mask for tracking which slots
170/// are under use. Each slot has their own [`UnsafeCell`].
171pub struct MaskTrackedArrayBase<T, M, const N: usize>
172where
173    Self: MaskTrackedArray<T, MaskType = M>,
174{
175    /// Mask used for tracking which slots in the array are filled.
176    mask: core::cell::Cell<M>,
177    /// An array of fillable slots.
178    array: [UnsafeCell<MaybeUninit<T>>; N],
179}
180
181/// An iterator for [`MaskTrackedArray`] variants.
182pub struct MaskTrackedArrayIter<T, M, const N: usize>
183where
184    MaskTrackedArrayBase<T, M, N>: MaskTrackedArray<T, MaskType = M>,
185{
186    bit_iterator: BitIter<M>,
187    source: MaskTrackedArrayBase<T, M, N>,
188}
189
190impl<T, M, const N: usize> Drop for MaskTrackedArrayBase<T, M, N>
191where
192    Self: MaskTrackedArray<T, MaskType = M>,
193{
194    fn drop(&mut self) {
195        self.clear();
196    }
197}
198macro_rules! mask_tracked_array_impl {
199    () => {};
200    (($num_ty:ty, $bits:expr, $alias_ident:ident), $($rest:tt)*) => {
201        mask_tracked_array_impl!(($num_ty, $bits, $alias_ident));
202        mask_tracked_array_impl!($($rest)*);
203    };
204    (($num_ty:ty, $bits:expr, $alias_ident:ident)) => {
205        #[doc = stringify!(A $num_ty tracked array which can hold $bits items) ]
206        pub type $alias_ident<T> = MaskTrackedArrayBase<T, $num_ty, $bits>;
207        impl<T> MaskTrackedArray<T> for MaskTrackedArrayBase<T, $num_ty, $bits> {
208            type MaskType = $num_ty;            fn contains_item_at(&self, index: usize) -> bool {
209                if index >= <$num_ty>::BITS as usize {
210                    return false;
211                }
212                self.mask.get() & (1 << index) != 0
213            }
214            fn mask(&self) -> Self::MaskType {
215                self.mask.get()
216            }
217            fn len(&self) -> u32 {
218                self.mask.get().count_ones()
219            }
220            unsafe fn get_unchecked_ref(&self, index: usize) -> &T {
221                unsafe { (&*self.array.get_unchecked(index).get()).assume_init_ref() }
222            }
223            unsafe fn get_unchecked_mut(&self, index: usize) -> &mut T {
224                unsafe { (&mut *self.array.get_unchecked(index).get()).assume_init_mut() }
225            }
226            fn clear(&mut self) {
227                if core::mem::needs_drop::<T>() {
228                    for index in bit_iter::BitIter::from(self.mask.get()) {
229                        unsafe {
230                            self.array
231                                .get_unchecked_mut(index)
232                                .get_mut()
233                                .assume_init_drop()
234                        };
235                    }
236                }
237                self.mask.set(0);
238            }
239            unsafe fn insert_unchecked(&self, index: usize, value: T) {
240                unsafe { (&mut *self.array.get_unchecked(index).get()).write(value)};
241                self.mask.update(|v| v | (1 << index));
242            }
243            unsafe fn remove_unchecked(&self, index: usize) -> T {
244                self.mask.update(|v| v & !(1 << index));
245                let mut empty = core::mem::MaybeUninit::uninit();
246                unsafe {
247                    let mut_ref = (&mut *self.array.get_unchecked(index).get());
248                    core::mem::swap(&mut empty, mut_ref);
249                    empty.assume_init()
250                }
251            }
252            #[inline]
253            fn iter_filled_indices(&self) -> impl Iterator<Item = usize> {
254                BitIter::from(self.mask.get())
255            }
256            #[inline]
257            fn iter_filled_indices_mask(&self, mask: Self::MaskType) -> impl Iterator<Item = usize> {
258                BitIter::from(self.mask.get() & mask)
259            }
260            #[inline]
261            fn iter_empty_indices(&self) -> impl Iterator<Item = usize> {
262                BitIter::from(!self.mask.get())
263            }
264        }
265        impl<T> Default for MaskTrackedArrayBase<T, $num_ty, $bits> {
266            fn default() -> Self {
267                Self {
268                    mask: core::cell::Cell::new(0),
269                    array: [const {core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit())}; $bits]
270                }
271            }
272        }
273        impl<T> core::iter::Iterator for MaskTrackedArrayIter<T, $num_ty, $bits> {
274            type Item = T;
275            fn next(&mut self) -> Option<Self::Item> {
276                let next_index = self.bit_iterator.next()?;
277                Some( unsafe { self.source.remove_unchecked(next_index) } )
278            }
279        }
280        impl<T> core::iter::IntoIterator for MaskTrackedArrayBase<T, $num_ty, $bits>
281        {
282            type Item = T;
283            type IntoIter = MaskTrackedArrayIter<T, $num_ty, $bits>;
284            fn into_iter(self) -> Self::IntoIter {
285                let bit_iterator = BitIter::from(self.mask.get());
286                MaskTrackedArrayIter {
287                    source: self,
288                    bit_iterator
289                }
290            }
291        }
292        impl<T> core::iter::FromIterator<T> for MaskTrackedArrayBase<T, $num_ty, $bits> {
293            fn from_iter<I>(iter: I) -> Self
294                where I: IntoIterator<Item = T>
295            {
296                let empty = Self::new();
297                for (index, value) in iter.into_iter().enumerate() {
298                    if index >= $bits {
299                        break;
300                    }
301                    unsafe { empty.insert_unchecked(index, value) };
302                }
303                empty
304            }
305        }
306        impl<T> core::iter::FromIterator<(usize, T)> for MaskTrackedArrayBase<T, $num_ty, $bits> {
307            fn from_iter<I>(iter: I) -> Self
308                where I: IntoIterator<Item = (usize, T)>
309            {
310                let empty = Self::new();
311                for (index, value) in iter.into_iter() {
312                    let _ = empty.insert(index, value);
313                }
314                empty
315            }
316        }
317        impl<T: PartialEq> PartialEq for MaskTrackedArrayBase<T, $num_ty, $bits> {
318            fn eq(&self, other: &Self) -> bool {
319                if self.mask != other.mask {
320                    return false;
321                }
322                self.iter().zip(other.iter()).all(|(left, right)| left.eq(right))
323            }
324        }
325        impl<T: Eq> Eq for MaskTrackedArrayBase<T, $num_ty, $bits> {}
326        impl<T: core::hash::Hash> core::hash::Hash for MaskTrackedArrayBase<T, $num_ty, $bits> {
327            fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
328                self.mask.get().hash(state);
329                self.iter().for_each(|v| v.hash(state));
330            }
331        }
332        impl<T: core::fmt::Debug> core::fmt::Debug for MaskTrackedArrayBase<T, $num_ty, $bits> {
333            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
334                f.debug_list().entries(self.iter()).finish()?;
335                Ok(())
336            }
337        }
338        paste::paste! {
339            #[cfg(test)]
340            mod [<$num_ty _tests>] {
341                use super::*;
342                extern crate std;
343                #[test]
344                fn from_iterator_and_back() {
345                    let mask = $alias_ident::from_iter(0..$bits);
346                    for (index, number) in mask.into_iter().enumerate() {
347                        assert_eq!(index, number);
348                    }
349                }
350                #[test]
351                fn from_too_big_iterator() {
352                    let numbers = [0; $bits + 1];
353                    let mask = $alias_ident::from_iter(numbers);
354                    assert_eq!(mask.len(), $bits);
355                }
356                #[test]
357                fn from_empty_iterator() {
358                    let numbers: [u8; 0] = [];
359                    let mask = $alias_ident::from_iter(numbers);
360                    assert_eq!(mask.len(), 0);
361                }
362                #[test]
363                fn hash_equality() {
364                    let mask = $alias_ident::new();
365                    assert!(mask.insert(0, 0).is_none());
366                    assert!(mask.insert(1, 1).is_none());
367                    let mask_2 = $alias_ident::new();
368                    assert!(mask_2.insert(1, 1).is_none());
369                    assert!(mask_2.insert(0, 0).is_none());
370                    assert_eq!(mask, mask_2);
371                    use std::hash::{ Hash, DefaultHasher, Hasher };
372                    let mut hasher = DefaultHasher::new();
373                    mask.hash(&mut hasher);
374                    let first_hash = hasher.finish();
375                    let mut hasher = DefaultHasher::new();
376                    mask_2.hash(&mut hasher);
377                    let second_hash = hasher.finish();
378                    assert_eq!(first_hash, second_hash);
379                }
380                #[test]
381                fn equality() {
382                    let first = $alias_ident::from_iter([1, 2]);
383                    let second = $alias_ident::from_iter([1]);
384                    assert_ne!(first, second);
385                }
386                #[test]
387                fn removal() {
388                    let mut array = $alias_ident::from_iter([1, 2, 3]);
389                    assert_eq!(Some(1), array.remove(0));
390                    assert_eq!(Some(2), array.remove(1));
391                    assert_eq!(Some(3), array.remove(2));
392                    assert_eq!(None, array.remove(0))
393                }
394                #[test]
395                fn failing_getters() {
396                    let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
397                    assert_eq!(None, array.get_ref(5));
398                    assert_eq!(None, array.get_ref(1000));
399                    assert_eq!(None, array.get_mut(5));
400                    assert_eq!(None, array.get_mut(1000));
401                }
402                #[test]
403                fn succeeding_getters() {
404                    let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
405                    assert_eq!(Some(&1), array.get_ref(0));
406                    assert_eq!(Some(&mut 2), array.get_mut(1));
407                }
408                #[test]
409                fn clearing() {
410                    let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
411                    array.clear();
412                    assert_eq!(array, $alias_ident::new());
413                    assert_eq!(array.len(), 0);
414                }
415                #[test]
416                fn clearing_with_drop() {
417                    use std::rc::Rc;
418                    let rc1 = Rc::new(1);
419                    let rc2 = Rc::new(2);
420                    let mut array = $alias_ident::from_iter([rc1.clone(), rc2.clone()]);
421                    assert_eq!(Rc::strong_count(&rc1), 2);
422                    assert_eq!(Rc::strong_count(&rc2), 2);
423                    array.clear();
424                    assert_eq!(Rc::strong_count(&rc1), 1);
425                    assert_eq!(Rc::strong_count(&rc2), 1);
426                }
427                #[test]
428                fn empty_indices_iterator() {
429                    let array = $alias_ident::from_iter([0, 1]);
430                    assert!(array.iter_empty_indices().all(|v| v != 0 && v != 1))
431                }
432                #[test]
433                fn mutable_ref_iterator() {
434                    let mut array = $alias_ident::from_iter([0, 1]);
435                    array.iter_mut().for_each(|v| *v += 1);
436                    let new_version = $alias_ident::from_iter([1, 2]);
437                    assert_eq!(array, new_version);
438                }
439                #[test]
440                fn insertion() {
441                    let array = $alias_ident::from_iter([0, 1]);
442                    assert_eq!(None, array.insert(2, 2));
443                    let new_array = $alias_ident::from_iter([0, 1, 2]);
444                    assert_eq!(array, new_array);
445                }
446                #[test]
447                fn debug_print_no_ub() {
448                    let array = $alias_ident::from_iter([0, 1]);
449                    let formatted_string = std::format!("{:?}", array);
450                    assert!(formatted_string.is_ascii());
451                }
452                #[test]
453                fn emptiness() {
454                    let array: $alias_ident<u8> = $alias_ident::new();
455                    assert!(array.is_empty());
456                    assert_eq!(0, array.len());
457                }
458                #[test]
459                fn pushing() {
460                    let mut array: $alias_ident<u8> = $alias_ident::new();
461                    assert_eq!(Ok(0), array.push(1));
462                    assert_eq!(Ok(1), array.push(2));
463                    assert_eq!(Ok(2), array.push(3));
464                    assert_eq!(Some(&1), array.get_ref(0));
465                    assert_eq!(Some(&2), array.get_ref(1));
466                    assert_eq!(Some(&3), array.get_ref(2));
467                }
468                #[test]
469                fn pushing_maxed_out() {
470                    let mut full_array = $alias_ident::from_iter([0u8; $bits]);
471                    assert_eq!(Err(1), full_array.push(1));
472                    assert!(full_array.iter().all(|v| *v == 0));
473                }
474                #[test]
475                fn successful_insertions() {
476                    let array = $alias_ident::new();
477                    assert_eq!(None, array.insert(0, 1));
478                    assert_eq!(None, array.insert(1, 1));
479                    assert!(array.contains_item_at(0));
480                    assert_eq!(Some(1), array.insert(0, 1));
481                    assert_eq!(0b11, array.mask());
482                }
483                #[test]
484                fn masked_iteration() {
485                    let mut array = $alias_ident::from_iter([true; $bits]);
486                    assert!(array.iter_mask($num_ty::ALL_SELECTED).all(|b| *b));
487                    assert!(array.iter_mut_mask($num_ty::ALL_SELECTED).all(|b| {*b = false; true}));
488                }
489                #[test]
490                fn from_iter_init() {
491                    let mut array: $alias_ident<u8> = $alias_ident::from_iter([(1, 10)]);
492                    assert_eq!($num_ty::index_to_mask(1), array.mask());
493                    assert_eq!(10, *array.get_mut(1).unwrap());
494                }
495            }
496        }
497    };
498}
499
500mask_tracked_array_impl!(
501    (u8, 8, MaskTrackedArrayU8),
502    (u16, 16, MaskTrackedArrayU16),
503    (u32, 32, MaskTrackedArrayU32),
504    (u64, 64, MaskTrackedArrayU64),
505    (u128, 128, MaskTrackedArrayU128)
506);