1#![no_std]
2#![warn(missing_docs)]
3use core::{cell::UnsafeCell, mem::MaybeUninit};
22
23use bit_iter::BitIter;
24
25mod mask_trait;
26#[cfg(feature = "serde")]
27#[doc(hidden)]
28pub mod serde_impl;
29pub use mask_trait::Mask;
30
31pub trait MaskTrackedArray<T>: Default + FromIterator<T> + FromIterator<(usize, T)> {
34 type MaskType: Mask;
36 fn contains_item_at(&self, index: usize) -> bool;
39 fn len(&self) -> u32;
41 fn is_empty(&self) -> bool {
43 self.len() == 0
44 }
45 #[must_use]
47 fn new() -> Self {
48 Self::default()
49 }
50 fn clear(&mut self);
53 fn get_ref(&self, index: usize) -> Option<&T> {
55 if self.contains_item_at(index) {
56 Some(unsafe { self.get_unchecked_ref(index) })
57 } else {
58 None
59 }
60 }
61 fn get_mut(&mut self, index: usize) -> Option<&mut T> {
63 if self.contains_item_at(index) {
64 Some(unsafe { self.get_unchecked_mut(index) })
65 } else {
66 None
67 }
68 }
69 unsafe fn get_unchecked_ref(&self, index: usize) -> &T;
73 #[allow(clippy::mut_from_ref)]
79 unsafe fn get_unchecked_mut(&self, index: usize) -> &mut T;
80 unsafe fn insert_unchecked(&self, index: usize, value: T);
87 #[must_use]
91 fn insert(&self, index: usize, value: T) -> Option<T> {
92 if self.contains_item_at(index) || index >= Self::MaskType::MAX_SELECTIONS as usize {
93 Some(value)
94 } else {
95 unsafe { self.insert_unchecked(index, value) };
96 None
97 }
98 }
99 unsafe fn remove_unchecked(&self, index: usize) -> T;
105 fn remove(&mut self, index: usize) -> Option<T> {
107 if self.contains_item_at(index) {
108 Some(unsafe { self.remove_unchecked(index) })
109 } else {
110 None
111 }
112 }
113 fn iter_filled_indices(&self) -> impl Iterator<Item = usize>;
115 fn iter_filled_indices_mask(&self, mask: Self::MaskType) -> impl Iterator<Item = usize>;
118 fn iter_empty_indices(&self) -> impl Iterator<Item = usize>;
120 fn iter<'a>(&'a self) -> impl Iterator<Item = &'a T>
122 where
123 T: 'a,
124 {
125 self.iter_filled_indices()
126 .map(|index| unsafe { self.get_unchecked_ref(index) })
127 }
128 fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T>
130 where
131 T: 'a,
132 {
133 self.iter_filled_indices()
134 .map(|index| unsafe { self.get_unchecked_mut(index) })
135 }
136 fn iter_mask<'a>(&'a self, mask: Self::MaskType) -> impl Iterator<Item = &'a T>
138 where
139 T: 'a,
140 {
141 self.iter_filled_indices_mask(mask)
142 .map(|index| unsafe { self.get_unchecked_ref(index) })
143 }
144 fn iter_mut_mask<'a>(&'a mut self, mask: Self::MaskType) -> impl Iterator<Item = &'a mut T>
147 where
148 T: 'a,
149 {
150 self.iter_filled_indices_mask(mask)
151 .map(|index| unsafe { self.get_unchecked_mut(index) })
152 }
153 fn mask(&self) -> Self::MaskType;
155 fn push(&mut self, value: T) -> Result<usize, T> {
158 if let Some(smallest) = self.iter_empty_indices().next() {
159 let None = self.insert(smallest, value) else {
160 unreachable!()
161 };
162 Ok(smallest)
163 } else {
164 Err(value)
165 }
166 }
167}
168
169pub struct MaskTrackedArrayBase<T, M, const N: usize>
172where
173 Self: MaskTrackedArray<T, MaskType = M>,
174{
175 mask: core::cell::Cell<M>,
177 array: [UnsafeCell<MaybeUninit<T>>; N],
179}
180
181pub struct MaskTrackedArrayIter<T, M, const N: usize>
183where
184 MaskTrackedArrayBase<T, M, N>: MaskTrackedArray<T, MaskType = M>,
185{
186 bit_iterator: BitIter<M>,
187 source: MaskTrackedArrayBase<T, M, N>,
188}
189
190impl<T, M, const N: usize> Drop for MaskTrackedArrayBase<T, M, N>
191where
192 Self: MaskTrackedArray<T, MaskType = M>,
193{
194 fn drop(&mut self) {
195 self.clear();
196 }
197}
198macro_rules! mask_tracked_array_impl {
199 () => {};
200 (($num_ty:ty, $bits:expr, $alias_ident:ident), $($rest:tt)*) => {
201 mask_tracked_array_impl!(($num_ty, $bits, $alias_ident));
202 mask_tracked_array_impl!($($rest)*);
203 };
204 (($num_ty:ty, $bits:expr, $alias_ident:ident)) => {
205 #[doc = stringify!(A $num_ty tracked array which can hold $bits items) ]
206 pub type $alias_ident<T> = MaskTrackedArrayBase<T, $num_ty, $bits>;
207 impl<T> MaskTrackedArray<T> for MaskTrackedArrayBase<T, $num_ty, $bits> {
208 type MaskType = $num_ty; fn contains_item_at(&self, index: usize) -> bool {
209 if index >= <$num_ty>::BITS as usize {
210 return false;
211 }
212 self.mask.get() & (1 << index) != 0
213 }
214 fn mask(&self) -> Self::MaskType {
215 self.mask.get()
216 }
217 fn len(&self) -> u32 {
218 self.mask.get().count_ones()
219 }
220 unsafe fn get_unchecked_ref(&self, index: usize) -> &T {
221 unsafe { (&*self.array.get_unchecked(index).get()).assume_init_ref() }
222 }
223 unsafe fn get_unchecked_mut(&self, index: usize) -> &mut T {
224 unsafe { (&mut *self.array.get_unchecked(index).get()).assume_init_mut() }
225 }
226 fn clear(&mut self) {
227 if core::mem::needs_drop::<T>() {
228 for index in bit_iter::BitIter::from(self.mask.get()) {
229 unsafe {
230 self.array
231 .get_unchecked_mut(index)
232 .get_mut()
233 .assume_init_drop()
234 };
235 }
236 }
237 self.mask.set(0);
238 }
239 unsafe fn insert_unchecked(&self, index: usize, value: T) {
240 unsafe { (&mut *self.array.get_unchecked(index).get()).write(value)};
241 self.mask.update(|v| v | (1 << index));
242 }
243 unsafe fn remove_unchecked(&self, index: usize) -> T {
244 self.mask.update(|v| v & !(1 << index));
245 let mut empty = core::mem::MaybeUninit::uninit();
246 unsafe {
247 let mut_ref = (&mut *self.array.get_unchecked(index).get());
248 core::mem::swap(&mut empty, mut_ref);
249 empty.assume_init()
250 }
251 }
252 #[inline]
253 fn iter_filled_indices(&self) -> impl Iterator<Item = usize> {
254 BitIter::from(self.mask.get())
255 }
256 #[inline]
257 fn iter_filled_indices_mask(&self, mask: Self::MaskType) -> impl Iterator<Item = usize> {
258 BitIter::from(self.mask.get() & mask)
259 }
260 #[inline]
261 fn iter_empty_indices(&self) -> impl Iterator<Item = usize> {
262 BitIter::from(!self.mask.get())
263 }
264 }
265 impl<T> Default for MaskTrackedArrayBase<T, $num_ty, $bits> {
266 fn default() -> Self {
267 Self {
268 mask: core::cell::Cell::new(0),
269 array: [const {core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit())}; $bits]
270 }
271 }
272 }
273 impl<T> core::iter::Iterator for MaskTrackedArrayIter<T, $num_ty, $bits> {
274 type Item = T;
275 fn next(&mut self) -> Option<Self::Item> {
276 let next_index = self.bit_iterator.next()?;
277 Some( unsafe { self.source.remove_unchecked(next_index) } )
278 }
279 }
280 impl<T> core::iter::IntoIterator for MaskTrackedArrayBase<T, $num_ty, $bits>
281 {
282 type Item = T;
283 type IntoIter = MaskTrackedArrayIter<T, $num_ty, $bits>;
284 fn into_iter(self) -> Self::IntoIter {
285 let bit_iterator = BitIter::from(self.mask.get());
286 MaskTrackedArrayIter {
287 source: self,
288 bit_iterator
289 }
290 }
291 }
292 impl<T> core::iter::FromIterator<T> for MaskTrackedArrayBase<T, $num_ty, $bits> {
293 fn from_iter<I>(iter: I) -> Self
294 where I: IntoIterator<Item = T>
295 {
296 let empty = Self::new();
297 for (index, value) in iter.into_iter().enumerate() {
298 if index >= $bits {
299 break;
300 }
301 unsafe { empty.insert_unchecked(index, value) };
302 }
303 empty
304 }
305 }
306 impl<T> core::iter::FromIterator<(usize, T)> for MaskTrackedArrayBase<T, $num_ty, $bits> {
307 fn from_iter<I>(iter: I) -> Self
308 where I: IntoIterator<Item = (usize, T)>
309 {
310 let empty = Self::new();
311 for (index, value) in iter.into_iter() {
312 let _ = empty.insert(index, value);
313 }
314 empty
315 }
316 }
317 impl<T: PartialEq> PartialEq for MaskTrackedArrayBase<T, $num_ty, $bits> {
318 fn eq(&self, other: &Self) -> bool {
319 if self.mask != other.mask {
320 return false;
321 }
322 self.iter().zip(other.iter()).all(|(left, right)| left.eq(right))
323 }
324 }
325 impl<T: Eq> Eq for MaskTrackedArrayBase<T, $num_ty, $bits> {}
326 impl<T: core::hash::Hash> core::hash::Hash for MaskTrackedArrayBase<T, $num_ty, $bits> {
327 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
328 self.mask.get().hash(state);
329 self.iter().for_each(|v| v.hash(state));
330 }
331 }
332 impl<T: core::fmt::Debug> core::fmt::Debug for MaskTrackedArrayBase<T, $num_ty, $bits> {
333 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
334 f.debug_list().entries(self.iter()).finish()?;
335 Ok(())
336 }
337 }
338 paste::paste! {
339 #[cfg(test)]
340 mod [<$num_ty _tests>] {
341 use super::*;
342 extern crate std;
343 #[test]
344 fn from_iterator_and_back() {
345 let mask = $alias_ident::from_iter(0..$bits);
346 for (index, number) in mask.into_iter().enumerate() {
347 assert_eq!(index, number);
348 }
349 }
350 #[test]
351 fn from_too_big_iterator() {
352 let numbers = [0; $bits + 1];
353 let mask = $alias_ident::from_iter(numbers);
354 assert_eq!(mask.len(), $bits);
355 }
356 #[test]
357 fn from_empty_iterator() {
358 let numbers: [u8; 0] = [];
359 let mask = $alias_ident::from_iter(numbers);
360 assert_eq!(mask.len(), 0);
361 }
362 #[test]
363 fn hash_equality() {
364 let mask = $alias_ident::new();
365 assert!(mask.insert(0, 0).is_none());
366 assert!(mask.insert(1, 1).is_none());
367 let mask_2 = $alias_ident::new();
368 assert!(mask_2.insert(1, 1).is_none());
369 assert!(mask_2.insert(0, 0).is_none());
370 assert_eq!(mask, mask_2);
371 use std::hash::{ Hash, DefaultHasher, Hasher };
372 let mut hasher = DefaultHasher::new();
373 mask.hash(&mut hasher);
374 let first_hash = hasher.finish();
375 let mut hasher = DefaultHasher::new();
376 mask_2.hash(&mut hasher);
377 let second_hash = hasher.finish();
378 assert_eq!(first_hash, second_hash);
379 }
380 #[test]
381 fn equality() {
382 let first = $alias_ident::from_iter([1, 2]);
383 let second = $alias_ident::from_iter([1]);
384 assert_ne!(first, second);
385 }
386 #[test]
387 fn removal() {
388 let mut array = $alias_ident::from_iter([1, 2, 3]);
389 assert_eq!(Some(1), array.remove(0));
390 assert_eq!(Some(2), array.remove(1));
391 assert_eq!(Some(3), array.remove(2));
392 assert_eq!(None, array.remove(0))
393 }
394 #[test]
395 fn failing_getters() {
396 let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
397 assert_eq!(None, array.get_ref(5));
398 assert_eq!(None, array.get_ref(1000));
399 assert_eq!(None, array.get_mut(5));
400 assert_eq!(None, array.get_mut(1000));
401 }
402 #[test]
403 fn succeeding_getters() {
404 let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
405 assert_eq!(Some(&1), array.get_ref(0));
406 assert_eq!(Some(&mut 2), array.get_mut(1));
407 }
408 #[test]
409 fn clearing() {
410 let mut array = $alias_ident::from_iter([1, 2, 3, 4]);
411 array.clear();
412 assert_eq!(array, $alias_ident::new());
413 assert_eq!(array.len(), 0);
414 }
415 #[test]
416 fn clearing_with_drop() {
417 use std::rc::Rc;
418 let rc1 = Rc::new(1);
419 let rc2 = Rc::new(2);
420 let mut array = $alias_ident::from_iter([rc1.clone(), rc2.clone()]);
421 assert_eq!(Rc::strong_count(&rc1), 2);
422 assert_eq!(Rc::strong_count(&rc2), 2);
423 array.clear();
424 assert_eq!(Rc::strong_count(&rc1), 1);
425 assert_eq!(Rc::strong_count(&rc2), 1);
426 }
427 #[test]
428 fn empty_indices_iterator() {
429 let array = $alias_ident::from_iter([0, 1]);
430 assert!(array.iter_empty_indices().all(|v| v != 0 && v != 1))
431 }
432 #[test]
433 fn mutable_ref_iterator() {
434 let mut array = $alias_ident::from_iter([0, 1]);
435 array.iter_mut().for_each(|v| *v += 1);
436 let new_version = $alias_ident::from_iter([1, 2]);
437 assert_eq!(array, new_version);
438 }
439 #[test]
440 fn insertion() {
441 let array = $alias_ident::from_iter([0, 1]);
442 assert_eq!(None, array.insert(2, 2));
443 let new_array = $alias_ident::from_iter([0, 1, 2]);
444 assert_eq!(array, new_array);
445 }
446 #[test]
447 fn debug_print_no_ub() {
448 let array = $alias_ident::from_iter([0, 1]);
449 let formatted_string = std::format!("{:?}", array);
450 assert!(formatted_string.is_ascii());
451 }
452 #[test]
453 fn emptiness() {
454 let array: $alias_ident<u8> = $alias_ident::new();
455 assert!(array.is_empty());
456 assert_eq!(0, array.len());
457 }
458 #[test]
459 fn pushing() {
460 let mut array: $alias_ident<u8> = $alias_ident::new();
461 assert_eq!(Ok(0), array.push(1));
462 assert_eq!(Ok(1), array.push(2));
463 assert_eq!(Ok(2), array.push(3));
464 assert_eq!(Some(&1), array.get_ref(0));
465 assert_eq!(Some(&2), array.get_ref(1));
466 assert_eq!(Some(&3), array.get_ref(2));
467 }
468 #[test]
469 fn pushing_maxed_out() {
470 let mut full_array = $alias_ident::from_iter([0u8; $bits]);
471 assert_eq!(Err(1), full_array.push(1));
472 assert!(full_array.iter().all(|v| *v == 0));
473 }
474 #[test]
475 fn successful_insertions() {
476 let array = $alias_ident::new();
477 assert_eq!(None, array.insert(0, 1));
478 assert_eq!(None, array.insert(1, 1));
479 assert!(array.contains_item_at(0));
480 assert_eq!(Some(1), array.insert(0, 1));
481 assert_eq!(0b11, array.mask());
482 }
483 #[test]
484 fn masked_iteration() {
485 let mut array = $alias_ident::from_iter([true; $bits]);
486 assert!(array.iter_mask($num_ty::ALL_SELECTED).all(|b| *b));
487 assert!(array.iter_mut_mask($num_ty::ALL_SELECTED).all(|b| {*b = false; true}));
488 }
489 #[test]
490 fn from_iter_init() {
491 let mut array: $alias_ident<u8> = $alias_ident::from_iter([(1, 10)]);
492 assert_eq!($num_ty::index_to_mask(1), array.mask());
493 assert_eq!(10, *array.get_mut(1).unwrap());
494 }
495 }
496 }
497 };
498}
499
500mask_tracked_array_impl!(
501 (u8, 8, MaskTrackedArrayU8),
502 (u16, 16, MaskTrackedArrayU16),
503 (u32, 32, MaskTrackedArrayU32),
504 (u64, 64, MaskTrackedArrayU64),
505 (u128, 128, MaskTrackedArrayU128)
506);