mollendorff_forge/monte_carlo/
sampler.rs1use rand::rngs::StdRng;
8use rand::{RngExt, SeedableRng};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum SamplingMethod {
13 MonteCarlo,
15 LatinHypercube,
17}
18
19impl std::str::FromStr for SamplingMethod {
20 type Err = String;
21
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 match s.to_lowercase().as_str() {
24 "monte_carlo" | "montecarlo" | "mc" => Ok(Self::MonteCarlo),
25 "latin_hypercube" | "latinhypercube" | "lhs" => Ok(Self::LatinHypercube),
26 _ => Err(format!(
27 "Unknown sampling method: {s}. Use 'monte_carlo' or 'latin_hypercube'"
28 )),
29 }
30 }
31}
32
33impl std::fmt::Display for SamplingMethod {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 Self::MonteCarlo => write!(f, "monte_carlo"),
37 Self::LatinHypercube => write!(f, "latin_hypercube"),
38 }
39 }
40}
41
42pub struct Sampler {
44 method: SamplingMethod,
45 rng: StdRng,
46}
47
48impl Sampler {
49 #[must_use]
51 pub fn new(method: SamplingMethod, seed: Option<u64>) -> Self {
52 let rng = seed.map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::seed_from_u64);
53 Self { method, rng }
54 }
55
56 #[must_use]
58 pub const fn method(&self) -> SamplingMethod {
59 self.method
60 }
61
62 pub fn generate_uniform_samples(&mut self, n: usize) -> Vec<f64> {
66 match self.method {
67 SamplingMethod::MonteCarlo => self.monte_carlo_samples(n),
68 SamplingMethod::LatinHypercube => self.latin_hypercube_samples(n),
69 }
70 }
71
72 pub fn generate_uniform_samples_nd(&mut self, n: usize, d: usize) -> Vec<Vec<f64>> {
75 match self.method {
76 SamplingMethod::MonteCarlo => (0..d).map(|_| self.monte_carlo_samples(n)).collect(),
77 SamplingMethod::LatinHypercube => self.latin_hypercube_samples_nd(n, d),
78 }
79 }
80
81 fn monte_carlo_samples(&mut self, n: usize) -> Vec<f64> {
83 (0..n).map(|_| self.rng.random::<f64>()).collect()
84 }
85
86 fn latin_hypercube_samples(&mut self, n: usize) -> Vec<f64> {
88 let mut samples: Vec<f64> = (0..n)
92 .map(|i| {
93 let lower = i as f64 / n as f64;
94 let upper = (i + 1) as f64 / n as f64;
95 self.rng.random::<f64>().mul_add(upper - lower, lower)
96 })
97 .collect();
98
99 for i in (1..n).rev() {
101 let j = self.rng.random_range(0..=i);
102 samples.swap(i, j);
103 }
104
105 samples
106 }
107
108 fn latin_hypercube_samples_nd(&mut self, n: usize, d: usize) -> Vec<Vec<f64>> {
111 (0..d).map(|_| self.latin_hypercube_samples(n)).collect()
112 }
113
114 pub const fn rng_mut(&mut self) -> &mut StdRng {
116 &mut self.rng
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct SampleStats {
123 pub mean: f64,
124 pub variance: f64,
125 pub min: f64,
126 pub max: f64,
127}
128
129impl SampleStats {
130 pub fn from_samples(samples: &[f64]) -> Self {
132 if samples.is_empty() {
133 return Self {
134 mean: 0.0,
135 variance: 0.0,
136 min: 0.0,
137 max: 0.0,
138 };
139 }
140
141 let n = samples.len() as f64;
142 let mean = samples.iter().sum::<f64>() / n;
143 let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
144 let min = samples.iter().copied().fold(f64::INFINITY, f64::min);
145 let max = samples.iter().copied().fold(f64::NEG_INFINITY, f64::max);
146
147 Self {
148 mean,
149 variance,
150 min,
151 max,
152 }
153 }
154}
155
156#[allow(clippy::float_cmp)]
158#[allow(clippy::similar_names)]
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use std::str::FromStr;
164
165 #[test]
166 fn test_sampling_method_from_str() {
167 assert_eq!(
168 SamplingMethod::from_str("monte_carlo").unwrap(),
169 SamplingMethod::MonteCarlo
170 );
171 assert_eq!(
172 SamplingMethod::from_str("latin_hypercube").unwrap(),
173 SamplingMethod::LatinHypercube
174 );
175 assert_eq!(
176 SamplingMethod::from_str("LHS").unwrap(),
177 SamplingMethod::LatinHypercube
178 );
179 assert!(SamplingMethod::from_str("invalid").is_err());
180 }
181
182 #[test]
183 fn test_monte_carlo_samples() {
184 let mut sampler = Sampler::new(SamplingMethod::MonteCarlo, Some(12345));
185 let samples = sampler.generate_uniform_samples(1000);
186
187 assert_eq!(samples.len(), 1000);
188 assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
189
190 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
192 assert!((mean - 0.5).abs() < 0.05);
193 }
194
195 #[test]
196 fn test_latin_hypercube_samples() {
197 let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(12345));
198 let samples = sampler.generate_uniform_samples(1000);
199
200 assert_eq!(samples.len(), 1000);
201 assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
202
203 let mean = samples.iter().sum::<f64>() / samples.len() as f64;
205 assert!((mean - 0.5).abs() < 0.02); let n = samples.len();
209 let mut stratum_counts = vec![0; n];
210 for &sample in &samples {
211 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
213 let stratum = (sample * n as f64).floor() as usize;
214 if stratum < n {
215 stratum_counts[stratum] += 1;
216 }
217 }
218 let variance: f64 = stratum_counts
221 .iter()
222 .map(|&c| (f64::from(c) - 1.0).powi(2))
223 .sum::<f64>()
224 / n as f64;
225 assert!(
226 variance < 0.1,
227 "LHS stratum counts should be uniform, variance: {variance}"
228 );
229 }
230
231 #[test]
232 fn test_lhs_better_convergence() {
233 let n = 1000;
235
236 let mut mc_variances = Vec::new();
238 for seed in 0..10 {
239 let mut sampler = Sampler::new(SamplingMethod::MonteCarlo, Some(seed));
240 let samples = sampler.generate_uniform_samples(n);
241 let mean = samples.iter().sum::<f64>() / n as f64;
242 mc_variances.push((mean - 0.5).powi(2));
243 }
244 let mc_avg_variance: f64 = mc_variances.iter().sum::<f64>() / mc_variances.len() as f64;
245
246 let mut lhs_variances = Vec::new();
248 for seed in 0..10 {
249 let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(seed));
250 let samples = sampler.generate_uniform_samples(n);
251 let mean = samples.iter().sum::<f64>() / n as f64;
252 lhs_variances.push((mean - 0.5).powi(2));
253 }
254 let lhs_avg_variance: f64 = lhs_variances.iter().sum::<f64>() / lhs_variances.len() as f64;
255
256 assert!(
258 lhs_avg_variance < mc_avg_variance,
259 "LHS ({lhs_avg_variance}) should have lower variance than MC ({mc_avg_variance})"
260 );
261 }
262
263 #[test]
264 fn test_seed_reproducibility() {
265 let mut sampler1 = Sampler::new(SamplingMethod::LatinHypercube, Some(42));
266 let samples1 = sampler1.generate_uniform_samples(100);
267
268 let mut sampler2 = Sampler::new(SamplingMethod::LatinHypercube, Some(42));
269 let samples2 = sampler2.generate_uniform_samples(100);
270
271 assert_eq!(
272 samples1, samples2,
273 "Same seed should produce identical results"
274 );
275 }
276
277 #[test]
278 fn test_multidimensional_samples() {
279 let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(12345));
280 let samples = sampler.generate_uniform_samples_nd(100, 3);
281
282 assert_eq!(samples.len(), 3);
283 assert!(samples.iter().all(|dim| dim.len() == 100));
284 assert!(samples
285 .iter()
286 .all(|dim| dim.iter().all(|&x| (0.0..1.0).contains(&x))));
287 }
288
289 #[test]
290 fn test_sample_stats() {
291 let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0];
292 let stats = SampleStats::from_samples(&samples);
293
294 assert_eq!(stats.mean, 3.0);
295 assert_eq!(stats.min, 1.0);
296 assert_eq!(stats.max, 5.0);
297 assert!((stats.variance - 2.0).abs() < 0.001);
298 }
299}