cl_aux/structures/
array_wrapper.rs

1use core::{
2  borrow::{Borrow, BorrowMut},
3  ops::{Deref, DerefMut},
4  slice::{Iter, IterMut},
5};
6
7/// Used for serialization, de-serialization or to construct custom arrays.
8#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
9#[repr(transparent)]
10pub struct ArrayWrapper<T, const N: usize>(
11  /// The actual array
12  pub [T; N],
13);
14
15impl<T, const N: usize> ArrayWrapper<T, N> {
16  /// Creates an array `[T; N]` where each array element `T` is returned by the `cb` call.
17  #[inline]
18  pub fn from_fn(mut cb: impl FnMut(usize) -> T) -> Self {
19    let mut idx = 0;
20    Self([(); N].map(|_| {
21      let res = cb(idx);
22      idx = idx.wrapping_add(1);
23      res
24    }))
25  }
26
27  /// Creates an array `ArrayWrapper` where each fallible array element `T` is returned by the `cb` call.
28  /// Unlike [`ArrayWrapper::from_fn`], where the element creation can't fail, this version will return an error
29  /// if any element creation was unsuccessful.
30  #[inline]
31  pub fn try_from_fn<E>(cb: impl FnMut(usize) -> Result<T, E>) -> Result<Self, E> {
32    Ok(Self(here_be_dragons::try_from_fn(cb)?))
33  }
34}
35
36impl<T, const N: usize> AsRef<[T; N]> for ArrayWrapper<T, N> {
37  #[inline]
38  fn as_ref(&self) -> &[T; N] {
39    self
40  }
41}
42
43impl<T, const N: usize> AsMut<[T; N]> for ArrayWrapper<T, N> {
44  #[inline]
45  fn as_mut(&mut self) -> &mut [T; N] {
46    self
47  }
48}
49
50impl<T, const N: usize> Borrow<[T; N]> for ArrayWrapper<T, N> {
51  #[inline]
52  fn borrow(&self) -> &[T; N] {
53    self
54  }
55}
56
57impl<T, const N: usize> BorrowMut<[T; N]> for ArrayWrapper<T, N> {
58  #[inline]
59  fn borrow_mut(&mut self) -> &mut [T; N] {
60    self
61  }
62}
63
64impl<T, const N: usize> Default for ArrayWrapper<T, N>
65where
66  T: Default,
67{
68  #[inline]
69  fn default() -> Self {
70    Self::from_fn(|_| T::default())
71  }
72}
73
74impl<T, const N: usize> Deref for ArrayWrapper<T, N> {
75  type Target = [T; N];
76
77  #[inline]
78  fn deref(&self) -> &[T; N] {
79    &self.0
80  }
81}
82
83impl<T, const N: usize> DerefMut for ArrayWrapper<T, N> {
84  #[inline]
85  fn deref_mut(&mut self) -> &mut [T; N] {
86    &mut self.0
87  }
88}
89
90impl<T, const N: usize> From<[T; N]> for ArrayWrapper<T, N> {
91  #[inline]
92  fn from(from: [T; N]) -> Self {
93    Self(from)
94  }
95}
96
97impl<'array, T, const N: usize> IntoIterator for &'array ArrayWrapper<T, N> {
98  type IntoIter = Iter<'array, T>;
99  type Item = &'array T;
100
101  #[inline]
102  fn into_iter(self) -> Self::IntoIter {
103    self.0.iter()
104  }
105}
106
107impl<'array, T, const N: usize> IntoIterator for &'array mut ArrayWrapper<T, N> {
108  type IntoIter = IterMut<'array, T>;
109  type Item = &'array mut T;
110
111  #[inline]
112  fn into_iter(self) -> Self::IntoIter {
113    self.0.iter_mut()
114  }
115}
116
117// Code was copied from Rustc, therefore, no UB should be triggered (at least theoretically)
118mod here_be_dragons {
119  #![allow(clippy::as_conversions, clippy::mem_forget, trivial_casts, unsafe_code)]
120
121  use core::{
122    mem::{self, MaybeUninit},
123    ptr::{self, addr_of, addr_of_mut},
124  };
125
126  #[inline]
127  pub(crate) fn try_from_fn<E, T, const N: usize>(
128    cb: impl FnMut(usize) -> Result<T, E>,
129  ) -> Result<[T; N], E> {
130    let mut iter = (0..N).map(cb);
131    debug_assert!(N <= iter.size_hint().1.unwrap_or(usize::MAX));
132    debug_assert!(N <= iter.size_hint().0);
133    // SAFETY: covered by the function contract.
134    unsafe { try_collect_into_array(&mut iter).unwrap_unchecked() }
135  }
136
137  #[allow(
138    // Takes ownership to prevent a double reference.
139    clippy::needless_pass_by_value
140  )]
141  unsafe fn array_assume_init<T, const N: usize>(array: [MaybeUninit<T>; N]) -> [T; N] {
142    // SAFETY:
143    // * The caller guarantees that all elements of the array are initialized
144    // * `MaybeUninit<T>` and T are guaranteed to have the same layout
145    // * `MaybeUninit` does not drop, so there are no double-frees
146    // And thus the conversion is safe
147    unsafe { (addr_of!(array).cast::<[T; N]>()).read() }
148  }
149
150  fn try_collect_into_array<E, T, const N: usize>(
151    iter: &mut impl Iterator<Item = Result<T, E>>,
152  ) -> Option<Result<[T; N], E>> {
153    struct Guard<'array, T, const N: usize> {
154      array_mut: &'array mut [MaybeUninit<T>; N],
155      initialized: usize,
156    }
157
158    impl<T, const N: usize> Drop for Guard<'_, T, N> {
159      fn drop(&mut self) {
160        debug_assert!(self.initialized <= N);
161
162        // SAFETY: slice only contains initialized objects.
163        let slice = unsafe { self.array_mut.get_unchecked_mut(..self.initialized) };
164        // SAFETY: slice only contains initialized objects.
165        let init_slice = unsafe { slice_assume_init_mut(slice) };
166        // SAFETY: slice only contains initialized objects.
167        unsafe {
168          ptr::drop_in_place(init_slice);
169        }
170      }
171    }
172
173    if N == 0 {
174      // SAFETY: An empty array is always inhabited and has no validity invariants.
175      return unsafe { Some(mem::zeroed()) };
176    }
177
178    let mut array = uninit_array::<T, N>();
179    let mut guard = Guard { array_mut: &mut array, initialized: 0 };
180
181    for item_rslt in iter {
182      let item = match item_rslt {
183        Err(err) => {
184          return Some(Err(err));
185        }
186        Ok(elem) => elem,
187      };
188
189      // SAFETY: `guard.initialized` starts at 0, is increased by one in the
190      // loop and the loop is aborted once it reaches N (which is
191      // `array.len()`).
192      unsafe {
193        let _ = guard.array_mut.get_unchecked_mut(guard.initialized).write(item);
194      }
195      guard.initialized = guard.initialized.wrapping_add(1);
196
197      // Check if the whole array was initialized.
198      if guard.initialized == N {
199        mem::forget(guard);
200
201        // SAFETY: the condition above asserts that all elements are
202        // initialized.
203        let out = unsafe { array_assume_init(array) };
204        return Some(Ok(out));
205      }
206    }
207
208    None
209  }
210
211  unsafe fn slice_assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
212    // SAFETY: similar to safety notes for `slice_get_ref`, but we have a
213    // mutable reference which is also guaranteed to be valid for writes.
214    unsafe { &mut *(addr_of_mut!(*slice) as *mut [T]) }
215  }
216
217  const fn uninit_array<T, const LEN: usize>() -> [MaybeUninit<T>; LEN] {
218    // SAFETY: An uninitialized `[MaybeUninit<_>; LEN]` is valid.
219    unsafe { MaybeUninit::<[MaybeUninit<T>; LEN]>::uninit().assume_init() }
220  }
221}
222
223#[cfg(feature = "serde")]
224mod serde {
225  use crate::{ArrayWrapper, ArrayWrapperRef};
226  use core::{fmt::Formatter, marker::PhantomData};
227  use serde::{
228    de::{self, SeqAccess, Visitor},
229    Deserialize, Deserializer, Serialize, Serializer,
230  };
231
232  impl<'de, T, const N: usize> Deserialize<'de> for ArrayWrapper<T, N>
233  where
234    T: Deserialize<'de>,
235  {
236    #[inline]
237    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
238    where
239      D: Deserializer<'de>,
240    {
241      struct ArrayVisitor<T, const N: usize>(PhantomData<T>);
242
243      impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
244      where
245        T: Deserialize<'de>,
246      {
247        type Value = ArrayWrapper<T, N>;
248
249        #[inline]
250        fn expecting(&self, formatter: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
251          formatter.write_fmt(format_args!("an array with {N} elements"))
252        }
253
254        #[inline]
255        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
256        where
257          A: SeqAccess<'de>,
258        {
259          ArrayWrapper::try_from_fn(|_| {
260            seq.next_element::<T>()?.ok_or_else(|| {
261              de::Error::invalid_length(N, &"Array need more data to be constructed")
262            })
263          })
264        }
265      }
266
267      deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>(PhantomData))
268    }
269  }
270
271  impl<T, const N: usize> Serialize for ArrayWrapper<T, N>
272  where
273    T: Serialize,
274  {
275    #[inline]
276    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
277    where
278      S: Serializer,
279    {
280      ArrayWrapperRef::from(&self.0).serialize(serializer)
281    }
282  }
283}