Skip to main content

mask_tracked_array/
lib.rs

1#![no_std]
2#![warn(missing_docs)]
3//! A no std and no alloc abstraction for some data structures on
4//! microcontrollers. A [`MaskTrackedArray`] is a
5//! [`UnsafeCell<MaybeUninit<T>>`] with a number mask for tracking which slots
6//! are filled and which aren't. The arrays are allocated at compile time using
7//! generics and as such come in different sizes based on the number types.
8//!
9//! You can think of a [`MaskTrackedArray`] as an array of numbered slots which
10//! can be accessed independently (unsafe code may be required).
11//!
12//! The current implementations supplied by this crate are
13//! - [`MaskTrackedArrayU8`]
14//! - [`MaskTrackedArrayU16`]
15//! - [`MaskTrackedArrayU32`]
16//! - [`MaskTrackedArrayU64`]
17//! - [`MaskTrackedArrayU128`]
18//!
19//! See the documentation on [`MaskTrackedArray`] to see what methods are
20//! available.
21use core::{cell::UnsafeCell, mem::MaybeUninit};
22
23use bit_iter::BitIter;
24
25#[cfg(feature = "serde")]
26#[doc(hidden)]
27pub mod serde_impl;
28
29/// Implemented by every variant of the mask tracked array. The
30/// [`MaskTrackedArray::MaskType`] is the number type used for the mask with
31/// [`MaskTrackedArray::MAX_COUNT`] being the size of the slots array.
32pub trait MaskTrackedArray<T>: Default {
33    /// The number type used as the mask.
34    type MaskType;
35    /// The maximum number of elements in the array. Note that this is based
36    /// on the number of bits in [`MaskTrackedArray::MaskType`].
37    const MAX_COUNT: usize;
38    /// Check if there is an item at a slot. This function will also return
39    /// false if the index is out of range.
40    fn contains_item_at(&self, index: usize) -> bool;
41    /// Check how many slots are occupied.
42    fn len(&self) -> u32;
43    /// Returns true if this array is completely empty.
44    fn is_empty(&self) -> bool {
45        self.len() == 0
46    }
47    /// Construct a new empty instance of this array.
48    #[must_use]
49    fn new() -> Self {
50        Self::default()
51    }
52    /// Clear out all items in this array, returning to its empty state. Drop
53    /// is called if needed.
54    fn clear(&mut self);
55    /// Get a reference to an item inside a slot if available.
56    fn get_ref(&self, index: usize) -> Option<&T> {
57        if self.contains_item_at(index) {
58            Some(unsafe { self.get_unchecked_ref(index) })
59        } else {
60            None
61        }
62    }
63    /// Get a mutable reference to an item inside a slot if available.
64    fn get_mut(&mut self, index: usize) -> Option<&mut T> {
65        if self.contains_item_at(index) {
66            Some(unsafe { self.get_unchecked_mut(index) })
67        } else {
68            None
69        }
70    }
71    /// Get an immutable reference to slot without bounds or validity checking.
72    /// # Safety
73    /// The given index must be valid.
74    unsafe fn get_unchecked_ref(&self, index: usize) -> &T;
75    /// Get a mutable reference to a slot without bounds, validity or borrow
76    /// checking.
77    /// # Safety
78    /// The given index must be valid and there must be no other references
79    /// to the same slot.
80    #[allow(clippy::mut_from_ref)]
81    unsafe fn get_unchecked_mut(&self, index: usize) -> &mut T;
82    /// Insert an item at a given index without bounds, validity or borrow
83    /// checking.
84    /// # Safety
85    /// The given index must be valid to avoid undefined behaviour. If a value
86    /// was already present, that value will be forgotten without running its
87    /// drop implementation.
88    unsafe fn insert_unchecked(&self, index: usize, value: T);
89    /// Try to insert an item at a given index. If insertion fails, return
90    /// the value in an option. If the option is none then the insertion
91    /// succeeded.
92    #[must_use]
93    fn insert(&self, index: usize, value: T) -> Option<T> {
94        if self.contains_item_at(index) || index >= Self::MAX_COUNT {
95            Some(value)
96        } else {
97            unsafe { self.insert_unchecked(index, value) };
98            None
99        }
100    }
101    /// Remove a value from a slot without checking the index or if the item is
102    /// there.
103    /// # Safety
104    /// Calling this function on currently referenced or nonexistent slots will
105    /// result in undefined behaviour.
106    unsafe fn remove_unchecked(&self, index: usize) -> T;
107    /// Remove a value at a specific index and return it if available.
108    fn remove(&mut self, index: usize) -> Option<T> {
109        if self.contains_item_at(index) {
110            Some(unsafe { self.remove_unchecked(index) })
111        } else {
112            None
113        }
114    }
115    /// Get an iterator over all filled slot indices.
116    fn iter_filled_indices(&self) -> impl Iterator<Item = usize>;
117    /// Get an iterator over all unfilled slot indices.
118    fn iter_empty_indices(&self) -> impl Iterator<Item = usize>;
119    /// Iterate over references to every filled slot.
120    fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T>
121    where
122        T: 'a;
123    /// Iterate over mutable references to every filled slot.
124    fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T>
125    where
126        T: 'a;
127}
128
129/// An array with slots for values backed by a mask for tracking which slots
130/// are under use. Each slot has their own [`UnsafeCell`].
131pub struct MaskTrackedArrayBase<T, M, const N: usize>
132where
133    Self: MaskTrackedArray<T, MaskType = M>,
134{
135    /// Mask used for tracking which slots in the array are filled.
136    mask: core::cell::Cell<M>,
137    /// An array of fillable slots.
138    array: [UnsafeCell<MaybeUninit<T>>; N],
139}
140
141/// An iterator for [`MaskTrackedArray`] variants.
142pub struct MaskTrackedArrayIter<T, M, const N: usize>
143where
144    MaskTrackedArrayBase<T, M, N>: MaskTrackedArray<T, MaskType = M>,
145{
146    bit_iterator: BitIter<M>,
147    source: MaskTrackedArrayBase<T, M, N>,
148}
149
150impl<T, M, const N: usize> Drop for MaskTrackedArrayBase<T, M, N>
151where
152    Self: MaskTrackedArray<T, MaskType = M>,
153{
154    fn drop(&mut self) {
155        self.clear();
156    }
157}
158macro_rules! mask_tracked_array_impl {
159    () => {};
160    (($num_ty:ty, $bits:expr, $alias_ident:ident), $($rest:tt)*) => {
161        mask_tracked_array_impl!(($num_ty, $bits, $alias_ident));
162        mask_tracked_array_impl!($($rest)*);
163    };
164    (($num_ty:ty, $bits:expr, $alias_ident:ident)) => {
165        #[doc = stringify!(A $num_ty tracked array which can hold $bits items) ]
166        pub type $alias_ident<T> = MaskTrackedArrayBase<T, $num_ty, $bits>;
167        impl<T> MaskTrackedArray<T> for MaskTrackedArrayBase<T, $num_ty, $bits> {
168            type MaskType = $num_ty;
169            const MAX_COUNT: usize = $bits;
170            fn contains_item_at(&self, index: usize) -> bool {
171                if index >= $bits {
172                    return false;
173                }
174                self.mask.get() & (1 << index) != 0
175            }
176            fn len(&self) -> u32 {
177                self.mask.get().count_ones()
178            }
179            unsafe fn get_unchecked_ref(&self, index: usize) -> &T {
180                unsafe { (&*self.array.get_unchecked(index).get()).assume_init_ref() }
181            }
182            unsafe fn get_unchecked_mut(&self, index: usize) -> &mut T {
183                unsafe { (&mut *self.array.get_unchecked(index).get()).assume_init_mut() }
184            }
185            fn clear(&mut self) {
186                if core::mem::needs_drop::<T>() {
187                    for index in bit_iter::BitIter::from(self.mask.get()) {
188                        unsafe {
189                            self.array
190                                .get_unchecked_mut(index)
191                                .get_mut()
192                                .assume_init_drop()
193                        };
194                    }
195                }
196                self.mask.set(0);
197            }
198            unsafe fn insert_unchecked(&self, index: usize, value: T) {
199                unsafe { (&mut *self.array.get_unchecked(index).get()).write(value)};
200                self.mask.update(|v| v | (1 << index));
201            }
202            unsafe fn remove_unchecked(&self, index: usize) -> T {
203                self.mask.update(|v| v & !(1 << index));
204                let mut empty = core::mem::MaybeUninit::uninit();
205                unsafe {
206                    let mut_ref = (&mut *self.array.get_unchecked(index).get());
207                    core::mem::swap(&mut empty, mut_ref);
208                    empty.assume_init()
209                }
210            }
211            #[inline]
212            fn iter_filled_indices(&self) -> impl Iterator<Item = usize> {
213                BitIter::from(self.mask.get())
214            }
215            #[inline]
216            fn iter_empty_indices(&self) -> impl Iterator<Item = usize> {
217                BitIter::from(!self.mask.get())
218            }
219            #[inline]
220            fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T> where T: 'a {
221                BitIter::from(self.mask.get()).map(|index| unsafe { self.get_unchecked_ref(index) } )
222            }
223            #[inline]
224            fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T> where T: 'a {
225                BitIter::from(self.mask.get()).map(|index| unsafe { self.get_unchecked_mut(index) } )
226            }
227        }
228        impl<T> Default for MaskTrackedArrayBase<T, $num_ty, $bits> {
229            fn default() -> Self {
230                Self {
231                    mask: core::cell::Cell::new(0),
232                    array: [const {core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit())}; $bits]
233                }
234            }
235        }
236        impl<T> core::iter::Iterator for MaskTrackedArrayIter<T, $num_ty, $bits> {
237            type Item = T;
238            fn next(&mut self) -> Option<Self::Item> {
239                let next_index = self.bit_iterator.next()?;
240                Some( unsafe { self.source.remove_unchecked(next_index) } )
241            }
242        }
243        impl<T> core::iter::IntoIterator for MaskTrackedArrayBase<T, $num_ty, $bits>
244        {
245            type Item = T;
246            type IntoIter = MaskTrackedArrayIter<T, $num_ty, $bits>;
247            fn into_iter(self) -> Self::IntoIter {
248                let bit_iterator = BitIter::from(self.mask.get());
249                MaskTrackedArrayIter {
250                    source: self,
251                    bit_iterator
252                }
253            }
254        }
255        impl<T> core::iter::FromIterator<T> for MaskTrackedArrayBase<T, $num_ty, $bits> {
256            fn from_iter<I>(iter: I) -> Self
257                where I: IntoIterator<Item = T>
258            {
259                let empty = Self::new();
260                for (index, value) in iter.into_iter().enumerate() {
261                    if index >= $bits {
262                        break;
263                    }
264                    unsafe { empty.insert_unchecked(index, value) };
265                }
266                empty
267            }
268        }
269        impl<T: PartialEq> PartialEq for MaskTrackedArrayBase<T, $num_ty, $bits> {
270            fn eq(&self, other: &Self) -> bool {
271                if self.mask != other.mask {
272                    return false;
273                }
274                self.iter().zip(other.iter()).all(|(left, right)| left.eq(right))
275            }
276        }
277        impl<T: Eq> Eq for MaskTrackedArrayBase<T, $num_ty, $bits> {}
278        impl<T: core::hash::Hash> core::hash::Hash for MaskTrackedArrayBase<T, $num_ty, $bits> {
279            fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
280                self.mask.get().hash(state);
281                self.iter().for_each(|v| v.hash(state));
282            }
283        }
284        impl<T: core::fmt::Debug> core::fmt::Debug for MaskTrackedArrayBase<T, $num_ty, $bits> {
285            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
286                f.debug_list().entries(self.iter()).finish()?;
287                Ok(())
288            }
289        }
290        paste::paste! {
291            #[cfg(test)]
292            mod [<$num_ty _tests>] {
293                use super::*;
294                extern crate std;
295                #[test]
296                fn from_iterator_and_back() {
297                    let mask = $alias_ident::from_iter(0..$bits);
298                    for (index, number) in mask.into_iter().enumerate() {
299                        assert_eq!(index, number);
300                    }
301                }
302                #[test]
303                fn from_too_big_iterator() {
304                    let numbers = [0; $bits + 1];
305                    let mask = $alias_ident::from_iter(numbers);
306                    assert_eq!(mask.len(), $bits);
307                }
308                #[test]
309                fn from_empty_iterator() {
310                    let numbers: [u8; 0] = [];
311                    let mask = $alias_ident::from_iter(numbers);
312                    assert_eq!(mask.len(), 0);
313                }
314                #[test]
315                fn hash_equality() {
316                    let mask = $alias_ident::new();
317                    assert!(mask.insert(0, 0).is_none());
318                    assert!(mask.insert(1, 1).is_none());
319                    let mask_2 = $alias_ident::new();
320                    assert!(mask_2.insert(1, 1).is_none());
321                    assert!(mask_2.insert(0, 0).is_none());
322                    assert_eq!(mask, mask_2);
323                    use std::hash::{ Hash, DefaultHasher, Hasher };
324                    let mut hasher = DefaultHasher::new();
325                    mask.hash(&mut hasher);
326                    let first_hash = hasher.finish();
327                    let mut hasher = DefaultHasher::new();
328                    mask_2.hash(&mut hasher);
329                    let second_hash = hasher.finish();
330                    assert_eq!(first_hash, second_hash);
331                }
332                #[test]
333                fn equality() {
334                    let first = $alias_ident::from_iter([1, 2]);
335                    let second = $alias_ident::from_iter([1]);
336                    assert_ne!(first, second);
337                }
338                #[test]
339                fn removal() {
340                    let mut array = $alias_ident::from_iter([1, 2, 3]);
341                    assert_eq!(Some(1), array.remove(0));
342                    assert_eq!(Some(2), array.remove(1));
343                    assert_eq!(Some(3), array.remove(2));
344                    assert_eq!(None, array.remove(0))
345                }
346                #[test]
347                fn failing_getters() {
348                    let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
349                    assert_eq!(None, array.get_ref(5));
350                    assert_eq!(None, array.get_ref(1000));
351                    assert_eq!(None, array.get_mut(5));
352                    assert_eq!(None, array.get_mut(1000));
353                }
354                #[test]
355                fn succeeding_getters() {
356                    let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
357                    assert_eq!(Some(&1), array.get_ref(0));
358                    assert_eq!(Some(&mut 2), array.get_mut(1));
359                }
360                #[test]
361                fn clearing() {
362                    let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
363                    array.clear();
364                    assert_eq!(array, $alias_ident::new());
365                    assert_eq!(array.len(), 0);
366                }
367                #[test]
368                fn clearing_with_drop() {
369                    use std::rc::Rc;
370                    let rc1 = Rc::new(1);
371                    let rc2 = Rc::new(2);
372                    let mut array = $alias_ident::from_iter([rc1.clone(), rc2.clone()]);
373                    assert_eq!(Rc::strong_count(&rc1), 2);
374                    assert_eq!(Rc::strong_count(&rc2), 2);
375                    array.clear();
376                    assert_eq!(Rc::strong_count(&rc1), 1);
377                    assert_eq!(Rc::strong_count(&rc2), 1);
378                }
379                #[test]
380                fn empty_indices_iterator() {
381                    let array = $alias_ident::from_iter([0, 1]);
382                    assert!(array.iter_empty_indices().all(|v| v != 0 && v != 1))
383                }
384                #[test]
385                fn mutable_ref_iterator() {
386                    let mut array = $alias_ident::from_iter([0, 1]);
387                    array.iter_mut().for_each(|v| *v += 1);
388                    let new_version = $alias_ident::from_iter([1, 2]);
389                    assert_eq!(array, new_version);
390                }
391                #[test]
392                fn insertion() {
393                    let array = $alias_ident::from_iter([0, 1]);
394                    assert_eq!(None, array.insert(2, 2));
395                    let new_array = $alias_ident::from_iter([0, 1, 2]);
396                    assert_eq!(array, new_array);
397                }
398                #[test]
399                fn debug_print_no_ub() {
400                    let array = $alias_ident::from_iter([0, 1]);
401                    let formatted_string = std::format!("{:?}", array);
402                    assert!(formatted_string.is_ascii());
403                }
404                #[test]
405                fn emptiness() {
406                    let array: $alias_ident<u8> = $alias_ident::new();
407                    assert!(array.is_empty());
408                    assert_eq!(0, array.len());
409                }
410            }
411        }
412    };
413}
414
415mask_tracked_array_impl!(
416    (u8, 8, MaskTrackedArrayU8),
417    (u16, 16, MaskTrackedArrayU16),
418    (u32, 32, MaskTrackedArrayU32),
419    (u64, 64, MaskTrackedArrayU64),
420    (u128, 128, MaskTrackedArrayU128)
421);