use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SamplingMethod {
MonteCarlo,
LatinHypercube,
}
impl std::str::FromStr for SamplingMethod {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"monte_carlo" | "montecarlo" | "mc" => Ok(Self::MonteCarlo),
"latin_hypercube" | "latinhypercube" | "lhs" => Ok(Self::LatinHypercube),
_ => Err(format!(
"Unknown sampling method: {s}. Use 'monte_carlo' or 'latin_hypercube'"
)),
}
}
}
impl std::fmt::Display for SamplingMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MonteCarlo => write!(f, "monte_carlo"),
Self::LatinHypercube => write!(f, "latin_hypercube"),
}
}
}
pub struct Sampler {
method: SamplingMethod,
rng: StdRng,
}
impl Sampler {
#[must_use]
pub fn new(method: SamplingMethod, seed: Option<u64>) -> Self {
let rng = seed.map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::seed_from_u64);
Self { method, rng }
}
#[must_use]
pub const fn method(&self) -> SamplingMethod {
self.method
}
pub fn generate_uniform_samples(&mut self, n: usize) -> Vec<f64> {
match self.method {
SamplingMethod::MonteCarlo => self.monte_carlo_samples(n),
SamplingMethod::LatinHypercube => self.latin_hypercube_samples(n),
}
}
pub fn generate_uniform_samples_nd(&mut self, n: usize, d: usize) -> Vec<Vec<f64>> {
match self.method {
SamplingMethod::MonteCarlo => (0..d).map(|_| self.monte_carlo_samples(n)).collect(),
SamplingMethod::LatinHypercube => self.latin_hypercube_samples_nd(n, d),
}
}
fn monte_carlo_samples(&mut self, n: usize) -> Vec<f64> {
(0..n).map(|_| self.rng.random::<f64>()).collect()
}
fn latin_hypercube_samples(&mut self, n: usize) -> Vec<f64> {
let mut samples: Vec<f64> = (0..n)
.map(|i| {
let lower = i as f64 / n as f64;
let upper = (i + 1) as f64 / n as f64;
self.rng.random::<f64>().mul_add(upper - lower, lower)
})
.collect();
for i in (1..n).rev() {
let j = self.rng.random_range(0..=i);
samples.swap(i, j);
}
samples
}
fn latin_hypercube_samples_nd(&mut self, n: usize, d: usize) -> Vec<Vec<f64>> {
(0..d).map(|_| self.latin_hypercube_samples(n)).collect()
}
pub const fn rng_mut(&mut self) -> &mut StdRng {
&mut self.rng
}
}
#[derive(Debug, Clone)]
pub struct SampleStats {
pub mean: f64,
pub variance: f64,
pub min: f64,
pub max: f64,
}
impl SampleStats {
pub fn from_samples(samples: &[f64]) -> Self {
if samples.is_empty() {
return Self {
mean: 0.0,
variance: 0.0,
min: 0.0,
max: 0.0,
};
}
let n = samples.len() as f64;
let mean = samples.iter().sum::<f64>() / n;
let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
let min = samples.iter().copied().fold(f64::INFINITY, f64::min);
let max = samples.iter().copied().fold(f64::NEG_INFINITY, f64::max);
Self {
mean,
variance,
min,
max,
}
}
}
#[allow(clippy::float_cmp)]
#[allow(clippy::similar_names)]
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_sampling_method_from_str() {
assert_eq!(
SamplingMethod::from_str("monte_carlo").unwrap(),
SamplingMethod::MonteCarlo
);
assert_eq!(
SamplingMethod::from_str("latin_hypercube").unwrap(),
SamplingMethod::LatinHypercube
);
assert_eq!(
SamplingMethod::from_str("LHS").unwrap(),
SamplingMethod::LatinHypercube
);
assert!(SamplingMethod::from_str("invalid").is_err());
}
#[test]
fn test_monte_carlo_samples() {
let mut sampler = Sampler::new(SamplingMethod::MonteCarlo, Some(12345));
let samples = sampler.generate_uniform_samples(1000);
assert_eq!(samples.len(), 1000);
assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
assert!((mean - 0.5).abs() < 0.05);
}
#[test]
fn test_latin_hypercube_samples() {
let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(12345));
let samples = sampler.generate_uniform_samples(1000);
assert_eq!(samples.len(), 1000);
assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
assert!((mean - 0.5).abs() < 0.02);
let n = samples.len();
let mut stratum_counts = vec![0; n];
for &sample in &samples {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let stratum = (sample * n as f64).floor() as usize;
if stratum < n {
stratum_counts[stratum] += 1;
}
}
let variance: f64 = stratum_counts
.iter()
.map(|&c| (f64::from(c) - 1.0).powi(2))
.sum::<f64>()
/ n as f64;
assert!(
variance < 0.1,
"LHS stratum counts should be uniform, variance: {variance}"
);
}
#[test]
fn test_lhs_better_convergence() {
let n = 1000;
let mut mc_variances = Vec::new();
for seed in 0..10 {
let mut sampler = Sampler::new(SamplingMethod::MonteCarlo, Some(seed));
let samples = sampler.generate_uniform_samples(n);
let mean = samples.iter().sum::<f64>() / n as f64;
mc_variances.push((mean - 0.5).powi(2));
}
let mc_avg_variance: f64 = mc_variances.iter().sum::<f64>() / mc_variances.len() as f64;
let mut lhs_variances = Vec::new();
for seed in 0..10 {
let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(seed));
let samples = sampler.generate_uniform_samples(n);
let mean = samples.iter().sum::<f64>() / n as f64;
lhs_variances.push((mean - 0.5).powi(2));
}
let lhs_avg_variance: f64 = lhs_variances.iter().sum::<f64>() / lhs_variances.len() as f64;
assert!(
lhs_avg_variance < mc_avg_variance,
"LHS ({lhs_avg_variance}) should have lower variance than MC ({mc_avg_variance})"
);
}
#[test]
fn test_seed_reproducibility() {
let mut sampler1 = Sampler::new(SamplingMethod::LatinHypercube, Some(42));
let samples1 = sampler1.generate_uniform_samples(100);
let mut sampler2 = Sampler::new(SamplingMethod::LatinHypercube, Some(42));
let samples2 = sampler2.generate_uniform_samples(100);
assert_eq!(
samples1, samples2,
"Same seed should produce identical results"
);
}
#[test]
fn test_multidimensional_samples() {
let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(12345));
let samples = sampler.generate_uniform_samples_nd(100, 3);
assert_eq!(samples.len(), 3);
assert!(samples.iter().all(|dim| dim.len() == 100));
assert!(samples
.iter()
.all(|dim| dim.iter().all(|&x| (0.0..1.0).contains(&x))));
}
#[test]
fn test_sample_stats() {
let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let stats = SampleStats::from_samples(&samples);
assert_eq!(stats.mean, 3.0);
assert_eq!(stats.min, 1.0);
assert_eq!(stats.max, 5.0);
assert!((stats.variance - 2.0).abs() < 0.001);
}
}