Skip to main content

hybrid_array/
flatten.rs

1use crate::{
2    Array, ArraySize,
3    typenum::{Prod, Quot, U0, Unsigned},
4};
5use core::{
6    mem::ManuallyDrop,
7    ops::{Div, Mul, Rem},
8    ptr,
9};
10
11/// Defines a sequence of sequences that can be merged into a bigger overall sequence.
12pub trait Flatten<T, M: ArraySize> {
13    /// Size of the output array.
14    type OutputSize: ArraySize;
15
16    /// Flatten array.
17    fn flatten(self) -> Array<T, Self::OutputSize>;
18}
19
20impl<T, N, M> Flatten<T, Prod<M, N>> for Array<Array<T, M>, N>
21where
22    N: ArraySize,
23    M: ArraySize + Mul<N>,
24    Prod<M, N>: ArraySize,
25{
26    type OutputSize = Prod<M, N>;
27
28    // SAFETY: this is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed
29    // to be safe by the Rust memory layout of these types.
30    fn flatten(self) -> Array<T, Self::OutputSize> {
31        let whole = ManuallyDrop::new(self);
32        unsafe { ptr::read(whole.as_ptr().cast()) }
33    }
34}
35
36/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size.
37pub trait Unflatten<M>
38where
39    M: ArraySize,
40{
41    /// Part of the array we're decomposing into.
42    type Part;
43
44    /// Unflatten array into `Self::Part` chunks.
45    fn unflatten(self) -> Array<Self::Part, M>;
46}
47
48impl<T, N, M> Unflatten<M> for Array<T, N>
49where
50    N: ArraySize + Div<M> + Rem<M, Output = U0>,
51    M: ArraySize,
52    Quot<N, M>: ArraySize,
53{
54    type Part = Array<T, Quot<N, M>>;
55
56    // SAFETY: this is doing the same thing as what is done in `Array::split`.
57    // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
58    // be safe by the Rust memory layout of these types.
59    fn unflatten(self) -> Array<Self::Part, M> {
60        let part_size = Quot::<N, M>::USIZE;
61        let whole = ManuallyDrop::new(self);
62        Array::from_fn(|i| unsafe {
63            let offset = i.checked_mul(part_size).expect("overflow");
64            ptr::read(whole.as_ptr().add(offset).cast())
65        })
66    }
67}
68
69impl<'a, T, N, M> Unflatten<M> for &'a Array<T, N>
70where
71    N: ArraySize + Div<M> + Rem<M, Output = U0>,
72    M: ArraySize,
73    Quot<N, M>: ArraySize,
74{
75    type Part = &'a Array<T, Quot<N, M>>;
76
77    // SAFETY: this is doing the same thing as what is done in `Array::split`.
78    // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
79    // be safe by the Rust memory layout of these types.
80    fn unflatten(self) -> Array<Self::Part, M> {
81        let part_size = Quot::<N, M>::USIZE;
82        let mut ptr: *const T = self.as_ptr();
83        Array::from_fn(|_i| unsafe {
84            let part = &*(ptr.cast());
85            ptr = ptr.add(part_size);
86            part
87        })
88    }
89}
90
91#[cfg(test)]
92mod test {
93    use super::*;
94    use crate::{
95        Array,
96        sizes::{U2, U5},
97    };
98
99    #[test]
100    fn flatten() {
101        let flat: Array<u8, _> = Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
102        let unflat2: Array<Array<u8, _>, _> = Array([
103            Array([1, 2]),
104            Array([3, 4]),
105            Array([5, 6]),
106            Array([7, 8]),
107            Array([9, 10]),
108        ]);
109        let unflat5: Array<Array<u8, _>, _> =
110            Array([Array([1, 2, 3, 4, 5]), Array([6, 7, 8, 9, 10])]);
111
112        // Flatten
113        let actual = unflat2.flatten();
114        assert_eq!(flat, actual);
115
116        let actual = unflat5.flatten();
117        assert_eq!(flat, actual);
118
119        // Unflatten
120        let actual: Array<Array<u8, U2>, U5> = flat.unflatten();
121        assert_eq!(unflat2, actual);
122
123        let actual: Array<Array<u8, U5>, U2> = flat.unflatten();
124        assert_eq!(unflat5, actual);
125
126        // Unflatten on references
127        let actual: Array<&Array<u8, U2>, U5> = (&flat).unflatten();
128        for (i, part) in actual.iter().enumerate() {
129            assert_eq!(&unflat2[i], *part);
130        }
131
132        let actual: Array<&Array<u8, U5>, U2> = (&flat).unflatten();
133        for (i, part) in actual.iter().enumerate() {
134            assert_eq!(&unflat5[i], *part);
135        }
136    }
137}