Skip to main content

simular/engine/
rng.rs

1//! Deterministic random number generation.
2//!
3//! Implements PCG (Permuted Congruential Generator) with partitioned seeds
4//! for reproducible parallel execution.
5//!
6//! # Reproducibility Guarantee
7//!
8//! Given the same master seed, all random number sequences will be
9//! bitwise-identical across:
10//! - Different runs
11//! - Different platforms
12//! - Different thread counts (via partitioning)
13
14use rand::prelude::*;
15use rand_pcg::Pcg64;
16use serde::{Deserialize, Serialize};
17
18/// Deterministic, reproducible random number generator.
19///
20/// Based on PCG (Permuted Congruential Generator) which provides:
21/// - Excellent statistical properties
22/// - Fast generation
23/// - Predictable sequences from seed
24/// - Independent streams via partitioning
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct SimRng {
27    /// Master seed for reproducibility.
28    master_seed: u64,
29    /// Current stream index for partitioning.
30    stream: u64,
31    /// Internal PCG state.
32    rng: Pcg64,
33}
34
35impl SimRng {
36    /// Create a new RNG with the given master seed.
37    #[must_use]
38    pub fn new(master_seed: u64) -> Self {
39        let rng = Pcg64::seed_from_u64(master_seed);
40        Self {
41            master_seed,
42            stream: 0,
43            rng,
44        }
45    }
46
47    /// Get the master seed.
48    #[must_use]
49    pub const fn master_seed(&self) -> u64 {
50        self.master_seed
51    }
52
53    /// Get current stream index.
54    #[must_use]
55    pub const fn stream(&self) -> u64 {
56        self.stream
57    }
58
59    /// Create partitioned RNGs for parallel execution.
60    ///
61    /// Each partition gets an independent stream derived from the master seed,
62    /// ensuring reproducibility regardless of execution order.
63    ///
64    /// # Example
65    ///
66    /// ```rust
67    /// use simular::engine::rng::SimRng;
68    ///
69    /// let mut rng = SimRng::new(42);
70    /// let partitions = rng.partition(4);
71    /// assert_eq!(partitions.len(), 4);
72    /// ```
73    #[must_use]
74    pub fn partition(&mut self, n: usize) -> Vec<Self> {
75        let partitions: Vec<Self> = (0..n)
76            .map(|i| {
77                let stream = self.stream + i as u64;
78                let seed = self
79                    .master_seed
80                    .wrapping_add(stream.wrapping_mul(0x9E37_79B9_7F4A_7C15));
81                Self {
82                    master_seed: self.master_seed,
83                    stream,
84                    rng: Pcg64::seed_from_u64(seed),
85                }
86            })
87            .collect();
88
89        self.stream += n as u64;
90        partitions
91    }
92
93    /// Generate a random f64 in [0, 1).
94    pub fn gen_f64(&mut self) -> f64 {
95        self.rng.gen()
96    }
97
98    /// Generate a random f64 in the given range.
99    ///
100    /// # Panics
101    ///
102    /// Panics if `min > max`.
103    pub fn gen_range_f64(&mut self, min: f64, max: f64) -> f64 {
104        assert!(min <= max, "Invalid range: min > max");
105        min + (max - min) * self.gen_f64()
106    }
107
108    /// Generate a random u64.
109    pub fn gen_u64(&mut self) -> u64 {
110        self.rng.gen()
111    }
112
113    /// Generate n random f64 samples in [0, 1).
114    #[must_use]
115    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
116        (0..n).map(|_| self.gen_f64()).collect()
117    }
118
119    /// Generate a standard normal sample using Box-Muller transform.
120    pub fn gen_standard_normal(&mut self) -> f64 {
121        // Box-Muller transform
122        let u1 = self.gen_f64();
123        let u2 = self.gen_f64();
124
125        // Avoid log(0)
126        let u1 = if u1 < f64::EPSILON { f64::EPSILON } else { u1 };
127
128        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
129    }
130
131    /// Generate a normal sample with given mean and std.
132    pub fn gen_normal(&mut self, mean: f64, std: f64) -> f64 {
133        mean + std * self.gen_standard_normal()
134    }
135
136    /// Get RNG state as bytes for hashing (audit logging).
137    ///
138    /// Returns a deterministic byte representation of the RNG state.
139    #[must_use]
140    pub fn state_bytes(&self) -> Vec<u8> {
141        // Use master seed, stream, and serialized RNG state
142        let mut bytes = Vec::with_capacity(24);
143        bytes.extend_from_slice(&self.master_seed.to_le_bytes());
144        bytes.extend_from_slice(&self.stream.to_le_bytes());
145        // Also include serialized PCG state for uniqueness
146        if let Ok(serialized) = bincode::serialize(&self.rng) {
147            bytes.extend_from_slice(&serialized);
148        }
149        bytes
150    }
151
152    /// Save RNG state for checkpoint.
153    ///
154    /// Note: PCG internal state is not directly serializable, so we save
155    /// enough information to recreate the RNG at the same point in the stream.
156    #[must_use]
157    pub fn save_state(&self) -> RngState {
158        // Generate a sequence of values that can be used to verify restoration
159        let mut test_rng = self.rng.clone();
160        let verification: Vec<u64> = (0..4).map(|_| test_rng.gen()).collect();
161
162        RngState {
163            master_seed: self.master_seed,
164            stream: self.stream,
165            verification_values: Some(verification),
166        }
167    }
168
169    /// Restore RNG state from checkpoint.
170    ///
171    /// # Errors
172    ///
173    /// Returns error if state cannot be restored.
174    pub fn restore_state(&mut self, state: &RngState) -> Result<(), RngRestoreError> {
175        if state.master_seed != self.master_seed {
176            return Err(RngRestoreError::SeedMismatch {
177                expected: self.master_seed,
178                found: state.master_seed,
179            });
180        }
181
182        self.stream = state.stream;
183
184        // Recreate from seed and stream
185        let seed = self
186            .master_seed
187            .wrapping_add(self.stream.wrapping_mul(0x9E37_79B9_7F4A_7C15));
188        self.rng = Pcg64::seed_from_u64(seed);
189
190        Ok(())
191    }
192}
193
194/// Saved RNG state for checkpointing.
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct RngState {
197    /// Master seed.
198    pub master_seed: u64,
199    /// Stream index.
200    pub stream: u64,
201    /// Verification values for testing restoration (optional).
202    pub verification_values: Option<Vec<u64>>,
203}
204
205/// Error restoring RNG state.
206#[derive(Debug, Clone, thiserror::Error)]
207pub enum RngRestoreError {
208    /// Seed mismatch.
209    #[error("Seed mismatch: expected {expected}, found {found}")]
210    SeedMismatch {
211        /// Expected seed.
212        expected: u64,
213        /// Found seed.
214        found: u64,
215    },
216    /// Corrupted state.
217    #[error("Corrupted RNG state")]
218    CorruptedState,
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    /// Property: Same seed produces same sequence.
226    #[test]
227    fn test_reproducibility() {
228        let mut rng1 = SimRng::new(42);
229        let mut rng2 = SimRng::new(42);
230
231        let seq1: Vec<f64> = (0..100).map(|_| rng1.gen_f64()).collect();
232        let seq2: Vec<f64> = (0..100).map(|_| rng2.gen_f64()).collect();
233
234        assert_eq!(seq1, seq2, "Same seed must produce identical sequences");
235    }
236
237    /// Property: Different seeds produce different sequences.
238    #[test]
239    fn test_different_seeds() {
240        let mut rng1 = SimRng::new(42);
241        let mut rng2 = SimRng::new(43);
242
243        let seq1: Vec<f64> = (0..100).map(|_| rng1.gen_f64()).collect();
244        let seq2: Vec<f64> = (0..100).map(|_| rng2.gen_f64()).collect();
245
246        assert_ne!(
247            seq1, seq2,
248            "Different seeds must produce different sequences"
249        );
250    }
251
252    /// Property: Partitions are independent.
253    #[test]
254    fn test_partition_independence() {
255        let mut rng = SimRng::new(42);
256        let mut partitions = rng.partition(4);
257
258        // Each partition should produce different sequences
259        let seqs: Vec<Vec<f64>> = partitions
260            .iter_mut()
261            .map(|p| (0..10).map(|_| p.gen_f64()).collect())
262            .collect();
263
264        for i in 0..seqs.len() {
265            for j in (i + 1)..seqs.len() {
266                assert_ne!(seqs[i], seqs[j], "Partitions must be independent");
267            }
268        }
269    }
270
271    /// Property: Partitions are reproducible.
272    #[test]
273    fn test_partition_reproducibility() {
274        let mut rng1 = SimRng::new(42);
275        let mut rng2 = SimRng::new(42);
276
277        let mut partitions1 = rng1.partition(4);
278        let mut partitions2 = rng2.partition(4);
279
280        for (p1, p2) in partitions1.iter_mut().zip(partitions2.iter_mut()) {
281            let seq1: Vec<f64> = (0..10).map(|_| p1.gen_f64()).collect();
282            let seq2: Vec<f64> = (0..10).map(|_| p2.gen_f64()).collect();
283            assert_eq!(seq1, seq2, "Partition sequences must be reproducible");
284        }
285    }
286
287    /// Property: Range sampling stays in bounds.
288    #[test]
289    fn test_range_bounds() {
290        let mut rng = SimRng::new(42);
291
292        for _ in 0..1000 {
293            let v = rng.gen_range_f64(-10.0, 10.0);
294            assert!((-10.0..10.0).contains(&v), "Value out of range: {v}");
295        }
296    }
297
298    /// Property: Normal distribution has correct moments.
299    #[test]
300    fn test_normal_distribution() {
301        let mut rng = SimRng::new(42);
302        let n = 10000;
303        let samples: Vec<f64> = (0..n).map(|_| rng.gen_standard_normal()).collect();
304
305        let mean: f64 = samples.iter().sum::<f64>() / n as f64;
306        let variance: f64 = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
307
308        // Mean should be close to 0
309        assert!(mean.abs() < 0.1, "Mean {mean} too far from 0");
310        // Variance should be close to 1
311        assert!(
312            (variance - 1.0).abs() < 0.1,
313            "Variance {variance} too far from 1"
314        );
315    }
316
317    /// Property: State save/restore preserves seed and stream info.
318    /// Note: Full RNG state restoration requires custom serialization which PCG doesn't support.
319    #[test]
320    fn test_state_save_restore() {
321        let rng = SimRng::new(42);
322
323        // Save state
324        let state = rng.save_state();
325
326        // Verify state contains correct info
327        assert_eq!(state.master_seed, 42);
328        assert_eq!(state.stream, 0);
329        assert!(state.verification_values.is_some());
330
331        // Restore to a new RNG
332        let mut rng2 = SimRng::new(42);
333        let result = rng2.restore_state(&state);
334        assert!(result.is_ok());
335        assert_eq!(rng2.master_seed(), 42);
336        assert_eq!(rng2.stream(), 0);
337    }
338
339    #[test]
340    fn test_gen_u64() {
341        let mut rng = SimRng::new(42);
342        let v1 = rng.gen_u64();
343        let v2 = rng.gen_u64();
344        // Should generate different values
345        assert_ne!(v1, v2);
346    }
347
348    #[test]
349    fn test_sample_n() {
350        let mut rng = SimRng::new(42);
351        let samples = rng.sample_n(10);
352        assert_eq!(samples.len(), 10);
353        // All samples should be in [0, 1)
354        for s in &samples {
355            assert!(*s >= 0.0 && *s < 1.0);
356        }
357    }
358
359    #[test]
360    fn test_gen_normal() {
361        let mut rng = SimRng::new(42);
362        let v = rng.gen_normal(10.0, 2.0);
363        // Should be somewhere in the plausible range
364        assert!(v > 0.0 && v < 20.0);
365    }
366
367    #[test]
368    fn test_restore_state_seed_mismatch() {
369        let rng = SimRng::new(42);
370        let state = rng.save_state();
371
372        let mut rng2 = SimRng::new(99); // Different seed
373        let result = rng2.restore_state(&state);
374        assert!(result.is_err());
375
376        if let Err(e) = result {
377            let display = format!("{}", e);
378            assert!(display.contains("mismatch"));
379        }
380    }
381
382    #[test]
383    fn test_rng_state_clone() {
384        let rng = SimRng::new(42);
385        let state = rng.save_state();
386        let cloned = state.clone();
387        assert_eq!(cloned.master_seed, state.master_seed);
388        assert_eq!(cloned.stream, state.stream);
389    }
390
391    #[test]
392    fn test_rng_restore_error_clone() {
393        let err = RngRestoreError::SeedMismatch {
394            expected: 42,
395            found: 99,
396        };
397        let cloned = err.clone();
398        assert!(matches!(cloned, RngRestoreError::SeedMismatch { .. }));
399
400        let err2 = RngRestoreError::CorruptedState;
401        let cloned2 = err2.clone();
402        assert!(matches!(cloned2, RngRestoreError::CorruptedState));
403    }
404
405    #[test]
406    fn test_rng_restore_error_display() {
407        let err = RngRestoreError::CorruptedState;
408        let display = format!("{}", err);
409        assert!(display.contains("Corrupted"));
410    }
411
412    #[test]
413    fn test_sim_rng_clone() {
414        let rng = SimRng::new(42);
415        let cloned = rng.clone();
416        assert_eq!(cloned.master_seed(), rng.master_seed());
417    }
418
419    #[test]
420    fn test_sim_rng_debug() {
421        let rng = SimRng::new(42);
422        let debug = format!("{:?}", rng);
423        assert!(debug.contains("SimRng"));
424    }
425
426    #[test]
427    fn test_rng_state_debug() {
428        let rng = SimRng::new(42);
429        let state = rng.save_state();
430        let debug = format!("{:?}", state);
431        assert!(debug.contains("RngState"));
432    }
433
434    #[test]
435    fn test_rng_restore_error_debug() {
436        let err = RngRestoreError::CorruptedState;
437        let debug = format!("{:?}", err);
438        assert!(debug.contains("CorruptedState"));
439    }
440
441    /// Mutation test: gen_normal must add mean correctly (catches + -> - mutation)
442    #[test]
443    fn test_gen_normal_mean_is_added() {
444        let mut rng = SimRng::new(42);
445        // Generate many samples with mean=100, std=0
446        // If std=0, result must equal mean exactly
447        for _ in 0..10 {
448            let v = rng.gen_normal(100.0, 0.0);
449            assert!(
450                (v - 100.0).abs() < 1e-10,
451                "gen_normal with std=0 must return mean exactly, got {v}"
452            );
453        }
454    }
455
456    /// Mutation test: gen_normal must multiply std correctly (catches * -> + mutation)
457    #[test]
458    fn test_gen_normal_std_is_multiplied() {
459        let mut rng = SimRng::new(42);
460        // With mean=0, std=10, variance should be 100
461        let samples: Vec<f64> = (0..10000).map(|_| rng.gen_normal(0.0, 10.0)).collect();
462        let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
463        let variance: f64 =
464            samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
465        // Variance should be close to 100 (std^2)
466        assert!(
467            (variance - 100.0).abs() < 15.0,
468            "Variance {variance} not close to 100"
469        );
470    }
471
472    /// Mutation test: gen_normal return value correctness (catches -> 1.0 mutation)
473    #[test]
474    fn test_gen_normal_not_constant() {
475        let mut rng = SimRng::new(42);
476        let samples: Vec<f64> = (0..100).map(|_| rng.gen_normal(0.0, 1.0)).collect();
477        // Should not all be equal to 1.0
478        let all_ones = samples.iter().all(|&x| (x - 1.0).abs() < 1e-10);
479        assert!(!all_ones, "gen_normal should not return constant 1.0");
480        // Should have variance
481        let unique_count = samples
482            .iter()
483            .map(|x| (*x * 1e6) as i64)
484            .collect::<std::collections::HashSet<_>>()
485            .len();
486        assert!(
487            unique_count > 50,
488            "gen_normal should produce varied outputs"
489        );
490    }
491
492    /// Mutation test: partition must increment stream by n (catches += -> *= mutation)
493    #[test]
494    fn test_partition_stream_increment() {
495        let mut rng = SimRng::new(42);
496        assert_eq!(rng.stream(), 0);
497
498        let _ = rng.partition(4);
499        assert_eq!(
500            rng.stream(),
501            4,
502            "Stream should increment by partition count"
503        );
504
505        let _ = rng.partition(3);
506        assert_eq!(rng.stream(), 7, "Stream should be 4 + 3 = 7");
507
508        // Catches *= mutation: if *= were used, stream would be 0*4=0, then 0*3=0
509        // or 4*3=12 instead of 7
510    }
511
512    /// Mutation test: gen_standard_normal uses correct formula (catches * -> / mutation)
513    #[test]
514    fn test_standard_normal_formula_correctness() {
515        let mut rng = SimRng::new(42);
516        // Box-Muller should produce values typically in [-4, 4] range
517        // If * were replaced with /, the angle would be wrong
518        let samples: Vec<f64> = (0..10000).map(|_| rng.gen_standard_normal()).collect();
519
520        // Check that cos(2*PI*u2) produces full range [-1, 1]
521        // If division were used instead of multiplication, cos would have wrong argument
522        let min = samples.iter().cloned().fold(f64::INFINITY, f64::min);
523        let max = samples.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
524
525        // Standard normal should span roughly [-4, 4] with high probability
526        assert!(min < -2.0, "Min {min} should be < -2 for standard normal");
527        assert!(max > 2.0, "Max {max} should be > 2 for standard normal");
528    }
529
530    /// Mutation test: gen_standard_normal must handle near-zero u1 (catches < -> == mutation)
531    #[test]
532    fn test_standard_normal_epsilon_guard() {
533        // The guard `if u1 < f64::EPSILON` protects against log(0)
534        // If changed to ==, values just above 0 but < EPSILON would cause -Inf
535        // We test by checking that no -Inf values appear
536        let mut rng = SimRng::new(12345);
537        for _ in 0..50000 {
538            let v = rng.gen_standard_normal();
539            assert!(
540                v.is_finite(),
541                "gen_standard_normal produced non-finite value: {v}"
542            );
543        }
544    }
545
546    /// Mutation test: Box-Muller 2*PI*u2 formula (catches second * -> / mutation)
547    #[test]
548    fn test_standard_normal_angle_formula() {
549        // Box-Muller: cos(2 * PI * u2) where u2 is uniform [0,1)
550        // If the second * were /, we'd get cos(2*PI/u2) which diverges as u2->0
551        // This would produce extreme outliers. We verify statistical properties.
552        let mut rng = SimRng::new(999);
553        let samples: Vec<f64> = (0..50000).map(|_| rng.gen_standard_normal()).collect();
554
555        // Calculate kurtosis - should be close to 3 for normal
556        let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
557        let variance: f64 =
558            samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
559        let fourth_moment: f64 =
560            samples.iter().map(|x| (x - mean).powi(4)).sum::<f64>() / samples.len() as f64;
561        let kurtosis = fourth_moment / (variance * variance);
562
563        // Normal distribution has kurtosis = 3. Allow some tolerance.
564        // If * -> / mutation, kurtosis would be much higher due to outliers
565        assert!(
566            (kurtosis - 3.0).abs() < 0.5,
567            "Kurtosis {kurtosis} far from expected 3.0, suggesting formula error"
568        );
569    }
570}
571
572#[cfg(test)]
573mod proptests {
574    use super::*;
575    use proptest::prelude::*;
576
577    proptest! {
578        /// Falsification test: reproducibility holds for any seed.
579        #[test]
580        fn prop_reproducibility(seed in 0u64..u64::MAX) {
581            let mut rng1 = SimRng::new(seed);
582            let mut rng2 = SimRng::new(seed);
583
584            let seq1: Vec<f64> = (0..100).map(|_| rng1.gen_f64()).collect();
585            let seq2: Vec<f64> = (0..100).map(|_| rng2.gen_f64()).collect();
586
587            prop_assert_eq!(seq1, seq2);
588        }
589
590        /// Falsification test: values in [0, 1) for any seed.
591        #[test]
592        fn prop_unit_interval(seed in 0u64..u64::MAX) {
593            let mut rng = SimRng::new(seed);
594
595            for _ in 0..100 {
596                let v = rng.gen_f64();
597                prop_assert!(v >= 0.0 && v < 1.0, "Value {} not in [0, 1)", v);
598            }
599        }
600
601        /// Falsification test: partition count is correct.
602        #[test]
603        fn prop_partition_count(seed in 0u64..u64::MAX, n in 1usize..100) {
604            let mut rng = SimRng::new(seed);
605            let partitions = rng.partition(n);
606            prop_assert_eq!(partitions.len(), n);
607        }
608    }
609}