radiate_core/domain/
random_provider.rs

1use rand::distr::{Distribution, StandardUniform, uniform::SampleUniform};
2use rand::rngs::StdRng;
3use rand::seq::SliceRandom;
4use rand::{Rng, SeedableRng};
5use std::sync::{Arc, Mutex, OnceLock};
6
7struct RandomProvider {
8    rng: Arc<Mutex<StdRng>>,
9}
10
11impl RandomProvider {
12    /// Returns the global instance of the registry.
13    pub(self) fn global() -> &'static RandomProvider {
14        static INSTANCE: OnceLock<RandomProvider> = OnceLock::new();
15
16        INSTANCE.get_or_init(|| RandomProvider {
17            rng: Arc::new(Mutex::new(StdRng::from_os_rng())),
18        })
19    }
20
21    pub(self) fn get_rng() -> StdRng {
22        let instance = Self::global();
23        let rng = instance.rng.lock().unwrap();
24        rng.clone()
25    }
26
27    /// Sets a new seed for the global RNG.
28    pub(self) fn set_seed(seed: u64) {
29        let instance = Self::global();
30        let mut rng = instance.rng.lock().unwrap();
31        *rng = StdRng::seed_from_u64(seed);
32    }
33
34    /// Generates a random number using the global RNG.
35    pub(self) fn random<T>() -> T
36    where
37        T: SampleUniform,
38        StandardUniform: Distribution<T>,
39    {
40        let instance = Self::global();
41        let mut rng = instance.rng.lock().unwrap();
42        rng.random()
43    }
44
45    pub(self) fn range<T>(range: std::ops::Range<T>) -> T
46    where
47        T: SampleUniform + PartialOrd,
48    {
49        let instance = Self::global();
50        let mut rng = instance.rng.lock().unwrap();
51        rng.random_range(range)
52    }
53
54    pub(self) fn bool(prob: f64) -> bool {
55        let instance = Self::global();
56        let mut rng = instance.rng.lock().unwrap();
57        rng.random_bool(prob)
58    }
59}
60
61/// Seeds the thread-local random number generator with the given seed.
62pub fn set_seed(seed: u64) {
63    RandomProvider::set_seed(seed);
64}
65
66/// Generates a random number of type T.
67///
68/// For floating point types, the number will be in the range [0, 1).
69/// For integer types, the number will be in the range [0, MAX).
70pub fn random<T>() -> T
71where
72    T: SampleUniform,
73    StandardUniform: Distribution<T>,
74{
75    RandomProvider::random()
76}
77
78/// Generates a random boolean with the given probability of being true.
79pub fn bool(prob: f64) -> bool {
80    RandomProvider::bool(prob)
81}
82
83/// Generates a random number of type T in the given range.
84pub fn range<T>(range: std::ops::Range<T>) -> T
85where
86    T: SampleUniform + PartialOrd,
87{
88    RandomProvider::range(range)
89}
90
91/// Chooses a random item from the given slice.
92pub fn choose<T>(items: &[T]) -> &T {
93    let index = range(0..items.len());
94    &items[index]
95}
96
97/// Generates a random number from a Gaussian distribution with the given mean and standard deviation.
98/// The Box-Muller transform is used to generate the random number.
99pub fn gaussian(mean: f64, std_dev: f64) -> f64 {
100    let u1: f64 = RandomProvider::random();
101    let u2: f64 = RandomProvider::random();
102
103    let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
104
105    mean + std_dev * z0
106}
107
108/// Shuffles the given slice in place.
109pub fn shuffle<T>(items: &mut [T]) {
110    let instance = RandomProvider::global();
111    items.shuffle(&mut *instance.rng.lock().unwrap());
112}
113
114/// Generates a vector of indexes from 0 to n-1 in random order.
115pub fn indexes(range: std::ops::Range<usize>) -> Vec<usize> {
116    let mut indexes = range.collect::<Vec<usize>>();
117    shuffle(&mut indexes);
118    indexes
119}
120
121pub fn weighted_choice(weights: &[f32]) -> usize {
122    let mut rng = RandomProvider::get_rng();
123
124    let mut cumulative_weights = vec![0.0; weights.len()];
125    cumulative_weights[0] = weights[0];
126
127    for i in 1..weights.len() {
128        cumulative_weights[i] = cumulative_weights[i - 1] + weights[i];
129    }
130
131    let random_value = rng.random_range(0.0..*cumulative_weights.last().unwrap());
132    cumulative_weights
133        .iter()
134        .position(|&x| x > random_value)
135        .unwrap_or(weights.len() - 1)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_random() {
144        for _ in 0..100 {
145            let value: f64 = random();
146            assert!((0.0..1.0).contains(&value));
147        }
148    }
149
150    #[test]
151    fn test_gen_range() {
152        for _ in 0..100 {
153            let value: f64 = range(0.0..100.0);
154            assert!((0.0..100.0).contains(&value));
155        }
156    }
157
158    #[test]
159    fn test_choose() {
160        for _ in 0..100 {
161            let items = vec![1, 2, 3, 4, 5];
162            let value = choose(&items);
163            assert!(items.contains(value));
164        }
165    }
166
167    #[test]
168    fn test_shuffle() {
169        let mut items = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
170        shuffle(&mut items);
171        assert_ne!(items, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
172    }
173
174    #[test]
175    fn test_indexes() {
176        let indexes = indexes(0..10);
177        assert_eq!(indexes.len(), 10);
178        assert_ne!(indexes, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
179    }
180}