1use crate::{
2 Array, ArraySize,
3 typenum::{Prod, Quot, U0, Unsigned},
4};
5use core::{
6 mem::ManuallyDrop,
7 ops::{Div, Mul, Rem},
8 ptr,
9};
10
11pub trait Flatten<T, M: ArraySize> {
13 type OutputSize: ArraySize;
15
16 fn flatten(self) -> Array<T, Self::OutputSize>;
18}
19
20impl<T, N, M> Flatten<T, Prod<M, N>> for Array<Array<T, M>, N>
21where
22 N: ArraySize,
23 M: ArraySize + Mul<N>,
24 Prod<M, N>: ArraySize,
25{
26 type OutputSize = Prod<M, N>;
27
28 fn flatten(self) -> Array<T, Self::OutputSize> {
31 let whole = ManuallyDrop::new(self);
32 unsafe { ptr::read(whole.as_ptr().cast()) }
33 }
34}
35
36pub trait Unflatten<M>
38where
39 M: ArraySize,
40{
41 type Part;
43
44 fn unflatten(self) -> Array<Self::Part, M>;
46}
47
48impl<T, N, M> Unflatten<M> for Array<T, N>
49where
50 N: ArraySize + Div<M> + Rem<M, Output = U0>,
51 M: ArraySize,
52 Quot<N, M>: ArraySize,
53{
54 type Part = Array<T, Quot<N, M>>;
55
56 fn unflatten(self) -> Array<Self::Part, M> {
60 let part_size = Quot::<N, M>::USIZE;
61 let whole = ManuallyDrop::new(self);
62 Array::from_fn(|i| unsafe {
63 let offset = i.checked_mul(part_size).expect("overflow");
64 ptr::read(whole.as_ptr().add(offset).cast())
65 })
66 }
67}
68
69impl<'a, T, N, M> Unflatten<M> for &'a Array<T, N>
70where
71 N: ArraySize + Div<M> + Rem<M, Output = U0>,
72 M: ArraySize,
73 Quot<N, M>: ArraySize,
74{
75 type Part = &'a Array<T, Quot<N, M>>;
76
77 fn unflatten(self) -> Array<Self::Part, M> {
81 let part_size = Quot::<N, M>::USIZE;
82 let mut ptr: *const T = self.as_ptr();
83 Array::from_fn(|_i| unsafe {
84 let part = &*(ptr.cast());
85 ptr = ptr.add(part_size);
86 part
87 })
88 }
89}
90
91#[cfg(test)]
92mod test {
93 use super::*;
94 use crate::{
95 Array,
96 sizes::{U2, U5},
97 };
98
99 #[test]
100 fn flatten() {
101 let flat: Array<u8, _> = Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
102 let unflat2: Array<Array<u8, _>, _> = Array([
103 Array([1, 2]),
104 Array([3, 4]),
105 Array([5, 6]),
106 Array([7, 8]),
107 Array([9, 10]),
108 ]);
109 let unflat5: Array<Array<u8, _>, _> =
110 Array([Array([1, 2, 3, 4, 5]), Array([6, 7, 8, 9, 10])]);
111
112 let actual = unflat2.flatten();
114 assert_eq!(flat, actual);
115
116 let actual = unflat5.flatten();
117 assert_eq!(flat, actual);
118
119 let actual: Array<Array<u8, U2>, U5> = flat.unflatten();
121 assert_eq!(unflat2, actual);
122
123 let actual: Array<Array<u8, U5>, U2> = flat.unflatten();
124 assert_eq!(unflat5, actual);
125
126 let actual: Array<&Array<u8, U2>, U5> = (&flat).unflatten();
128 for (i, part) in actual.iter().enumerate() {
129 assert_eq!(&unflat2[i], *part);
130 }
131
132 let actual: Array<&Array<u8, U5>, U2> = (&flat).unflatten();
133 for (i, part) in actual.iter().enumerate() {
134 assert_eq!(&unflat5[i], *part);
135 }
136 }
137}