1pub trait SplitJoin {
11 type Halved;
13
14 fn split(self) -> LoHi<Self::Halved>;
16
17 fn join(halves: LoHi<Self::Halved>) -> Self;
19}
20
21#[derive(Debug, Clone, Copy)]
24pub struct LoHi<T> {
25 pub lo: T,
27 pub hi: T,
29}
30
31impl<T> LoHi<T> {
32 pub fn new(lo: T, hi: T) -> Self {
34 Self { lo, hi }
35 }
36
37 pub fn join<U>(self) -> U
39 where
40 U: SplitJoin<Halved = T>,
41 {
42 U::join(self)
43 }
44
45 pub fn zip<U>(self) -> U
47 where
48 U: crate::traits::ZipUnzip<Halved = T>,
49 {
50 U::zip(self)
51 }
52
53 pub fn map_with<U, F, R>(self, x: LoHi<U>, mut f: F) -> LoHi<R>
58 where
59 F: FnMut(T, U) -> R,
60 {
61 let lo = f(self.lo, x.lo);
62 let hi = f(self.hi, x.hi);
63 LoHi { lo, hi }
64 }
65
66 pub fn map<F, R>(self, mut f: F) -> LoHi<R>
70 where
71 F: FnMut(T) -> R,
72 {
73 let lo = f(self.lo);
74 let hi = f(self.hi);
75 LoHi { lo, hi }
76 }
77}
78
79macro_rules! array_splitjoin {
80 ($N:literal) => {
81 impl<T: Copy> SplitJoin for [T; $N] {
82 type Halved = [T; { $N / 2 }];
83
84 #[inline(always)]
85 fn split(self) -> LoHi<Self::Halved> {
86 const BASE: usize = { $N / 2 };
87 LoHi {
88 lo: core::array::from_fn(|i| self[i]),
89 hi: core::array::from_fn(|i| self[BASE + i]),
90 }
91 }
92
93 #[inline(always)]
94 fn join(lohi: LoHi<Self::Halved>) -> Self {
95 const BASE: usize = { $N / 2 };
96 core::array::from_fn(|i| {
97 if i < BASE {
98 lohi.lo[i]
99 } else {
100 lohi.hi[i - BASE]
101 }
102 })
103 }
104 }
105 };
106}
107
108array_splitjoin!(2);
109array_splitjoin!(4);
110array_splitjoin!(8);
111array_splitjoin!(16);
112array_splitjoin!(32);
113array_splitjoin!(64);
114
115#[cfg(test)]
120mod tests {
121 use std::fmt::Display;
122
123 use rand::{
124 SeedableRng,
125 distr::{Distribution, StandardUniform},
126 rngs::StdRng,
127 };
128
129 use super::*;
130
131 fn test_split<T>(full: &[T], lo: &[T], hi: &[T], context: &dyn Display)
132 where
133 T: PartialEq + std::fmt::Debug,
134 {
135 let full_len = full.len();
136 assert_eq!(
137 full_len % 2,
138 0,
139 "full length must be even, instead got {} -- {}",
140 full_len,
141 context
142 );
143 let half_len = full_len / 2;
144 assert_eq!(
145 half_len,
146 lo.len(),
147 "unexpected \"lo\" length -- {}",
148 context
149 );
150 assert_eq!(
151 half_len,
152 hi.len(),
153 "unexpected \"hi\" length -- {}",
154 context
155 );
156
157 for i in 0..half_len {
158 assert_eq!(
159 full[i], lo[i],
160 "low check failed at index {} -- {}",
161 i, context
162 );
163 }
164
165 for i in 0..half_len {
166 assert_eq!(
167 full[i + half_len],
168 hi[i],
169 "high check failed at index {} -- {}",
170 i,
171 context
172 );
173 }
174 }
175
176 struct Lazy<'a, T> {
177 base: &'a [T],
178 lo: &'a [T],
179 hi: &'a [T],
180 }
181
182 impl<T> std::fmt::Display for Lazy<'_, T>
183 where
184 T: std::fmt::Debug,
185 {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 write!(
188 f,
189 "base = {:?}, lo = {:?}, hi = {:?}",
190 self.base, self.lo, self.hi
191 )
192 }
193 }
194
195 macro_rules! test_splitjoin {
196 ($fn:ident, $len:literal, $trials:literal, $seed:literal) => {
197 #[test]
198 fn $fn() {
199 const NUM_TRIALS: usize = $trials;
200 let mut rng = StdRng::seed_from_u64($seed);
201 for _ in 0..NUM_TRIALS {
202 let base: [i8; $len] =
203 core::array::from_fn(|_| StandardUniform {}.sample(&mut rng));
204
205 let LoHi { lo, hi } = base.split();
206
207 let context = Lazy {
208 base: &base,
209 lo: &lo,
210 hi: &hi,
211 };
212
213 test_split(&base, &lo, &hi, &context);
214
215 let rejoined = <[i8; $len]>::join(LoHi::new(lo, hi));
216 assert_eq!(base, rejoined);
217 }
218 }
219 };
220 }
221
222 test_splitjoin!(test_splitjoin_2, 2, 100, 0x5943d0578df47cdd);
223 test_splitjoin!(test_splitjoin_4, 4, 100, 0xc735a1c37c9a8c2c);
224 test_splitjoin!(test_splitjoin_8, 8, 100, 0x4dcf648800b9f9b6);
225 test_splitjoin!(test_splitjoin_16, 16, 50, 0xf7386a0621134477);
226 test_splitjoin!(test_splitjoin_32, 32, 50, 0xb3b0ded762020295);
227 test_splitjoin!(test_splitjoin_64, 64, 25, 0x0fc17da7d8a9e1d0);
228}