1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
use rand::{ RngCore};
pub(crate) struct CoinFlipper<R: RngCore> {
pub rng: R,
chunk: u32,
chunk_remaining: u32,
}
impl<R: RngCore> CoinFlipper<R> {
pub fn new(rng: R) -> Self {
Self {
rng,
chunk: 0,
chunk_remaining: 0,
}
}
/// Returns true with a probability of 1 / denominator.
/// Uses an expected two bits of randomness
pub fn gen_ratio_one_over(&mut self, denominator: usize) -> bool {
//For this case we can use an optimization, checking a large number of bits at once. If all those bits are successful, then we specialize
let n = usize::BITS - denominator.leading_zeros() - 1;
if !self.all_next(n) {
return false;
}
return self.gen_ratio(1 << n, denominator);
}
/// Returns true with a probability of numerator / denominator
/// Uses an expected two bits of randomness
pub fn gen_ratio(&mut self, mut numerator: usize, denominator: usize) -> bool {
// Explanation:
// We are trying to return true with a probability of n / d
// If n >= d, we can just return true
// Otherwise there are two possibilities 2n < d and 2n >= d
// In either case we flip a coin.
// If 2n < d
// If it comes up tails, return false
// If it comes up heads, double n and start again
// This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d (if 2n was greater than d we would effectively round it down to 1)
// If 2n >= d
// If it comes up tails, set n to 2n - d
// If it comes up heads, return true
// This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d
while numerator < denominator {
if let Some(next_numerator) = numerator.checked_mul(2) { //This condition will usually be true
if self.next() { //Heads
numerator = next_numerator; //if 2n >= d we this will be checked at the start of the next loop
} else { //Tails
if next_numerator < denominator {
return false; //2n < d
} else {
numerator = next_numerator - denominator; //2n was greater than d so set it to 2n - d
}
}
} else {//Special branch just for massive numbers.
//2n > usize::max >= d so 2n >= d
if self.next() { //heads
return true;
} else { //tails
numerator = numerator.wrapping_sub(denominator).wrapping_add(numerator); //2n - d
}
}
}
return true;
}
/// Consume one bit of randomness
/// Has a one in two chance of returning true
fn next(&mut self) -> bool {
if let Some(new_rem) = self.chunk_remaining.checked_sub(1)
{
self.chunk_remaining = new_rem;
}
else {
self.chunk = self.rng.next_u32();
self.chunk_remaining = u32::BITS - 1;
};
let result = self.chunk.trailing_zeros() > 0; //TODO check if there is a faster test the last bit
self.chunk = self.chunk.wrapping_shr(1);
return result;
}
/// If the next n bits of randomness are all zeroes, consume them and return true.
/// Otherwise return false and consume the number of zeroes plus one
/// Has a one in 2 to the n chance of returning true
fn all_next(&mut self, mut n: u32) -> bool {
let mut zeros = self.chunk.trailing_zeros();
while self.chunk_remaining < n { //Check we have enough randomness left
if zeros >= self.chunk_remaining {
n -= self.chunk_remaining; // Remaining bits are zeroes, we will need to generate more bits and continue
} else {
self.chunk_remaining -= zeros + 1; //There was a one in the remaining bits so we can consume it and continue
self.chunk = self.chunk >> (zeros + 1);
return false;
}
self.chunk = self.rng.next_u32();
self.chunk_remaining = u32::BITS;
zeros = self.chunk.trailing_zeros();
}
let result = zeros >= n;
let bits_to_consume = if result { n } else { zeros + 1 };
self.chunk = self.chunk.wrapping_shr(bits_to_consume);
self.chunk_remaining = self.chunk_remaining.saturating_sub(bits_to_consume);
return result;
}
}
#[cfg(test)]
mod tests {
use core::ops::Range;
use crate::CoinFlipper;
use rand::{Rng, RngCore, SeedableRng};
/// How many runs to do
const RUNS: usize = 10000;
/// Different length arrays to use
const LENGTH: usize = 10000;
const START: usize = 1;
const SEED: u64 = 123;
#[test]
pub fn test_one_over_for_big_numbers() {
let rng = get_rng();
let mut coin_flipper = CoinFlipper::new(rng);
let mut count = 0;
for _ in 0..LENGTH {
if coin_flipper.gen_ratio_one_over((2_i64.pow(33) + 1) as usize) {
count += 1;
}
}
let average_gens = ((LENGTH) as f64) / (coin_flipper.rng.count as f64);
println!(
"Gens: {} (1 per {} gens)",
coin_flipper.rng.count, average_gens
);
println!("Count: {count}");
assert_contains(15.5..16.5, &average_gens); //Should be about 16
assert!(count < 2); //Should not get it twice
}
#[test]
pub fn test_gen_ratio_for_big_numbers() {
let rng = get_rng();
let mut coin_flipper = CoinFlipper::new(rng);
let mut count = 0;
for _ in 0..RUNS {
if coin_flipper.gen_ratio((usize::MAX / 2) + 1, usize::MAX) {
count += 1;
}
}
let average_gens = (RUNS as f64) / (coin_flipper.rng.count as f64);
println!(
"Gens: {} (1 per {} gens)",
coin_flipper.rng.count, average_gens
);
println!("Count: {count}");
let mean = (count as f64) / RUNS as f64;
println!("Mean: {mean}");
assert_contains(0.45..0.55, &mean); //Should be about 0.5
}
#[test]
pub fn test_coin_flipper_gen_ratio() {
let rng = get_rng();
let mut coin_flipper = CoinFlipper::new(rng);
let mut counts: Vec<_> = Default::default();
for d in START..=LENGTH {
let mut count = 0;
for _ in 0..RUNS {
if coin_flipper.gen_ratio_one_over(d) {
count += 1;
}
}
counts.push(count);
}
let adjusted_counts: Vec<_> = counts
.iter()
.enumerate()
.map(|(i, &x)| (i + START) * x)
.map(|z| (z as f64) / (RUNS as f64))
.collect();
// println!(
// "{}",
// adjusted_counts
// .iter()
// .map(|z| z.to_string())
// .collect::<Vec<_>>()
// .join(", ")
// );
let average_gens = ((RUNS * LENGTH) as f64) / (coin_flipper.rng.count as f64);
println!(
"Gens: {} (1 per {} gens)",
coin_flipper.rng.count, average_gens
);
let (mean, variance, standard_deviation) = get_stats(adjusted_counts);
println!("mean: {mean}, variance: {variance}, standard deviation: {standard_deviation}");
//assert_contains(15.5..16.5, &average_gens); //Should be just over 16 gens per gen_ratio
assert_contains(0.95..1.05, &mean); //Should be about 1 because we are adjusting
assert_contains(0.0..10.0, &standard_deviation);
}
fn get_rng() -> CountingRng<rand::rngs::StdRng> {
let inner = rand::rngs::StdRng::seed_from_u64(SEED);
CountingRng {
rng: inner,
count: 0,
}
}
pub fn get_stats(vec: Vec<f64>) -> (f64, f64, f64) {
let mean: f64 = vec.iter().map(|&x| x as f64 / (vec.len() as f64)).sum();
let variance: f64 = vec
.iter()
.map(|&x| f64::powi((x as f64) - mean, 2) / (vec.len() as f64))
.sum();
let standard_deviation = f64::sqrt(variance);
(mean, variance, standard_deviation)
}
fn assert_contains(range: Range<f64>, n: &f64) {
if !range.contains(n) {
panic!("The range {:?} does not contain {n}", range)
}
}
struct CountingRng<Inner: Rng> {
pub rng: Inner,
pub count: usize,
}
impl<Inner: Rng> RngCore for CountingRng<Inner> {
fn next_u32(&mut self) -> u32 {
self.count += 1;
self.rng.next_u32()
}
fn next_u64(&mut self) -> u64 {
self.count += 1;
self.rng.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.count += 1;
self.rng.fill_bytes(dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
self.count += 1;
self.rng.try_fill_bytes(dest)
}
}
}