Skip to main content

rosomaxa/utils/
random.rs

1#[cfg(test)]
2#[path = "../../tests/unit/utils/random_test.rs"]
3mod random_test;
4
5use crate::utils::Float;
6use rand::prelude::*;
7use rand::Error;
8use rand_distr::{Gamma, Normal};
9use std::cell::RefCell;
10use std::cmp::Ordering;
11use std::sync::Arc;
12
13/// Provides the way to sample from different distributions.
14pub trait DistributionSampler {
15    /// Returns a sample from gamma distribution.
16    fn gamma(&self, shape: Float, scale: Float) -> Float;
17
18    /// Returns a sample from normal distribution.
19    fn normal(&self, mean: Float, std_dev: Float) -> Float;
20}
21
22/// Provides the way to use randomized values in generic way.
23pub trait Random: Send + Sync {
24    /// Produces integral random value, uniformly distributed on the closed interval [min, max]
25    fn uniform_int(&self, min: i32, max: i32) -> i32;
26
27    /// Produces real random value, uniformly distributed on the closed interval [min, max)
28    fn uniform_real(&self, min: Float, max: Float) -> Float;
29
30    /// Flips a coin and returns true if it is "heads", false otherwise.
31    fn is_head_not_tails(&self) -> bool;
32
33    /// Tests probability value in (0., 1.) range.
34    fn is_hit(&self, probability: Float) -> bool;
35
36    /// Returns an index from collected with probability weight.
37    /// Uses exponential distribution where the weights are the rate of the distribution (lambda)
38    /// and selects the smallest sampled value.
39    fn weighted(&self, weights: &[usize]) -> usize;
40
41    /// Returns RNG.
42    fn get_rng(&self) -> RandomGen;
43}
44
45/// Provides way to sample from different distributions.
46#[derive(Clone)]
47pub struct DefaultDistributionSampler(Arc<dyn Random>);
48
49impl DefaultDistributionSampler {
50    /// Creates a new instance of `DefaultDistributionSampler`.
51    pub fn new(random: Arc<dyn Random>) -> Self {
52        Self(random)
53    }
54
55    /// Returns a sample from gamma distribution.
56    pub fn sample_gamma(shape: Float, scale: Float, random: &dyn Random) -> Float {
57        Gamma::new(shape, scale)
58            .unwrap_or_else(|_| panic!("cannot create gamma dist: shape={shape}, scale={scale}"))
59            .sample(&mut random.get_rng())
60    }
61
62    /// Returns a sample from normal distribution.
63    pub fn sample_normal(mean: Float, std_dev: Float, random: &dyn Random) -> Float {
64        Normal::new(mean, std_dev)
65            .unwrap_or_else(|_| panic!("cannot create normal dist: mean={mean}, std_dev={std_dev}"))
66            .sample(&mut random.get_rng())
67    }
68}
69
70impl DistributionSampler for DefaultDistributionSampler {
71    fn gamma(&self, shape: Float, scale: Float) -> Float {
72        Self::sample_gamma(shape, scale, self.0.as_ref())
73    }
74
75    fn normal(&self, mean: Float, std_dev: Float) -> Float {
76        Self::sample_normal(mean, std_dev, self.0.as_ref())
77    }
78}
79
80/// A default random implementation.
81#[derive(Default)]
82pub struct DefaultRandom {
83    use_repeatable: bool,
84}
85
86impl DefaultRandom {
87    /// Creates an instance of `DefaultRandom` with repeatable (predictable) random generation.
88    pub fn new_repeatable() -> Self {
89        Self { use_repeatable: true }
90    }
91}
92
93impl Random for DefaultRandom {
94    fn uniform_int(&self, min: i32, max: i32) -> i32 {
95        if min == max {
96            return min;
97        }
98
99        assert!(min < max);
100        self.get_rng().gen_range(min..max + 1)
101    }
102
103    fn uniform_real(&self, min: Float, max: Float) -> Float {
104        if (min - max).abs() < Float::EPSILON {
105            return min;
106        }
107
108        assert!(min < max);
109        self.get_rng().gen_range(min..max)
110    }
111
112    fn is_head_not_tails(&self) -> bool {
113        self.get_rng().gen_bool(0.5)
114    }
115
116    fn is_hit(&self, probability: Float) -> bool {
117        #![allow(clippy::unnecessary_cast)]
118        self.get_rng().gen_bool(probability.clamp(0., 1.) as f64)
119    }
120
121    fn weighted(&self, weights: &[usize]) -> usize {
122        weights
123            .iter()
124            .zip(0_usize..)
125            .map(|(&weight, index)| (-self.uniform_real(0., 1.).ln() / weight as Float, index))
126            .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
127            .unwrap()
128            .1
129    }
130
131    fn get_rng(&self) -> RandomGen {
132        RandomGen { use_repeatable: self.use_repeatable }
133    }
134}
135
136thread_local! {
137    /// Random generator seeded from thread_rng to make runs non-repeatable.
138    static RANDOMIZED_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_rng(thread_rng()).expect("cannot get RNG from thread rng"));
139
140    /// Random generator seeded with 0 SmallRng to make runs repeatable.
141    static REPEATABLE_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::seed_from_u64(0));
142}
143
144/// Provides underlying random generator API.
145#[derive(Clone, Debug)]
146pub struct RandomGen {
147    use_repeatable: bool,
148}
149
150impl RandomGen {
151    /// Creates an instance of `RandomGen` using random generator with fixed seed.
152    pub fn new_repeatable() -> Self {
153        Self { use_repeatable: true }
154    }
155
156    /// Creates an instance of `RandomGen` using random generator with randomized seed.
157    pub fn new_randomized() -> Self {
158        Self { use_repeatable: false }
159    }
160}
161
162impl RngCore for RandomGen {
163    fn next_u32(&mut self) -> u32 {
164        // NOTE use 'likely!' macro for better branch prediction once it is stabilized?
165        if self.use_repeatable {
166            REPEATABLE_RNG.with(|t| t.borrow_mut().next_u32())
167        } else {
168            RANDOMIZED_RNG.with(|t| t.borrow_mut().next_u32())
169        }
170    }
171
172    fn next_u64(&mut self) -> u64 {
173        if self.use_repeatable {
174            REPEATABLE_RNG.with(|t| t.borrow_mut().next_u64())
175        } else {
176            RANDOMIZED_RNG.with(|t| t.borrow_mut().next_u64())
177        }
178    }
179
180    fn fill_bytes(&mut self, dest: &mut [u8]) {
181        if self.use_repeatable {
182            REPEATABLE_RNG.with(|t| t.borrow_mut().fill_bytes(dest))
183        } else {
184            RANDOMIZED_RNG.with(|t| t.borrow_mut().fill_bytes(dest))
185        }
186    }
187
188    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
189        if self.use_repeatable {
190            REPEATABLE_RNG.with(|t| t.borrow_mut().try_fill_bytes(dest))
191        } else {
192            RANDOMIZED_RNG.with(|t| t.borrow_mut().try_fill_bytes(dest))
193        }
194    }
195}
196
197impl CryptoRng for RandomGen {}
198
199/// Returns an index of max element in values. In case of many same max elements,
200/// returns the one from them at random.
201pub fn random_argmax<I>(values: I, random: &dyn Random) -> Option<usize>
202where
203    I: Iterator<Item = Float>,
204{
205    let mut rng = random.get_rng();
206    let mut count = 0;
207    values
208        .enumerate()
209        .max_by(move |(_, r), (_, s)| match r.total_cmp(s) {
210            Ordering::Equal => {
211                count += 1;
212                if rng.gen_range(0..=count) == 0 {
213                    Ordering::Less
214                } else {
215                    Ordering::Greater
216                }
217            }
218            Ordering::Less => {
219                count = 0;
220                Ordering::Less
221            }
222            Ordering::Greater => Ordering::Greater,
223        })
224        .map(|(idx, _)| idx)
225}