primitives/types/heap_array/
array.rs

1use std::{
2    fmt::{Debug, Display},
3    marker::PhantomData,
4    ops::{Add, Mul, Sub},
5    vec,
6};
7
8use bytemuck::{box_bytes_of, from_box_bytes, BoxBytes, Pod};
9use derive_more::derive::{AsMut, AsRef, Deref, DerefMut, IntoIterator};
10use hybrid_array::{Array, ArraySize};
11use serde::{Deserialize, Serialize};
12use typenum::{Diff, Prod, Sum, U1, U2, U3};
13
14use crate::{errors::PrimitiveError, types::Positive};
15
16/// An array on the heap that encodes its length in the type system.
17#[derive(Deref, DerefMut, Clone, IntoIterator, AsRef, AsMut, Eq)]
18#[into_iterator(owned, ref, ref_mut)]
19pub struct HeapArray<T: Sized, M: Positive> {
20    #[deref]
21    #[deref_mut]
22    #[as_ref(forward)]
23    #[as_mut(forward)]
24    pub(super) data: Box<[T]>,
25    #[into_iterator(ignore)]
26    // `fn() -> M` is used instead of `M` so `HeapArray<T, M>` doesn't need `M` to implement `Send
27    // + Sync` to be `Send + Sync` itself. This would be the case if `M` was used directly.
28    pub(super) _len: PhantomData<fn() -> M>,
29}
30
31// bytemuck::BoxBytes transformation for copy-less casting
32impl<T: Pod, M: Positive> HeapArray<T, M> {
33    pub fn into_box_bytes(self) -> BoxBytes {
34        box_bytes_of(self.data)
35    }
36
37    pub fn from_box_bytes(buf: BoxBytes) -> Self {
38        Self {
39            data: from_box_bytes(buf),
40            _len: PhantomData,
41        }
42    }
43}
44
45impl<T: Sized + Default, M: Positive> HeapArray<T, M> {
46    pub fn from_single_value(val: T) -> Self {
47        Self {
48            data: vec![val].into_boxed_slice(),
49            _len: PhantomData,
50        }
51    }
52}
53
54impl<T: Sized + Default, M: Positive> Default for HeapArray<T, M> {
55    fn default() -> Self {
56        Self {
57            data: (0..M::USIZE).map(|_| T::default()).collect(),
58            _len: PhantomData,
59        }
60    }
61}
62
63impl<T: Sized + Debug, M: Positive> Debug for HeapArray<T, M> {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct(format!("HeapArray<{}>", M::USIZE).as_str())
66            .field("data", &self.data)
67            .finish()
68    }
69}
70
71impl<T: Sized + Serialize, M: Positive> Serialize for HeapArray<T, M> {
72    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
73        self.data.serialize(serializer)
74    }
75}
76
77impl<'de, T: Sized + Deserialize<'de>, M: Positive> Deserialize<'de> for HeapArray<T, M> {
78    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
79        let data = Box::<[T]>::deserialize(deserializer)?;
80
81        if data.len() != M::USIZE {
82            return Err(serde::de::Error::custom(format!(
83                "Expected array of length {}, got {}",
84                M::USIZE,
85                data.len()
86            )));
87        }
88
89        Ok(Self {
90            data,
91            _len: PhantomData,
92        })
93    }
94}
95
96impl<T: Sized, M: Positive> FromIterator<T> for HeapArray<T, M> {
97    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
98        let data = iter.into_iter().collect::<Box<_>>();
99        assert_eq!(data.len(), M::USIZE);
100        Self {
101            data,
102            _len: PhantomData,
103        }
104    }
105}
106
107// -----------------------
108// |   Split and Merge   |
109// -----------------------
110
111impl<T: Sized + Copy, M: Positive> HeapArray<T, M> {
112    pub fn split<M1, M2>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>)
113    where
114        M1: Positive,
115        M2: Positive + Add<M1, Output = M>,
116    {
117        let (m1, m2) = self.split_at(M1::USIZE);
118        (
119            HeapArray::<T, M1> {
120                data: m1.into(),
121                _len: PhantomData,
122            },
123            HeapArray::<T, M2> {
124                data: m2.into(),
125                _len: PhantomData,
126            },
127        )
128    }
129
130    pub fn split_halves<MDiv2>(&self) -> (HeapArray<T, MDiv2>, HeapArray<T, MDiv2>)
131    where
132        MDiv2: Positive + Mul<U2, Output = M>,
133    {
134        let (m1, m2) = self.split_at(MDiv2::USIZE);
135        (
136            HeapArray::<T, MDiv2> {
137                data: m1.into(),
138                _len: PhantomData,
139            },
140            HeapArray::<T, MDiv2> {
141                data: m2.into(),
142                _len: PhantomData,
143            },
144        )
145    }
146
147    pub fn merge_halves(this: Self, other: Self) -> HeapArray<T, Prod<M, U2>>
148    where
149        M: Mul<U2, Output: Positive>,
150    {
151        let mut vec = this.data.into_vec();
152        vec.extend(other.data.into_vec());
153        HeapArray::<T, Prod<M, U2>> {
154            data: vec.into_boxed_slice(),
155            _len: PhantomData,
156        }
157    }
158
159    pub fn split3<M1, M2, M3>(&self) -> (HeapArray<T, M1>, HeapArray<T, M2>, HeapArray<T, M3>)
160    where
161        M1: Positive,
162        M2: Positive + Add<M1>,
163        M3: Positive + Add<Sum<M2, M1>, Output = M>,
164    {
165        let (m1, m_rest) = self.split_at(M1::USIZE);
166        let (m2, m3) = m_rest.split_at(M2::USIZE);
167        (
168            HeapArray::<T, M1> {
169                data: m1.into(),
170                _len: PhantomData,
171            },
172            HeapArray::<T, M2> {
173                data: m2.into(),
174                _len: PhantomData,
175            },
176            HeapArray::<T, M3> {
177                data: m3.into(),
178                _len: PhantomData,
179            },
180        )
181    }
182
183    pub fn split_thirds<MDiv3>(
184        &self,
185    ) -> (
186        HeapArray<T, MDiv3>,
187        HeapArray<T, MDiv3>,
188        HeapArray<T, MDiv3>,
189    )
190    where
191        MDiv3: Positive + Mul<U3, Output = M>,
192    {
193        let (m1, m_rest) = self.split_at(MDiv3::USIZE);
194        let (m2, m3) = m_rest.split_at(MDiv3::USIZE);
195        (
196            HeapArray::<T, MDiv3> {
197                data: m1.into(),
198                _len: PhantomData,
199            },
200            HeapArray::<T, MDiv3> {
201                data: m2.into(),
202                _len: PhantomData,
203            },
204            HeapArray::<T, MDiv3> {
205                data: m3.into(),
206                _len: PhantomData,
207            },
208        )
209    }
210
211    pub fn merge_thirds(first: Self, second: Self, third: Self) -> HeapArray<T, Prod<M, U3>>
212    where
213        M: Mul<U3, Output: Positive>,
214    {
215        let mut vec = first.data.into_vec();
216        vec.extend(second.data.into_vec());
217        vec.extend(third.data.into_vec());
218        HeapArray::<T, Prod<M, U3>> {
219            data: vec.into_boxed_slice(),
220            _len: PhantomData,
221        }
222    }
223}
224
225pub struct HeapArrayTuple<T1: Sized, T2: Sized, M: Positive>(
226    pub HeapArray<T1, M>,
227    pub HeapArray<T2, M>,
228);
229
230impl<T1: Sized, T2: Sized, M: Positive> FromIterator<(T1, T2)> for HeapArrayTuple<T1, T2, M> {
231    fn from_iter<I: IntoIterator<Item = (T1, T2)>>(iter: I) -> Self {
232        let (data1, data2): (Vec<_>, Vec<_>) = iter.into_iter().unzip();
233
234        assert_eq!(data1.len(), M::USIZE);
235        assert_eq!(data2.len(), M::USIZE);
236        HeapArrayTuple(
237            HeapArray::<T1, M> {
238                data: data1.into_boxed_slice(),
239                _len: PhantomData,
240            },
241            HeapArray::<T2, M> {
242                data: data2.into_boxed_slice(),
243                _len: PhantomData,
244            },
245        )
246    }
247}
248
249impl<T: Sized, M: Positive> TryFrom<Vec<T>> for HeapArray<T, M> {
250    type Error = PrimitiveError;
251
252    fn try_from(data: Vec<T>) -> Result<Self, PrimitiveError> {
253        if data.len() == M::USIZE {
254            Ok(Self {
255                data: data.into_boxed_slice(),
256                _len: PhantomData,
257            })
258        } else {
259            Err(PrimitiveError::SizeError(M::USIZE, data.len()))
260        }
261    }
262}
263
264impl<T: Sized> From<T> for HeapArray<T, U1> {
265    fn from(element: T) -> Self {
266        Self {
267            data: Box::new([element]),
268            _len: PhantomData,
269        }
270    }
271}
272
273impl<T: Sized, M: Positive> From<HeapArray<T, M>> for Vec<T> {
274    fn from(array: HeapArray<T, M>) -> Self {
275        array.data.into_vec()
276    }
277}
278
279impl<T: Sized + Clone, M: Positive + ArraySize> From<Array<T, M>> for HeapArray<T, M> {
280    fn from(array: Array<T, M>) -> Self {
281        Self {
282            data: array.to_vec().into_boxed_slice(),
283            _len: PhantomData,
284        }
285    }
286}
287
288impl<T: Sized, M: Positive> HeapArray<T, M> {
289    pub fn split_last<N: Positive>(self) -> (HeapArray<T, Diff<M, N>>, HeapArray<T, N>)
290    where
291        M: Sub<N, Output: Positive>,
292    {
293        let mut vec = self.data.into_vec();
294        let last_n = vec.split_off(M::USIZE - N::USIZE);
295
296        (
297            HeapArray {
298                data: vec.into_boxed_slice(),
299                _len: PhantomData,
300            },
301            HeapArray {
302                data: last_n.into_boxed_slice(),
303                _len: PhantomData,
304            },
305        )
306    }
307
308    pub fn from_fn(f: impl FnMut(usize) -> T) -> Self {
309        Self {
310            data: (0..M::USIZE).map(f).collect::<Box<_>>(),
311            _len: PhantomData,
312        }
313    }
314
315    pub fn from_constant(c: T) -> Self
316    where
317        T: Copy,
318    {
319        Self {
320            data: (0..M::USIZE).map(|_| c).collect::<Box<_>>(),
321            _len: PhantomData,
322        }
323    }
324}
325
326impl<T: Sized, M: Positive> Display for HeapArray<T, M>
327where
328    T: Display,
329{
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        write!(f, "[")?;
332        for (i, item) in self.data.iter().enumerate() {
333            if i != 0 {
334                write!(f, ", ")?;
335            }
336            write!(f, "{item}")?;
337        }
338        write!(f, "]")
339    }
340}
341
342#[cfg(test)]
343pub mod tests {
344    use hybrid_array::sizes::{U2, U3, U6};
345    use typenum::{U1, U4};
346
347    use super::*;
348
349    #[test]
350    fn test_heap_array() {
351        let array = HeapArray::<_, U3>::from_fn(|i| i);
352        assert_eq!(array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
353    }
354
355    #[test]
356    fn test_default() {
357        let array = HeapArray::<usize, U3>::default();
358        assert_eq!(array.len(), 3);
359    }
360
361    #[test]
362    fn test_heap_array_split_last() {
363        let array = HeapArray::<_, U6>::from_fn(|i| i);
364        let (first, last) = array.split_last::<U2>();
365        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
366        assert_eq!(last.into_iter().collect::<Vec<_>>(), vec![4, 5]);
367    }
368
369    #[test]
370    fn test_heap_array_from_array() {
371        let array = Array::<_, U3>::from_fn(|i| i);
372        let heap_array = HeapArray::<_, U3>::from(array);
373        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
374    }
375
376    #[test]
377    fn test_heap_array_from_vec() {
378        let vec = vec![0, 1, 2];
379        let heap_array = HeapArray::<_, U3>::try_from(vec).unwrap();
380        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
381
382        let vec = vec![0, 1];
383        let heap_array = HeapArray::<_, U3>::try_from(vec);
384        assert!(heap_array.is_err());
385    }
386
387    #[test]
388    fn test_heap_array_from_iter() {
389        let heap_array = HeapArray::<_, U3>::from_fn(|i| i);
390        assert_eq!(heap_array.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
391    }
392
393    #[test]
394    #[should_panic]
395    fn test_heap_array_from_iter_wrong_size() {
396        HeapArray::<_, U2>::from_iter(0..3);
397    }
398
399    #[test]
400    fn test_heap_array_deserialize() {
401        let array = HeapArray::<usize, U6>::from_fn(|i| i);
402        let serialized = bincode::serialize(&array).unwrap();
403        bincode::deserialize::<HeapArray<usize, U6>>(&serialized).unwrap();
404
405        let wrong_deserialize = bincode::deserialize::<HeapArray<usize, U3>>(&serialized);
406        assert!(wrong_deserialize.is_err());
407    }
408
409    #[test]
410    fn test_heap_array_split() {
411        let array = HeapArray::<_, U6>::from_fn(|i| i);
412        let (first, second) = array.split::<U4, U2>();
413        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2, 3]);
414        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![4, 5]);
415
416        let (first, second) = array.split_halves::<U3>();
417        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
418        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4, 5]);
419
420        // let (first, second) = array.split::<U4, U1>(); --> doesn't compile
421
422        // let (first, second) = array.split_halves::<U2>(); --> doesn't compile
423    }
424
425    #[test]
426    fn test_heap_array_split3() {
427        let array = HeapArray::<_, U6>::from_fn(|i| i);
428        let (first, second, third) = array.split3::<U3, U2, U1>();
429        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1, 2]);
430        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![3, 4]);
431        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![5]);
432
433        let (first, second, third) = array.split_thirds::<U2>();
434        assert_eq!(first.into_iter().collect::<Vec<_>>(), vec![0, 1]);
435        assert_eq!(second.into_iter().collect::<Vec<_>>(), vec![2, 3]);
436        assert_eq!(third.into_iter().collect::<Vec<_>>(), vec![4, 5]);
437
438        // let (a1, a2, a3) = array.split3::<U3, U2, U2>(); // doesn't compile
439
440        // let (a1, a2, a3) = array.split_thirds::<U2>(); // doesn't compile
441    }
442}