Skip to main content

diskann_wide/
splitjoin.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6/// Split a type into or join from two halves.
7///
8/// For example, even dimensional fixed size arrays of length `N` will be split so the first
9/// `N / 2` elements are in the low half, and the last `N / 2` elements are in the high half.
10pub trait SplitJoin {
11    /// The type of the halved element.
12    type Halved;
13
14    /// Split `self` into two equal halves.
15    fn split(self) -> LoHi<Self::Halved>;
16
17    /// Create `self` by joining the two halves.
18    fn join(halves: LoHi<Self::Halved>) -> Self;
19}
20
21/// Representation of the low and high halves associated with an implementation of
22/// [`SplitJoin`] or [`crate::ZipUnzip`].
23#[derive(Debug, Clone, Copy)]
24pub struct LoHi<T> {
25    /// The first half of a split entity.
26    pub lo: T,
27    /// The second half of a split entity.
28    pub hi: T,
29}
30
31impl<T> LoHi<T> {
32    /// Construct a new `LoHi` from the low and high parts.
33    pub fn new(lo: T, hi: T) -> Self {
34        Self { lo, hi }
35    }
36
37    /// Join the `lo` and `hi` portions.
38    pub fn join<U>(self) -> U
39    where
40        U: SplitJoin<Halved = T>,
41    {
42        U::join(self)
43    }
44
45    /// Zip-interleave the `lo` and `hi` halves into a full-width vector.
46    pub fn zip<U>(self) -> U
47    where
48        U: crate::traits::ZipUnzip<Halved = T>,
49    {
50        U::zip(self)
51    }
52
53    /// Return a new [`LoHi`] with the function `f` applied to the pairwise members
54    /// of `self` and `x`.
55    ///
56    /// If it does not panic, `f` will be invoked exactly twice, first on `lo`, then on `hi`.
57    pub fn map_with<U, F, R>(self, x: LoHi<U>, mut f: F) -> LoHi<R>
58    where
59        F: FnMut(T, U) -> R,
60    {
61        let lo = f(self.lo, x.lo);
62        let hi = f(self.hi, x.hi);
63        LoHi { lo, hi }
64    }
65
66    /// Return a new [`LoHi`] with the function `f` applied to each member.
67    ///
68    /// If it does not panic, `f` will be invoked exactly twice, first on `lo`, then on `hi`.
69    pub fn map<F, R>(self, mut f: F) -> LoHi<R>
70    where
71        F: FnMut(T) -> R,
72    {
73        let lo = f(self.lo);
74        let hi = f(self.hi);
75        LoHi { lo, hi }
76    }
77}
78
79macro_rules! array_splitjoin {
80    ($N:literal) => {
81        impl<T: Copy> SplitJoin for [T; $N] {
82            type Halved = [T; { $N / 2 }];
83
84            #[inline(always)]
85            fn split(self) -> LoHi<Self::Halved> {
86                const BASE: usize = { $N / 2 };
87                LoHi {
88                    lo: core::array::from_fn(|i| self[i]),
89                    hi: core::array::from_fn(|i| self[BASE + i]),
90                }
91            }
92
93            #[inline(always)]
94            fn join(lohi: LoHi<Self::Halved>) -> Self {
95                const BASE: usize = { $N / 2 };
96                core::array::from_fn(|i| {
97                    if i < BASE {
98                        lohi.lo[i]
99                    } else {
100                        lohi.hi[i - BASE]
101                    }
102                })
103            }
104        }
105    };
106}
107
108array_splitjoin!(2);
109array_splitjoin!(4);
110array_splitjoin!(8);
111array_splitjoin!(16);
112array_splitjoin!(32);
113array_splitjoin!(64);
114
115///////////
116// Tests //
117///////////
118
119#[cfg(test)]
120mod tests {
121    use std::fmt::Display;
122
123    use rand::{
124        SeedableRng,
125        distr::{Distribution, StandardUniform},
126        rngs::StdRng,
127    };
128
129    use super::*;
130
131    fn test_split<T>(full: &[T], lo: &[T], hi: &[T], context: &dyn Display)
132    where
133        T: PartialEq + std::fmt::Debug,
134    {
135        let full_len = full.len();
136        assert_eq!(
137            full_len % 2,
138            0,
139            "full length must be even, instead got {} -- {}",
140            full_len,
141            context
142        );
143        let half_len = full_len / 2;
144        assert_eq!(
145            half_len,
146            lo.len(),
147            "unexpected \"lo\" length -- {}",
148            context
149        );
150        assert_eq!(
151            half_len,
152            hi.len(),
153            "unexpected \"hi\" length -- {}",
154            context
155        );
156
157        for i in 0..half_len {
158            assert_eq!(
159                full[i], lo[i],
160                "low check failed at index {} -- {}",
161                i, context
162            );
163        }
164
165        for i in 0..half_len {
166            assert_eq!(
167                full[i + half_len],
168                hi[i],
169                "high check failed at index {} -- {}",
170                i,
171                context
172            );
173        }
174    }
175
176    struct Lazy<'a, T> {
177        base: &'a [T],
178        lo: &'a [T],
179        hi: &'a [T],
180    }
181
182    impl<T> std::fmt::Display for Lazy<'_, T>
183    where
184        T: std::fmt::Debug,
185    {
186        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187            write!(
188                f,
189                "base = {:?}, lo = {:?}, hi = {:?}",
190                self.base, self.lo, self.hi
191            )
192        }
193    }
194
195    macro_rules! test_splitjoin {
196        ($fn:ident, $len:literal, $trials:literal, $seed:literal) => {
197            #[test]
198            fn $fn() {
199                const NUM_TRIALS: usize = $trials;
200                let mut rng = StdRng::seed_from_u64($seed);
201                for _ in 0..NUM_TRIALS {
202                    let base: [i8; $len] =
203                        core::array::from_fn(|_| StandardUniform {}.sample(&mut rng));
204
205                    let LoHi { lo, hi } = base.split();
206
207                    let context = Lazy {
208                        base: &base,
209                        lo: &lo,
210                        hi: &hi,
211                    };
212
213                    test_split(&base, &lo, &hi, &context);
214
215                    let rejoined = <[i8; $len]>::join(LoHi::new(lo, hi));
216                    assert_eq!(base, rejoined);
217                }
218            }
219        };
220    }
221
222    test_splitjoin!(test_splitjoin_2, 2, 100, 0x5943d0578df47cdd);
223    test_splitjoin!(test_splitjoin_4, 4, 100, 0xc735a1c37c9a8c2c);
224    test_splitjoin!(test_splitjoin_8, 8, 100, 0x4dcf648800b9f9b6);
225    test_splitjoin!(test_splitjoin_16, 16, 50, 0xf7386a0621134477);
226    test_splitjoin!(test_splitjoin_32, 32, 50, 0xb3b0ded762020295);
227    test_splitjoin!(test_splitjoin_64, 64, 25, 0x0fc17da7d8a9e1d0);
228}