fuzzerang/distributions/
standard_buffered.rs

1use super::{Buffered, Ranged, TryDistribution, TryRanged};
2use anyhow::{anyhow, Result};
3use bitvec::{order::Lsb0, vec::BitVec};
4use core::ops::{Range, RangeInclusive};
5use rand::{prelude::Distribution, Rng};
6use std::{cell::RefCell, io::Read, mem::size_of};
7
8pub trait StandardBufferedSample {}
9pub trait StandardBufferedSampleRange {}
10
11#[derive(Clone, Debug)]
12/// Similar to the [`rand::distributions::Standard`] distribution in that it generates
13/// values in the "expected" way for each type
14pub struct StandardBuffered {
15    buf: RefCell<BitVec<u8, Lsb0>>,
16}
17
18impl StandardBuffered {
19    /// Create a new [`StandardBuffered`] rng that buffers data from the sampling RNG and
20    /// uses it to generate values using the smallest possible number of bits. For example,
21    /// a [`bool`] is 1 bit in size, so only 1 bit will be used to generate it
22    pub fn new() -> Self {
23        Self {
24            buf: RefCell::new(BitVec::new()),
25        }
26    }
27}
28
29impl Buffered for StandardBuffered {
30    /// Ensures enough bits are in the buffer to generate an instance of a type.
31    /// Returns Ok if enough bits were generated.
32    /// Returns Err if not enough bits were generated, and fills the
33    /// buffer zero.
34    fn try_ensure<R: Rng + ?Sized>(&self, bits: usize, rng: &mut R) -> Result<()> {
35        if self.buf.borrow().len() < bits {
36            let bits_needed = bits - self.buf.borrow().len();
37            let bytes_needed = ((bits_needed + (u8::BITS as usize - 1))
38                & (!(u8::BITS as usize - 1)))
39                / u8::BITS as usize;
40            let mut bits = vec![0u8; bytes_needed];
41            rng.try_fill_bytes(&mut bits)?;
42            self.buf.borrow_mut().extend(bits);
43        }
44        Ok(())
45    }
46
47    fn ensure<R: Rng + ?Sized>(&self, bits: usize, rng: &mut R) {
48        self.try_ensure::<R>(bits, rng)
49            .expect("Generator::ensure failed");
50    }
51}
52
53impl Default for StandardBuffered {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl Distribution<bool> for StandardBuffered {
60    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
61        // Special case: bool only requires 1 bit, even though a bool is a full byte in size
62        self.ensure::<R>(1, rng);
63        self.buf.borrow_mut().remove(0)
64    }
65}
66
67impl TryDistribution<bool> for StandardBuffered {
68    fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<bool> {
69        // Special case: bool only requires 1 bit, even though a bool is a full byte in size
70        self.try_ensure::<R>(1, rng)?;
71        Ok(self.buf.borrow_mut().remove(0))
72    }
73}
74
75impl Distribution<char> for StandardBuffered {
76    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> char {
77        self.ensure::<R>(u8::BITS as usize, rng);
78        let mut bytes = vec![0u8; 1];
79        self.buf
80            .borrow_mut()
81            .read_exact(&mut bytes)
82            .expect("Failed to read into buffer");
83        bytes[0] as char
84    }
85}
86
87impl TryDistribution<char> for StandardBuffered {
88    fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<char> {
89        self.try_ensure::<R>(u8::BITS as usize, rng)?;
90        let mut bytes = vec![0u8; 1];
91        self.buf.borrow_mut().read_exact(&mut bytes)?;
92        Ok(bytes[0] as char)
93    }
94}
95
96macro_rules! impl_distribution_integral {
97    ($T:ty) => {
98        impl StandardBufferedSample for $T {}
99
100        impl Distribution<$T> for StandardBuffered {
101            fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $T {
102                self.ensure::<R>(<$T>::BITS as usize, rng);
103                let mut bytes = vec![0u8; size_of::<$T>()];
104                self.buf
105                    .borrow_mut()
106                    .read_exact(&mut bytes)
107                    .expect("Failed to read into buffer");
108                <$T>::from_le_bytes(bytes.as_slice().try_into().expect("Invalid bytes"))
109            }
110        }
111
112        impl TryDistribution<$T> for StandardBuffered {
113            fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<$T> {
114                self.try_ensure::<R>(<$T>::BITS as usize, rng)?;
115                let mut bytes = vec![0u8; size_of::<$T>()];
116                self.buf.borrow_mut().read_exact(&mut bytes)?;
117                bytes
118                    .as_slice()
119                    .try_into()
120                    .map(|a| <$T>::from_le_bytes(a))
121                    .map_err(|e| anyhow!("Invalid bytes: {}", e))
122            }
123        }
124    };
125}
126
127impl_distribution_integral! { u8 }
128impl_distribution_integral! { u16 }
129impl_distribution_integral! { u32 }
130impl_distribution_integral! { u64 }
131impl_distribution_integral! { usize }
132impl_distribution_integral! { i8 }
133impl_distribution_integral! { i16 }
134impl_distribution_integral! { i32 }
135impl_distribution_integral! { i64 }
136impl_distribution_integral! { isize }
137
138macro_rules! impl_ranged_integral {
139    ($T:ty, $UT:ty) => {
140        impl_ranged_integral! { $T, $UT, $T }
141    };
142    ($T:ty, $UT:ty, $C:ty) => {
143        impl StandardBufferedSampleRange for $C {}
144
145        impl Ranged<$C> for StandardBuffered {
146            fn sample_range<R: Rng + ?Sized>(&self, rng: &mut R, range: Range<$C>) -> $C {
147                self.sample_range_inclusive(rng, range.start..=(range.end as $T - 1) as $C)
148            }
149
150            fn sample_range_inclusive<R: Rng + ?Sized>(
151                &self,
152                rng: &mut R,
153                range: RangeInclusive<$C>,
154            ) -> $C {
155                let end = *range.end() as $T;
156                let start = *range.start() as $T;
157                // Get the size of the range
158                let range_size = end.wrapping_sub(start).wrapping_add(1) as $UT;
159
160                if range_size == 0 {
161                    self.sample(rng)
162                } else {
163                    // Get the number of bits needed to represent the maximum value in the range
164                    let bits_needed: u32 = range_size.ilog2() as u32 + 1;
165                    // Ensure we have enough bits in the buffer to generate a value
166                    self.ensure::<R>(bits_needed as usize, rng);
167                    // Find the maximum usable value
168                    let mut v = loop {
169                        // We can use T because we know the range is small enough to fit in T
170                        let mut v: $UT = 0;
171                        // Read bits from the buffer and OR bits into v
172                        for i in 0..bits_needed {
173                            let bit = self.buf.borrow_mut().remove(0);
174                            v |= (bit as $UT) << i;
175                        }
176
177                        if v < range_size {
178                            break v;
179                        }
180
181                        self.ensure::<R>(bits_needed as usize, rng);
182                    } as $T;
183
184                    v += start;
185                    v as $C
186                }
187            }
188        }
189
190        impl TryRanged<$C> for StandardBuffered {
191            fn try_sample_range<R: Rng + ?Sized>(
192                &self,
193                rng: &mut R,
194                range: Range<$C>,
195            ) -> Result<$C> {
196                self.try_sample_range_inclusive(rng, range.start..=(range.end as $T - 1) as $C)
197            }
198
199            fn try_sample_range_inclusive<R: Rng + ?Sized>(
200                &self,
201                rng: &mut R,
202                range: RangeInclusive<$C>,
203            ) -> Result<$C> {
204                let end = *range.end() as $T;
205                let start = *range.start() as $T;
206                // Get the size of the range
207                let range_size = end.wrapping_sub(start).wrapping_add(1) as $UT;
208
209                if range_size == 0 {
210                    self.try_sample(rng)
211                } else {
212                    // Get the number of bits needed to represent the maximum value in the range
213                    let bits_needed: u32 = range_size.ilog2() as u32 + 1;
214                    // Ensure we have enough bits in the buffer to generate a value
215                    self.try_ensure::<R>(bits_needed as usize, rng)?;
216                    // Find the maximum usable value
217                    let mut v = loop {
218                        // We can use T because we know the range is small enough to fit in T
219                        let mut v: $UT = 0;
220                        // Read bits from the buffer and OR bits into v
221                        for i in 0..bits_needed {
222                            let bit = self.buf.borrow_mut().remove(0);
223                            v |= (bit as $UT) << i;
224                        }
225
226                        if v < range_size {
227                            break v;
228                        }
229
230                        self.try_ensure::<R>(bits_needed as usize, rng)?;
231                    } as $T;
232
233                    v += start;
234                    Ok(v as $C)
235                }
236            }
237        }
238    };
239}
240
241impl_ranged_integral! { u8, u8, char }
242impl_ranged_integral! { u8, u8 }
243impl_ranged_integral! { u16, u16 }
244impl_ranged_integral! { u32, u32 }
245impl_ranged_integral! { u64, u64 }
246impl_ranged_integral! { usize, usize }
247impl_ranged_integral! { i8, u8 }
248impl_ranged_integral! { i16, u16 }
249impl_ranged_integral! { i32, u32 }
250impl_ranged_integral! { i64, u64 }
251impl_ranged_integral! { isize, usize }
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::{distributions::Ranged, rngs::StandardSeedableRng};
257    use concat_idents::concat_idents;
258    use rand::{thread_rng, SeedableRng};
259    use std::iter::repeat;
260
261    macro_rules! test_sample_impl {
262        ($T:ty, $TN:ident) => {
263            #[test]
264            fn $TN() {
265                const SAMPLES: usize = 8;
266                const BYTES_NEEDED: usize = size_of::<$T>() * SAMPLES;
267
268                let mut rng = StandardSeedableRng::from_seed(vec![0xff; BYTES_NEEDED]);
269                let dist = StandardBuffered::new();
270                (0..SAMPLES).for_each(|_| {
271                    let s: $T = rng.sample(&dist);
272                    assert_eq!(
273                        s,
274                        <$T>::from_le_bytes([0xff; size_of::<$T>()]),
275                        "Expected true"
276                    );
277                });
278            }
279        };
280    }
281
282    #[test]
283    fn test_bool() {
284        let mut rng = StandardSeedableRng::from_seed(vec![0xff]);
285        let dist = StandardBuffered::new();
286        for i in 0..8 {
287            let s: bool = rng.sample(&dist);
288            assert!(s, "Expected true on iteration {}", i);
289        }
290    }
291
292    #[test]
293    fn test_char() {
294        let mut rng = StandardSeedableRng::from_seed(vec![0x41; 8]);
295        let dist = StandardBuffered::new();
296        for i in 0..8 {
297            let s: char = rng.sample(&dist);
298            assert_eq!(s, 'A', "Expected character on iteration {}", i);
299        }
300    }
301
302    test_sample_impl!(u8, test_sample_u8);
303    test_sample_impl!(u16, test_sample_u16);
304    test_sample_impl!(u32, test_sample_u32);
305    test_sample_impl!(u64, test_sample_u64);
306    test_sample_impl!(usize, test_sample_usize);
307    test_sample_impl!(i8, test_sample_i8);
308    test_sample_impl!(i16, test_sample_i16);
309    test_sample_impl!(i32, test_sample_i32);
310    test_sample_impl!(i64, test_sample_i64);
311    test_sample_impl!(isize, test_sample_isize);
312
313    #[test]
314    fn test_sample_range_char() {
315        const RANGE_MAX: char = 'Z';
316        const RANGE_MIN: char = 'A';
317        const SAMPLES: usize = 64;
318        let bytes_needed: usize =
319            ((RANGE_MAX as u8 - RANGE_MIN as u8).ilog2() as usize + 1) * SAMPLES;
320        let mut rng = StandardSeedableRng::from_seed(
321            (0..255)
322                .take(bytes_needed / 2)
323                .chain((0..255).rev().take(bytes_needed / 2))
324                .collect(),
325        );
326        let dist = StandardBuffered::new();
327        (0..SAMPLES * 2).for_each(|_| {
328            let s: char = dist.sample_range(&mut rng, RANGE_MIN..RANGE_MAX);
329            assert!(s >= RANGE_MIN, "Unexpected value");
330            assert!(s < RANGE_MAX, "Unexpected value");
331        });
332    }
333
334    macro_rules! test_sample_rangeimpl {
335        ($T:ty, $TN:ident) => {
336            concat_idents!(test_name = $TN, _one, {
337                #[test]
338                fn test_name() {
339                    const RANGE_MAX: $T = 48;
340                    const RANGE_MIN: $T = 8;
341                    const SAMPLES: usize = 1;
342                    let mut rng = StandardSeedableRng::from_seed(
343                        (0..255)
344                            .take(size_of::<$T>())
345                            .chain((0..255).rev().take(size_of::<$T>()))
346                            .collect(),
347                    );
348                    let dist = StandardBuffered::new();
349                    (0..SAMPLES).for_each(|_| {
350                        let s: $T = dist.sample_range(&mut rng, RANGE_MIN..RANGE_MAX);
351                        assert!(s < RANGE_MAX, "Unexpected value");
352                        assert!(s >= RANGE_MIN, "Unexpected value");
353                    });
354                }
355            });
356
357            #[test]
358            fn $TN() {
359                const RANGE_MAX: $T = 48;
360                const RANGE_MIN: $T = 8;
361                const SAMPLES: usize = 64;
362                let bytes_needed: usize =
363                    ((RANGE_MAX - RANGE_MIN).ilog2() as usize + 1) * SAMPLES * 2;
364                let mut rng = StandardSeedableRng::from_seed(
365                    repeat((0..255)).flatten().take(bytes_needed).collect(),
366                );
367                let dist = StandardBuffered::new();
368                (0..SAMPLES).for_each(|_| {
369                    let s: $T = dist.sample_range(&mut rng, 0..RANGE_MAX);
370                    assert!(s < RANGE_MAX, "Unexpected value");
371                });
372            }
373
374            concat_idents!(test_name = $TN, _inclusive, {
375                #[test]
376                fn test_name() {
377                    const RANGE_MAX: $T = 48;
378                    const RANGE_MIN: $T = 8;
379                    const SAMPLES: usize = 64;
380                    let bytes_needed: usize =
381                        ((RANGE_MAX - RANGE_MIN).ilog2() as usize + 1) * SAMPLES * 2;
382                    let mut rng = StandardSeedableRng::from_seed(
383                        repeat((0..255)).flatten().take(bytes_needed).collect(),
384                    );
385                    let dist = StandardBuffered::new();
386                    (0..SAMPLES).for_each(|_| {
387                        let s: $T = dist.sample_range_inclusive(&mut rng, 0..=RANGE_MAX);
388                        assert!(s <= RANGE_MAX, "Unexpected value");
389                    });
390                }
391            });
392        };
393    }
394
395    test_sample_rangeimpl!(u8, test_sample_range_u8);
396    test_sample_rangeimpl!(u16, test_sample_range_u16);
397    test_sample_rangeimpl!(u32, test_sample_range_u32);
398    test_sample_rangeimpl!(u64, test_sample_range_u64);
399    test_sample_rangeimpl!(usize, test_sample_range_usize);
400    test_sample_rangeimpl!(i8, test_sample_range_i8);
401    test_sample_rangeimpl!(i16, test_sample_range_i16);
402    test_sample_rangeimpl!(i32, test_sample_range_i32);
403    test_sample_rangeimpl!(i64, test_sample_range_i64);
404    test_sample_rangeimpl!(isize, test_sample_range_isize);
405
406    macro_rules! test_sample_rangeimpl_uniform {
407        ($T:ty, $TN:ident) => {
408            #[test]
409            fn $TN() {
410                fn is_random(data: &[$T], min: $T, max: $T) -> bool {
411                    let r: f32 = (max - min) as f32;
412                    let mut counts = vec![0; r as usize];
413                    for &d in data {
414                        counts[(d - min) as usize] += 1;
415                    }
416                    let n_r = data.len() as f32 / (max - min) as f32;
417                    let chi_sq_n: f32 = counts.iter().map(|&c| (c as f32 - n_r).powi(2)).sum();
418                    let chi_sq = chi_sq_n / n_r;
419                    f32::from((chi_sq - r)).abs() <= 2.0 * f32::from(r).sqrt()
420                }
421
422                let mut trng = thread_rng();
423                const RANGE_MAX: $T = 106;
424                const RANGE_MIN: $T = 0;
425                const SAMPLES: usize = 100_000;
426                for _ in 0..10 {
427                    let seed = (0..SAMPLES * 2).map(|_| trng.gen()).collect::<Vec<_>>();
428                    let mut rng = StandardSeedableRng::from_seed(seed);
429                    let dist = StandardBuffered::new();
430                    let sampled = (0..SAMPLES)
431                        .map(|_| dist.sample_range(&mut rng, RANGE_MIN..RANGE_MAX))
432                        .collect::<Vec<$T>>();
433                    if is_random(&sampled, RANGE_MIN, RANGE_MAX) {
434                        assert!(true, "Sampled values are random");
435                        return;
436                    }
437                }
438                assert!(false, "Sampled values were not random in 10 tries");
439            }
440        };
441    }
442
443    test_sample_rangeimpl_uniform!(u8, test_sample_range_uniform_u8);
444    test_sample_rangeimpl_uniform!(u16, test_sample_range_uniform_u16);
445    test_sample_rangeimpl_uniform!(u32, test_sample_range_uniform_u32);
446    test_sample_rangeimpl_uniform!(i8, test_sample_range_uniform_i8);
447    test_sample_rangeimpl_uniform!(i16, test_sample_range_uniform_i16);
448    test_sample_rangeimpl_uniform!(i32, test_sample_range_uniform_i32);
449}