use crate::Variable;
use rand::{rngs::StdRng, seq::SliceRandom, Rng};
#[derive(Debug, Clone)]
pub enum Initialisation {
LatinHypercube { centred: bool },
Random,
}
impl Initialisation {
pub fn generate_samples(
&self,
n_points: usize,
variables: &[Variable],
rng: &mut StdRng,
) -> Vec<Vec<f64>> {
match self {
Initialisation::LatinHypercube { centred } => {
latin_hypercube_sample(n_points, variables, *centred, rng)
}
Initialisation::Random => (0..n_points)
.map(|_| {
variables
.iter()
.map(|var| rng.gen_range(var.0..var.1))
.collect()
})
.collect(),
}
}
}
pub fn latin_hypercube_sample(
n_points: usize,
variables: &[Variable],
centred: bool,
rng: &mut StdRng,
) -> Vec<Vec<f64>> {
if n_points == 0 {
return Vec::new();
}
let interval_size = 1.0 / (n_points as f64);
let mut sample: Vec<Vec<f64>> = Vec::new();
for var in variables {
let (lower, upper) = (var.0, var.1);
let scale = upper - lower;
let mut vec: Vec<f64> = (0..n_points)
.map(|i| i as f64 * interval_size)
.map(|xx| {
let shift = if centred {
0.5 * interval_size
} else {
rng.gen::<f64>() * interval_size
};
xx + shift
})
.map(|xx| xx * scale + lower)
.collect();
vec.shuffle(rng);
sample.push(vec);
}
(0..n_points)
.map(|i| (0..variables.len()).map(|j| sample[j][i]).collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
#[test]
fn test_random_initialisation() {
let vars = vec![Variable(-5.0, 5.0), Variable(0.0, 10.0)];
let n_points = 10;
let mut rng = StdRng::seed_from_u64(42);
let init = Initialisation::Random;
let samples = init.generate_samples(n_points, &vars, &mut rng);
assert_eq!(samples.len(), n_points);
assert_eq!(samples[0].len(), vars.len());
for sample in samples {
assert!(sample[0] >= -5.0 && sample[0] <= 5.0);
assert!(sample[1] >= 0.0 && sample[1] <= 10.0);
}
}
#[test]
fn test_lhs_properties() {
let vars = vec![Variable(0.0, 1.0), Variable(100.0, 200.0)];
let n_points = 10;
let mut rng = StdRng::seed_from_u64(42);
let init = Initialisation::LatinHypercube { centred: false };
let samples = init.generate_samples(n_points, &vars, &mut rng);
assert_eq!(samples.len(), n_points);
for sample in &samples {
assert_eq!(sample.len(), vars.len());
assert!(sample[0] >= 0.0 && sample[0] <= 1.0);
assert!(sample[1] >= 100.0 && sample[1] <= 200.0);
}
for dim in 0..vars.len() {
let mut interval_counts = vec![0; n_points];
let (min, max) = (vars[dim].0, vars[dim].1);
let interval_size = (max - min) / n_points as f64;
for sample in &samples {
let value = sample[dim];
let interval_index = ((value - min) / interval_size).floor() as usize;
let clamped_index = interval_index.min(n_points - 1);
interval_counts[clamped_index] += 1;
}
assert!(
interval_counts.iter().all(|&count| count == 1),
"LHS property failed for dimension {}",
dim
);
}
}
#[test]
fn test_lhs_centred() {
let vars = vec![Variable(0.0, 10.0)];
let n_points = 10;
let mut rng = StdRng::seed_from_u64(42);
let init = Initialisation::LatinHypercube { centred: true };
let samples = init.generate_samples(n_points, &vars, &mut rng);
let point_in_first_interval = samples.iter().find(|s| s[0] < 1.0).unwrap();
assert!((point_in_first_interval[0] - 0.5).abs() < 1e-9);
let point_in_last_interval = samples.iter().find(|s| s[0] > 9.0).unwrap();
assert!((point_in_last_interval[0] - 9.5).abs() < 1e-9);
}
}