1use rand_core::RngCore;
17
18pub fn bucket<R: RngCore>(rng: &mut R, num_buckets: u32) -> u32 {
24 if num_buckets <= 1 {
25 return 0;
26 }
27
28 let r0 = rng.next_u64();
29
30 let x_masked_init =
31 ((r0 ^ (r0 >> 32)) as u32) & (0xFFFF_FFFFu32 >> (num_buckets - 1).leading_zeros());
32
33 let mut x_masked = x_masked_init;
34
35 loop {
36 if x_masked == 0 {
37 return 0;
38 }
39 let bucket_range_min: u32 = 1u32 << (31 - x_masked.leading_zeros());
40 let bit_count = x_masked.count_ones();
41 let bucket_idx =
42 bucket_range_min | (r0.wrapping_shr(bit_count * 32) as u32 & (bucket_range_min - 1));
43
44 if bucket_idx < num_buckets {
45 return bucket_idx;
46 }
47
48 let bucket_range_max = (bucket_range_min << 1) - 1;
49
50 loop {
51 let r1 = rng.next_u64();
52
53 let idx1 = (r1 as u32) & bucket_range_max;
54 if idx1 < bucket_range_min {
55 break;
56 }
57 if idx1 < num_buckets {
58 return idx1;
59 }
60
61 let idx2 = ((r1 >> 32) as u32) & bucket_range_max;
62 if idx2 < bucket_range_min {
63 break;
64 }
65 if idx2 < num_buckets {
66 return idx2;
67 }
68 }
69
70 x_masked ^= bucket_range_min;
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use rand_core::RngCore;
78
79 struct AddRng(u64);
81
82 impl RngCore for AddRng {
83 fn next_u32(&mut self) -> u32 {
84 self.next_u64() as u32
85 }
86 fn next_u64(&mut self) -> u64 {
87 self.0 = self.0.wrapping_add(19_561_023);
88 self.0
89 }
90 fn fill_bytes(&mut self, dest: &mut [u8]) {
91 rand_core::impls::fill_bytes_via_next(self, dest);
92 }
93 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
94 self.fill_bytes(dest);
95 Ok(())
96 }
97 }
98
99 fn add_rng(seed: u64) -> AddRng {
100 AddRng(seed)
101 }
102
103 struct SplitMix64(u64);
106
107 impl SplitMix64 {
108 fn new(seed: u64) -> Self {
109 Self(seed)
110 }
111 }
112
113 impl RngCore for SplitMix64 {
114 fn next_u32(&mut self) -> u32 {
115 self.next_u64() as u32
116 }
117 fn next_u64(&mut self) -> u64 {
118 self.0 = self.0.wrapping_add(0x9e3779b97f4a7c15);
119 let mut z = self.0;
120 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
121 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
122 z ^ (z >> 31)
123 }
124 fn fill_bytes(&mut self, dest: &mut [u8]) {
125 rand_core::impls::fill_bytes_via_next(self, dest);
126 }
127 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
128 self.fill_bytes(dest);
129 Ok(())
130 }
131 }
132
133 #[test]
134 fn single_bucket_always_zero() {
135 for seed in [0u64, 1, u64::MAX, 0xdeadbeef] {
136 assert_eq!(bucket(&mut add_rng(seed), 1), 0);
137 }
138 }
139
140 #[test]
141 fn zero_buckets_returns_zero() {
142 assert_eq!(bucket(&mut add_rng(42), 0), 0);
143 }
144
145 #[test]
146 fn result_in_range() {
147 for n in 2u32..=64 {
148 for seed in [0u64, 1, 12345, 0xdeadbeefcafe, u64::MAX] {
149 let b = bucket(&mut add_rng(seed), n);
150 assert!(b < n, "bucket(seed={seed}, n={n}) = {b} out of range");
151 }
152 }
153 }
154
155 #[test]
156 fn deterministic() {
157 for n in [2u32, 3, 10, 100] {
158 for seed in [0u64, 1, 999, u64::MAX / 3] {
159 assert_eq!(bucket(&mut add_rng(seed), n), bucket(&mut add_rng(seed), n));
160 }
161 }
162 }
163
164 #[test]
175 fn consistent_remapping_fraction() {
176 use rand_core::OsRng;
177
178 const KEY_COUNT: u64 = 10_000;
179
180 for n in [2u32, 3, 5, 10, 20, 50, 100] {
181 let mut moved: u64 = 0;
182
183 for _ in 0..KEY_COUNT {
184 let key = OsRng.next_u64();
185 let b_n = bucket(&mut SplitMix64::new(key), n);
186 let b_n1 = bucket(&mut SplitMix64::new(key), n + 1);
187 if b_n != b_n1 {
188 moved += 1;
189 }
190 }
191
192 let p = 1.0 / (n + 1) as f64;
194 let expected = KEY_COUNT as f64 * p;
195 let std_dev = (KEY_COUNT as f64 * p * (1.0 - p)).sqrt();
196 let tolerance = 4.0 * std_dev;
197
198 assert!(
199 (moved as f64 - expected).abs() <= tolerance,
200 "n={n}: moved {moved}/{KEY_COUNT} keys, expected {expected:.1} ± {tolerance:.1}"
201 );
202 }
203 }
204}