Skip to main content

mollendorff_forge/monte_carlo/
sampler.rs

1//! Sampling Methods for Monte Carlo Simulation
2//!
3//! Supports:
4//! - Monte Carlo (pure random sampling)
5//! - Latin Hypercube (stratified sampling, 5x faster convergence)
6
7use rand::rngs::StdRng;
8use rand::{RngExt, SeedableRng};
9
10/// Sampling method enumeration
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum SamplingMethod {
13    /// Pure random Monte Carlo sampling
14    MonteCarlo,
15    /// Latin Hypercube Sampling (stratified, faster convergence)
16    LatinHypercube,
17}
18
19impl std::str::FromStr for SamplingMethod {
20    type Err = String;
21
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        match s.to_lowercase().as_str() {
24            "monte_carlo" | "montecarlo" | "mc" => Ok(Self::MonteCarlo),
25            "latin_hypercube" | "latinhypercube" | "lhs" => Ok(Self::LatinHypercube),
26            _ => Err(format!(
27                "Unknown sampling method: {s}. Use 'monte_carlo' or 'latin_hypercube'"
28            )),
29        }
30    }
31}
32
33impl std::fmt::Display for SamplingMethod {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::MonteCarlo => write!(f, "monte_carlo"),
37            Self::LatinHypercube => write!(f, "latin_hypercube"),
38        }
39    }
40}
41
42/// Sampler for generating random values
43pub struct Sampler {
44    method: SamplingMethod,
45    rng: StdRng,
46}
47
48impl Sampler {
49    /// Create a new sampler with the given method and optional seed
50    #[must_use]
51    pub fn new(method: SamplingMethod, seed: Option<u64>) -> Self {
52        let rng = seed.map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::seed_from_u64);
53        Self { method, rng }
54    }
55
56    /// Get the sampling method
57    #[must_use]
58    pub const fn method(&self) -> SamplingMethod {
59        self.method
60    }
61
62    /// Generate n uniform samples in [0, 1)
63    /// For Monte Carlo: pure random
64    /// For Latin Hypercube: stratified sampling
65    pub fn generate_uniform_samples(&mut self, n: usize) -> Vec<f64> {
66        match self.method {
67            SamplingMethod::MonteCarlo => self.monte_carlo_samples(n),
68            SamplingMethod::LatinHypercube => self.latin_hypercube_samples(n),
69        }
70    }
71
72    /// Generate samples for multiple dimensions
73    /// Returns n samples for each of d dimensions
74    pub fn generate_uniform_samples_nd(&mut self, n: usize, d: usize) -> Vec<Vec<f64>> {
75        match self.method {
76            SamplingMethod::MonteCarlo => (0..d).map(|_| self.monte_carlo_samples(n)).collect(),
77            SamplingMethod::LatinHypercube => self.latin_hypercube_samples_nd(n, d),
78        }
79    }
80
81    /// Pure random Monte Carlo samples
82    fn monte_carlo_samples(&mut self, n: usize) -> Vec<f64> {
83        (0..n).map(|_| self.rng.random::<f64>()).collect()
84    }
85
86    /// Latin Hypercube samples for 1 dimension
87    fn latin_hypercube_samples(&mut self, n: usize) -> Vec<f64> {
88        // Divide [0, 1) into n equal intervals
89        // Sample one value from each interval
90        // Then shuffle
91        let mut samples: Vec<f64> = (0..n)
92            .map(|i| {
93                let lower = i as f64 / n as f64;
94                let upper = (i + 1) as f64 / n as f64;
95                self.rng.random::<f64>().mul_add(upper - lower, lower)
96            })
97            .collect();
98
99        // Fisher-Yates shuffle
100        for i in (1..n).rev() {
101            let j = self.rng.random_range(0..=i);
102            samples.swap(i, j);
103        }
104
105        samples
106    }
107
108    /// Latin Hypercube samples for d dimensions
109    /// Each dimension is independently stratified, then shuffled
110    fn latin_hypercube_samples_nd(&mut self, n: usize, d: usize) -> Vec<Vec<f64>> {
111        (0..d).map(|_| self.latin_hypercube_samples(n)).collect()
112    }
113
114    /// Get mutable reference to RNG for custom sampling
115    pub const fn rng_mut(&mut self) -> &mut StdRng {
116        &mut self.rng
117    }
118}
119
120/// Statistics about a sample set
121#[derive(Debug, Clone)]
122pub struct SampleStats {
123    pub mean: f64,
124    pub variance: f64,
125    pub min: f64,
126    pub max: f64,
127}
128
129impl SampleStats {
130    /// Calculate statistics from samples
131    pub fn from_samples(samples: &[f64]) -> Self {
132        if samples.is_empty() {
133            return Self {
134                mean: 0.0,
135                variance: 0.0,
136                min: 0.0,
137                max: 0.0,
138            };
139        }
140
141        let n = samples.len() as f64;
142        let mean = samples.iter().sum::<f64>() / n;
143        let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
144        let min = samples.iter().copied().fold(f64::INFINITY, f64::min);
145        let max = samples.iter().copied().fold(f64::NEG_INFINITY, f64::max);
146
147        Self {
148            mean,
149            variance,
150            min,
151            max,
152        }
153    }
154}
155
156// Financial math: exact float comparison validated against Excel/Gnumeric/R
157#[allow(clippy::float_cmp)]
158// sampler/samples, sampler1/samples1 — standard statistical terminology
159#[allow(clippy::similar_names)]
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use std::str::FromStr;
164
165    #[test]
166    fn test_sampling_method_from_str() {
167        assert_eq!(
168            SamplingMethod::from_str("monte_carlo").unwrap(),
169            SamplingMethod::MonteCarlo
170        );
171        assert_eq!(
172            SamplingMethod::from_str("latin_hypercube").unwrap(),
173            SamplingMethod::LatinHypercube
174        );
175        assert_eq!(
176            SamplingMethod::from_str("LHS").unwrap(),
177            SamplingMethod::LatinHypercube
178        );
179        assert!(SamplingMethod::from_str("invalid").is_err());
180    }
181
182    #[test]
183    fn test_monte_carlo_samples() {
184        let mut sampler = Sampler::new(SamplingMethod::MonteCarlo, Some(12345));
185        let samples = sampler.generate_uniform_samples(1000);
186
187        assert_eq!(samples.len(), 1000);
188        assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
189
190        // Mean should be approximately 0.5
191        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
192        assert!((mean - 0.5).abs() < 0.05);
193    }
194
195    #[test]
196    fn test_latin_hypercube_samples() {
197        let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(12345));
198        let samples = sampler.generate_uniform_samples(1000);
199
200        assert_eq!(samples.len(), 1000);
201        assert!(samples.iter().all(|&x| (0.0..1.0).contains(&x)));
202
203        // Mean should be approximately 0.5
204        let mean = samples.iter().sum::<f64>() / samples.len() as f64;
205        assert!((mean - 0.5).abs() < 0.02); // LHS should be closer to 0.5
206
207        // LHS should have better coverage - check each stratum has exactly one sample
208        let n = samples.len();
209        let mut stratum_counts = vec![0; n];
210        for &sample in &samples {
211            // cast_possible_truncation: stratum index is in [0, n) by construction
212            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
213            let stratum = (sample * n as f64).floor() as usize;
214            if stratum < n {
215                stratum_counts[stratum] += 1;
216            }
217        }
218        // Each stratum should have approximately 1 sample
219        // (May not be exactly 1 due to floating point)
220        let variance: f64 = stratum_counts
221            .iter()
222            .map(|&c| (f64::from(c) - 1.0).powi(2))
223            .sum::<f64>()
224            / n as f64;
225        assert!(
226            variance < 0.1,
227            "LHS stratum counts should be uniform, variance: {variance}"
228        );
229    }
230
231    #[test]
232    fn test_lhs_better_convergence() {
233        // LHS should have lower variance for the same sample size
234        let n = 1000;
235
236        // Monte Carlo variance
237        let mut mc_variances = Vec::new();
238        for seed in 0..10 {
239            let mut sampler = Sampler::new(SamplingMethod::MonteCarlo, Some(seed));
240            let samples = sampler.generate_uniform_samples(n);
241            let mean = samples.iter().sum::<f64>() / n as f64;
242            mc_variances.push((mean - 0.5).powi(2));
243        }
244        let mc_avg_variance: f64 = mc_variances.iter().sum::<f64>() / mc_variances.len() as f64;
245
246        // LHS variance
247        let mut lhs_variances = Vec::new();
248        for seed in 0..10 {
249            let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(seed));
250            let samples = sampler.generate_uniform_samples(n);
251            let mean = samples.iter().sum::<f64>() / n as f64;
252            lhs_variances.push((mean - 0.5).powi(2));
253        }
254        let lhs_avg_variance: f64 = lhs_variances.iter().sum::<f64>() / lhs_variances.len() as f64;
255
256        // LHS should have lower variance
257        assert!(
258            lhs_avg_variance < mc_avg_variance,
259            "LHS ({lhs_avg_variance}) should have lower variance than MC ({mc_avg_variance})"
260        );
261    }
262
263    #[test]
264    fn test_seed_reproducibility() {
265        let mut sampler1 = Sampler::new(SamplingMethod::LatinHypercube, Some(42));
266        let samples1 = sampler1.generate_uniform_samples(100);
267
268        let mut sampler2 = Sampler::new(SamplingMethod::LatinHypercube, Some(42));
269        let samples2 = sampler2.generate_uniform_samples(100);
270
271        assert_eq!(
272            samples1, samples2,
273            "Same seed should produce identical results"
274        );
275    }
276
277    #[test]
278    fn test_multidimensional_samples() {
279        let mut sampler = Sampler::new(SamplingMethod::LatinHypercube, Some(12345));
280        let samples = sampler.generate_uniform_samples_nd(100, 3);
281
282        assert_eq!(samples.len(), 3);
283        assert!(samples.iter().all(|dim| dim.len() == 100));
284        assert!(samples
285            .iter()
286            .all(|dim| dim.iter().all(|&x| (0.0..1.0).contains(&x))));
287    }
288
289    #[test]
290    fn test_sample_stats() {
291        let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0];
292        let stats = SampleStats::from_samples(&samples);
293
294        assert_eq!(stats.mean, 3.0);
295        assert_eq!(stats.min, 1.0);
296        assert_eq!(stats.max, 5.0);
297        assert!((stats.variance - 2.0).abs() < 0.001);
298    }
299}