1pub trait SplitJoin {
11 type Halved;
13
14 fn split(self) -> LoHi<Self::Halved>;
16
17 fn join(halves: LoHi<Self::Halved>) -> Self;
19}
20
21#[derive(Debug, Clone, Copy)]
24pub struct LoHi<T> {
25 pub lo: T,
27 pub hi: T,
29}
30
31impl<T> LoHi<T> {
32 pub fn new(lo: T, hi: T) -> Self {
34 Self { lo, hi }
35 }
36
37 pub fn join<U>(self) -> U
39 where
40 U: SplitJoin<Halved = T>,
41 {
42 U::join(self)
43 }
44
45 pub fn map_with<U, F, R>(self, x: LoHi<U>, mut f: F) -> LoHi<R>
50 where
51 F: FnMut(T, U) -> R,
52 {
53 let lo = f(self.lo, x.lo);
54 let hi = f(self.hi, x.hi);
55 LoHi { lo, hi }
56 }
57
58 pub fn map<F, R>(self, mut f: F) -> LoHi<R>
62 where
63 F: FnMut(T) -> R,
64 {
65 let lo = f(self.lo);
66 let hi = f(self.hi);
67 LoHi { lo, hi }
68 }
69}
70
71macro_rules! array_splitjoin {
72 ($N:literal) => {
73 impl<T: Copy> SplitJoin for [T; $N] {
74 type Halved = [T; { $N / 2 }];
75
76 #[inline(always)]
77 fn split(self) -> LoHi<Self::Halved> {
78 const BASE: usize = { $N / 2 };
79 LoHi {
80 lo: core::array::from_fn(|i| self[i]),
81 hi: core::array::from_fn(|i| self[BASE + i]),
82 }
83 }
84
85 #[inline(always)]
86 fn join(lohi: LoHi<Self::Halved>) -> Self {
87 const BASE: usize = { $N / 2 };
88 core::array::from_fn(|i| {
89 if i < BASE {
90 lohi.lo[i]
91 } else {
92 lohi.hi[i - BASE]
93 }
94 })
95 }
96 }
97 };
98}
99
100array_splitjoin!(2);
101array_splitjoin!(4);
102array_splitjoin!(8);
103array_splitjoin!(16);
104array_splitjoin!(32);
105array_splitjoin!(64);
106
107#[cfg(test)]
112mod tests {
113 use std::fmt::Display;
114
115 use rand::{
116 SeedableRng,
117 distr::{Distribution, StandardUniform},
118 rngs::StdRng,
119 };
120
121 use super::*;
122
123 fn test_split<T>(full: &[T], lo: &[T], hi: &[T], context: &dyn Display)
124 where
125 T: PartialEq + std::fmt::Debug,
126 {
127 let full_len = full.len();
128 assert_eq!(
129 full_len % 2,
130 0,
131 "full length must be even, instead got {} -- {}",
132 full_len,
133 context
134 );
135 let half_len = full_len / 2;
136 assert_eq!(
137 half_len,
138 lo.len(),
139 "unexpected \"lo\" length -- {}",
140 context
141 );
142 assert_eq!(
143 half_len,
144 hi.len(),
145 "unexpected \"hi\" length -- {}",
146 context
147 );
148
149 for i in 0..half_len {
150 assert_eq!(
151 full[i], lo[i],
152 "low check failed at index {} -- {}",
153 i, context
154 );
155 }
156
157 for i in 0..half_len {
158 assert_eq!(
159 full[i + half_len],
160 hi[i],
161 "high check failed at index {} -- {}",
162 i,
163 context
164 );
165 }
166 }
167
168 struct Lazy<'a, T> {
169 base: &'a [T],
170 lo: &'a [T],
171 hi: &'a [T],
172 }
173
174 impl<T> std::fmt::Display for Lazy<'_, T>
175 where
176 T: std::fmt::Debug,
177 {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 write!(
180 f,
181 "base = {:?}, lo = {:?}, hi = {:?}",
182 self.base, self.lo, self.hi
183 )
184 }
185 }
186
187 macro_rules! test_splitjoin {
188 ($fn:ident, $len:literal, $trials:literal, $seed:literal) => {
189 #[test]
190 fn $fn() {
191 const NUM_TRIALS: usize = $trials;
192 let mut rng = StdRng::seed_from_u64($seed);
193 for _ in 0..NUM_TRIALS {
194 let base: [i8; $len] =
195 core::array::from_fn(|_| StandardUniform {}.sample(&mut rng));
196
197 let LoHi { lo, hi } = base.split();
198
199 let context = Lazy {
200 base: &base,
201 lo: &lo,
202 hi: &hi,
203 };
204
205 test_split(&base, &lo, &hi, &context);
206
207 let rejoined = <[i8; $len]>::join(LoHi::new(lo, hi));
208 assert_eq!(base, rejoined);
209 }
210 }
211 };
212 }
213
214 test_splitjoin!(test_splitjoin_2, 2, 100, 0x5943d0578df47cdd);
215 test_splitjoin!(test_splitjoin_4, 4, 100, 0xc735a1c37c9a8c2c);
216 test_splitjoin!(test_splitjoin_8, 8, 100, 0x4dcf648800b9f9b6);
217 test_splitjoin!(test_splitjoin_16, 16, 50, 0xf7386a0621134477);
218 test_splitjoin!(test_splitjoin_32, 32, 50, 0xb3b0ded762020295);
219 test_splitjoin!(test_splitjoin_64, 64, 25, 0x0fc17da7d8a9e1d0);
220}