rng_buffer/
lib.rs

1// Copyright ©️ 2024 Chris Hennick
2#![cfg_attr(not(feature = "std"), no_std)]
3
4extern crate alloc;
5
6use alloc::rc::Rc;
7use bytemuck::cast_slice_mut;
8use core::cell::RefCell;
9use core::mem::size_of;
10use delegate::delegate;
11use rand::rngs::adapter::ReseedingRng;
12use rand_chacha::ChaCha12Core;
13use rand_core::{CryptoRng, Error, OsRng, RngCore, SeedableRng};
14use rand_core::block::{BlockRng64, BlockRngCore};
15
16/// Wrapper around an array, that implements [Default] by copying the default element.
17#[derive(Copy, Clone)]
18#[repr(transparent)]
19pub struct DefaultableArray<const N: usize, T: Default + Copy>([T; N]);
20
21impl <const N: usize, T: Default + Copy> Default for DefaultableArray<N, T> {
22    fn default() -> Self {
23        Self([T::default(); N])
24    }
25}
26
27impl<const N: usize, T: Default + Copy> AsMut<[T; N]> for DefaultableArray<N, T> {
28    fn as_mut(&mut self) -> &mut [T; N] {
29        &mut self.0
30    }
31}
32
33impl<const N: usize, T: Default + Copy> AsRef<[T; N]> for DefaultableArray<N, T> {
34    fn as_ref(&self) -> &[T; N] {
35        &self.0
36    }
37}
38
39impl<const N: usize, T: Default + Copy> AsRef<[T]> for DefaultableArray<N, T> {
40    fn as_ref(&self) -> &[T] {
41        self.0.as_slice()
42    }
43}
44
45impl<const N: usize, T: Default + Copy> AsMut<[T]> for DefaultableArray<N, T> {
46    fn as_mut(&mut self) -> &mut [T] {
47        self.0.as_mut_slice()
48    }
49}
50
51/// Wrapper around an [RngCore] that fills an 8*[N]-byte buffer at a time in order to make fewer system calls.
52#[derive(Copy, Clone, Debug)]
53#[repr(transparent)]
54pub struct RngBufferCore<const N: usize, T: RngCore>(pub T);
55
56const WORDS_PER_STD_RNG_SEED: usize = 4;
57const DEFAULT_SEEDS_PER_BUFFER: usize = 16;
58const DEFAULT_BUFFER_SIZE: usize = WORDS_PER_STD_RNG_SEED * DEFAULT_SEEDS_PER_BUFFER;
59
60impl <const N: usize, T: RngCore> BlockRngCore for RngBufferCore<N, T> {
61    type Item = u64;
62    type Results = DefaultableArray<N, u64>;
63
64    fn generate(&mut self, results: &mut Self::Results) {
65        self.0.fill_bytes(cast_slice_mut(results.as_mut()));
66    }
67}
68
69impl <const N: usize, T: RngCore> From<T> for RngBufferCore<N, T> {
70    fn from(value: T) -> Self {
71        Self(value)
72    }
73}
74
75/// Wraps an [RngBufferCore] using a [BlockRng64]. Also wraps it in an [Rc] and [RefCell] so that the buffer will be
76/// shared with all clones of the instance in the same thread. (This buffer isn't meant to be shared between threads,
77/// because benchmarks indicate that the overhead cost of communication between threads is usually larger than that of
78/// the system call that an [OsRng] makes.)
79#[derive(Clone)]
80#[repr(transparent)]
81pub struct RngBufferWrapper<const N: usize, T: RngCore>(Rc<RefCell<BlockRng64<RngBufferCore<N, T>>>>);
82
83impl <const N: usize, T: RngCore> From<T> for RngBufferWrapper<N, T> {
84    fn from(value: T) -> Self {
85        Self(Rc::new(RefCell::new(BlockRng64::new(value.into()))))
86    }
87}
88
89/// Wraps an RNG in an [Rc] and [RefCell] so that it can be shared (within the same thread) across structs that expect
90/// to own one.
91#[derive(Clone)]
92#[repr(transparent)]
93pub struct RngWrapper<T: RngCore>(Rc<RefCell<T>>);
94
95impl <T: RngCore> From<T> for RngWrapper<T> {
96    fn from(value: T) -> Self {
97        Self(Rc::new(RefCell::new(value)))
98    }
99}
100
101// This isn't implemented for RngBufferWrapper because the buffering loses fast key erasure if the underlying RNG has
102// that feature.
103impl <T: RngCore + CryptoRng> CryptoRng for RngWrapper<T> {}
104
105impl <const N: usize, T: RngCore> RngCore for RngBufferWrapper<N, T> {
106    delegate!{
107        to self.0.as_ref().borrow_mut().core.0 {
108            fn next_u32(&mut self) -> u32;
109            fn next_u64(&mut self) -> u64;
110        }
111    }
112
113    fn fill_bytes(&mut self, dest: &mut [u8]) {
114        self.try_fill_bytes(dest).expect("RngBufferWrapper core threw an error from try_fill_bytes")
115    }
116
117    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
118        if dest.len() >= N * size_of::<u64>() {
119            self.0.as_ref().borrow_mut().core.0.try_fill_bytes(dest)
120        } else {
121            unsafe { self.0.as_ptr().as_mut().unwrap().try_fill_bytes(dest) }
122        }
123    }
124}
125
126
127impl <T: RngCore> RngCore for RngWrapper<T> {
128    delegate!{
129        to self.0.borrow_mut() {
130            fn next_u32(&mut self) -> u32;
131            fn next_u64(&mut self) -> u64;
132            fn fill_bytes(&mut self, dest: &mut [u8]);
133            fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error>;
134        }
135    }
136}
137
138/// A wrapper around [OsRng] that uses an [RngBufferCore] with a reasonable default buffer size.
139pub type DefaultSeedSourceRng = RngBufferWrapper<DEFAULT_BUFFER_SIZE, OsRng>;
140
141/// Creates an instance of [DefaultSeedSourceRng] that doesn't share state with any other instance.
142pub fn build_default_seeder() -> DefaultSeedSourceRng {
143   OsRng::default().into()
144}
145
146impl Default for DefaultSeedSourceRng {
147    #[cfg(feature = "std")]
148    fn default() -> Self {
149        thread_seed_source()
150    }
151    #[cfg(not(feature = "std"))]
152    fn default() -> Self {
153        build_default_seeder()
154    }
155}
156
157/// A drop-in replacement for [rand::ThreadRng] that behaves identically, except that it uses an [RngBufferCore] to
158/// wrap the [OsRng] that it uses to reseed itself.
159pub type DefaultRng = RngWrapper<ReseedingRng<ChaCha12Core, DefaultSeedSourceRng>>;
160
161/// Creates an instance of [DefaultRng] that uses the given seed source.
162pub fn build_default_rng(mut seeder: DefaultSeedSourceRng) -> DefaultRng {
163    let mut seed = [0u8; 32];
164    seeder.fill_bytes(&mut seed);
165    ReseedingRng::new(ChaCha12Core::from_seed(seed), THREAD_RNG_RESEED_THRESHOLD, seeder).into()
166}
167
168impl Default for DefaultRng {
169    #[cfg(feature = "std")]
170    fn default() -> Self {
171        thread_rng()
172    }
173
174    #[cfg(not(feature = "std"))]
175    fn default() -> Self {
176        build_default_rng(DefaultSeedSourceRng::default())
177    }
178}
179
180const THREAD_RNG_RESEED_THRESHOLD: u64 = 1 << 16;
181
182#[cfg(feature = "std")]
183thread_local! {
184    static THREAD_SEEDER_KEY: DefaultSeedSourceRng = build_default_seeder();
185    static THREAD_RNG_KEY: DefaultRng = THREAD_SEEDER_KEY.with(|seeder| {
186            build_default_rng(seeder.clone())
187        });
188}
189
190/// Obtains the default [DefaultSeedSourceRng] for this thread.
191#[cfg(feature = "std")]
192pub fn thread_seed_source() -> DefaultSeedSourceRng {
193    THREAD_SEEDER_KEY.with(DefaultSeedSourceRng::clone)
194}
195
196/// Obtains this thread's default [DefaultRng], which is identical to [rand::thread_rng]() except that it uses
197/// [thread_seed_source]() rather than directly invoking [OsRng] to reseed itself.
198#[cfg(feature = "std")]
199pub fn thread_rng() -> DefaultRng {
200    THREAD_RNG_KEY.with(DefaultRng::clone)
201}
202
203#[cfg(test)]
204mod tests {
205    use rand_core::{Error};
206    use crate::{build_default_seeder, DefaultSeedSourceRng};
207
208    #[test]
209    fn basic_test() -> Result<(), Error> {
210        use rand::rngs::StdRng;
211        use rand::SeedableRng;
212        let shared_seeder: DefaultSeedSourceRng = build_default_seeder();
213        let client_prng: StdRng = StdRng::from_rng(shared_seeder)?;
214        let zero_seed_prng = StdRng::from_seed([0; 32]);
215        assert_ne!(client_prng, zero_seed_prng);
216        Ok(())
217    }
218}