Skip to main content

radiate_core/domain/
random_provider.rs

1use rand::distr::{Distribution, StandardUniform, uniform::SampleUniform};
2use rand::rngs::SmallRng;
3use rand::rngs::SysRng;
4use rand::seq::SliceRandom;
5use rand::{Rng, RngExt, SeedableRng};
6use std::cell::RefCell;
7use std::ops::Range;
8use std::sync::{Arc, LazyLock, Mutex};
9
10static GLOBAL_RNG: LazyLock<Arc<Mutex<SmallRng>>> =
11    LazyLock::new(|| Arc::new(Mutex::new(SmallRng::try_from_rng(&mut SysRng).unwrap())));
12
13thread_local! {
14    static TLS_RNG: RefCell<SmallRng> = RefCell::new({
15        let mut global = GLOBAL_RNG.lock().unwrap();
16        SmallRng::seed_from_u64(global.next_u64())
17    });
18}
19
20pub fn with_rng<R>(f: impl FnOnce(&mut RdRand<'_>) -> R) -> R {
21    TLS_RNG.with(|cell| {
22        let mut rng = cell.borrow_mut();
23        f(&mut RdRand::new(&mut rng))
24    })
25}
26
27/// Seeds the thread-local random number generator with the given seed.
28pub fn set_seed(seed: u64) {
29    let mut global = GLOBAL_RNG.lock().unwrap();
30    *global = SmallRng::seed_from_u64(seed);
31}
32
33/// Temporarily sets the seed of the thread-local random number generator to the given seed
34/// for the duration of the closure `f`. After `f` completes, the original state of the RNG is restored.
35pub fn scoped_seed<R>(seed: u64, f: impl FnOnce() -> R) -> R {
36    TLS_RNG.with(|cell| {
37        let original_seed = {
38            let mut rng = cell.borrow_mut();
39            let original = rng.clone();
40            *rng = SmallRng::seed_from_u64(seed);
41            original
42        };
43
44        let result = f();
45
46        let mut rng = cell.borrow_mut();
47        *rng = original_seed;
48
49        result
50    })
51}
52
53///
54/// For floating point types, the number will be in the range [0, 1).
55/// For integer types, the number will be in the range [0, MAX).
56#[inline(always)]
57pub fn random<T>() -> T
58where
59    T: SampleUniform,
60    StandardUniform: Distribution<T>,
61{
62    with_rng(|rng| rng.random())
63}
64
65/// Generates a random boolean with the given probability of being true.
66#[inline(always)]
67pub fn bool(prob: f32) -> bool {
68    with_rng(|rng| rng.bool(prob))
69}
70
71/// Generates a random number of type T in the given range.
72pub fn range<T>(range: Range<T>) -> T
73where
74    T: SampleUniform + PartialOrd,
75{
76    with_rng(|rng| rng.range(range))
77}
78
79/// Chooses a random item from the given slice.
80pub fn choose<T>(items: &[T]) -> &T {
81    with_rng(|rng| rng.choose(items))
82}
83
84pub fn choose_mut<T>(items: &mut [T]) -> &mut T {
85    with_rng(|rng| rng.choose_mut(items))
86}
87
88/// Generates a random number from a Gaussian distribution with the given mean and standard deviation.
89/// The Box-Muller transform is used to generate the random number.
90pub fn gaussian(mean: f64, std_dev: f64) -> f64 {
91    with_rng(|rng| rng.gaussian(mean, std_dev))
92}
93
94/// Shuffles the given slice in place.
95pub fn shuffle<T>(items: &mut [T]) {
96    with_rng(|rng| rng.shuffle(items));
97}
98
99/// Generates a vector of indexes from 0 to n-1 in random order.
100pub fn shuffled_indices(range: Range<usize>) -> Vec<usize> {
101    with_rng(|rng| rng.shuffled_indices(range))
102}
103
104pub fn sample_indices(range: Range<usize>, sample_size: usize) -> Vec<usize> {
105    with_rng(|rng| rng.sample_indices(range, sample_size))
106}
107
108/// Returns a vector of indexes from the given range, each included with the given probability.
109pub fn cond_indices(range: Range<usize>, prob: f32) -> Vec<usize> {
110    with_rng(|rng| rng.cond_indices(range, prob))
111}
112
113pub struct RdRand<'a>(&'a mut SmallRng);
114
115impl<'a> RdRand<'a> {
116    pub fn new(rng: &'a mut SmallRng) -> Self {
117        RdRand(rng)
118    }
119
120    #[inline]
121    pub fn random<T>(&mut self) -> T
122    where
123        T: SampleUniform,
124        StandardUniform: Distribution<T>,
125    {
126        self.0.random()
127    }
128
129    #[inline]
130    pub fn range<T>(&mut self, range: Range<T>) -> T
131    where
132        T: SampleUniform + PartialOrd,
133    {
134        self.0.random_range(range)
135    }
136
137    #[inline]
138    pub fn bool(&mut self, prob: f32) -> bool {
139        self.0.random_bool(prob as f64)
140    }
141
142    #[inline]
143    pub fn choose<'b, T>(&mut self, items: &'b [T]) -> &'b T {
144        let index = self.0.random_range(0..items.len());
145        &items[index]
146    }
147
148    #[inline]
149    pub fn choose_mut<'b, T>(&mut self, items: &'b mut [T]) -> &'b mut T {
150        let index = self.0.random_range(0..items.len());
151        &mut items[index]
152    }
153
154    #[inline]
155    pub fn shuffle<T>(&mut self, items: &mut [T]) {
156        items.shuffle(&mut self.0);
157    }
158
159    #[inline]
160    pub fn gaussian(&mut self, mean: f64, std_dev: f64) -> f64 {
161        let u1: f64 = self.0.random();
162        let u2: f64 = self.0.random();
163        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
164        mean + std_dev * z0
165    }
166
167    #[inline]
168    pub fn shuffled_indices(&mut self, range: Range<usize>) -> Vec<usize> {
169        let mut indexes = range.collect::<Vec<usize>>();
170        indexes.shuffle(&mut self.0);
171        indexes
172    }
173
174    #[inline]
175    pub fn sample_indices(&mut self, range: Range<usize>, sample_size: usize) -> Vec<usize> {
176        let mut indexes = range.collect::<Vec<usize>>();
177        indexes.shuffle(&mut self.0);
178        indexes.truncate(sample_size);
179        indexes
180    }
181
182    #[inline]
183    pub fn cond_indices(&mut self, range: Range<usize>, prob: f32) -> Vec<usize> {
184        if prob >= 1.0 {
185            return range.collect();
186        }
187
188        if prob <= 0.0 {
189            return Vec::new();
190        }
191
192        range.filter(|_| self.0.random::<f32>() < prob).collect()
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_random() {
202        for _ in 0..100 {
203            let value: f64 = random();
204            assert!((0.0..1.0).contains(&value));
205        }
206    }
207
208    #[test]
209    fn test_gen_range() {
210        for _ in 0..100 {
211            let value: f64 = range(0.0..100.0);
212            assert!((0.0..100.0).contains(&value));
213        }
214    }
215
216    #[test]
217    fn test_choose() {
218        for _ in 0..100 {
219            let items = vec![1, 2, 3, 4, 5];
220            let value = choose(&items);
221            assert!(items.contains(value));
222        }
223    }
224
225    #[test]
226    fn test_shuffle() {
227        let mut items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
228        shuffle(&mut items);
229        assert_ne!(items, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
230    }
231
232    #[test]
233    fn test_indexes() {
234        let indexes = shuffled_indices(0..10);
235        assert_eq!(indexes.len(), 10);
236        assert_ne!(indexes, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
237    }
238}