Skip to main content

primitives/types/heap_array/
array.rs

1use std::{
2    fmt::{Debug, Display},
3    marker::PhantomData,
4    ops::{Add, Mul, Sub},
5    vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use rayon::iter::IntoParallelIterator;
12use serde::{Deserialize, Serialize};
13use typenum::{Diff, Prod, Sum, U1, U2, U3};
14
15use crate::{errors::PrimitiveError, types::Positive};
16
17/// An array on the heap that encodes its length in the type system.
18#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
19#[into_iterator(owned, ref, ref_mut)]
20pub struct HeapArray<T: Sized, M: Positive> {
21    #[deref]
22    #[deref_mut]
23    #[as_ref(forward)]
24    #[as_mut(forward)]
25    pub(super) data: Box<[T]>,
26    #[into_iterator(ignore)]
27    // `fn() -> M` is used instead of `M` so `HeapArray<T, M>` doesn't need `M` to implement `Send
28    // + Sync` to be `Send + Sync` itself. This would be the case if `M` was used directly.
29    pub(super) _len: PhantomData<fn() -> M>,
30}
31
32impl<T: Sized, M: Positive> HeapArray<T, M> {
33    pub(super) fn new(data: Box<[T]>) -> Self {
34        Self {
35            data,
36            _len: PhantomData,
37        }
38    }
39}
40
41// bytemuck::BoxBytes transformation for copy-less casting
42impl<T: Pod, M: Positive> HeapArray<T, M> {
43    pub fn into_box_bytes(self) -> BoxBytes {
44        box_bytes_of(self.data)
45    }
46
47    pub fn from_box_bytes(buf: BoxBytes) -> Self {
48        Self {
49            data: from_box_bytes(buf),
50            _len: PhantomData,
51        }
52    }
53}
54
55impl<T: Sized, M: Positive> HeapArray<T, M> {
56    pub fn map<F, U>(self, f: F) -> HeapArray<U, M>
57    where
58        F: FnMut(T) -> U,
59    {
60        self.into_iter().map(f).collect()
61    }
62}
63
64impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
65    pub fn from_single_value(val: T) -> Self {
66        Self {
67            data: vec![val].into_boxed_slice(),
68            _len: PhantomData,
69        }
70    }
71}
72
73impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
74    fn default() -> Self {
75        Self {
76            data: (0..M::USIZE).map(|_| T::default()).collect(),
77            _len: PhantomData,
78        }
79    }
80}
81
82impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
85            .field("data", &self.data)
86            .finish()
87    }
88}
89
90impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
91    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
92        use serde::ser::SerializeTuple;
93        let mut tuple = serializer.serialize_tuple(M::USIZE)?;
94        for element in self.data.iter() {
95            tuple.serialize_element(element)?;
96        }
97        tuple.end()
98    }
99}
100
101impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
102    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
103        struct HeapArrayVisitor<T, M: Positive> {
104            _phantom: PhantomData<(T, M)>,
105        }
106
107        impl<'de, T: Deserialize<'de>, M: Positive> serde::de::Visitor<'de> for HeapArrayVisitor<T, M> {
108            type Value = HeapArray<T, M>;
109
110            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
111                write!(formatter, "a tuple of {} elements", M::USIZE)
112            }
113
114            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
115            where
116                A: serde::de::SeqAccess<'de>,
117            {
118                let mut data = Vec::with_capacity(M::USIZE);
119                for i in 0..M::USIZE {
120                    let element = seq
121                        .next_element()?
122                        .ok_or_else(|| serde::de::Error::invalid_length(i, &self))?;
123                    data.push(element);
124                }
125
126                Ok(HeapArray {
127                    data: data.into_boxed_slice(),
128                    _len: PhantomData,
129                })
130            }
131        }
132
133        deserializer.deserialize_tuple(
134            M::USIZE,
135            HeapArrayVisitor {
136                _phantom: PhantomData,
137            },
138        )
139    }
140}
141
142impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
143    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
144        let data = iter.into_iter().collect::<Box<_>>();
145        assert_eq!(data.len(), M::USIZE,);
146        Self {
147            data,
148            _len: PhantomData,
149        }
150    }
151}
152
153impl<T: Sized + Send, M: Positive> IntoParallelIterator for HeapArray<T, M> {
154    type Item = T;
155    type Iter = rayon::vec::IntoIter<T>;
156
157    fn into_par_iter(self) -> Self::Iter {
158        self.data.into_par_iter()
159    }
160}
161
162// -----------------------
163// |   Split and Merge   |
164// -----------------------
165
166impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
167    pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
168    where
169        M1: Positive,
170        M2: Positive + Add<M1, Output = M>,
171    {
172        let (m1, m2) = self.split_at(M1::USIZE);
173        (
174            HeapArray::<T, M1> {
175                data: m1.into(),
176                _len: PhantomData,
177            },
178            HeapArray::<T, M2> {
179                data: m2.into(),
180                _len: PhantomData,
181            },
182        )
183    }
184
185    pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
186    where
187        MDiv2: Positive + Mul<U2, Output = M>,
188    {
189        let (m1, m2) = self.split_at(MDiv2::USIZE);
190        (
191            HeapArray::<T, MDiv2> {
192                data: m1.into(),
193                _len: PhantomData,
194            },
195            HeapArray::<T, MDiv2> {
196                data: m2.into(),
197                _len: PhantomData,
198            },
199        )
200    }
201
202    pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
203    where
204        M: Mul<U2, Output: Positive>,
205    {
206        let mut vec = this.data.into_vec();
207        vec.extend(other.data.into_vec());
208        HeapArray::<T, Prod<M, U2>> {
209            data: vec.into_boxed_slice(),
210            _len: PhantomData,
211        }
212    }
213
214    pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
215    where
216        M1: Positive,
217        M2: Positive + Add<M1>,
218        M3: Positive + Add<Sum<M2, M1>, Output = M>,
219    {
220        let (m1, m_rest) = self.split_at(M1::USIZE);
221        let (m2, m3) = m_rest.split_at(M2::USIZE);
222        (
223            HeapArray::<T, M1> {
224                data: m1.into(),
225                _len: PhantomData,
226            },
227            HeapArray::<T, M2> {
228                data: m2.into(),
229                _len: PhantomData,
230            },
231            HeapArray::<T, M3> {
232                data: m3.into(),
233                _len: PhantomData,
234            },
235        )
236    }
237
238    pub fn split_thirds<MDiv3>(
239        &self,
240    ) -> (
241        HeapArray<T, MDiv3>,
242        HeapArray<T, MDiv3>,
243        HeapArray<T, MDiv3>,
244    )
245    where
246        MDiv3: Positive + Mul<U3, Output = M>,
247    {
248        let (m1, m_rest) = self.split_at(MDiv3::USIZE);
249        let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
250        (
251            HeapArray::<T, MDiv3> {
252                data: m1.into(),
253                _len: PhantomData,
254            },
255            HeapArray::<T, MDiv3> {
256                data: m2.into(),
257                _len: PhantomData,
258            },
259            HeapArray::<T, MDiv3> {
260                data: m3.into(),
261                _len: PhantomData,
262            },
263        )
264    }
265
266    pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
267    where
268        M: Mul<U3, Output: Positive>,
269    {
270        let mut vec = first.data.into_vec();
271        vec.extend(second.data.into_vec());
272        vec.extend(third.data.into_vec());
273        HeapArray::<T, Prod<M, U3>> {
274            data: vec.into_boxed_slice(),
275            _len: PhantomData,
276        }
277    }
278}
279
280pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
281    pub HeapArray<T1, M>,
282    pub HeapArray<T2, M>,
283);
284
285impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
286    fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
287        let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
288
289        assert_eq!(data1.len(), M::USIZE);
290        assert_eq!(data2.len(), M::USIZE);
291        HeapArrayTuple(
292            HeapArray::<T1, M> {
293                data: data1.into_boxed_slice(),
294                _len: PhantomData,
295            },
296            HeapArray::<T2, M> {
297                data: data2.into_boxed_slice(),
298                _len: PhantomData,
299            },
300        )
301    }
302}
303
304impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
305    type Error = PrimitiveError;
306
307    fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
308        if data.len() == M::USIZE {
309            Ok(Self {
310                data: data.into_boxed_slice(),
311                _len: PhantomData,
312            })
313        } else {
314            Err(PrimitiveError::InvalidSize(M::USIZE, data.len()))
315        }
316    }
317}
318
319impl<T: Sized> From<T> for HeapArray<T, U1> {
320    fn from(element: T) -> Self {
321        Self {
322            data: Box::new([element]),
323            _len: PhantomData,
324        }
325    }
326}
327
328impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
329    fn from(array: HeapArray<T, M>) -> Self {
330        array.data.into_vec()
331    }
332}
333
334impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
335    fn from(array: Array<T, M>) -> Self {
336        Self {
337            data: array.to_vec().into_boxed_slice(),
338            _len: PhantomData,
339        }
340    }
341}
342
343impl<T: Sized, M: Positive> HeapArray<T, M> {
344    pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
345    where
346        M: Sub<N, Output: Positive>,
347    {
348        let mut vec = self.data.into_vec();
349        let last_n = vec.split_off(M::USIZE - N::USIZE);
350
351        (
352            HeapArray {
353                data: vec.into_boxed_slice(),
354                _len: PhantomData,
355            },
356            HeapArray {
357                data: last_n.into_boxed_slice(),
358                _len: PhantomData,
359            },
360        )
361    }
362
363    pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
364        Self {
365            data: (0..M::USIZE).map(f).collect::<Box<_>>(),
366            _len: PhantomData,
367        }
368    }
369
370    pub fn from_constant(c: T) -> Self
371    where
372        T: Copy,
373    {
374        Self {
375            data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
376            _len: PhantomData,
377        }
378    }
379}
380
381impl<T: Sized, M: Positive> Display for HeapArray<T, M>
382where
383    T: Display,
384{
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        write!(f, "[")?;
387        for (i, item) in self.data.iter().enumerate() {
388            if i != 0 {
389                write!(f, ", ")?;
390            }
391            write!(f, "{item}")?;
392        }
393        write!(f, "]")
394    }
395}
396
397#[cfg(test)]
398pub mod tests {
399    use hybrid_array::sizes::{U2, U3, U6};
400    use typenum::{U1, U4};
401
402    use super::*;
403
404    #[test]
405    fn test_heap_array() {
406        let array = HeapArray::<_, U3>::from_fn(|i| i);
407        assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
408    }
409
410    #[test]
411    fn test_default() {
412        let array = HeapArray::<usize, U3>::default();
413        assert_eq!(array.len(), 3);
414    }
415
416    #[test]
417    fn test_heap_array_split_last() {
418        let array = HeapArray::<_, U6>::from_fn(|i| i);
419        let (first, last) = array.split_last::<U2>();
420        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
421        assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
422    }
423
424    #[test]
425    fn test_heap_array_from_array() {
426        let array = Array::<_, U3>::from_fn(|i| i);
427        let heap_array = HeapArray::<_, U3>::from(array);
428        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
429    }
430
431    #[test]
432    fn test_heap_array_from_vec() {
433        let vec = vec![0, 1, 2];
434        let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
435        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
436
437        let vec = vec![0, 1];
438        let heap_array = HeapArray::<_, U3>::try_from(vec);
439        assert!(heap_array.is_err());
440    }
441
442    #[test]
443    fn test_heap_array_from_iter() {
444        let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
445        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
446    }
447
448    #[test]
449    #[should_panic]
450    fn test_heap_array_from_iter_wrong_size() {
451        HeapArray::<_, U2>::from_iter(0..3);
452    }
453
454    #[test]
455    fn test_heap_array_deserialize() {
456        let array = HeapArray::<usize, U6>::from_fn(|i| i);
457        let serialized = bincode::serialize(&array).unwrap();
458        bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
459
460        // With the new tuple-based serialization (which doesn't include length),
461        // we can't detect length mismatches unless we use strict deserialization
462        // that rejects trailing bytes.
463        use bincode::Options;
464        let config = bincode::DefaultOptions::new()
465            .with_fixint_encoding()
466            .reject_trailing_bytes();
467
468        let wrong_deserialize = config.deserialize::<HeapArray<usize, U3>>(&serialized);
469        assert!(wrong_deserialize.is_err());
470    }
471
472    #[test]
473    fn test_heap_array_split() {
474        let array = HeapArray::<_, U6>::from_fn(|i| i);
475        let (first, second) = array.split::<U4, U2>();
476        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
477        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
478
479        let (first, second) = array.split_halves::<U3>();
480        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
481        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
482
483        // let (first, second) = array.split::<U4, U1>(); --> doesn't compile
484
485        // let (first, second) = array.split_halves::<U2>(); --> doesn't compile
486    }
487
488    #[test]
489    fn test_heap_array_split3() {
490        let array = HeapArray::<_, U6>::from_fn(|i| i);
491        let (first, second, third) = array.split3::<U3, U2, U1>();
492        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
493        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
494        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
495
496        let (first, second, third) = array.split_thirds::<U2>();
497        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
498        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
499        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
500
501        // let (a1, a2, a3) = array.split3::<U3, U2, U2>(); // doesn't compile
502
503        // let (a1, a2, a3) = array.split_thirds::<U2>(); // doesn't compile
504    }
505}