1use arbitrary::{Arbitrary, Unstructured};
2use rand::prelude::*;
3use rand::rngs::StdRng;
4use std::panic::{self, AssertUnwindSafe};
5
6use crate::{Shrink, Shrinker};
7
8const MAX_PASSES: u64 = 100;
10
11const INITIAL_VEC_LEN: usize = 1024;
13
14#[derive(Debug)]
16pub struct HeckCheck {
17 bytes: Vec<u8>,
18 max_count: u64,
19 seed: u64,
20 rng: StdRng,
21}
22
23impl Default for HeckCheck {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl HeckCheck {
30 pub fn new() -> Self {
32 let seed = rand::random();
33 Self::from_seed(seed)
34 }
35
36 pub fn from_seed(seed: u64) -> Self {
38 let rng = StdRng::seed_from_u64(seed);
39 Self {
40 seed,
41 rng,
42 bytes: vec![0u8; INITIAL_VEC_LEN],
43 max_count: MAX_PASSES,
44 }
45 }
46
47 pub fn check<A, F>(&mut self, f: F)
49 where
50 A: for<'b> Arbitrary<'b>,
51 F: FnMut(A) -> arbitrary::Result<()>,
52 {
53 self.check_with_shrinker::<_, _, Shrinker>(f)
54 }
55
56 pub fn check_with_shrinker<A, F, S>(&mut self, mut f: F)
58 where
59 A: for<'b> Arbitrary<'b>,
60 F: FnMut(A) -> arbitrary::Result<()>,
61 S: Shrink,
62 {
63 if self.bytes.len() < A::size_hint(0).0 {
65 self.grow_vec(Some(A::size_hint(0).0));
66 }
67
68 let hook = panic::take_hook();
69 panic::set_hook(Box::new(|_| {}));
70
71 for _ in 0..self.max_count {
72 self.rng.fill_bytes(&mut self.bytes);
73 let mut u = Unstructured::new(&self.bytes);
74 let instance = A::arbitrary(&mut u).unwrap();
75
76 let mut more_data = false;
78
79 let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
82 if let Err(arbitrary::Error::NotEnoughData) = f(instance) {
83 more_data = true;
84 }
85 }));
86
87 let u_len = u.len();
88 if more_data {
89 self.grow_vec(None);
90 }
91
92 if res.is_err() {
94 let upper = self.bytes.len() - u_len;
95 let mut shrinker = S::shrink(self.bytes[0..upper].to_owned());
96 loop {
97 let mut u = Unstructured::new(shrinker.next());
98 let instance = A::arbitrary(&mut u).unwrap();
99
100 let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
101 f(instance).unwrap();
102 }));
103 if let Some(case) = shrinker.report(res.into()) {
104 panic::set_hook(hook);
105 let sequence = base64::encode(case);
106 match sequence.len() {
107 0 => panic!("The failing base64 sequence is: ``. Pass an empty string to `heckcheck::replay` to create a permanent reproduction."),
108 _ => panic!("The failing base64 sequence is: `{}`. Pass this to `heckcheck::replay` to create a permanent reproduction.", sequence),
109 }
110 }
111 }
112 }
113 }
114 }
115
116 fn grow_vec(&mut self, target: Option<usize>) {
117 match target {
118 Some(target) => {
119 if target.checked_sub(self.bytes.len()).is_some() {
120 self.bytes.resize_with(target, || 0);
121 }
122 }
123 None => self.bytes.resize_with(self.bytes.len() * 2, || 0),
124 };
125 }
126
127 pub fn seed(&self) -> u64 {
129 self.seed
130 }
131}