primitives/random/
test_rng.rs

1use std::{cell::UnsafeCell, rc::Rc};
2
3use rand::{CryptoRng, RngCore, SeedableRng};
4use rand_chacha::ChaCha12Rng;
5
6use crate::random::Prng;
7
8/// Returns a test RNG that can be used for testing purposes.
9///
10/// If the environment variable `ASYNC_MPC_NON_DETERMINISTIC_TESTS` is set, it will use
11/// a random (non-deterministic) seed for the RNG. Otherwise, it will use the current
12/// thread's name as the seed (deterministic).
13pub fn test_rng() -> TestRng {
14    let rng = TEST_RNG_KEY.with(|t| t.clone());
15    TestRng { rng }
16}
17
18thread_local!(
19    static TEST_RNG_KEY: Rc<UnsafeCell<ChaCha12Rng>> = {
20    let rng = match std::env::var("ASYNC_MPC_NON_DETERMINISTIC_TESTS") {
21        Ok(_) => ChaCha12Rng::from_rng(rand::thread_rng()).unwrap_or_else(|err|
22                panic!("could not initialize test_rng: {err}")),
23        Err(_) => {
24            let thread = std::thread::current();
25            ChaCha12Rng::from_hashed_seed(thread.name().unwrap_or("async_mpc_test"))
26        }
27    };
28    Rc::new(UnsafeCell::new(rng))
29});
30
31/// A reference to the thread-local random number generator used for testing.
32///
33/// Based on [ThreadRng] from the rand crate, explicitly not Send or Sync.
34///
35/// [`ThreadRng`]: rand::ThreadRng
36#[derive(Clone, Debug)]
37pub struct TestRng {
38    // Rc is explicitly !Send and !Sync
39    rng: Rc<UnsafeCell<ChaCha12Rng>>,
40}
41
42impl Default for TestRng {
43    fn default() -> TestRng {
44        test_rng()
45    }
46}
47
48impl RngCore for TestRng {
49    #[inline(always)]
50    fn next_u32(&mut self) -> u32 {
51        // SAFETY: We must make sure to stop using `rng` before anyone else
52        // creates another mutable reference
53        let rng = unsafe { &mut *self.rng.get() };
54        rng.next_u32()
55    }
56
57    #[inline(always)]
58    fn next_u64(&mut self) -> u64 {
59        // SAFETY: We must make sure to stop using `rng` before anyone else
60        // creates another mutable reference
61        let rng = unsafe { &mut *self.rng.get() };
62        rng.next_u64()
63    }
64
65    fn fill_bytes(&mut self, dest: &mut [u8]) {
66        // SAFETY: We must make sure to stop using `rng` before anyone else
67        // creates another mutable reference
68        let rng = unsafe { &mut *self.rng.get() };
69        rng.fill_bytes(dest)
70    }
71
72    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
73        // SAFETY: We must make sure to stop using `rng` before anyone else
74        // creates another mutable reference
75        let rng = unsafe { &mut *self.rng.get() };
76        rng.try_fill_bytes(dest)
77    }
78}
79
80impl CryptoRng for TestRng {}