1use crate::views::{Matrix, MatrixView};
7use rand::{rngs::StdRng, Rng, SeedableRng};
8
9pub trait SampleLatinHyperCube: Sized + Copy + Default {
12 fn sample_latin_hypercube(
13 data: MatrixView<Self>,
14 num_samples: usize,
15 seed: Option<u64>,
16 ) -> Matrix<Self>;
17}
18
19impl<T: Sized + Copy + Default> SampleLatinHyperCube for T {
20 fn sample_latin_hypercube(
21 data: MatrixView<Self>,
22 num_samples: usize,
23 seed: Option<u64>,
24 ) -> Matrix<Self> {
25 let nrows = data.nrows();
26 let ncols = data.ncols();
27 if ncols == 0 || nrows == 0 {
28 return Matrix::new(T::default(), num_samples, ncols);
29 }
30
31 let seed = seed.unwrap_or(0xaf2f5fa0b5161acf);
32 let mut rng = StdRng::seed_from_u64(seed);
33 let mut result: Matrix<Self> = Matrix::new(T::default(), num_samples, ncols);
34
35 for (s, res) in result.row_iter_mut().enumerate() {
37 for (idx, val) in res.iter_mut().enumerate() {
38 let step = nrows / num_samples;
39 let value = data
40 .get_row(rng.random_range(s * step..(s + 1) * step))
41 .unwrap()
42 .get(idx)
43 .unwrap();
44 *val = *value;
45 }
46 }
47
48 for start_idx in 0..num_samples {
50 for dim_idx in 0..ncols {
51 let swap_idx = rng.random_range(0..num_samples);
52 let swap = result[(start_idx, dim_idx)];
53 result[(start_idx, dim_idx)] = result[(swap_idx, dim_idx)];
54 result[(swap_idx, dim_idx)] = swap;
55 }
56 }
57
58 result
59 }
60}
61
62#[cfg(test)]
67mod tests {
68 use std::fmt::Display;
69
70 use crate::views::{Init, Matrix};
71 use diskann_vector::conversion::CastFromSlice;
72 use half::f16;
73 use rand::{
74 distr::{Distribution, StandardUniform},
75 rngs::StdRng,
76 SeedableRng,
77 };
78
79 use super::*;
80
81 fn example_dataset() -> Matrix<f32> {
82 let data: Vec<f32> = vec![
83 0.203688,
85 0.841956,
86 0.855665,
87 0.801917,
88 0.754536,
89 0.312881,
91 0.217382,
92 0.0644115,
93 0.348708,
94 0.999495,
95 0.657741,
97 0.914681,
98 0.555228,
99 0.13253,
100 0.118615,
101 0.356464,
103 0.207449,
104 0.452471,
105 0.925219,
106 0.508498,
107 0.749786,
109 0.90786,
110 0.129618,
111 0.597719,
112 0.000622153,
113 0.569517,
115 0.435447,
116 0.558136,
117 0.480974,
118 0.711425,
119 0.896353,
121 0.275053,
122 0.0427179,
123 0.660916,
124 0.464851,
125 0.558689,
127 0.596543,
128 0.740983,
129 0.122136,
130 0.453822,
131 0.526895,
133 0.492643,
134 0.0951115,
135 0.495487,
136 0.446127,
137 0.454093,
139 0.160239,
140 0.924585,
141 0.901708,
142 0.329328,
143 ];
144
145 Matrix::<f32>::try_from(data.into(), 10, 5).unwrap()
146 }
147
148 fn example_dataset_u8() -> Matrix<u8> {
149 let data: Vec<u8> = vec![
150 52, 215, 218, 204, 192, 79, 55, 16, 89, 255, 167, 233, 141, 33, 30, 91, 53, 115, 236, 130, 191, 232, 33, 152, 1, 145, 111, 142, 122, 181, ];
157
158 Matrix::<u8>::try_from(data.into(), 6, 5).unwrap()
159 }
160
161 fn example_dataset_i8() -> Matrix<i8> {
163 let data: Vec<i8> = vec![
164 -76, 87, 90, 76, 64, -49, -73, -112, -39, 127, 39, 105, 13, -95, -98, -37, -75, -13, 108, 2, -37, -75, -13, 108, 2, 17, -17, 14, -6, 53, ];
171
172 Matrix::<i8>::try_from(data.into(), 6, 5).unwrap()
173 }
174
175 fn test_for_type<T>(data: Matrix<T>)
176 where
177 T: SampleLatinHyperCube + PartialEq + std::fmt::Debug + Display,
178 StandardUniform: Distribution<T>,
179 {
180 let x = Matrix::<T>::new(T::default(), 0, 10);
182 assert_eq!(
183 T::sample_latin_hypercube(x.as_view(), 1, None),
184 Matrix::<T>::new(T::default(), 1, x.ncols())
185 );
186
187 let x = Matrix::<T>::new(T::default(), 1, 0);
189 assert_eq!(
190 T::sample_latin_hypercube(x.as_view(), 1, None),
191 Matrix::<T>::new(T::default(), 1, x.ncols())
192 );
193
194 let mut rng: StdRng = StdRng::seed_from_u64(0xaf2f5fa0b5161acf);
195
196 let dist = StandardUniform;
198 for dim in 1..20 {
199 let x = Matrix::<T>::new(Init(|| dist.sample(&mut rng)), 1, dim);
200 assert_eq!(
201 T::sample_latin_hypercube(x.as_view(), 1, None),
202 Matrix::<T>::try_from(x.row(0).to_vec().into_boxed_slice(), 1, dim).unwrap()
203 );
204 }
205
206 let starts = T::sample_latin_hypercube(data.as_view(), 2, None);
208 for s in starts.row_iter() {
209 for (col, &val) in s.iter().enumerate() {
210 let col_vals: Vec<T> = (0..data.nrows())
211 .map(|row| {
212 *data
213 .get_row(row)
214 .expect("Row must exist")
215 .get(col)
216 .expect("Column must exist")
217 })
218 .collect();
219 assert!(
220 col_vals.contains(&val),
221 "Value {} in column {} not found in data",
222 val,
223 col
224 );
225 }
226 }
227 }
228
229 #[test]
230 fn test_f32() {
231 test_for_type(example_dataset())
232 }
233
234 #[test]
235 fn test_f16() {
236 let data = example_dataset();
237 let mut data_f16 = Matrix::<f16>::new(f16::default(), data.nrows(), data.ncols());
238 data_f16.as_mut_slice().cast_from_slice(data.as_slice());
239 test_for_type(data_f16);
240 }
241
242 #[test]
243 fn test_u8() {
244 test_for_type(example_dataset_u8());
245 }
246
247 #[test]
248 fn test_i8() {
249 test_for_type(example_dataset_i8());
250 }
251}