1use rand::prelude::*;
15use rand_pcg::Pcg64;
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct SimRng {
27 master_seed: u64,
29 stream: u64,
31 rng: Pcg64,
33}
34
35impl SimRng {
36 #[must_use]
38 pub fn new(master_seed: u64) -> Self {
39 let rng = Pcg64::seed_from_u64(master_seed);
40 Self {
41 master_seed,
42 stream: 0,
43 rng,
44 }
45 }
46
47 #[must_use]
49 pub const fn master_seed(&self) -> u64 {
50 self.master_seed
51 }
52
53 #[must_use]
55 pub const fn stream(&self) -> u64 {
56 self.stream
57 }
58
59 #[must_use]
74 pub fn partition(&mut self, n: usize) -> Vec<Self> {
75 let partitions: Vec<Self> = (0..n)
76 .map(|i| {
77 let stream = self.stream + i as u64;
78 let seed = self
79 .master_seed
80 .wrapping_add(stream.wrapping_mul(0x9E37_79B9_7F4A_7C15));
81 Self {
82 master_seed: self.master_seed,
83 stream,
84 rng: Pcg64::seed_from_u64(seed),
85 }
86 })
87 .collect();
88
89 self.stream += n as u64;
90 partitions
91 }
92
93 pub fn gen_f64(&mut self) -> f64 {
95 self.rng.gen()
96 }
97
98 pub fn gen_range_f64(&mut self, min: f64, max: f64) -> f64 {
104 assert!(min <= max, "Invalid range: min > max");
105 min + (max - min) * self.gen_f64()
106 }
107
108 pub fn gen_u64(&mut self) -> u64 {
110 self.rng.gen()
111 }
112
113 #[must_use]
115 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
116 (0..n).map(|_| self.gen_f64()).collect()
117 }
118
119 pub fn gen_standard_normal(&mut self) -> f64 {
121 let u1 = self.gen_f64();
123 let u2 = self.gen_f64();
124
125 let u1 = if u1 < f64::EPSILON { f64::EPSILON } else { u1 };
127
128 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
129 }
130
131 pub fn gen_normal(&mut self, mean: f64, std: f64) -> f64 {
133 mean + std * self.gen_standard_normal()
134 }
135
136 #[must_use]
140 pub fn state_bytes(&self) -> Vec<u8> {
141 let mut bytes = Vec::with_capacity(24);
143 bytes.extend_from_slice(&self.master_seed.to_le_bytes());
144 bytes.extend_from_slice(&self.stream.to_le_bytes());
145 if let Ok(serialized) = bincode::serialize(&self.rng) {
147 bytes.extend_from_slice(&serialized);
148 }
149 bytes
150 }
151
152 #[must_use]
157 pub fn save_state(&self) -> RngState {
158 let mut test_rng = self.rng.clone();
160 let verification: Vec<u64> = (0..4).map(|_| test_rng.gen()).collect();
161
162 RngState {
163 master_seed: self.master_seed,
164 stream: self.stream,
165 verification_values: Some(verification),
166 }
167 }
168
169 pub fn restore_state(&mut self, state: &RngState) -> Result<(), RngRestoreError> {
175 if state.master_seed != self.master_seed {
176 return Err(RngRestoreError::SeedMismatch {
177 expected: self.master_seed,
178 found: state.master_seed,
179 });
180 }
181
182 self.stream = state.stream;
183
184 let seed = self
186 .master_seed
187 .wrapping_add(self.stream.wrapping_mul(0x9E37_79B9_7F4A_7C15));
188 self.rng = Pcg64::seed_from_u64(seed);
189
190 Ok(())
191 }
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct RngState {
197 pub master_seed: u64,
199 pub stream: u64,
201 pub verification_values: Option<Vec<u64>>,
203}
204
205#[derive(Debug, Clone, thiserror::Error)]
207pub enum RngRestoreError {
208 #[error("Seed mismatch: expected {expected}, found {found}")]
210 SeedMismatch {
211 expected: u64,
213 found: u64,
215 },
216 #[error("Corrupted RNG state")]
218 CorruptedState,
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
227 fn test_reproducibility() {
228 let mut rng1 = SimRng::new(42);
229 let mut rng2 = SimRng::new(42);
230
231 let seq1: Vec<f64> = (0..100).map(|_| rng1.gen_f64()).collect();
232 let seq2: Vec<f64> = (0..100).map(|_| rng2.gen_f64()).collect();
233
234 assert_eq!(seq1, seq2, "Same seed must produce identical sequences");
235 }
236
237 #[test]
239 fn test_different_seeds() {
240 let mut rng1 = SimRng::new(42);
241 let mut rng2 = SimRng::new(43);
242
243 let seq1: Vec<f64> = (0..100).map(|_| rng1.gen_f64()).collect();
244 let seq2: Vec<f64> = (0..100).map(|_| rng2.gen_f64()).collect();
245
246 assert_ne!(
247 seq1, seq2,
248 "Different seeds must produce different sequences"
249 );
250 }
251
252 #[test]
254 fn test_partition_independence() {
255 let mut rng = SimRng::new(42);
256 let mut partitions = rng.partition(4);
257
258 let seqs: Vec<Vec<f64>> = partitions
260 .iter_mut()
261 .map(|p| (0..10).map(|_| p.gen_f64()).collect())
262 .collect();
263
264 for i in 0..seqs.len() {
265 for j in (i + 1)..seqs.len() {
266 assert_ne!(seqs[i], seqs[j], "Partitions must be independent");
267 }
268 }
269 }
270
271 #[test]
273 fn test_partition_reproducibility() {
274 let mut rng1 = SimRng::new(42);
275 let mut rng2 = SimRng::new(42);
276
277 let mut partitions1 = rng1.partition(4);
278 let mut partitions2 = rng2.partition(4);
279
280 for (p1, p2) in partitions1.iter_mut().zip(partitions2.iter_mut()) {
281 let seq1: Vec<f64> = (0..10).map(|_| p1.gen_f64()).collect();
282 let seq2: Vec<f64> = (0..10).map(|_| p2.gen_f64()).collect();
283 assert_eq!(seq1, seq2, "Partition sequences must be reproducible");
284 }
285 }
286
287 #[test]
289 fn test_range_bounds() {
290 let mut rng = SimRng::new(42);
291
292 for _ in 0..1000 {
293 let v = rng.gen_range_f64(-10.0, 10.0);
294 assert!((-10.0..10.0).contains(&v), "Value out of range: {v}");
295 }
296 }
297
298 #[test]
300 fn test_normal_distribution() {
301 let mut rng = SimRng::new(42);
302 let n = 10000;
303 let samples: Vec<f64> = (0..n).map(|_| rng.gen_standard_normal()).collect();
304
305 let mean: f64 = samples.iter().sum::<f64>() / n as f64;
306 let variance: f64 = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n as f64;
307
308 assert!(mean.abs() < 0.1, "Mean {mean} too far from 0");
310 assert!(
312 (variance - 1.0).abs() < 0.1,
313 "Variance {variance} too far from 1"
314 );
315 }
316
317 #[test]
320 fn test_state_save_restore() {
321 let rng = SimRng::new(42);
322
323 let state = rng.save_state();
325
326 assert_eq!(state.master_seed, 42);
328 assert_eq!(state.stream, 0);
329 assert!(state.verification_values.is_some());
330
331 let mut rng2 = SimRng::new(42);
333 let result = rng2.restore_state(&state);
334 assert!(result.is_ok());
335 assert_eq!(rng2.master_seed(), 42);
336 assert_eq!(rng2.stream(), 0);
337 }
338
339 #[test]
340 fn test_gen_u64() {
341 let mut rng = SimRng::new(42);
342 let v1 = rng.gen_u64();
343 let v2 = rng.gen_u64();
344 assert_ne!(v1, v2);
346 }
347
348 #[test]
349 fn test_sample_n() {
350 let mut rng = SimRng::new(42);
351 let samples = rng.sample_n(10);
352 assert_eq!(samples.len(), 10);
353 for s in &samples {
355 assert!(*s >= 0.0 && *s < 1.0);
356 }
357 }
358
359 #[test]
360 fn test_gen_normal() {
361 let mut rng = SimRng::new(42);
362 let v = rng.gen_normal(10.0, 2.0);
363 assert!(v > 0.0 && v < 20.0);
365 }
366
367 #[test]
368 fn test_restore_state_seed_mismatch() {
369 let rng = SimRng::new(42);
370 let state = rng.save_state();
371
372 let mut rng2 = SimRng::new(99); let result = rng2.restore_state(&state);
374 assert!(result.is_err());
375
376 if let Err(e) = result {
377 let display = format!("{}", e);
378 assert!(display.contains("mismatch"));
379 }
380 }
381
382 #[test]
383 fn test_rng_state_clone() {
384 let rng = SimRng::new(42);
385 let state = rng.save_state();
386 let cloned = state.clone();
387 assert_eq!(cloned.master_seed, state.master_seed);
388 assert_eq!(cloned.stream, state.stream);
389 }
390
391 #[test]
392 fn test_rng_restore_error_clone() {
393 let err = RngRestoreError::SeedMismatch {
394 expected: 42,
395 found: 99,
396 };
397 let cloned = err.clone();
398 assert!(matches!(cloned, RngRestoreError::SeedMismatch { .. }));
399
400 let err2 = RngRestoreError::CorruptedState;
401 let cloned2 = err2.clone();
402 assert!(matches!(cloned2, RngRestoreError::CorruptedState));
403 }
404
405 #[test]
406 fn test_rng_restore_error_display() {
407 let err = RngRestoreError::CorruptedState;
408 let display = format!("{}", err);
409 assert!(display.contains("Corrupted"));
410 }
411
412 #[test]
413 fn test_sim_rng_clone() {
414 let rng = SimRng::new(42);
415 let cloned = rng.clone();
416 assert_eq!(cloned.master_seed(), rng.master_seed());
417 }
418
419 #[test]
420 fn test_sim_rng_debug() {
421 let rng = SimRng::new(42);
422 let debug = format!("{:?}", rng);
423 assert!(debug.contains("SimRng"));
424 }
425
426 #[test]
427 fn test_rng_state_debug() {
428 let rng = SimRng::new(42);
429 let state = rng.save_state();
430 let debug = format!("{:?}", state);
431 assert!(debug.contains("RngState"));
432 }
433
434 #[test]
435 fn test_rng_restore_error_debug() {
436 let err = RngRestoreError::CorruptedState;
437 let debug = format!("{:?}", err);
438 assert!(debug.contains("CorruptedState"));
439 }
440
441 #[test]
443 fn test_gen_normal_mean_is_added() {
444 let mut rng = SimRng::new(42);
445 for _ in 0..10 {
448 let v = rng.gen_normal(100.0, 0.0);
449 assert!(
450 (v - 100.0).abs() < 1e-10,
451 "gen_normal with std=0 must return mean exactly, got {v}"
452 );
453 }
454 }
455
456 #[test]
458 fn test_gen_normal_std_is_multiplied() {
459 let mut rng = SimRng::new(42);
460 let samples: Vec<f64> = (0..10000).map(|_| rng.gen_normal(0.0, 10.0)).collect();
462 let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
463 let variance: f64 =
464 samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
465 assert!(
467 (variance - 100.0).abs() < 15.0,
468 "Variance {variance} not close to 100"
469 );
470 }
471
472 #[test]
474 fn test_gen_normal_not_constant() {
475 let mut rng = SimRng::new(42);
476 let samples: Vec<f64> = (0..100).map(|_| rng.gen_normal(0.0, 1.0)).collect();
477 let all_ones = samples.iter().all(|&x| (x - 1.0).abs() < 1e-10);
479 assert!(!all_ones, "gen_normal should not return constant 1.0");
480 let unique_count = samples
482 .iter()
483 .map(|x| (*x * 1e6) as i64)
484 .collect::<std::collections::HashSet<_>>()
485 .len();
486 assert!(
487 unique_count > 50,
488 "gen_normal should produce varied outputs"
489 );
490 }
491
492 #[test]
494 fn test_partition_stream_increment() {
495 let mut rng = SimRng::new(42);
496 assert_eq!(rng.stream(), 0);
497
498 let _ = rng.partition(4);
499 assert_eq!(
500 rng.stream(),
501 4,
502 "Stream should increment by partition count"
503 );
504
505 let _ = rng.partition(3);
506 assert_eq!(rng.stream(), 7, "Stream should be 4 + 3 = 7");
507
508 }
511
512 #[test]
514 fn test_standard_normal_formula_correctness() {
515 let mut rng = SimRng::new(42);
516 let samples: Vec<f64> = (0..10000).map(|_| rng.gen_standard_normal()).collect();
519
520 let min = samples.iter().cloned().fold(f64::INFINITY, f64::min);
523 let max = samples.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
524
525 assert!(min < -2.0, "Min {min} should be < -2 for standard normal");
527 assert!(max > 2.0, "Max {max} should be > 2 for standard normal");
528 }
529
530 #[test]
532 fn test_standard_normal_epsilon_guard() {
533 let mut rng = SimRng::new(12345);
537 for _ in 0..50000 {
538 let v = rng.gen_standard_normal();
539 assert!(
540 v.is_finite(),
541 "gen_standard_normal produced non-finite value: {v}"
542 );
543 }
544 }
545
546 #[test]
548 fn test_standard_normal_angle_formula() {
549 let mut rng = SimRng::new(999);
553 let samples: Vec<f64> = (0..50000).map(|_| rng.gen_standard_normal()).collect();
554
555 let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
557 let variance: f64 =
558 samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
559 let fourth_moment: f64 =
560 samples.iter().map(|x| (x - mean).powi(4)).sum::<f64>() / samples.len() as f64;
561 let kurtosis = fourth_moment / (variance * variance);
562
563 assert!(
566 (kurtosis - 3.0).abs() < 0.5,
567 "Kurtosis {kurtosis} far from expected 3.0, suggesting formula error"
568 );
569 }
570}
571
572#[cfg(test)]
573mod proptests {
574 use super::*;
575 use proptest::prelude::*;
576
577 proptest! {
578 #[test]
580 fn prop_reproducibility(seed in 0u64..u64::MAX) {
581 let mut rng1 = SimRng::new(seed);
582 let mut rng2 = SimRng::new(seed);
583
584 let seq1: Vec<f64> = (0..100).map(|_| rng1.gen_f64()).collect();
585 let seq2: Vec<f64> = (0..100).map(|_| rng2.gen_f64()).collect();
586
587 prop_assert_eq!(seq1, seq2);
588 }
589
590 #[test]
592 fn prop_unit_interval(seed in 0u64..u64::MAX) {
593 let mut rng = SimRng::new(seed);
594
595 for _ in 0..100 {
596 let v = rng.gen_f64();
597 prop_assert!(v >= 0.0 && v < 1.0, "Value {} not in [0, 1)", v);
598 }
599 }
600
601 #[test]
603 fn prop_partition_count(seed in 0u64..u64::MAX, n in 1usize..100) {
604 let mut rng = SimRng::new(seed);
605 let partitions = rng.partition(n);
606 prop_assert_eq!(partitions.len(), n);
607 }
608 }
609}