1use super::{Buffered, Ranged, TryDistribution, TryRanged};
2use anyhow::{anyhow, Result};
3use bitvec::{order::Lsb0, vec::BitVec};
4use core::ops::{Range, RangeInclusive};
5use rand::{prelude::Distribution, Rng};
6use std::{cell::RefCell, io::Read, mem::size_of};
7
8pub trait StandardBufferedSample {}
9pub trait StandardBufferedSampleRange {}
10
11#[derive(Clone, Debug)]
12pub struct StandardBuffered {
15 buf: RefCell<BitVec<u8, Lsb0>>,
16}
17
18impl StandardBuffered {
19 pub fn new() -> Self {
23 Self {
24 buf: RefCell::new(BitVec::new()),
25 }
26 }
27}
28
29impl Buffered for StandardBuffered {
30 fn try_ensure<R: Rng + ?Sized>(&self, bits: usize, rng: &mut R) -> Result<()> {
35 if self.buf.borrow().len() < bits {
36 let bits_needed = bits - self.buf.borrow().len();
37 let bytes_needed = ((bits_needed + (u8::BITS as usize - 1))
38 & (!(u8::BITS as usize - 1)))
39 / u8::BITS as usize;
40 let mut bits = vec![0u8; bytes_needed];
41 rng.try_fill_bytes(&mut bits)?;
42 self.buf.borrow_mut().extend(bits);
43 }
44 Ok(())
45 }
46
47 fn ensure<R: Rng + ?Sized>(&self, bits: usize, rng: &mut R) {
48 self.try_ensure::<R>(bits, rng)
49 .expect("Generator::ensure failed");
50 }
51}
52
53impl Default for StandardBuffered {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl Distribution<bool> for StandardBuffered {
60 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
61 self.ensure::<R>(1, rng);
63 self.buf.borrow_mut().remove(0)
64 }
65}
66
67impl TryDistribution<bool> for StandardBuffered {
68 fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<bool> {
69 self.try_ensure::<R>(1, rng)?;
71 Ok(self.buf.borrow_mut().remove(0))
72 }
73}
74
75impl Distribution<char> for StandardBuffered {
76 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> char {
77 self.ensure::<R>(u8::BITS as usize, rng);
78 let mut bytes = vec![0u8; 1];
79 self.buf
80 .borrow_mut()
81 .read_exact(&mut bytes)
82 .expect("Failed to read into buffer");
83 bytes[0] as char
84 }
85}
86
87impl TryDistribution<char> for StandardBuffered {
88 fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<char> {
89 self.try_ensure::<R>(u8::BITS as usize, rng)?;
90 let mut bytes = vec![0u8; 1];
91 self.buf.borrow_mut().read_exact(&mut bytes)?;
92 Ok(bytes[0] as char)
93 }
94}
95
96macro_rules! impl_distribution_integral {
97 ($T:ty) => {
98 impl StandardBufferedSample for $T {}
99
100 impl Distribution<$T> for StandardBuffered {
101 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $T {
102 self.ensure::<R>(<$T>::BITS as usize, rng);
103 let mut bytes = vec![0u8; size_of::<$T>()];
104 self.buf
105 .borrow_mut()
106 .read_exact(&mut bytes)
107 .expect("Failed to read into buffer");
108 <$T>::from_le_bytes(bytes.as_slice().try_into().expect("Invalid bytes"))
109 }
110 }
111
112 impl TryDistribution<$T> for StandardBuffered {
113 fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<$T> {
114 self.try_ensure::<R>(<$T>::BITS as usize, rng)?;
115 let mut bytes = vec![0u8; size_of::<$T>()];
116 self.buf.borrow_mut().read_exact(&mut bytes)?;
117 bytes
118 .as_slice()
119 .try_into()
120 .map(|a| <$T>::from_le_bytes(a))
121 .map_err(|e| anyhow!("Invalid bytes: {}", e))
122 }
123 }
124 };
125}
126
127impl_distribution_integral! { u8 }
128impl_distribution_integral! { u16 }
129impl_distribution_integral! { u32 }
130impl_distribution_integral! { u64 }
131impl_distribution_integral! { usize }
132impl_distribution_integral! { i8 }
133impl_distribution_integral! { i16 }
134impl_distribution_integral! { i32 }
135impl_distribution_integral! { i64 }
136impl_distribution_integral! { isize }
137
138macro_rules! impl_ranged_integral {
139 ($T:ty, $UT:ty) => {
140 impl_ranged_integral! { $T, $UT, $T }
141 };
142 ($T:ty, $UT:ty, $C:ty) => {
143 impl StandardBufferedSampleRange for $C {}
144
145 impl Ranged<$C> for StandardBuffered {
146 fn sample_range<R: Rng + ?Sized>(&self, rng: &mut R, range: Range<$C>) -> $C {
147 self.sample_range_inclusive(rng, range.start..=(range.end as $T - 1) as $C)
148 }
149
150 fn sample_range_inclusive<R: Rng + ?Sized>(
151 &self,
152 rng: &mut R,
153 range: RangeInclusive<$C>,
154 ) -> $C {
155 let end = *range.end() as $T;
156 let start = *range.start() as $T;
157 let range_size = end.wrapping_sub(start).wrapping_add(1) as $UT;
159
160 if range_size == 0 {
161 self.sample(rng)
162 } else {
163 let bits_needed: u32 = range_size.ilog2() as u32 + 1;
165 self.ensure::<R>(bits_needed as usize, rng);
167 let mut v = loop {
169 let mut v: $UT = 0;
171 for i in 0..bits_needed {
173 let bit = self.buf.borrow_mut().remove(0);
174 v |= (bit as $UT) << i;
175 }
176
177 if v < range_size {
178 break v;
179 }
180
181 self.ensure::<R>(bits_needed as usize, rng);
182 } as $T;
183
184 v += start;
185 v as $C
186 }
187 }
188 }
189
190 impl TryRanged<$C> for StandardBuffered {
191 fn try_sample_range<R: Rng + ?Sized>(
192 &self,
193 rng: &mut R,
194 range: Range<$C>,
195 ) -> Result<$C> {
196 self.try_sample_range_inclusive(rng, range.start..=(range.end as $T - 1) as $C)
197 }
198
199 fn try_sample_range_inclusive<R: Rng + ?Sized>(
200 &self,
201 rng: &mut R,
202 range: RangeInclusive<$C>,
203 ) -> Result<$C> {
204 let end = *range.end() as $T;
205 let start = *range.start() as $T;
206 let range_size = end.wrapping_sub(start).wrapping_add(1) as $UT;
208
209 if range_size == 0 {
210 self.try_sample(rng)
211 } else {
212 let bits_needed: u32 = range_size.ilog2() as u32 + 1;
214 self.try_ensure::<R>(bits_needed as usize, rng)?;
216 let mut v = loop {
218 let mut v: $UT = 0;
220 for i in 0..bits_needed {
222 let bit = self.buf.borrow_mut().remove(0);
223 v |= (bit as $UT) << i;
224 }
225
226 if v < range_size {
227 break v;
228 }
229
230 self.try_ensure::<R>(bits_needed as usize, rng)?;
231 } as $T;
232
233 v += start;
234 Ok(v as $C)
235 }
236 }
237 }
238 };
239}
240
241impl_ranged_integral! { u8, u8, char }
242impl_ranged_integral! { u8, u8 }
243impl_ranged_integral! { u16, u16 }
244impl_ranged_integral! { u32, u32 }
245impl_ranged_integral! { u64, u64 }
246impl_ranged_integral! { usize, usize }
247impl_ranged_integral! { i8, u8 }
248impl_ranged_integral! { i16, u16 }
249impl_ranged_integral! { i32, u32 }
250impl_ranged_integral! { i64, u64 }
251impl_ranged_integral! { isize, usize }
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::{distributions::Ranged, rngs::StandardSeedableRng};
257 use concat_idents::concat_idents;
258 use rand::{thread_rng, SeedableRng};
259 use std::iter::repeat;
260
261 macro_rules! test_sample_impl {
262 ($T:ty, $TN:ident) => {
263 #[test]
264 fn $TN() {
265 const SAMPLES: usize = 8;
266 const BYTES_NEEDED: usize = size_of::<$T>() * SAMPLES;
267
268 let mut rng = StandardSeedableRng::from_seed(vec![0xff; BYTES_NEEDED]);
269 let dist = StandardBuffered::new();
270 (0..SAMPLES).for_each(|_| {
271 let s: $T = rng.sample(&dist);
272 assert_eq!(
273 s,
274 <$T>::from_le_bytes([0xff; size_of::<$T>()]),
275 "Expected true"
276 );
277 });
278 }
279 };
280 }
281
282 #[test]
283 fn test_bool() {
284 let mut rng = StandardSeedableRng::from_seed(vec![0xff]);
285 let dist = StandardBuffered::new();
286 for i in 0..8 {
287 let s: bool = rng.sample(&dist);
288 assert!(s, "Expected true on iteration {}", i);
289 }
290 }
291
292 #[test]
293 fn test_char() {
294 let mut rng = StandardSeedableRng::from_seed(vec![0x41; 8]);
295 let dist = StandardBuffered::new();
296 for i in 0..8 {
297 let s: char = rng.sample(&dist);
298 assert_eq!(s, 'A', "Expected character on iteration {}", i);
299 }
300 }
301
302 test_sample_impl!(u8, test_sample_u8);
303 test_sample_impl!(u16, test_sample_u16);
304 test_sample_impl!(u32, test_sample_u32);
305 test_sample_impl!(u64, test_sample_u64);
306 test_sample_impl!(usize, test_sample_usize);
307 test_sample_impl!(i8, test_sample_i8);
308 test_sample_impl!(i16, test_sample_i16);
309 test_sample_impl!(i32, test_sample_i32);
310 test_sample_impl!(i64, test_sample_i64);
311 test_sample_impl!(isize, test_sample_isize);
312
313 #[test]
314 fn test_sample_range_char() {
315 const RANGE_MAX: char = 'Z';
316 const RANGE_MIN: char = 'A';
317 const SAMPLES: usize = 64;
318 let bytes_needed: usize =
319 ((RANGE_MAX as u8 - RANGE_MIN as u8).ilog2() as usize + 1) * SAMPLES;
320 let mut rng = StandardSeedableRng::from_seed(
321 (0..255)
322 .take(bytes_needed / 2)
323 .chain((0..255).rev().take(bytes_needed / 2))
324 .collect(),
325 );
326 let dist = StandardBuffered::new();
327 (0..SAMPLES * 2).for_each(|_| {
328 let s: char = dist.sample_range(&mut rng, RANGE_MIN..RANGE_MAX);
329 assert!(s >= RANGE_MIN, "Unexpected value");
330 assert!(s < RANGE_MAX, "Unexpected value");
331 });
332 }
333
334 macro_rules! test_sample_rangeimpl {
335 ($T:ty, $TN:ident) => {
336 concat_idents!(test_name = $TN, _one, {
337 #[test]
338 fn test_name() {
339 const RANGE_MAX: $T = 48;
340 const RANGE_MIN: $T = 8;
341 const SAMPLES: usize = 1;
342 let mut rng = StandardSeedableRng::from_seed(
343 (0..255)
344 .take(size_of::<$T>())
345 .chain((0..255).rev().take(size_of::<$T>()))
346 .collect(),
347 );
348 let dist = StandardBuffered::new();
349 (0..SAMPLES).for_each(|_| {
350 let s: $T = dist.sample_range(&mut rng, RANGE_MIN..RANGE_MAX);
351 assert!(s < RANGE_MAX, "Unexpected value");
352 assert!(s >= RANGE_MIN, "Unexpected value");
353 });
354 }
355 });
356
357 #[test]
358 fn $TN() {
359 const RANGE_MAX: $T = 48;
360 const RANGE_MIN: $T = 8;
361 const SAMPLES: usize = 64;
362 let bytes_needed: usize =
363 ((RANGE_MAX - RANGE_MIN).ilog2() as usize + 1) * SAMPLES * 2;
364 let mut rng = StandardSeedableRng::from_seed(
365 repeat((0..255)).flatten().take(bytes_needed).collect(),
366 );
367 let dist = StandardBuffered::new();
368 (0..SAMPLES).for_each(|_| {
369 let s: $T = dist.sample_range(&mut rng, 0..RANGE_MAX);
370 assert!(s < RANGE_MAX, "Unexpected value");
371 });
372 }
373
374 concat_idents!(test_name = $TN, _inclusive, {
375 #[test]
376 fn test_name() {
377 const RANGE_MAX: $T = 48;
378 const RANGE_MIN: $T = 8;
379 const SAMPLES: usize = 64;
380 let bytes_needed: usize =
381 ((RANGE_MAX - RANGE_MIN).ilog2() as usize + 1) * SAMPLES * 2;
382 let mut rng = StandardSeedableRng::from_seed(
383 repeat((0..255)).flatten().take(bytes_needed).collect(),
384 );
385 let dist = StandardBuffered::new();
386 (0..SAMPLES).for_each(|_| {
387 let s: $T = dist.sample_range_inclusive(&mut rng, 0..=RANGE_MAX);
388 assert!(s <= RANGE_MAX, "Unexpected value");
389 });
390 }
391 });
392 };
393 }
394
395 test_sample_rangeimpl!(u8, test_sample_range_u8);
396 test_sample_rangeimpl!(u16, test_sample_range_u16);
397 test_sample_rangeimpl!(u32, test_sample_range_u32);
398 test_sample_rangeimpl!(u64, test_sample_range_u64);
399 test_sample_rangeimpl!(usize, test_sample_range_usize);
400 test_sample_rangeimpl!(i8, test_sample_range_i8);
401 test_sample_rangeimpl!(i16, test_sample_range_i16);
402 test_sample_rangeimpl!(i32, test_sample_range_i32);
403 test_sample_rangeimpl!(i64, test_sample_range_i64);
404 test_sample_rangeimpl!(isize, test_sample_range_isize);
405
406 macro_rules! test_sample_rangeimpl_uniform {
407 ($T:ty, $TN:ident) => {
408 #[test]
409 fn $TN() {
410 fn is_random(data: &[$T], min: $T, max: $T) -> bool {
411 let r: f32 = (max - min) as f32;
412 let mut counts = vec![0; r as usize];
413 for &d in data {
414 counts[(d - min) as usize] += 1;
415 }
416 let n_r = data.len() as f32 / (max - min) as f32;
417 let chi_sq_n: f32 = counts.iter().map(|&c| (c as f32 - n_r).powi(2)).sum();
418 let chi_sq = chi_sq_n / n_r;
419 f32::from((chi_sq - r)).abs() <= 2.0 * f32::from(r).sqrt()
420 }
421
422 let mut trng = thread_rng();
423 const RANGE_MAX: $T = 106;
424 const RANGE_MIN: $T = 0;
425 const SAMPLES: usize = 100_000;
426 for _ in 0..10 {
427 let seed = (0..SAMPLES * 2).map(|_| trng.gen()).collect::<Vec<_>>();
428 let mut rng = StandardSeedableRng::from_seed(seed);
429 let dist = StandardBuffered::new();
430 let sampled = (0..SAMPLES)
431 .map(|_| dist.sample_range(&mut rng, RANGE_MIN..RANGE_MAX))
432 .collect::<Vec<$T>>();
433 if is_random(&sampled, RANGE_MIN, RANGE_MAX) {
434 assert!(true, "Sampled values are random");
435 return;
436 }
437 }
438 assert!(false, "Sampled values were not random in 10 tries");
439 }
440 };
441 }
442
443 test_sample_rangeimpl_uniform!(u8, test_sample_range_uniform_u8);
444 test_sample_rangeimpl_uniform!(u16, test_sample_range_uniform_u16);
445 test_sample_rangeimpl_uniform!(u32, test_sample_range_uniform_u32);
446 test_sample_rangeimpl_uniform!(i8, test_sample_range_uniform_i8);
447 test_sample_rangeimpl_uniform!(i16, test_sample_range_uniform_i16);
448 test_sample_rangeimpl_uniform!(i32, test_sample_range_uniform_i32);
449}