use crate::views::{Matrix, MatrixView};
use rand::{rngs::StdRng, Rng, SeedableRng};
pub trait SampleLatinHyperCube: Sized + Copy + Default {
fn sample_latin_hypercube(
data: MatrixView<Self>,
num_samples: usize,
seed: Option<u64>,
) -> Matrix<Self>;
}
impl<T: Sized + Copy + Default> SampleLatinHyperCube for T {
fn sample_latin_hypercube(
data: MatrixView<Self>,
num_samples: usize,
seed: Option<u64>,
) -> Matrix<Self> {
let nrows = data.nrows();
let ncols = data.ncols();
if ncols == 0 || nrows == 0 {
return Matrix::new(T::default(), num_samples, ncols);
}
let seed = seed.unwrap_or(0xaf2f5fa0b5161acf);
let mut rng = StdRng::seed_from_u64(seed);
let mut result: Matrix<Self> = Matrix::new(T::default(), num_samples, ncols);
for (s, res) in result.row_iter_mut().enumerate() {
for (idx, val) in res.iter_mut().enumerate() {
let step = nrows / num_samples;
let value = data
.get_row(rng.random_range(s * step..(s + 1) * step))
.unwrap()
.get(idx)
.unwrap();
*val = *value;
}
}
for start_idx in 0..num_samples {
for dim_idx in 0..ncols {
let swap_idx = rng.random_range(0..num_samples);
let swap = result[(start_idx, dim_idx)];
result[(start_idx, dim_idx)] = result[(swap_idx, dim_idx)];
result[(swap_idx, dim_idx)] = swap;
}
}
result
}
}
#[cfg(test)]
mod tests {
use std::fmt::Display;
use crate::views::{Init, Matrix};
use diskann_vector::conversion::CastFromSlice;
use half::f16;
use rand::{
distr::{Distribution, StandardUniform},
rngs::StdRng,
SeedableRng,
};
use super::*;
fn example_dataset() -> Matrix<f32> {
let data: Vec<f32> = vec![
0.203688,
0.841956,
0.855665,
0.801917,
0.754536,
0.312881,
0.217382,
0.0644115,
0.348708,
0.999495,
0.657741,
0.914681,
0.555228,
0.13253,
0.118615,
0.356464,
0.207449,
0.452471,
0.925219,
0.508498,
0.749786,
0.90786,
0.129618,
0.597719,
0.000622153,
0.569517,
0.435447,
0.558136,
0.480974,
0.711425,
0.896353,
0.275053,
0.0427179,
0.660916,
0.464851,
0.558689,
0.596543,
0.740983,
0.122136,
0.453822,
0.526895,
0.492643,
0.0951115,
0.495487,
0.446127,
0.454093,
0.160239,
0.924585,
0.901708,
0.329328,
];
Matrix::<f32>::try_from(data.into(), 10, 5).unwrap()
}
fn example_dataset_u8() -> Matrix<u8> {
let data: Vec<u8> = vec![
52, 215, 218, 204, 192, 79, 55, 16, 89, 255, 167, 233, 141, 33, 30, 91, 53, 115, 236, 130, 191, 232, 33, 152, 1, 145, 111, 142, 122, 181, ];
Matrix::<u8>::try_from(data.into(), 6, 5).unwrap()
}
fn example_dataset_i8() -> Matrix<i8> {
let data: Vec<i8> = vec![
-76, 87, 90, 76, 64, -49, -73, -112, -39, 127, 39, 105, 13, -95, -98, -37, -75, -13, 108, 2, -37, -75, -13, 108, 2, 17, -17, 14, -6, 53, ];
Matrix::<i8>::try_from(data.into(), 6, 5).unwrap()
}
fn test_for_type<T>(data: Matrix<T>)
where
T: SampleLatinHyperCube + PartialEq + std::fmt::Debug + Display,
StandardUniform: Distribution<T>,
{
let x = Matrix::<T>::new(T::default(), 0, 10);
assert_eq!(
T::sample_latin_hypercube(x.as_view(), 1, None),
Matrix::<T>::new(T::default(), 1, x.ncols())
);
let x = Matrix::<T>::new(T::default(), 1, 0);
assert_eq!(
T::sample_latin_hypercube(x.as_view(), 1, None),
Matrix::<T>::new(T::default(), 1, x.ncols())
);
let mut rng: StdRng = StdRng::seed_from_u64(0xaf2f5fa0b5161acf);
let dist = StandardUniform;
for dim in 1..20 {
let x = Matrix::<T>::new(Init(|| dist.sample(&mut rng)), 1, dim);
assert_eq!(
T::sample_latin_hypercube(x.as_view(), 1, None),
Matrix::<T>::try_from(x.row(0).to_vec().into_boxed_slice(), 1, dim).unwrap()
);
}
let starts = T::sample_latin_hypercube(data.as_view(), 2, None);
for s in starts.row_iter() {
for (col, &val) in s.iter().enumerate() {
let col_vals: Vec<T> = (0..data.nrows())
.map(|row| {
*data
.get_row(row)
.expect("Row must exist")
.get(col)
.expect("Column must exist")
})
.collect();
assert!(
col_vals.contains(&val),
"Value {} in column {} not found in data",
val,
col
);
}
}
}
#[test]
fn test_f32() {
test_for_type(example_dataset())
}
#[test]
fn test_f16() {
let data = example_dataset();
let mut data_f16 = Matrix::<f16>::new(f16::default(), data.nrows(), data.ncols());
data_f16.as_mut_slice().cast_from_slice(data.as_slice());
test_for_type(data_f16);
}
#[test]
fn test_u8() {
test_for_type(example_dataset_u8());
}
#[test]
fn test_i8() {
test_for_type(example_dataset_i8());
}
}