1#[cfg(test)]
2#[path = "../../tests/unit/utils/random_test.rs"]
3mod random_test;
4
5use crate::utils::Float;
6use rand::prelude::*;
7use rand::Error;
8use rand_distr::{Gamma, Normal};
9use std::cell::RefCell;
10use std::cmp::Ordering;
11use std::sync::Arc;
12
13pub trait DistributionSampler {
15 fn gamma(&self, shape: Float, scale: Float) -> Float;
17
18 fn normal(&self, mean: Float, std_dev: Float) -> Float;
20}
21
22pub trait Random: Send + Sync {
24 fn uniform_int(&self, min: i32, max: i32) -> i32;
26
27 fn uniform_real(&self, min: Float, max: Float) -> Float;
29
30 fn is_head_not_tails(&self) -> bool;
32
33 fn is_hit(&self, probability: Float) -> bool;
35
36 fn weighted(&self, weights: &[usize]) -> usize;
40
41 fn get_rng(&self) -> RandomGen;
43}
44
45#[derive(Clone)]
47pub struct DefaultDistributionSampler(Arc<dyn Random>);
48
49impl DefaultDistributionSampler {
50 pub fn new(random: Arc<dyn Random>) -> Self {
52 Self(random)
53 }
54
55 pub fn sample_gamma(shape: Float, scale: Float, random: &dyn Random) -> Float {
57 Gamma::new(shape, scale)
58 .unwrap_or_else(|_| panic!("cannot create gamma dist: shape={shape}, scale={scale}"))
59 .sample(&mut random.get_rng())
60 }
61
62 pub fn sample_normal(mean: Float, std_dev: Float, random: &dyn Random) -> Float {
64 Normal::new(mean, std_dev)
65 .unwrap_or_else(|_| panic!("cannot create normal dist: mean={mean}, std_dev={std_dev}"))
66 .sample(&mut random.get_rng())
67 }
68}
69
70impl DistributionSampler for DefaultDistributionSampler {
71 fn gamma(&self, shape: Float, scale: Float) -> Float {
72 Self::sample_gamma(shape, scale, self.0.as_ref())
73 }
74
75 fn normal(&self, mean: Float, std_dev: Float) -> Float {
76 Self::sample_normal(mean, std_dev, self.0.as_ref())
77 }
78}
79
80#[derive(Default)]
82pub struct DefaultRandom {
83 use_repeatable: bool,
84}
85
86impl DefaultRandom {
87 pub fn new_repeatable() -> Self {
89 Self { use_repeatable: true }
90 }
91}
92
93impl Random for DefaultRandom {
94 fn uniform_int(&self, min: i32, max: i32) -> i32 {
95 if min == max {
96 return min;
97 }
98
99 assert!(min < max);
100 self.get_rng().gen_range(min..max + 1)
101 }
102
103 fn uniform_real(&self, min: Float, max: Float) -> Float {
104 if (min - max).abs() < Float::EPSILON {
105 return min;
106 }
107
108 assert!(min < max);
109 self.get_rng().gen_range(min..max)
110 }
111
112 fn is_head_not_tails(&self) -> bool {
113 self.get_rng().gen_bool(0.5)
114 }
115
116 fn is_hit(&self, probability: Float) -> bool {
117 #![allow(clippy::unnecessary_cast)]
118 self.get_rng().gen_bool(probability.clamp(0., 1.) as f64)
119 }
120
121 fn weighted(&self, weights: &[usize]) -> usize {
122 weights
123 .iter()
124 .zip(0_usize..)
125 .map(|(&weight, index)| (-self.uniform_real(0., 1.).ln() / weight as Float, index))
126 .min_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
127 .unwrap()
128 .1
129 }
130
131 fn get_rng(&self) -> RandomGen {
132 RandomGen { use_repeatable: self.use_repeatable }
133 }
134}
135
136thread_local! {
137 static RANDOMIZED_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_rng(thread_rng()).expect("cannot get RNG from thread rng"));
139
140 static REPEATABLE_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::seed_from_u64(0));
142}
143
144#[derive(Clone, Debug)]
146pub struct RandomGen {
147 use_repeatable: bool,
148}
149
150impl RandomGen {
151 pub fn new_repeatable() -> Self {
153 Self { use_repeatable: true }
154 }
155
156 pub fn new_randomized() -> Self {
158 Self { use_repeatable: false }
159 }
160}
161
162impl RngCore for RandomGen {
163 fn next_u32(&mut self) -> u32 {
164 if self.use_repeatable {
166 REPEATABLE_RNG.with(|t| t.borrow_mut().next_u32())
167 } else {
168 RANDOMIZED_RNG.with(|t| t.borrow_mut().next_u32())
169 }
170 }
171
172 fn next_u64(&mut self) -> u64 {
173 if self.use_repeatable {
174 REPEATABLE_RNG.with(|t| t.borrow_mut().next_u64())
175 } else {
176 RANDOMIZED_RNG.with(|t| t.borrow_mut().next_u64())
177 }
178 }
179
180 fn fill_bytes(&mut self, dest: &mut [u8]) {
181 if self.use_repeatable {
182 REPEATABLE_RNG.with(|t| t.borrow_mut().fill_bytes(dest))
183 } else {
184 RANDOMIZED_RNG.with(|t| t.borrow_mut().fill_bytes(dest))
185 }
186 }
187
188 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
189 if self.use_repeatable {
190 REPEATABLE_RNG.with(|t| t.borrow_mut().try_fill_bytes(dest))
191 } else {
192 RANDOMIZED_RNG.with(|t| t.borrow_mut().try_fill_bytes(dest))
193 }
194 }
195}
196
197impl CryptoRng for RandomGen {}
198
199pub fn random_argmax<I>(values: I, random: &dyn Random) -> Option<usize>
202where
203 I: Iterator<Item = Float>,
204{
205 let mut rng = random.get_rng();
206 let mut count = 0;
207 values
208 .enumerate()
209 .max_by(move |(_, r), (_, s)| match r.total_cmp(s) {
210 Ordering::Equal => {
211 count += 1;
212 if rng.gen_range(0..=count) == 0 {
213 Ordering::Less
214 } else {
215 Ordering::Greater
216 }
217 }
218 Ordering::Less => {
219 count = 0;
220 Ordering::Less
221 }
222 Ordering::Greater => Ordering::Greater,
223 })
224 .map(|(idx, _)| idx)
225}