primitives/types/heap_array/
array.rs

1use std::{
2    fmt::{Debug, Display},
3    marker::PhantomData,
4    ops::{Add, Mul, Sub},
5    vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17/// An array on the heap that encodes its length in the type system.
18#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21    #[deref]
22    #[deref_mut]
23    #[as_ref(forward)]
24    #[as_mut(forward)]
25    pub(super) data: Box<[T]>,
26    #[into_iterator(ignore)]
27    // `fn() -> M` is used instead of `M` so `HeapArray<T, M>` doesn't need `M` to implement `Send
28    // + Sync` to be `Send + Sync` itself. This would be the case if `M` was used directly.
29    pub(super) _len: PhantomData<fn() -> M>,
30}
31
32impl<T: Sized, M: Positive> HeapArray<T, M> {
33    pub(super) fn new(data: Box<[T]>) -> Self {
34        Self {
35            data,
36            _len: PhantomData,
37        }
38    }
39}
40
41// bytemuck::BoxBytes transformation for copy-less casting
42impl<T: Pod, M: Positive> HeapArray<T, M> {
43    pub fn into_box_bytes(self) -> BoxBytes {
44        box_bytes_of(self.data)
45    }
46
47    pub fn from_box_bytes(buf: BoxBytes) -> Self {
48        Self {
49            data: from_box_bytes(buf),
50            _len: PhantomData,
51        }
52    }
53}
54
55impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
56    pub fn from_single_value(val: T) -> Self {
57        Self {
58            data: vec![val].into_boxed_slice(),
59            _len: PhantomData,
60        }
61    }
62}
63
64impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
65    fn default() -> Self {
66        Self {
67            data: (0..M::USIZE).map(|_| T::default()).collect(),
68            _len: PhantomData,
69        }
70    }
71}
72
73impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
76            .field("data", &self.data)
77            .finish()
78    }
79}
80
81impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
82    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
83        self.data.serialize(serializer)
84    }
85}
86
87impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
88    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
89        let data = Box::<[T]>::deserialize(deserializer)?;
90
91        if data.len() != M::USIZE {
92            return Err(serde::de::Error::custom(format!(
93                "Expected array of length {}, got {}",
94                M::USIZE,
95                data.len()
96            )));
97        }
98
99        Ok(Self {
100            data,
101            _len: PhantomData,
102        })
103    }
104}
105
106impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
107    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
108        let data = iter.into_iter().collect::<Box<_>>();
109        assert_eq!(data.len(), M::USIZE,);
110        Self {
111            data,
112            _len: PhantomData,
113        }
114    }
115}
116
117impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
118    type Item = T;
119    type Iter = rayon::vec::IntoIter<T>;
120
121    fn into_par_iter(self) -> Self::Iter {
122        self.data.into_par_iter()
123    }
124}
125
126// -----------------------
127// |   Split and Merge   |
128// -----------------------
129
130impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
131    pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
132    where
133        M1: Positive,
134        M2: Positive + Add<M1, Output = M>,
135    {
136        let (m1, m2) = self.split_at(M1::USIZE);
137        (
138            HeapArray::<T, M1> {
139                data: m1.into(),
140                _len: PhantomData,
141            },
142            HeapArray::<T, M2> {
143                data: m2.into(),
144                _len: PhantomData,
145            },
146        )
147    }
148
149    pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
150    where
151        MDiv2: Positive + Mul<U2, Output = M>,
152    {
153        let (m1, m2) = self.split_at(MDiv2::USIZE);
154        (
155            HeapArray::<T, MDiv2> {
156                data: m1.into(),
157                _len: PhantomData,
158            },
159            HeapArray::<T, MDiv2> {
160                data: m2.into(),
161                _len: PhantomData,
162            },
163        )
164    }
165
166    pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
167    where
168        M: Mul<U2, Output: Positive>,
169    {
170        let mut vec = this.data.into_vec();
171        vec.extend(other.data.into_vec());
172        HeapArray::<T, Prod<M, U2>> {
173            data: vec.into_boxed_slice(),
174            _len: PhantomData,
175        }
176    }
177
178    pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
179    where
180        M1: Positive,
181        M2: Positive + Add<M1>,
182        M3: Positive + Add<Sum<M2, M1>, Output = M>,
183    {
184        let (m1, m_rest) = self.split_at(M1::USIZE);
185        let (m2, m3) = m_rest.split_at(M2::USIZE);
186        (
187            HeapArray::<T, M1> {
188                data: m1.into(),
189                _len: PhantomData,
190            },
191            HeapArray::<T, M2> {
192                data: m2.into(),
193                _len: PhantomData,
194            },
195            HeapArray::<T, M3> {
196                data: m3.into(),
197                _len: PhantomData,
198            },
199        )
200    }
201
202    pub fn split_thirds<MDiv3>(
203        &self,
204    ) -> (
205        HeapArray<T, MDiv3>,
206        HeapArray<T, MDiv3>,
207        HeapArray<T, MDiv3>,
208    )
209    where
210        MDiv3: Positive + Mul<U3, Output = M>,
211    {
212        let (m1, m_rest) = self.split_at(MDiv3::USIZE);
213        let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
214        (
215            HeapArray::<T, MDiv3> {
216                data: m1.into(),
217                _len: PhantomData,
218            },
219            HeapArray::<T, MDiv3> {
220                data: m2.into(),
221                _len: PhantomData,
222            },
223            HeapArray::<T, MDiv3> {
224                data: m3.into(),
225                _len: PhantomData,
226            },
227        )
228    }
229
230    pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
231    where
232        M: Mul<U3, Output: Positive>,
233    {
234        let mut vec = first.data.into_vec();
235        vec.extend(second.data.into_vec());
236        vec.extend(third.data.into_vec());
237        HeapArray::<T, Prod<M, U3>> {
238            data: vec.into_boxed_slice(),
239            _len: PhantomData,
240        }
241    }
242}
243
244pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
245    pub HeapArray<T1, M>,
246    pub HeapArray<T2, M>,
247);
248
249impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
250    fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
251        let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
252
253        assert_eq!(data1.len(), M::USIZE);
254        assert_eq!(data2.len(), M::USIZE);
255        HeapArrayTuple(
256            HeapArray::<T1, M> {
257                data: data1.into_boxed_slice(),
258                _len: PhantomData,
259            },
260            HeapArray::<T2, M> {
261                data: data2.into_boxed_slice(),
262                _len: PhantomData,
263            },
264        )
265    }
266}
267
268impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
269    type Error = PrimitiveError;
270
271    fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
272        if data.len() == M::USIZE {
273            Ok(Self {
274                data: data.into_boxed_slice(),
275                _len: PhantomData,
276            })
277        } else {
278            Err(PrimitiveError::SizeError(M::USIZE, data.len()))
279        }
280    }
281}
282
283impl<T: Sized> From<T> for HeapArray<T, U1> {
284    fn from(element: T) -> Self {
285        Self {
286            data: Box::new([element]),
287            _len: PhantomData,
288        }
289    }
290}
291
292impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
293    fn from(array: HeapArray<T, M>) -> Self {
294        array.data.into_vec()
295    }
296}
297
298impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
299    fn from(array: Array<T, M>) -> Self {
300        Self {
301            data: array.to_vec().into_boxed_slice(),
302            _len: PhantomData,
303        }
304    }
305}
306
307impl<T: Sized, M: Positive> HeapArray<T, M> {
308    pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
309    where
310        M: Sub<N, Output: Positive>,
311    {
312        let mut vec = self.data.into_vec();
313        let last_n = vec.split_off(M::USIZE - N::USIZE);
314
315        (
316            HeapArray {
317                data: vec.into_boxed_slice(),
318                _len: PhantomData,
319            },
320            HeapArray {
321                data: last_n.into_boxed_slice(),
322                _len: PhantomData,
323            },
324        )
325    }
326
327    pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
328        Self {
329            data: (0..M::USIZE).map(f).collect::<Box<_>>(),
330            _len: PhantomData,
331        }
332    }
333
334    pub fn from_constant(c: T) -> Self
335    where
336        T: Copy,
337    {
338        Self {
339            data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
340            _len: PhantomData,
341        }
342    }
343}
344
345impl<T: Sized, M: Positive> Display for HeapArray<T, M>
346where
347    T: Display,
348{
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        write!(f, "[")?;
351        for (i, item) in self.data.iter().enumerate() {
352            if i != 0 {
353                write!(f, ", ")?;
354            }
355            write!(f, "{item}")?;
356        }
357        write!(f, "]")
358    }
359}
360
361#[cfg(test)]
362pub mod tests {
363    use hybrid_array::sizes::{U2, U3, U6};
364    use typenum::{U1, U4};
365
366    use super::*;
367
368    #[test]
369    fn test_heap_array() {
370        let array = HeapArray::<_, U3>::from_fn(|i| i);
371        assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
372    }
373
374    #[test]
375    fn test_default() {
376        let array = HeapArray::<usize, U3>::default();
377        assert_eq!(array.len(), 3);
378    }
379
380    #[test]
381    fn test_heap_array_split_last() {
382        let array = HeapArray::<_, U6>::from_fn(|i| i);
383        let (first, last) = array.split_last::<U2>();
384        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
385        assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
386    }
387
388    #[test]
389    fn test_heap_array_from_array() {
390        let array = Array::<_, U3>::from_fn(|i| i);
391        let heap_array = HeapArray::<_, U3>::from(array);
392        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
393    }
394
395    #[test]
396    fn test_heap_array_from_vec() {
397        let vec = vec![0, 1, 2];
398        let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
399        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
400
401        let vec = vec![0, 1];
402        let heap_array = HeapArray::<_, U3>::try_from(vec);
403        assert!(heap_array.is_err());
404    }
405
406    #[test]
407    fn test_heap_array_from_iter() {
408        let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
409        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
410    }
411
412    #[test]
413    #[should_panic]
414    fn test_heap_array_from_iter_wrong_size() {
415        HeapArray::<_, U2>::from_iter(0..3);
416    }
417
418    #[test]
419    fn test_heap_array_deserialize() {
420        let array = HeapArray::<usize, U6>::from_fn(|i| i);
421        let serialized = bincode::serialize(&array).unwrap();
422        bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
423
424        let wrong_deserialize = bincode::deserialize::<HeapArray<usize, U3>>(&serialized);
425        assert!(wrong_deserialize.is_err());
426    }
427
428    #[test]
429    fn test_heap_array_split() {
430        let array = HeapArray::<_, U6>::from_fn(|i| i);
431        let (first, second) = array.split::<U4, U2>();
432        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
433        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
434
435        let (first, second) = array.split_halves::<U3>();
436        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
437        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
438
439        // let (first, second) = array.split::<U4, U1>(); --> doesn't compile
440
441        // let (first, second) = array.split_halves::<U2>(); --> doesn't compile
442    }
443
444    #[test]
445    fn test_heap_array_split3() {
446        let array = HeapArray::<_, U6>::from_fn(|i| i);
447        let (first, second, third) = array.split3::<U3, U2, U1>();
448        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
449        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
450        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
451
452        let (first, second, third) = array.split_thirds::<U2>();
453        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
454        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
455        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
456
457        // let (a1, a2, a3) = array.split3::<U3, U2, U2>(); // doesn't compile
458
459        // let (a1, a2, a3) = array.split_thirds::<U2>(); // doesn't compile
460    }
461}