primitives/types/heap_array/
array.rs

1use std::{
2    fmt::{Debug, Display},
3    marker::PhantomData,
4    ops::{Add, Mul, Sub},
5    vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17/// An array on the heap that encodes its length in the type system.
18#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21    #[deref]
22    #[deref_mut]
23    #[as_ref(forward)]
24    #[as_mut(forward)]
25    pub(super) data: Box<[T]>,
26    #[into_iterator(ignore)]
27    // `fn() -> M` is used instead of `M` so `HeapArray<T, M>` doesn't need `M` to implement `Send
28    // + Sync` to be `Send + Sync` itself. This would be the case if `M` was used directly.
29    pub(super) _len: PhantomData<fn() -> M>,
30}
31
32// bytemuck::BoxBytes transformation for copy-less casting
33impl<T: Pod, M: Positive> HeapArray<T, M> {
34    pub fn into_box_bytes(self) -> BoxBytes {
35        box_bytes_of(self.data)
36    }
37
38    pub fn from_box_bytes(buf: BoxBytes) -> Self {
39        Self {
40            data: from_box_bytes(buf),
41            _len: PhantomData,
42        }
43    }
44}
45
46impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
47    pub fn from_single_value(val: T) -> Self {
48        Self {
49            data: vec![val].into_boxed_slice(),
50            _len: PhantomData,
51        }
52    }
53}
54
55impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
56    fn default() -> Self {
57        Self {
58            data: (0..M::USIZE).map(|_| T::default()).collect(),
59            _len: PhantomData,
60        }
61    }
62}
63
64impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
67            .field("data", &self.data)
68            .finish()
69    }
70}
71
72impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
73    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
74        self.data.serialize(serializer)
75    }
76}
77
78impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
79    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
80        let data = Box::<[T]>::deserialize(deserializer)?;
81
82        if data.len() != M::USIZE {
83            return Err(serde::de::Error::custom(format!(
84                "Expected array of length {}, got {}",
85                M::USIZE,
86                data.len()
87            )));
88        }
89
90        Ok(Self {
91            data,
92            _len: PhantomData,
93        })
94    }
95}
96
97impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
98    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
99        let data = iter.into_iter().collect::<Box<_>>();
100        assert_eq!(data.len(), M::USIZE,);
101        Self {
102            data,
103            _len: PhantomData,
104        }
105    }
106}
107
108impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
109    type Item = T;
110    type Iter = rayon::vec::IntoIter<T>;
111
112    fn into_par_iter(self) -> Self::Iter {
113        self.data.into_par_iter()
114    }
115}
116
117// -----------------------
118// |   Split and Merge   |
119// -----------------------
120
121impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
122    pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
123    where
124        M1: Positive,
125        M2: Positive + Add<M1, Output = M>,
126    {
127        let (m1, m2) = self.split_at(M1::USIZE);
128        (
129            HeapArray::<T, M1> {
130                data: m1.into(),
131                _len: PhantomData,
132            },
133            HeapArray::<T, M2> {
134                data: m2.into(),
135                _len: PhantomData,
136            },
137        )
138    }
139
140    pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
141    where
142        MDiv2: Positive + Mul<U2, Output = M>,
143    {
144        let (m1, m2) = self.split_at(MDiv2::USIZE);
145        (
146            HeapArray::<T, MDiv2> {
147                data: m1.into(),
148                _len: PhantomData,
149            },
150            HeapArray::<T, MDiv2> {
151                data: m2.into(),
152                _len: PhantomData,
153            },
154        )
155    }
156
157    pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
158    where
159        M: Mul<U2, Output: Positive>,
160    {
161        let mut vec = this.data.into_vec();
162        vec.extend(other.data.into_vec());
163        HeapArray::<T, Prod<M, U2>> {
164            data: vec.into_boxed_slice(),
165            _len: PhantomData,
166        }
167    }
168
169    pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
170    where
171        M1: Positive,
172        M2: Positive + Add<M1>,
173        M3: Positive + Add<Sum<M2, M1>, Output = M>,
174    {
175        let (m1, m_rest) = self.split_at(M1::USIZE);
176        let (m2, m3) = m_rest.split_at(M2::USIZE);
177        (
178            HeapArray::<T, M1> {
179                data: m1.into(),
180                _len: PhantomData,
181            },
182            HeapArray::<T, M2> {
183                data: m2.into(),
184                _len: PhantomData,
185            },
186            HeapArray::<T, M3> {
187                data: m3.into(),
188                _len: PhantomData,
189            },
190        )
191    }
192
193    pub fn split_thirds<MDiv3>(
194        &self,
195    ) -> (
196        HeapArray<T, MDiv3>,
197        HeapArray<T, MDiv3>,
198        HeapArray<T, MDiv3>,
199    )
200    where
201        MDiv3: Positive + Mul<U3, Output = M>,
202    {
203        let (m1, m_rest) = self.split_at(MDiv3::USIZE);
204        let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
205        (
206            HeapArray::<T, MDiv3> {
207                data: m1.into(),
208                _len: PhantomData,
209            },
210            HeapArray::<T, MDiv3> {
211                data: m2.into(),
212                _len: PhantomData,
213            },
214            HeapArray::<T, MDiv3> {
215                data: m3.into(),
216                _len: PhantomData,
217            },
218        )
219    }
220
221    pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
222    where
223        M: Mul<U3, Output: Positive>,
224    {
225        let mut vec = first.data.into_vec();
226        vec.extend(second.data.into_vec());
227        vec.extend(third.data.into_vec());
228        HeapArray::<T, Prod<M, U3>> {
229            data: vec.into_boxed_slice(),
230            _len: PhantomData,
231        }
232    }
233}
234
235pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
236    pub HeapArray<T1, M>,
237    pub HeapArray<T2, M>,
238);
239
240impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
241    fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
242        let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
243
244        assert_eq!(data1.len(), M::USIZE);
245        assert_eq!(data2.len(), M::USIZE);
246        HeapArrayTuple(
247            HeapArray::<T1, M> {
248                data: data1.into_boxed_slice(),
249                _len: PhantomData,
250            },
251            HeapArray::<T2, M> {
252                data: data2.into_boxed_slice(),
253                _len: PhantomData,
254            },
255        )
256    }
257}
258
259impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
260    type Error = PrimitiveError;
261
262    fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
263        if data.len() == M::USIZE {
264            Ok(Self {
265                data: data.into_boxed_slice(),
266                _len: PhantomData,
267            })
268        } else {
269            Err(PrimitiveError::SizeError(M::USIZE, data.len()))
270        }
271    }
272}
273
274impl<T: Sized> From<T> for HeapArray<T, U1> {
275    fn from(element: T) -> Self {
276        Self {
277            data: Box::new([element]),
278            _len: PhantomData,
279        }
280    }
281}
282
283impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
284    fn from(array: HeapArray<T, M>) -> Self {
285        array.data.into_vec()
286    }
287}
288
289impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
290    fn from(array: Array<T, M>) -> Self {
291        Self {
292            data: array.to_vec().into_boxed_slice(),
293            _len: PhantomData,
294        }
295    }
296}
297
298impl<T: Sized, M: Positive> HeapArray<T, M> {
299    pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
300    where
301        M: Sub<N, Output: Positive>,
302    {
303        let mut vec = self.data.into_vec();
304        let last_n = vec.split_off(M::USIZE - N::USIZE);
305
306        (
307            HeapArray {
308                data: vec.into_boxed_slice(),
309                _len: PhantomData,
310            },
311            HeapArray {
312                data: last_n.into_boxed_slice(),
313                _len: PhantomData,
314            },
315        )
316    }
317
318    pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
319        Self {
320            data: (0..M::USIZE).map(f).collect::<Box<_>>(),
321            _len: PhantomData,
322        }
323    }
324
325    pub fn from_constant(c: T) -> Self
326    where
327        T: Copy,
328    {
329        Self {
330            data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
331            _len: PhantomData,
332        }
333    }
334}
335
336impl<T: Sized, M: Positive> Display for HeapArray<T, M>
337where
338    T: Display,
339{
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        write!(f, "[")?;
342        for (i, item) in self.data.iter().enumerate() {
343            if i != 0 {
344                write!(f, ", ")?;
345            }
346            write!(f, "{item}")?;
347        }
348        write!(f, "]")
349    }
350}
351
352#[cfg(test)]
353pub mod tests {
354    use hybrid_array::sizes::{U2, U3, U6};
355    use typenum::{U1, U4};
356
357    use super::*;
358
359    #[test]
360    fn test_heap_array() {
361        let array = HeapArray::<_, U3>::from_fn(|i| i);
362        assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
363    }
364
365    #[test]
366    fn test_default() {
367        let array = HeapArray::<usize, U3>::default();
368        assert_eq!(array.len(), 3);
369    }
370
371    #[test]
372    fn test_heap_array_split_last() {
373        let array = HeapArray::<_, U6>::from_fn(|i| i);
374        let (first, last) = array.split_last::<U2>();
375        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
376        assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
377    }
378
379    #[test]
380    fn test_heap_array_from_array() {
381        let array = Array::<_, U3>::from_fn(|i| i);
382        let heap_array = HeapArray::<_, U3>::from(array);
383        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
384    }
385
386    #[test]
387    fn test_heap_array_from_vec() {
388        let vec = vec![0, 1, 2];
389        let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
390        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
391
392        let vec = vec![0, 1];
393        let heap_array = HeapArray::<_, U3>::try_from(vec);
394        assert!(heap_array.is_err());
395    }
396
397    #[test]
398    fn test_heap_array_from_iter() {
399        let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
400        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
401    }
402
403    #[test]
404    #[should_panic]
405    fn test_heap_array_from_iter_wrong_size() {
406        HeapArray::<_, U2>::from_iter(0..3);
407    }
408
409    #[test]
410    fn test_heap_array_deserialize() {
411        let array = HeapArray::<usize, U6>::from_fn(|i| i);
412        let serialized = bincode::serialize(&array).unwrap();
413        bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
414
415        let wrong_deserialize = bincode::deserialize::<HeapArray<usize, U3>>(&serialized);
416        assert!(wrong_deserialize.is_err());
417    }
418
419    #[test]
420    fn test_heap_array_split() {
421        let array = HeapArray::<_, U6>::from_fn(|i| i);
422        let (first, second) = array.split::<U4, U2>();
423        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
424        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
425
426        let (first, second) = array.split_halves::<U3>();
427        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
428        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
429
430        // let (first, second) = array.split::<U4, U1>(); --> doesn't compile
431
432        // let (first, second) = array.split_halves::<U2>(); --> doesn't compile
433    }
434
435    #[test]
436    fn test_heap_array_split3() {
437        let array = HeapArray::<_, U6>::from_fn(|i| i);
438        let (first, second, third) = array.split3::<U3, U2, U1>();
439        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
440        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
441        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
442
443        let (first, second, third) = array.split_thirds::<U2>();
444        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
445        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
446        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
447
448        // let (a1, a2, a3) = array.split3::<U3, U2, U2>(); // doesn't compile
449
450        // let (a1, a2, a3) = array.split_thirds::<U2>(); // doesn't compile
451    }
452}