1use rand::SeedableRng;
9use rand::{
10 RngExt,
11 distr::{Distribution, StandardUniform, uniform::SampleUniform},
12};
13use rand_chacha::ChaCha8Rng;
14use std::cell::{Cell, RefCell};
15use std::collections::VecDeque;
16
17thread_local! {
18 static SIM_RNG: RefCell<ChaCha8Rng> = RefCell::new(ChaCha8Rng::seed_from_u64(0));
23
24 static CURRENT_SEED: RefCell<u64> = const { RefCell::new(0) };
29
30 static RNG_CALL_COUNT: Cell<u64> = const { Cell::new(0) };
35
36 static RNG_BREAKPOINTS: RefCell<VecDeque<(u64, u64)>> = const { RefCell::new(VecDeque::new()) };
41}
42
43fn pre_sample() {
48 RNG_CALL_COUNT.with(|c| c.set(c.get() + 1));
49 check_rng_breakpoint();
50}
51
52fn check_rng_breakpoint() {
58 RNG_BREAKPOINTS.with(|bp| {
59 let mut breakpoints = bp.borrow_mut();
60 while let Some(&(target_count, new_seed)) = breakpoints.front() {
61 let count = RNG_CALL_COUNT.with(|c| c.get());
62 if count > target_count {
63 breakpoints.pop_front();
64 SIM_RNG.with(|rng| {
65 *rng.borrow_mut() = ChaCha8Rng::seed_from_u64(new_seed);
66 });
67 CURRENT_SEED.with(|s| {
68 *s.borrow_mut() = new_seed;
69 });
70 RNG_CALL_COUNT.with(|c| c.set(1));
71 } else {
72 break;
73 }
74 }
75 });
76}
77
78pub fn sim_random<T>() -> T
90where
91 StandardUniform: Distribution<T>,
92{
93 pre_sample();
94 SIM_RNG.with(|rng| rng.borrow_mut().sample(StandardUniform))
95}
96
97pub fn sim_random_range<T>(range: std::ops::Range<T>) -> T
112where
113 T: SampleUniform + PartialOrd,
114{
115 pre_sample();
116 SIM_RNG.with(|rng| rng.borrow_mut().random_range(range))
117}
118
119pub fn sim_random_range_or_default<T>(range: std::ops::Range<T>) -> T
134where
135 T: SampleUniform + PartialOrd + Clone,
136{
137 if range.start >= range.end {
138 range.start
139 } else {
140 sim_random_range(range)
141 }
142}
143
144pub fn set_sim_seed(seed: u64) {
156 SIM_RNG.with(|rng| {
157 *rng.borrow_mut() = ChaCha8Rng::seed_from_u64(seed);
158 });
159 CURRENT_SEED.with(|current| {
160 *current.borrow_mut() = seed;
161 });
162}
163
164pub fn sim_random_f64() -> f64 {
172 pre_sample();
173 SIM_RNG.with(|rng| rng.borrow_mut().sample(StandardUniform))
174}
175
176pub fn get_current_sim_seed() -> u64 {
187 CURRENT_SEED.with(|current| *current.borrow())
188}
189
190pub fn reset_sim_rng() {
198 SIM_RNG.with(|rng| {
199 *rng.borrow_mut() = ChaCha8Rng::seed_from_u64(0);
200 });
201 CURRENT_SEED.with(|current| {
202 *current.borrow_mut() = 0;
203 });
204 RNG_CALL_COUNT.with(|c| c.set(0));
205 RNG_BREAKPOINTS.with(|bp| bp.borrow_mut().clear());
206}
207
208pub fn get_rng_call_count() -> u64 {
213 RNG_CALL_COUNT.with(|c| c.get())
214}
215
216pub fn reset_rng_call_count() {
220 RNG_CALL_COUNT.with(|c| c.set(0));
221}
222
223pub fn set_rng_breakpoints(breakpoints: Vec<(u64, u64)>) {
235 RNG_BREAKPOINTS.with(|bp| {
236 *bp.borrow_mut() = VecDeque::from(breakpoints);
237 });
238}
239
240pub fn clear_rng_breakpoints() {
242 RNG_BREAKPOINTS.with(|bp| bp.borrow_mut().clear());
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_deterministic_randomness() {
251 set_sim_seed(42);
253 let value1: f64 = sim_random();
254 let value2: u32 = sim_random();
255 let value3: bool = sim_random();
256
257 set_sim_seed(42);
259 assert_eq!(value1, sim_random::<f64>());
260 assert_eq!(value2, sim_random::<u32>());
261 assert_eq!(value3, sim_random::<bool>());
262 }
263
264 #[test]
265 fn test_different_seeds_produce_different_values() {
266 set_sim_seed(1);
268 let value1_seed1: f64 = sim_random();
269 let value2_seed1: f64 = sim_random();
270
271 set_sim_seed(2);
273 let value1_seed2: f64 = sim_random();
274 let value2_seed2: f64 = sim_random();
275
276 assert_ne!(value1_seed1, value1_seed2);
278 assert_ne!(value2_seed1, value2_seed2);
279 }
280
281 #[test]
282 fn test_sim_random_range() {
283 set_sim_seed(42);
284
285 for _ in 0..100 {
287 let value = sim_random_range(10..20);
288 assert!(value >= 10);
289 assert!(value < 20);
290 }
291
292 for _ in 0..100 {
294 let value = sim_random_range(0.0..1.0);
295 assert!(value >= 0.0);
296 assert!(value < 1.0);
297 }
298 }
299
300 #[test]
301 fn test_range_determinism() {
302 set_sim_seed(123);
303 let value1 = sim_random_range(100..1000);
304 let value2 = sim_random_range(0.0..10.0);
305
306 set_sim_seed(123);
307 assert_eq!(value1, sim_random_range(100..1000));
308 assert_eq!(value2, sim_random_range(0.0..10.0));
309 }
310
311 #[test]
312 fn test_reset_clears_state() {
313 set_sim_seed(42);
315 let _advance1: f64 = sim_random();
316 let _advance2: f64 = sim_random();
317 let after_advance: f64 = sim_random();
318
319 reset_sim_rng();
321 set_sim_seed(42);
322 let first_value: f64 = sim_random();
323
324 assert_ne!(after_advance, first_value);
326 }
327
328 #[test]
329 fn test_sequence_persistence_within_thread() {
330 set_sim_seed(42);
331 let value1: f64 = sim_random();
332 let value2: f64 = sim_random();
333 let value3: f64 = sim_random();
334
335 set_sim_seed(42);
337 assert_eq!(value1, sim_random::<f64>());
338 assert_eq!(value2, sim_random::<f64>());
339 assert_eq!(value3, sim_random::<f64>());
340 }
341
342 #[test]
343 fn test_multiple_resets_and_seeds() {
344 for seed in [1, 42, 12345] {
346 reset_sim_rng();
347 set_sim_seed(seed);
348 let first: f64 = sim_random();
349
350 reset_sim_rng();
351 set_sim_seed(seed);
352 assert_eq!(first, sim_random::<f64>());
353 }
354 }
355
356 #[test]
357 fn test_get_current_sim_seed() {
358 set_sim_seed(12345);
360 assert_eq!(get_current_sim_seed(), 12345);
361
362 set_sim_seed(98765);
363 assert_eq!(get_current_sim_seed(), 98765);
364
365 reset_sim_rng();
367 assert_eq!(get_current_sim_seed(), 0);
368 }
369
370 #[test]
371 fn test_call_counting() {
372 reset_sim_rng();
373 set_sim_seed(42);
374 assert_eq!(get_rng_call_count(), 0);
375
376 let _: f64 = sim_random();
377 assert_eq!(get_rng_call_count(), 1);
378
379 let _: u32 = sim_random();
380 assert_eq!(get_rng_call_count(), 2);
381
382 let _ = sim_random_range(0..100);
383 assert_eq!(get_rng_call_count(), 3);
384
385 let _ = sim_random_f64();
386 assert_eq!(get_rng_call_count(), 4);
387
388 let _ = sim_random_range_or_default(0..100);
390 assert_eq!(get_rng_call_count(), 5);
391
392 let _ = sim_random_range_or_default(100..100);
394 assert_eq!(get_rng_call_count(), 5);
395 }
396
397 #[test]
398 fn test_breakpoint_reseed() {
399 reset_sim_rng();
400 set_sim_seed(100);
401
402 let mut old_values = Vec::new();
404 for _ in 0..5 {
405 old_values.push(sim_random::<f64>());
406 }
407
408 reset_sim_rng();
410 set_sim_seed(200);
411 let new_seed_first: f64 = sim_random();
412
413 reset_sim_rng();
415 set_sim_seed(100);
416 set_rng_breakpoints(vec![(5, 200)]);
417
418 for (i, expected) in old_values.iter().enumerate() {
420 let actual: f64 = sim_random();
421 assert_eq!(*expected, actual, "Mismatch at call {}", i + 1);
422 }
423
424 let after_breakpoint: f64 = sim_random();
426 assert_eq!(after_breakpoint, new_seed_first);
427 assert_eq!(get_rng_call_count(), 1);
428 assert_eq!(get_current_sim_seed(), 200);
429 }
430
431 #[test]
432 fn test_chained_breakpoints() {
433 reset_sim_rng();
434 set_sim_seed(10);
435 set_rng_breakpoints(vec![(3, 20), (2, 30)]);
436
437 let _: f64 = sim_random(); let _: f64 = sim_random(); let _: f64 = sim_random(); assert_eq!(get_current_sim_seed(), 10);
442
443 let _: f64 = sim_random();
445 assert_eq!(get_current_sim_seed(), 20);
446 assert_eq!(get_rng_call_count(), 1);
447
448 let _: f64 = sim_random(); let _: f64 = sim_random();
453 assert_eq!(get_current_sim_seed(), 30);
454 assert_eq!(get_rng_call_count(), 1);
455 }
456
457 #[test]
458 fn test_replay_determinism() {
459 reset_sim_rng();
461 set_sim_seed(42);
462 let _: f64 = sim_random();
463 let _: f64 = sim_random();
464 let _: f64 = sim_random();
465 let fork_count = get_rng_call_count();
466 set_sim_seed(99);
467 reset_rng_call_count();
468 let post_fork_1: f64 = sim_random();
469 let post_fork_2: f64 = sim_random();
470
471 reset_sim_rng();
473 set_sim_seed(42);
474 set_rng_breakpoints(vec![(fork_count, 99)]);
475 let _: f64 = sim_random();
476 let _: f64 = sim_random();
477 let _: f64 = sim_random();
478 let replay_1: f64 = sim_random();
480 let replay_2: f64 = sim_random();
481
482 assert_eq!(post_fork_1, replay_1);
483 assert_eq!(post_fork_2, replay_2);
484 }
485
486 #[test]
487 fn test_reset_clears_everything_including_breakpoints() {
488 set_sim_seed(42);
489 let _: f64 = sim_random();
490 let _: f64 = sim_random();
491 set_rng_breakpoints(vec![(10, 99)]);
492
493 assert_eq!(get_rng_call_count(), 2);
494
495 reset_sim_rng();
496
497 assert_eq!(get_rng_call_count(), 0);
498 assert_eq!(get_current_sim_seed(), 0);
499
500 set_sim_seed(42);
502 let _: f64 = sim_random();
503 assert_eq!(get_rng_call_count(), 1);
504 assert_eq!(get_current_sim_seed(), 42); }
506}