use super::*;
use crate::log;
use crate::mutators as m;
use std::fmt::{self, Debug};
use std::panic;
use std::prelude::v1::*;
pub type CheckResult<T> = std::result::Result<(), CheckError<T>>;
pub enum CheckError<T> {
Failed(CheckFailure<T>),
EmptyCorpus,
MutatorError(Error),
}
impl<T> From<Error> for CheckError<T> {
fn from(v: Error) -> Self {
Self::MutatorError(v)
}
}
impl<T> From<CheckFailure<T>> for CheckError<T> {
fn from(v: CheckFailure<T>) -> Self {
Self::Failed(v)
}
}
impl<T> std::error::Error for CheckError<T>
where
T: 'static + Debug,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CheckError::MutatorError(e) => Some(e),
CheckError::Failed(f) => Some(f),
CheckError::EmptyCorpus => None,
}
}
}
impl<T> fmt::Display for CheckError<T>
where
T: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CheckError::Failed(e) => write!(f, "check failure: {e}"),
CheckError::EmptyCorpus => write!(f, "cannot check an empty corpus"),
CheckError::MutatorError(e) => write!(f, "mutator error: {e}"),
}
}
}
impl<T> Debug for CheckError<T>
where
T: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl<T> CheckError<T> {
#[track_caller]
pub fn unwrap_failed(self) -> CheckFailure<T> {
match self {
CheckError::Failed(f) => f,
_ => panic!("CheckError::unwrap_failed called on non-failed CheckError"),
}
}
#[track_caller]
pub fn unwrap_mutator_error(self) -> Error {
match self {
CheckError::MutatorError(e) => e,
_ => panic!("CheckError::unwrap_error called on non-error CheckError"),
}
}
}
#[non_exhaustive]
pub struct CheckFailure<T> {
pub value: T,
pub message: String,
}
impl<T> fmt::Display for CheckFailure<T>
where
T: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let CheckFailure { value, message } = self;
write!(f, "failed on input {value:?}: {message}")
}
}
impl<T> Debug for CheckFailure<T>
where
T: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl<T> std::error::Error for CheckFailure<T> where T: Debug {}
#[derive(Debug)]
pub struct Check {
iters: usize,
shrink_iters: usize,
}
impl Default for Check {
fn default() -> Check {
Check::new()
}
}
impl Check {
pub fn new() -> Check {
Check {
iters: 1000,
shrink_iters: 1000,
}
}
pub fn iters(&mut self, iters: usize) -> &mut Check {
self.iters = iters;
self
}
pub fn shrink_iters(&mut self, shrink_iters: usize) -> &mut Check {
self.shrink_iters = shrink_iters;
self
}
pub fn run<T, S>(
&self,
property: impl FnMut(&T) -> std::result::Result<(), S>,
) -> CheckResult<T>
where
T: Clone + Debug + Default + DefaultMutate,
S: ToString,
{
self.run_with(m::default::<T>(), [T::default()], property)
}
pub fn run_with<M, T, S>(
&self,
mut mutator: M,
initial_corpus: impl IntoIterator<Item = T>,
mut property: impl FnMut(&T) -> std::result::Result<(), S>,
) -> CheckResult<T>
where
M: Mutate<T>,
T: Clone + Debug,
S: ToString,
{
let mut corpus = initial_corpus.into_iter().collect::<Vec<_>>();
if corpus.is_empty() {
return Err(CheckError::EmptyCorpus);
}
for value in &corpus {
if let Err(msg) = Self::check_one(value, &mut property) {
return self.shrink(mutator, value.clone(), property, msg);
}
}
let mut session = Session::new();
for _ in 0..self.iters {
let index = session.context.rng().gen_index(corpus.len()).unwrap();
match session.mutate_with(&mut mutator, &mut corpus[index]) {
Ok(()) => {}
Err(e) if e.is_exhausted() => {
corpus.swap_remove(index);
if corpus.is_empty() {
return Ok(());
}
}
Err(e) => return Err(e.into()),
}
if let Err(msg) = Self::check_one(&corpus[index], &mut property) {
return self.shrink(mutator, corpus[index].clone(), property, msg);
}
}
Ok(())
}
fn check_one<T, S>(
value: &T,
mut property: impl FnMut(&T) -> std::result::Result<(), S>,
) -> std::result::Result<(), String>
where
T: Debug,
S: ToString,
{
match panic::catch_unwind(panic::AssertUnwindSafe(|| property(value))) {
Ok(Ok(())) => Ok(()),
Ok(Err(msg)) => Err(msg.to_string()),
Err(_) => Err("<panicked>".into()),
}
}
fn shrink<M, T, S>(
&self,
mut mutator: M,
mut value: T,
mut property: impl FnMut(&T) -> std::result::Result<(), S>,
mut message: String,
) -> CheckResult<T>
where
M: Mutate<T>,
T: Clone + Debug,
S: ToString,
{
log::warn!("failed on input {value:?}: {message}");
if self.shrink_iters == 0 {
return Err(CheckFailure { value, message }.into());
}
log::debug!("shrinking for {} iters...", self.shrink_iters);
let mut session = Session::new().shrink(true);
for _ in 0..self.shrink_iters {
let mut candidate = value.clone();
match session.mutate_with(&mut mutator, &mut candidate) {
Err(e) if e.is_exhausted() => break,
Err(e) => {
log::info!("got mutator error during shrinking, ignoring: {e}");
continue;
}
Ok(()) => {}
}
match Self::check_one(&candidate, &mut property) {
Ok(()) => {
}
Err(msg) => {
message = msg;
log::debug!("got failure for shrunken input {value:?}: {message}");
value = candidate;
}
}
}
log::info!("shrunk failing input down to {value:?}");
Err(CheckFailure { value, message }.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::result::Result;
fn check() -> Check {
let _ = env_logger::builder().is_test(true).try_init();
Check::new()
}
#[test]
fn check_run_with_okay() {
check()
.run_with(m::just(true), [true], |b: &bool| {
if *b {
Ok(())
} else {
Err("expected true!")
}
})
.unwrap();
}
#[test]
fn check_run_with_fail() {
let failure = check()
.run_with(m::bool(), [true], |b: &bool| {
if *b {
Ok(())
} else {
Err("expected true!")
}
})
.unwrap_err()
.unwrap_failed();
assert_eq!(failure.value, false);
assert_eq!(failure.message, "expected true!");
}
#[test]
fn check_run_with_empty_corpus() {
let result = check().run_with(m::bool(), [], |b: &bool| {
if *b {
Ok(())
} else {
Err("expected true!")
}
});
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), CheckError::EmptyCorpus));
}
#[test]
fn check_run_with_fail_and_shrink() {
let failure = check()
.shrink_iters(1000)
.run_with(m::u8(), [u8::MAX], |x: &u8| {
if *x < 10 {
Ok(())
} else {
Err("expected < 10")
}
})
.unwrap_err()
.unwrap_failed();
assert_eq!(failure.value, 10);
assert_eq!(failure.message, "expected < 10");
}
#[test]
fn check_run_with_fail_on_panic() {
let result = check().run_with(m::bool(), [true], |_: &bool| -> Result<(), String> {
panic!("oh no!")
});
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), CheckError::Failed(_)));
}
#[test]
fn check_run_with_fail_on_panic_and_shrink() {
let failure = check()
.shrink_iters(1000)
.run_with(m::u8(), [u8::MAX], |x: &u8| -> Result<(), String> {
assert!(*x < 10);
Ok(())
})
.unwrap_err()
.unwrap_failed();
assert_eq!(failure.value, 10);
assert_eq!(failure.message, "<panicked>");
}
}