1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#[cfg(test)]
#[path = "../../tests/unit/utils/random_test.rs"]
mod random_test;

use rand::prelude::*;
use rand::Error;
use rand_distr::{Gamma, Normal};
use std::cell::RefCell;
use std::cmp::Ordering;
use std::sync::Arc;

/// Provides the way to sample from different distributions.
pub trait DistributionSampler {
    /// Returns a sample from gamma distribution.
    fn gamma(&self, shape: f64, scale: f64) -> f64;

    /// Returns a sample from normal distribution.
    fn normal(&self, mean: f64, std_dev: f64) -> f64;
}

/// Provides the way to use randomized values in generic way.
pub trait Random {
    /// Produces integral random value, uniformly distributed on the closed interval [min, max]
    fn uniform_int(&self, min: i32, max: i32) -> i32;

    /// Produces real random value, uniformly distributed on the closed interval [min, max)
    fn uniform_real(&self, min: f64, max: f64) -> f64;

    /// Flips a coin and returns true if it is "heads", false otherwise.
    fn is_head_not_tails(&self) -> bool;

    /// Tests probability value in (0., 1.) range.
    fn is_hit(&self, probability: f64) -> bool;

    /// Returns an index from collected with probability weight.
    /// Uses exponential distribution where the weights are the rate of the distribution (lambda)
    /// and selects the smallest sampled value.
    fn weighted(&self, weights: &[usize]) -> usize;

    /// Returns RNG.
    fn get_rng(&self) -> RandomGen;
}

/// Provides way to sample from different distributions.
#[derive(Clone)]
pub struct DefaultDistributionSampler(Arc<dyn Random + Send + Sync>);

impl DefaultDistributionSampler {
    /// Creates a new instance of `DefaultDistributionSampler`.
    pub fn new(random: Arc<dyn Random + Send + Sync>) -> Self {
        Self(random)
    }
}

impl DistributionSampler for DefaultDistributionSampler {
    fn gamma(&self, shape: f64, scale: f64) -> f64 {
        Gamma::new(shape, scale)
            .unwrap_or_else(|_| panic!("cannot create gamma dist: shape={shape}, scale={scale}"))
            .sample(&mut self.0.get_rng())
    }

    fn normal(&self, mean: f64, std_dev: f64) -> f64 {
        Normal::new(mean, std_dev)
            .unwrap_or_else(|_| panic!("cannot create normal dist: mean={mean}, std_dev={std_dev}"))
            .sample(&mut self.0.get_rng())
    }
}

/// A default random implementation.
#[derive(Default)]
pub struct DefaultRandom {
    use_repeatable: bool,
}

impl DefaultRandom {
    /// Creates an instance of `DefaultRandom` with repeatable (predictable) random generation.
    pub fn new_repeatable() -> Self {
        Self { use_repeatable: true }
    }
}

impl Random for DefaultRandom {
    fn uniform_int(&self, min: i32, max: i32) -> i32 {
        if min == max {
            return min;
        }

        assert!(min < max);
        self.get_rng().gen_range(min..max + 1)
    }

    fn uniform_real(&self, min: f64, max: f64) -> f64 {
        if (min - max).abs() < f64::EPSILON {
            return min;
        }

        assert!(min < max);
        self.get_rng().gen_range(min..max)
    }

    fn is_head_not_tails(&self) -> bool {
        self.get_rng().gen_bool(0.5)
    }

    fn is_hit(&self, probability: f64) -> bool {
        self.get_rng().gen_bool(probability.clamp(0., 1.))
    }

    fn weighted(&self, weights: &[usize]) -> usize {
        weights
            .iter()
            .zip(0_usize..)
            .map(|(&weight, index)| (-self.uniform_real(0., 1.).ln() / weight as f64, index))
            .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
            .unwrap()
            .1
    }

    fn get_rng(&self) -> RandomGen {
        RandomGen { use_repeatable: self.use_repeatable }
    }
}

thread_local! {
    /// Random generator seeded from thread_rng to make runs non-repeatable.
    static RANDOMIZED_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_rng(thread_rng()).expect("cannot get RNG from thread rng"));

    /// Random generator seeded with 0 SmallRng to make runs repeatable.
    static REPEATABLE_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::seed_from_u64(0));
}

/// Provides underlying random generator API.
#[derive(Clone, Debug)]
pub struct RandomGen {
    use_repeatable: bool,
}

impl RandomGen {
    /// Creates an instance of `RandomGen` using random generator with fixed seed.
    pub fn new_repeatable() -> Self {
        Self { use_repeatable: true }
    }

    /// Creates an instance of `RandomGen` using random generator with randomized seed.
    pub fn new_randomized() -> Self {
        Self { use_repeatable: false }
    }
}

impl RngCore for RandomGen {
    fn next_u32(&mut self) -> u32 {
        // NOTE use 'likely!' macro for better branch prediction once it is stabilized?
        if self.use_repeatable {
            REPEATABLE_RNG.with(|t| t.borrow_mut().next_u32())
        } else {
            RANDOMIZED_RNG.with(|t| t.borrow_mut().next_u32())
        }
    }

    fn next_u64(&mut self) -> u64 {
        if self.use_repeatable {
            REPEATABLE_RNG.with(|t| t.borrow_mut().next_u64())
        } else {
            RANDOMIZED_RNG.with(|t| t.borrow_mut().next_u64())
        }
    }

    fn fill_bytes(&mut self, dest: &mut [u8]) {
        if self.use_repeatable {
            REPEATABLE_RNG.with(|t| t.borrow_mut().fill_bytes(dest))
        } else {
            RANDOMIZED_RNG.with(|t| t.borrow_mut().fill_bytes(dest))
        }
    }

    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
        if self.use_repeatable {
            REPEATABLE_RNG.with(|t| t.borrow_mut().try_fill_bytes(dest))
        } else {
            RANDOMIZED_RNG.with(|t| t.borrow_mut().try_fill_bytes(dest))
        }
    }
}

impl CryptoRng for RandomGen {}

/// Returns an index of max element in values. In case of many same max elements,
/// returns the one from them at random.
pub fn random_argmax<I>(values: I, random: &dyn Random) -> Option<usize>
where
    I: Iterator<Item = f64>,
{
    let mut rng = random.get_rng();
    let mut count = 0;
    values
        .enumerate()
        .max_by(move |(_, r), (_, s)| match r.total_cmp(s) {
            Ordering::Equal => {
                count += 1;
                if rng.gen_range(0..=count) == 0 {
                    Ordering::Less
                } else {
                    Ordering::Greater
                }
            }
            Ordering::Less => {
                count = 0;
                Ordering::Less
            }
            Ordering::Greater => Ordering::Greater,
        })
        .map(|(idx, _)| idx)
}