1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
use core::ops::Not;
use rand::RngCore;
pub(crate) struct CoinFlipper<R: RngCore> {
pub rng: R,
chunk: u32,
chunk_remaining: u32,
}
impl<R: RngCore> CoinFlipper<R> {
pub fn new(rng: R) -> Self {
Self {
rng,
chunk: 0,
chunk_remaining: 0,
}
}
#[inline]
pub fn gen_index(&mut self, inclusive_upper_bound: u32) -> u32 {
if inclusive_upper_bound == 0 {
return 0;
}
let lz = inclusive_upper_bound.leading_zeros();
let digits = u32::BITS - lz;
loop {
if self.chunk_remaining < digits {
self.chunk = self.rng.next_u32();
self.chunk_remaining = u32::BITS;
}
let r = self.chunk >> lz;
self.chunk <<= digits;
self.chunk_remaining -= digits;
if r <= inclusive_upper_bound {
return r;
}
}
}
#[inline]
/// Returns true with a probability of 1 / d
/// Uses an expected two bits of randomness
/// Panics if d == 0
pub fn gen_ratio_one_over(&mut self, d: usize) -> bool {
debug_assert_ne!(d, 0);
// This uses the same logic as `gen_ratio` but is optimized for the case that
// the starting numerator is one (which it always is for `Sequence::Choose()`)
// In this case (but not `gen_ratio`), this way of calculating c is always accurate
let c = (usize::BITS - 1 - d.leading_zeros()).min(32);
if self.flip_c_heads(c) {
let numerator = 1 << c;
return self.gen_ratio(numerator, d);
} else {
return false;
}
}
#[inline]
/// Returns true with a probability of n / d
/// Uses an expected two bits of randomness
fn gen_ratio(&mut self, mut n: usize, d: usize) -> bool {
// Explanation:
// We are trying to return true with a probability of n / d
// If n >= d, we can just return true
// Otherwise there are two possibilities 2n < d and 2n >= d
// In either case we flip a coin.
// If 2n < d
// If it comes up tails, return false
// If it comes up heads, double n and start again
// This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d
// (if 2n was greater than d we would effectively round it down to 1
// by returning true)
// If 2n >= d
// If it comes up tails, set n to 2n - d and start again
// If it comes up heads, return true
// This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d
// Note that if 2n = d and the coin comes up tails, n will be set to 0
// before restarting which is equivalent to returning false.
// As a performance optimization we can flip multiple coins at once
// This is efficient because we can use the `lzcnt` intrinsic
// We can check up to 32 flips at once but we only receive one bit of information
// - all heads or at least one tail.
// Let c be the number of coins to flip. 1 <= c <= 32
// If 2n < d, n * 2^c < d
// If the result is all heads, then set n to n * 2^c
// If there was at least one tail, return false
// If 2n >= d, the order of results matters so we flip one coin at a time so c = 1
// Ideally, c will be as high as possible within these constraints
while n < d {
// Find a good value for c by counting leading zeros
// This will either give the highest possible c, or 1 less than that
let c = n
.leading_zeros()
.saturating_sub(d.leading_zeros() + 1)
.clamp(1, 32);
if self.flip_c_heads(c) {
// All heads
// Set n to n * 2^c
// If 2n >= d, the while loop will exit and we will return `true`
// If n * 2^c > `usize::MAX` we always return `true` anyway
n = n.saturating_mul(2_usize.pow(c));
} else {
//At least one tail
if c == 1 {
// Calculate 2n - d.
// We need to use wrapping as 2n might be greater than `usize::MAX`
let next_n = n.wrapping_add(n).wrapping_sub(d);
if next_n == 0 || next_n > n {
// This will happen if 2n < d
return false;
}
n = next_n;
} else {
// c > 1 so 2n < d so we can return false
return false;
}
}
}
true
}
pub fn try_skip(&mut self, min_denominator: u32) -> u32 {
if self.chunk == 0 {
//chunk is all zeros - can't skip anything
return 0;
}
let heads_needed = 31_u32.saturating_sub(min_denominator.leading_zeros());
if heads_needed == 0 {
return 0;
}
let mut bits = self.chunk;
//let just_padding_zeros = (u32::MAX << 32 - self.chunk_remaining);
for i in 1..heads_needed {
// if bits == just_padding_zeros {
// break; //Non-padding bits are all ones - there is no long enough string of zeros
// }
bits |= self.chunk.wrapping_shl(i);
}
let amount_to_consume = bits.leading_ones();
let total_ones = self.chunk.count_ones();
self.chunk = self.chunk.wrapping_shl(amount_to_consume);
self.chunk_remaining = self.chunk_remaining.saturating_sub(amount_to_consume);
let skip = total_ones.saturating_sub(self.chunk.count_ones());
return skip;
}
/// If the next `c` bits of randomness all represent heads, consume them, return true
/// Otherwise return false and consume the number of heads plus one.
/// Generates new bits of randomness when necessary (in 32 bit chunks)
/// Has a 1 in 2 to the `c` chance of returning true
/// `c` must be less than or equal to 32
fn flip_c_heads(&mut self, mut c: u32) -> bool {
debug_assert!(c <= 32);
// Note that zeros on the left of the chunk represent heads.
// It needs to be this way round because zeros are filled in when left shifting
loop {
let zeros = self.chunk.leading_zeros();
if zeros < c {
// The happy path - we found a 1 and can return false
// Note that because a 1 bit was detected,
// We cannot have run out of random bits so we don't need to check
// First consume all of the bits read
// Using shl seems to give worse performance for size-hinted iterators
self.chunk = self.chunk.wrapping_shl(zeros + 1);
self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1);
return false;
} else {
// The number of zeros is larger than `c`
// There are two possibilities
if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) {
// Those zeroes were all part of our random chunk,
// throw away `c` bits of randomness and return true
self.chunk_remaining = new_remaining;
self.chunk <<= c;
return true;
} else {
// Some of those zeroes were part of the random chunk
// and some were part of the space behind it
// We need to take into account only the zeroes that were random
c -= self.chunk_remaining;
// Generate a new chunk
self.chunk = self.rng.next_u32();
self.chunk_remaining = 32;
// Go back to start of loop
}
}
}
}
}