radiate_core/domain/
random_provider.rs1use rand::distr::{Distribution, StandardUniform, uniform::SampleUniform};
2use rand::rngs::SmallRng;
3use rand::rngs::SysRng;
4use rand::seq::SliceRandom;
5use rand::{Rng, RngExt, SeedableRng};
6use std::cell::RefCell;
7use std::ops::Range;
8use std::sync::{Arc, LazyLock, Mutex};
9
10static GLOBAL_RNG: LazyLock<Arc<Mutex<SmallRng>>> =
11 LazyLock::new(|| Arc::new(Mutex::new(SmallRng::try_from_rng(&mut SysRng).unwrap())));
12
13thread_local! {
14 static TLS_RNG: RefCell<SmallRng> = RefCell::new({
15 let mut global = GLOBAL_RNG.lock().unwrap();
16 SmallRng::seed_from_u64(global.next_u64())
17 });
18}
19
20pub fn with_rng<R>(f: impl FnOnce(&mut RdRand<'_>) -> R) -> R {
21 TLS_RNG.with(|cell| {
22 let mut rng = cell.borrow_mut();
23 f(&mut RdRand::new(&mut rng))
24 })
25}
26
27pub fn set_seed(seed: u64) {
29 let mut global = GLOBAL_RNG.lock().unwrap();
30 *global = SmallRng::seed_from_u64(seed);
31}
32
33pub fn scoped_seed<R>(seed: u64, f: impl FnOnce() -> R) -> R {
36 TLS_RNG.with(|cell| {
37 let original_seed = {
38 let mut rng = cell.borrow_mut();
39 let original = rng.clone();
40 *rng = SmallRng::seed_from_u64(seed);
41 original
42 };
43
44 let result = f();
45
46 let mut rng = cell.borrow_mut();
47 *rng = original_seed;
48
49 result
50 })
51}
52
53#[inline(always)]
57pub fn random<T>() -> T
58where
59 T: SampleUniform,
60 StandardUniform: Distribution<T>,
61{
62 with_rng(|rng| rng.random())
63}
64
65#[inline(always)]
67pub fn bool(prob: f32) -> bool {
68 with_rng(|rng| rng.bool(prob))
69}
70
71pub fn range<T>(range: Range<T>) -> T
73where
74 T: SampleUniform + PartialOrd,
75{
76 with_rng(|rng| rng.range(range))
77}
78
79pub fn choose<T>(items: &[T]) -> &T {
81 with_rng(|rng| rng.choose(items))
82}
83
84pub fn choose_mut<T>(items: &mut [T]) -> &mut T {
85 with_rng(|rng| rng.choose_mut(items))
86}
87
88pub fn gaussian(mean: f64, std_dev: f64) -> f64 {
91 with_rng(|rng| rng.gaussian(mean, std_dev))
92}
93
94pub fn shuffle<T>(items: &mut [T]) {
96 with_rng(|rng| rng.shuffle(items));
97}
98
99pub fn shuffled_indices(range: Range<usize>) -> Vec<usize> {
101 with_rng(|rng| rng.shuffled_indices(range))
102}
103
104pub fn sample_indices(range: Range<usize>, sample_size: usize) -> Vec<usize> {
105 with_rng(|rng| rng.sample_indices(range, sample_size))
106}
107
108pub fn cond_indices(range: Range<usize>, prob: f32) -> Vec<usize> {
110 with_rng(|rng| rng.cond_indices(range, prob))
111}
112
113pub struct RdRand<'a>(&'a mut SmallRng);
114
115impl<'a> RdRand<'a> {
116 pub fn new(rng: &'a mut SmallRng) -> Self {
117 RdRand(rng)
118 }
119
120 #[inline]
121 pub fn random<T>(&mut self) -> T
122 where
123 T: SampleUniform,
124 StandardUniform: Distribution<T>,
125 {
126 self.0.random()
127 }
128
129 #[inline]
130 pub fn range<T>(&mut self, range: Range<T>) -> T
131 where
132 T: SampleUniform + PartialOrd,
133 {
134 self.0.random_range(range)
135 }
136
137 #[inline]
138 pub fn bool(&mut self, prob: f32) -> bool {
139 self.0.random_bool(prob as f64)
140 }
141
142 #[inline]
143 pub fn choose<'b, T>(&mut self, items: &'b [T]) -> &'b T {
144 let index = self.0.random_range(0..items.len());
145 &items[index]
146 }
147
148 #[inline]
149 pub fn choose_mut<'b, T>(&mut self, items: &'b mut [T]) -> &'b mut T {
150 let index = self.0.random_range(0..items.len());
151 &mut items[index]
152 }
153
154 #[inline]
155 pub fn shuffle<T>(&mut self, items: &mut [T]) {
156 items.shuffle(&mut self.0);
157 }
158
159 #[inline]
160 pub fn gaussian(&mut self, mean: f64, std_dev: f64) -> f64 {
161 let u1: f64 = self.0.random();
162 let u2: f64 = self.0.random();
163 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
164 mean + std_dev * z0
165 }
166
167 #[inline]
168 pub fn shuffled_indices(&mut self, range: Range<usize>) -> Vec<usize> {
169 let mut indexes = range.collect::<Vec<usize>>();
170 indexes.shuffle(&mut self.0);
171 indexes
172 }
173
174 #[inline]
175 pub fn sample_indices(&mut self, range: Range<usize>, sample_size: usize) -> Vec<usize> {
176 let mut indexes = range.collect::<Vec<usize>>();
177 indexes.shuffle(&mut self.0);
178 indexes.truncate(sample_size);
179 indexes
180 }
181
182 #[inline]
183 pub fn cond_indices(&mut self, range: Range<usize>, prob: f32) -> Vec<usize> {
184 if prob >= 1.0 {
185 return range.collect();
186 }
187
188 if prob <= 0.0 {
189 return Vec::new();
190 }
191
192 range.filter(|_| self.0.random::<f32>() < prob).collect()
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_random() {
202 for _ in 0..100 {
203 let value: f64 = random();
204 assert!((0.0..1.0).contains(&value));
205 }
206 }
207
208 #[test]
209 fn test_gen_range() {
210 for _ in 0..100 {
211 let value: f64 = range(0.0..100.0);
212 assert!((0.0..100.0).contains(&value));
213 }
214 }
215
216 #[test]
217 fn test_choose() {
218 for _ in 0..100 {
219 let items = vec![1, 2, 3, 4, 5];
220 let value = choose(&items);
221 assert!(items.contains(value));
222 }
223 }
224
225 #[test]
226 fn test_shuffle() {
227 let mut items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
228 shuffle(&mut items);
229 assert_ne!(items, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
230 }
231
232 #[test]
233 fn test_indexes() {
234 let indexes = shuffled_indices(0..10);
235 assert_eq!(indexes.len(), 10);
236 assert_ne!(indexes, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
237 }
238}