primitives/types/heap_array/
array.rs

1use std::{
2    fmt::{Debug, Display},
3    marker::PhantomData,
4    ops::{Add, Mul, Sub},
5    vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17/// An array on the heap that encodes its length in the type system.
18#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21    #[deref]
22    #[deref_mut]
23    #[as_ref(forward)]
24    #[as_mut(forward)]
25    pub(super) data: Box<[T]>,
26    #[into_iterator(ignore)]
27    // `fn() -> M` is used instead of `M` so `HeapArray<T, M>` doesn't need `M` to implement `Send
28    // + Sync` to be `Send + Sync` itself. This would be the case if `M` was used directly.
29    pub(super) _len: PhantomData<fn() -> M>,
30}
31
32impl<T: Sized, M: Positive> HeapArray<T, M> {
33    pub(super) fn new(data: Box<[T]>) -> Self {
34        Self {
35            data,
36            _len: PhantomData,
37        }
38    }
39}
40
41// bytemuck::BoxBytes transformation for copy-less casting
42impl<T: Pod, M: Positive> HeapArray<T, M> {
43    pub fn into_box_bytes(self) -> BoxBytes {
44        box_bytes_of(self.data)
45    }
46
47    pub fn from_box_bytes(buf: BoxBytes) -> Self {
48        Self {
49            data: from_box_bytes(buf),
50            _len: PhantomData,
51        }
52    }
53}
54
55impl<T: Sized, M: Positive> HeapArray<T, M> {
56    pub fn map<F, U>(self, f: F) -> HeapArray<U, M>
57    where
58        F: FnMut(T) -> U,
59    {
60        self.into_iter().map(f).collect()
61    }
62}
63
64impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
65    pub fn from_single_value(val: T) -> Self {
66        Self {
67            data: vec![val].into_boxed_slice(),
68            _len: PhantomData,
69        }
70    }
71}
72
73impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
74    fn default() -> Self {
75        Self {
76            data: (0..M::USIZE).map(|_| T::default()).collect(),
77            _len: PhantomData,
78        }
79    }
80}
81
82impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
85            .field("data", &self.data)
86            .finish()
87    }
88}
89
90impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
91    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
92        self.data.serialize(serializer)
93    }
94}
95
96impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
97    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
98        let data = Box::<[T]>::deserialize(deserializer)?;
99
100        if data.len() != M::USIZE {
101            return Err(serde::de::Error::custom(format!(
102                "Expected array of length {}, got {}",
103                M::USIZE,
104                data.len()
105            )));
106        }
107
108        Ok(Self {
109            data,
110            _len: PhantomData,
111        })
112    }
113}
114
115impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
116    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
117        let data = iter.into_iter().collect::<Box<_>>();
118        assert_eq!(data.len(), M::USIZE,);
119        Self {
120            data,
121            _len: PhantomData,
122        }
123    }
124}
125
126impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
127    type Item = T;
128    type Iter = rayon::vec::IntoIter<T>;
129
130    fn into_par_iter(self) -> Self::Iter {
131        self.data.into_par_iter()
132    }
133}
134
135// -----------------------
136// |   Split and Merge   |
137// -----------------------
138
139impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
140    pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
141    where
142        M1: Positive,
143        M2: Positive + Add<M1, Output = M>,
144    {
145        let (m1, m2) = self.split_at(M1::USIZE);
146        (
147            HeapArray::<T, M1> {
148                data: m1.into(),
149                _len: PhantomData,
150            },
151            HeapArray::<T, M2> {
152                data: m2.into(),
153                _len: PhantomData,
154            },
155        )
156    }
157
158    pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
159    where
160        MDiv2: Positive + Mul<U2, Output = M>,
161    {
162        let (m1, m2) = self.split_at(MDiv2::USIZE);
163        (
164            HeapArray::<T, MDiv2> {
165                data: m1.into(),
166                _len: PhantomData,
167            },
168            HeapArray::<T, MDiv2> {
169                data: m2.into(),
170                _len: PhantomData,
171            },
172        )
173    }
174
175    pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
176    where
177        M: Mul<U2, Output: Positive>,
178    {
179        let mut vec = this.data.into_vec();
180        vec.extend(other.data.into_vec());
181        HeapArray::<T, Prod<M, U2>> {
182            data: vec.into_boxed_slice(),
183            _len: PhantomData,
184        }
185    }
186
187    pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
188    where
189        M1: Positive,
190        M2: Positive + Add<M1>,
191        M3: Positive + Add<Sum<M2, M1>, Output = M>,
192    {
193        let (m1, m_rest) = self.split_at(M1::USIZE);
194        let (m2, m3) = m_rest.split_at(M2::USIZE);
195        (
196            HeapArray::<T, M1> {
197                data: m1.into(),
198                _len: PhantomData,
199            },
200            HeapArray::<T, M2> {
201                data: m2.into(),
202                _len: PhantomData,
203            },
204            HeapArray::<T, M3> {
205                data: m3.into(),
206                _len: PhantomData,
207            },
208        )
209    }
210
211    pub fn split_thirds<MDiv3>(
212        &self,
213    ) -> (
214        HeapArray<T, MDiv3>,
215        HeapArray<T, MDiv3>,
216        HeapArray<T, MDiv3>,
217    )
218    where
219        MDiv3: Positive + Mul<U3, Output = M>,
220    {
221        let (m1, m_rest) = self.split_at(MDiv3::USIZE);
222        let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
223        (
224            HeapArray::<T, MDiv3> {
225                data: m1.into(),
226                _len: PhantomData,
227            },
228            HeapArray::<T, MDiv3> {
229                data: m2.into(),
230                _len: PhantomData,
231            },
232            HeapArray::<T, MDiv3> {
233                data: m3.into(),
234                _len: PhantomData,
235            },
236        )
237    }
238
239    pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
240    where
241        M: Mul<U3, Output: Positive>,
242    {
243        let mut vec = first.data.into_vec();
244        vec.extend(second.data.into_vec());
245        vec.extend(third.data.into_vec());
246        HeapArray::<T, Prod<M, U3>> {
247            data: vec.into_boxed_slice(),
248            _len: PhantomData,
249        }
250    }
251}
252
253pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
254    pub HeapArray<T1, M>,
255    pub HeapArray<T2, M>,
256);
257
258impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
259    fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
260        let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
261
262        assert_eq!(data1.len(), M::USIZE);
263        assert_eq!(data2.len(), M::USIZE);
264        HeapArrayTuple(
265            HeapArray::<T1, M> {
266                data: data1.into_boxed_slice(),
267                _len: PhantomData,
268            },
269            HeapArray::<T2, M> {
270                data: data2.into_boxed_slice(),
271                _len: PhantomData,
272            },
273        )
274    }
275}
276
277impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
278    type Error = PrimitiveError;
279
280    fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
281        if data.len() == M::USIZE {
282            Ok(Self {
283                data: data.into_boxed_slice(),
284                _len: PhantomData,
285            })
286        } else {
287            Err(PrimitiveError::InvalidSize(M::USIZE, data.len()))
288        }
289    }
290}
291
292impl<T: Sized> From<T> for HeapArray<T, U1> {
293    fn from(element: T) -> Self {
294        Self {
295            data: Box::new([element]),
296            _len: PhantomData,
297        }
298    }
299}
300
301impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
302    fn from(array: HeapArray<T, M>) -> Self {
303        array.data.into_vec()
304    }
305}
306
307impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
308    fn from(array: Array<T, M>) -> Self {
309        Self {
310            data: array.to_vec().into_boxed_slice(),
311            _len: PhantomData,
312        }
313    }
314}
315
316impl<T: Sized, M: Positive> HeapArray<T, M> {
317    pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
318    where
319        M: Sub<N, Output: Positive>,
320    {
321        let mut vec = self.data.into_vec();
322        let last_n = vec.split_off(M::USIZE - N::USIZE);
323
324        (
325            HeapArray {
326                data: vec.into_boxed_slice(),
327                _len: PhantomData,
328            },
329            HeapArray {
330                data: last_n.into_boxed_slice(),
331                _len: PhantomData,
332            },
333        )
334    }
335
336    pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
337        Self {
338            data: (0..M::USIZE).map(f).collect::<Box<_>>(),
339            _len: PhantomData,
340        }
341    }
342
343    pub fn from_constant(c: T) -> Self
344    where
345        T: Copy,
346    {
347        Self {
348            data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
349            _len: PhantomData,
350        }
351    }
352}
353
354impl<T: Sized, M: Positive> Display for HeapArray<T, M>
355where
356    T: Display,
357{
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        write!(f, "[")?;
360        for (i, item) in self.data.iter().enumerate() {
361            if i != 0 {
362                write!(f, ", ")?;
363            }
364            write!(f, "{item}")?;
365        }
366        write!(f, "]")
367    }
368}
369
370#[cfg(test)]
371pub mod tests {
372    use hybrid_array::sizes::{U2, U3, U6};
373    use typenum::{U1, U4};
374
375    use super::*;
376
377    #[test]
378    fn test_heap_array() {
379        let array = HeapArray::<_, U3>::from_fn(|i| i);
380        assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
381    }
382
383    #[test]
384    fn test_default() {
385        let array = HeapArray::<usize, U3>::default();
386        assert_eq!(array.len(), 3);
387    }
388
389    #[test]
390    fn test_heap_array_split_last() {
391        let array = HeapArray::<_, U6>::from_fn(|i| i);
392        let (first, last) = array.split_last::<U2>();
393        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
394        assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
395    }
396
397    #[test]
398    fn test_heap_array_from_array() {
399        let array = Array::<_, U3>::from_fn(|i| i);
400        let heap_array = HeapArray::<_, U3>::from(array);
401        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
402    }
403
404    #[test]
405    fn test_heap_array_from_vec() {
406        let vec = vec![0, 1, 2];
407        let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
408        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
409
410        let vec = vec![0, 1];
411        let heap_array = HeapArray::<_, U3>::try_from(vec);
412        assert!(heap_array.is_err());
413    }
414
415    #[test]
416    fn test_heap_array_from_iter() {
417        let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
418        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
419    }
420
421    #[test]
422    #[should_panic]
423    fn test_heap_array_from_iter_wrong_size() {
424        HeapArray::<_, U2>::from_iter(0..3);
425    }
426
427    #[test]
428    fn test_heap_array_deserialize() {
429        let array = HeapArray::<usize, U6>::from_fn(|i| i);
430        let serialized = bincode::serialize(&array).unwrap();
431        bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
432
433        let wrong_deserialize = bincode::deserialize::<HeapArray<usize, U3>>(&serialized);
434        assert!(wrong_deserialize.is_err());
435    }
436
437    #[test]
438    fn test_heap_array_split() {
439        let array = HeapArray::<_, U6>::from_fn(|i| i);
440        let (first, second) = array.split::<U4, U2>();
441        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
442        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
443
444        let (first, second) = array.split_halves::<U3>();
445        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
446        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
447
448        // let (first, second) = array.split::<U4, U1>(); --> doesn't compile
449
450        // let (first, second) = array.split_halves::<U2>(); --> doesn't compile
451    }
452
453    #[test]
454    fn test_heap_array_split3() {
455        let array = HeapArray::<_, U6>::from_fn(|i| i);
456        let (first, second, third) = array.split3::<U3, U2, U1>();
457        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
458        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
459        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
460
461        let (first, second, third) = array.split_thirds::<U2>();
462        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
463        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
464        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
465
466        // let (a1, a2, a3) = array.split3::<U3, U2, U2>(); // doesn't compile
467
468        // let (a1, a2, a3) = array.split_thirds::<U2>(); // doesn't compile
469    }
470}