Skip to main content

diskann_wide/
splitjoin.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6/// Split a type into or join from two halves.
7///
8/// For example, even dimensional fixed size arrays of length `N` will be split so the first
9/// `N / 2` elements are in the low half, and the last `N / 2` elements are in the high half.
10pub trait SplitJoin {
11    /// The type of the halved element.
12    type Halved;
13
14    /// Split `self` into two equal halves.
15    fn split(self) -> LoHi<Self::Halved>;
16
17    /// Create `self` by joining the two halves.
18    fn join(halves: LoHi<Self::Halved>) -> Self;
19}
20
21/// Representation of the low and high halves associated with an implementation of
22/// [`SplitJoin`].
23#[derive(Debug, Clone, Copy)]
24pub struct LoHi<T> {
25    /// The first half of a split entity.
26    pub lo: T,
27    /// The second half of a split entity.
28    pub hi: T,
29}
30
31impl<T> LoHi<T> {
32    /// Construct a new `LoHi` from the low and high parts.
33    pub fn new(lo: T, hi: T) -> Self {
34        Self { lo, hi }
35    }
36
37    /// Join the `lo` and `hi` portions.
38    pub fn join<U>(self) -> U
39    where
40        U: SplitJoin<Halved = T>,
41    {
42        U::join(self)
43    }
44
45    /// Return a new [`LoHi`] with the function `f` applied to the pairwise members of
46    /// of `self` and `x`.
47    ///
48    /// If it does not panic, `f` will be invoked exactly twice, first on `lo`, then on `hi`.
49    pub fn map_with<U, F, R>(self, x: LoHi<U>, mut f: F) -> LoHi<R>
50    where
51        F: FnMut(T, U) -> R,
52    {
53        let lo = f(self.lo, x.lo);
54        let hi = f(self.hi, x.hi);
55        LoHi { lo, hi }
56    }
57
58    /// Return a new [`LoHi`] with the function `f` applied to each member.
59    ///
60    /// If it does not panic, `f` will be invoked exactly twice, first on `lo`, then on `hi`.
61    pub fn map<F, R>(self, mut f: F) -> LoHi<R>
62    where
63        F: FnMut(T) -> R,
64    {
65        let lo = f(self.lo);
66        let hi = f(self.hi);
67        LoHi { lo, hi }
68    }
69}
70
71macro_rules! array_splitjoin {
72    ($N:literal) => {
73        impl<T: Copy> SplitJoin for [T; $N] {
74            type Halved = [T; { $N / 2 }];
75
76            #[inline(always)]
77            fn split(self) -> LoHi<Self::Halved> {
78                const BASE: usize = { $N / 2 };
79                LoHi {
80                    lo: core::array::from_fn(|i| self[i]),
81                    hi: core::array::from_fn(|i| self[BASE + i]),
82                }
83            }
84
85            #[inline(always)]
86            fn join(lohi: LoHi<Self::Halved>) -> Self {
87                const BASE: usize = { $N / 2 };
88                core::array::from_fn(|i| {
89                    if i < BASE {
90                        lohi.lo[i]
91                    } else {
92                        lohi.hi[i - BASE]
93                    }
94                })
95            }
96        }
97    };
98}
99
100array_splitjoin!(2);
101array_splitjoin!(4);
102array_splitjoin!(8);
103array_splitjoin!(16);
104array_splitjoin!(32);
105array_splitjoin!(64);
106
107///////////
108// Tests //
109///////////
110
111#[cfg(test)]
112mod tests {
113    use std::fmt::Display;
114
115    use rand::{
116        SeedableRng,
117        distr::{Distribution, StandardUniform},
118        rngs::StdRng,
119    };
120
121    use super::*;
122
123    fn test_split<T>(full: &[T], lo: &[T], hi: &[T], context: &dyn Display)
124    where
125        T: PartialEq + std::fmt::Debug,
126    {
127        let full_len = full.len();
128        assert_eq!(
129            full_len % 2,
130            0,
131            "full length must be even, instead got {} -- {}",
132            full_len,
133            context
134        );
135        let half_len = full_len / 2;
136        assert_eq!(
137            half_len,
138            lo.len(),
139            "unexpected \"lo\" length -- {}",
140            context
141        );
142        assert_eq!(
143            half_len,
144            hi.len(),
145            "unexpected \"hi\" length -- {}",
146            context
147        );
148
149        for i in 0..half_len {
150            assert_eq!(
151                full[i], lo[i],
152                "low check failed at index {} -- {}",
153                i, context
154            );
155        }
156
157        for i in 0..half_len {
158            assert_eq!(
159                full[i + half_len],
160                hi[i],
161                "high check failed at index {} -- {}",
162                i,
163                context
164            );
165        }
166    }
167
168    struct Lazy<'a, T> {
169        base: &'a [T],
170        lo: &'a [T],
171        hi: &'a [T],
172    }
173
174    impl<T> std::fmt::Display for Lazy<'_, T>
175    where
176        T: std::fmt::Debug,
177    {
178        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179            write!(
180                f,
181                "base = {:?}, lo = {:?}, hi = {:?}",
182                self.base, self.lo, self.hi
183            )
184        }
185    }
186
187    macro_rules! test_splitjoin {
188        ($fn:ident, $len:literal, $trials:literal, $seed:literal) => {
189            #[test]
190            fn $fn() {
191                const NUM_TRIALS: usize = $trials;
192                let mut rng = StdRng::seed_from_u64($seed);
193                for _ in 0..NUM_TRIALS {
194                    let base: [i8; $len] =
195                        core::array::from_fn(|_| StandardUniform {}.sample(&mut rng));
196
197                    let LoHi { lo, hi } = base.split();
198
199                    let context = Lazy {
200                        base: &base,
201                        lo: &lo,
202                        hi: &hi,
203                    };
204
205                    test_split(&base, &lo, &hi, &context);
206
207                    let rejoined = <[i8; $len]>::join(LoHi::new(lo, hi));
208                    assert_eq!(base, rejoined);
209                }
210            }
211        };
212    }
213
214    test_splitjoin!(test_splitjoin_2, 2, 100, 0x5943d0578df47cdd);
215    test_splitjoin!(test_splitjoin_4, 4, 100, 0xc735a1c37c9a8c2c);
216    test_splitjoin!(test_splitjoin_8, 8, 100, 0x4dcf648800b9f9b6);
217    test_splitjoin!(test_splitjoin_16, 16, 50, 0xf7386a0621134477);
218    test_splitjoin!(test_splitjoin_32, 32, 50, 0xb3b0ded762020295);
219    test_splitjoin!(test_splitjoin_64, 64, 25, 0x0fc17da7d8a9e1d0);
220}