radiate_core/domain/
random_provider.rs1use rand::distr::{Distribution, StandardUniform, uniform::SampleUniform};
2use rand::rngs::StdRng;
3use rand::seq::SliceRandom;
4use rand::{Rng, SeedableRng};
5use std::sync::{Arc, Mutex, OnceLock};
6
7struct RandomProvider {
8 rng: Arc<Mutex<StdRng>>,
9}
10
11impl RandomProvider {
12 pub(self) fn global() -> &'static RandomProvider {
14 static INSTANCE: OnceLock<RandomProvider> = OnceLock::new();
15
16 INSTANCE.get_or_init(|| RandomProvider {
17 rng: Arc::new(Mutex::new(StdRng::from_os_rng())),
18 })
19 }
20
21 pub(self) fn get_rng() -> StdRng {
22 let instance = Self::global();
23 let rng = instance.rng.lock().unwrap();
24 rng.clone()
25 }
26
27 pub(self) fn set_seed(seed: u64) {
29 let instance = Self::global();
30 let mut rng = instance.rng.lock().unwrap();
31 *rng = StdRng::seed_from_u64(seed);
32 }
33
34 pub(self) fn random<T>() -> T
36 where
37 T: SampleUniform,
38 StandardUniform: Distribution<T>,
39 {
40 let instance = Self::global();
41 let mut rng = instance.rng.lock().unwrap();
42 rng.random()
43 }
44
45 pub(self) fn range<T>(range: std::ops::Range<T>) -> T
46 where
47 T: SampleUniform + PartialOrd,
48 {
49 let instance = Self::global();
50 let mut rng = instance.rng.lock().unwrap();
51 rng.random_range(range)
52 }
53
54 pub(self) fn bool(prob: f64) -> bool {
55 let instance = Self::global();
56 let mut rng = instance.rng.lock().unwrap();
57 rng.random_bool(prob)
58 }
59}
60
61pub fn set_seed(seed: u64) {
63 RandomProvider::set_seed(seed);
64}
65
66pub fn random<T>() -> T
71where
72 T: SampleUniform,
73 StandardUniform: Distribution<T>,
74{
75 RandomProvider::random()
76}
77
78pub fn bool(prob: f64) -> bool {
80 RandomProvider::bool(prob)
81}
82
83pub fn range<T>(range: std::ops::Range<T>) -> T
85where
86 T: SampleUniform + PartialOrd,
87{
88 RandomProvider::range(range)
89}
90
91pub fn choose<T>(items: &[T]) -> &T {
93 let index = range(0..items.len());
94 &items[index]
95}
96
97pub fn gaussian(mean: f64, std_dev: f64) -> f64 {
100 let u1: f64 = RandomProvider::random();
101 let u2: f64 = RandomProvider::random();
102
103 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
104
105 mean + std_dev * z0
106}
107
108pub fn shuffle<T>(items: &mut [T]) {
110 let instance = RandomProvider::global();
111 items.shuffle(&mut *instance.rng.lock().unwrap());
112}
113
114pub fn indexes(range: std::ops::Range<usize>) -> Vec<usize> {
116 let mut indexes = range.collect::<Vec<usize>>();
117 shuffle(&mut indexes);
118 indexes
119}
120
121pub fn weighted_choice(weights: &[f32]) -> usize {
122 let mut rng = RandomProvider::get_rng();
123
124 let mut cumulative_weights = vec![0.0; weights.len()];
125 cumulative_weights[0] = weights[0];
126
127 for i in 1..weights.len() {
128 cumulative_weights[i] = cumulative_weights[i - 1] + weights[i];
129 }
130
131 let random_value = rng.random_range(0.0..*cumulative_weights.last().unwrap());
132 cumulative_weights
133 .iter()
134 .position(|&x| x > random_value)
135 .unwrap_or(weights.len() - 1)
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_random() {
144 for _ in 0..100 {
145 let value: f64 = random();
146 assert!((0.0..1.0).contains(&value));
147 }
148 }
149
150 #[test]
151 fn test_gen_range() {
152 for _ in 0..100 {
153 let value: f64 = range(0.0..100.0);
154 assert!((0.0..100.0).contains(&value));
155 }
156 }
157
158 #[test]
159 fn test_choose() {
160 for _ in 0..100 {
161 let items = vec![1, 2, 3, 4, 5];
162 let value = choose(&items);
163 assert!(items.contains(value));
164 }
165 }
166
167 #[test]
168 fn test_shuffle() {
169 let mut items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
170 shuffle(&mut items);
171 assert_ne!(items, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
172 }
173
174 #[test]
175 fn test_indexes() {
176 let indexes = indexes(0..10);
177 assert_eq!(indexes.len(), 10);
178 assert_ne!(indexes, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
179 }
180}