heckcheck/
checker.rs

1use arbitrary::{Arbitrary, Unstructured};
2use rand::prelude::*;
3use rand::rngs::StdRng;
4use std::panic::{self, AssertUnwindSafe};
5
6use crate::{Shrink, Shrinker};
7
8/// The base number of iterations performed to find an error using `heckcheck`.
9const MAX_PASSES: u64 = 100;
10
11/// The amount of data we initially allocate.
12const INITIAL_VEC_LEN: usize = 1024;
13
14/// The main test checker.
15#[derive(Debug)]
16pub struct HeckCheck {
17    bytes: Vec<u8>,
18    max_count: u64,
19    seed: u64,
20    rng: StdRng,
21}
22
23impl Default for HeckCheck {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl HeckCheck {
30    /// Create a new instance.
31    pub fn new() -> Self {
32        let seed = rand::random();
33        Self::from_seed(seed)
34    }
35
36    /// Create a new instance from a seed.
37    pub fn from_seed(seed: u64) -> Self {
38        let rng = StdRng::seed_from_u64(seed);
39        Self {
40            seed,
41            rng,
42            bytes: vec![0u8; INITIAL_VEC_LEN],
43            max_count: MAX_PASSES,
44        }
45    }
46
47    /// Check the target.
48    pub fn check<A, F>(&mut self, f: F)
49    where
50        A: for<'b> Arbitrary<'b>,
51        F: FnMut(A) -> arbitrary::Result<()>,
52    {
53        self.check_with_shrinker::<_, _, Shrinker>(f)
54    }
55
56    /// Check the target with the specified shrinker.
57    pub fn check_with_shrinker<A, F, S>(&mut self, mut f: F)
58    where
59        A: for<'b> Arbitrary<'b>,
60        F: FnMut(A) -> arbitrary::Result<()>,
61        S: Shrink,
62    {
63        // Make sure we have enough bytes in our buffer before we start testing.
64        if self.bytes.len() < A::size_hint(0).0 {
65            self.grow_vec(Some(A::size_hint(0).0));
66        }
67
68        let hook = panic::take_hook();
69        panic::set_hook(Box::new(|_| {}));
70
71        for _ in 0..self.max_count {
72            self.rng.fill_bytes(&mut self.bytes);
73            let mut u = Unstructured::new(&self.bytes);
74            let instance = A::arbitrary(&mut u).unwrap();
75
76            // Track whether we should allocate more data for a future loop.
77            let mut more_data = false;
78
79            // Call the closure. Handle the return type from `Arbitrary`, and
80            // handle possible panics from the closure.
81            let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
82                if let Err(arbitrary::Error::NotEnoughData) = f(instance) {
83                    more_data = true;
84                }
85            }));
86
87            let u_len = u.len();
88            if more_data {
89                self.grow_vec(None);
90            }
91
92            // If the test panicked we start reducing the test case.
93            if res.is_err() {
94                let upper = self.bytes.len() - u_len;
95                let mut shrinker = S::shrink(self.bytes[0..upper].to_owned());
96                loop {
97                    let mut u = Unstructured::new(shrinker.next());
98                    let instance = A::arbitrary(&mut u).unwrap();
99
100                    let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
101                        f(instance).unwrap();
102                    }));
103                    if let Some(case) = shrinker.report(res.into()) {
104                        panic::set_hook(hook);
105                        let sequence = base64::encode(case);
106                        match sequence.len() {
107                            0 => panic!("The failing base64 sequence is: ``. Pass an empty string to `heckcheck::replay` to create a permanent reproduction."),
108                            _ => panic!("The failing base64 sequence is: `{}`. Pass this to `heckcheck::replay` to create a permanent reproduction.", sequence),
109                        }
110                    }
111                }
112            }
113        }
114    }
115
116    fn grow_vec(&mut self, target: Option<usize>) {
117        match target {
118            Some(target) => {
119                if target.checked_sub(self.bytes.len()).is_some() {
120                    self.bytes.resize_with(target, || 0);
121                }
122            }
123            None => self.bytes.resize_with(self.bytes.len() * 2, || 0),
124        };
125    }
126
127    /// Access the value of `seed`.
128    pub fn seed(&self) -> u64 {
129        self.seed
130    }
131}