Skip to main content

diskann_utils/sampling/
latin_hypercube.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use crate::views::{Matrix, MatrixView};
7use rand::{rngs::StdRng, Rng, SeedableRng};
8
9/// Return multiple rows sampled using Latin Hypercube Sampling in `data` that aproximetely uniformly distributed.
10/// This makes the assumtion that the data is uniformly distributed.
11pub trait SampleLatinHyperCube: Sized + Copy + Default {
12    fn sample_latin_hypercube(
13        data: MatrixView<Self>,
14        num_samples: usize,
15        seed: Option<u64>,
16    ) -> Matrix<Self>;
17}
18
19impl<T: Sized + Copy + Default> SampleLatinHyperCube for T {
20    fn sample_latin_hypercube(
21        data: MatrixView<Self>,
22        num_samples: usize,
23        seed: Option<u64>,
24    ) -> Matrix<Self> {
25        let nrows = data.nrows();
26        let ncols = data.ncols();
27        if ncols == 0 || nrows == 0 {
28            return Matrix::new(T::default(), num_samples, ncols);
29        }
30
31        let seed = seed.unwrap_or(0xaf2f5fa0b5161acf);
32        let mut rng = StdRng::seed_from_u64(seed);
33        let mut result: Matrix<Self> = Matrix::new(T::default(), num_samples, ncols);
34
35        // sample a random partitions down the diagonal
36        for (s, res) in result.row_iter_mut().enumerate() {
37            for (idx, val) in res.iter_mut().enumerate() {
38                let step = nrows / num_samples;
39                let value = data
40                    .get_row(rng.random_range(s * step..(s + 1) * step))
41                    .unwrap()
42                    .get(idx)
43                    .unwrap();
44                *val = *value;
45            }
46        }
47
48        // shuffle the dimensions between the vectors for random sampling
49        for start_idx in 0..num_samples {
50            for dim_idx in 0..ncols {
51                let swap_idx = rng.random_range(0..num_samples);
52                let swap = result[(start_idx, dim_idx)];
53                result[(start_idx, dim_idx)] = result[(swap_idx, dim_idx)];
54                result[(swap_idx, dim_idx)] = swap;
55            }
56        }
57
58        result
59    }
60}
61
62///////////
63// Tests //
64///////////
65
66#[cfg(test)]
67mod tests {
68    use std::fmt::Display;
69
70    use crate::views::{Init, Matrix};
71    use diskann_vector::conversion::CastFromSlice;
72    use half::f16;
73    use rand::{
74        distr::{Distribution, StandardUniform},
75        rngs::StdRng,
76        SeedableRng,
77    };
78
79    use super::*;
80
81    fn example_dataset() -> Matrix<f32> {
82        let data: Vec<f32> = vec![
83            // row 0
84            0.203688,
85            0.841956,
86            0.855665,
87            0.801917,
88            0.754536,
89            // row 1
90            0.312881,
91            0.217382,
92            0.0644115,
93            0.348708,
94            0.999495,
95            // row 2
96            0.657741,
97            0.914681,
98            0.555228,
99            0.13253,
100            0.118615,
101            // row 3
102            0.356464,
103            0.207449,
104            0.452471,
105            0.925219,
106            0.508498,
107            // row 4
108            0.749786,
109            0.90786,
110            0.129618,
111            0.597719,
112            0.000622153,
113            // row 5 -- this is the medoid
114            0.569517,
115            0.435447,
116            0.558136,
117            0.480974,
118            0.711425,
119            // row 6
120            0.896353,
121            0.275053,
122            0.0427179,
123            0.660916,
124            0.464851,
125            // row 7
126            0.558689,
127            0.596543,
128            0.740983,
129            0.122136,
130            0.453822,
131            // row 8
132            0.526895,
133            0.492643,
134            0.0951115,
135            0.495487,
136            0.446127,
137            // row 9
138            0.454093,
139            0.160239,
140            0.924585,
141            0.901708,
142            0.329328,
143        ];
144
145        Matrix::<f32>::try_from(data.into(), 10, 5).unwrap()
146    }
147
148    fn example_dataset_u8() -> Matrix<u8> {
149        let data: Vec<u8> = vec![
150            52, 215, 218, 204, 192, // row 0
151            79, 55, 16, 89, 255, // row 1
152            167, 233, 141, 33, 30, // row 2
153            91, 53, 115, 236, 130, // row 3
154            191, 232, 33, 152, 1, // row 4
155            145, 111, 142, 122, 181, // row 5 -- this is the medoid
156        ];
157
158        Matrix::<u8>::try_from(data.into(), 6, 5).unwrap()
159    }
160
161    // This is a test for the i8 function. Each entry is between -128 and 127.
162    fn example_dataset_i8() -> Matrix<i8> {
163        let data: Vec<i8> = vec![
164            -76, 87, 90, 76, 64, // row 0
165            -49, -73, -112, -39, 127, // row 1
166            39, 105, 13, -95, -98, // row 2
167            -37, -75, -13, 108, 2, // row 3
168            -37, -75, -13, 108, 2, // row 4
169            17, -17, 14, -6, 53, // row 5 -- this is the medoid
170        ];
171
172        Matrix::<i8>::try_from(data.into(), 6, 5).unwrap()
173    }
174
175    fn test_for_type<T>(data: Matrix<T>)
176    where
177        T: SampleLatinHyperCube + PartialEq + std::fmt::Debug + Display,
178        StandardUniform: Distribution<T>,
179    {
180        // No Rows
181        let x = Matrix::<T>::new(T::default(), 0, 10);
182        assert_eq!(
183            T::sample_latin_hypercube(x.as_view(), 1, None),
184            Matrix::<T>::new(T::default(), 1, x.ncols())
185        );
186
187        // No Cols0
188        let x = Matrix::<T>::new(T::default(), 1, 0);
189        assert_eq!(
190            T::sample_latin_hypercube(x.as_view(), 1, None),
191            Matrix::<T>::new(T::default(), 1, x.ncols())
192        );
193
194        let mut rng: StdRng = StdRng::seed_from_u64(0xaf2f5fa0b5161acf);
195
196        // One row
197        let dist = StandardUniform;
198        for dim in 1..20 {
199            let x = Matrix::<T>::new(Init(|| dist.sample(&mut rng)), 1, dim);
200            assert_eq!(
201                T::sample_latin_hypercube(x.as_view(), 1, None),
202                Matrix::<T>::try_from(x.row(0).to_vec().into_boxed_slice(), 1, dim).unwrap()
203            );
204        }
205
206        // Example dataset
207        let starts = T::sample_latin_hypercube(data.as_view(), 2, None);
208        for s in starts.row_iter() {
209            for (col, &val) in s.iter().enumerate() {
210                let col_vals: Vec<T> = (0..data.nrows())
211                    .map(|row| {
212                        *data
213                            .get_row(row)
214                            .expect("Row must exist")
215                            .get(col)
216                            .expect("Column must exist")
217                    })
218                    .collect();
219                assert!(
220                    col_vals.contains(&val),
221                    "Value {} in column {} not found in data",
222                    val,
223                    col
224                );
225            }
226        }
227    }
228
229    #[test]
230    fn test_f32() {
231        test_for_type(example_dataset())
232    }
233
234    #[test]
235    fn test_f16() {
236        let data = example_dataset();
237        let mut data_f16 = Matrix::<f16>::new(f16::default(), data.nrows(), data.ncols());
238        data_f16.as_mut_slice().cast_from_slice(data.as_slice());
239        test_for_type(data_f16);
240    }
241
242    #[test]
243    fn test_u8() {
244        test_for_type(example_dataset_u8());
245    }
246
247    #[test]
248    fn test_i8() {
249        test_for_type(example_dataset_i8());
250    }
251}