use std::cmp;
use std::env;
use std::fmt::Debug;
use std::panic;
use sample_std::{Random, Sample};
use crate::tester::Status::{Discard, Fail, Pass};
use crate::{error, info, trace};
pub struct SampleTest {
tests: u64,
max_tests: u64,
min_tests_passed: u64,
gen: Random,
}
fn st_tests() -> u64 {
let default = 100;
match env::var("SAMPLE_TEST_TESTS") {
Ok(val) => val.parse().unwrap_or(default),
Err(_) => default,
}
}
fn st_max_tests() -> u64 {
let default = 10_000;
match env::var("SAMPLE_TEST_MAX_TESTS") {
Ok(val) => val.parse().unwrap_or(default),
Err(_) => default,
}
}
fn st_min_tests_passed() -> u64 {
let default = 0;
match env::var("SAMPLE_TEST_MIN_TESTS_PASSED") {
Ok(val) => val.parse().unwrap_or(default),
Err(_) => default,
}
}
impl SampleTest {
pub fn new() -> SampleTest {
let gen = Random::new();
let tests = st_tests();
let max_tests = cmp::max(tests, st_max_tests());
let min_tests_passed = st_min_tests_passed();
SampleTest {
tests,
max_tests,
min_tests_passed,
gen,
}
}
pub fn tests(mut self, tests: u64) -> SampleTest {
self.tests = tests;
self
}
pub fn max_tests(mut self, max_tests: u64) -> SampleTest {
self.max_tests = max_tests;
self
}
pub fn min_tests_passed(mut self, min_tests_passed: u64) -> SampleTest {
self.min_tests_passed = min_tests_passed;
self
}
pub fn sample_test_count<S, A>(&mut self, mut s: S, f: A) -> Result<u64, TestResult>
where
A: Testable<S>,
S: Sample,
S::Output: Clone + Debug,
{
let mut n_tests_passed = 0;
for _ in 0..self.max_tests {
if n_tests_passed >= self.tests {
break;
}
match f.test_once(&mut s, &mut self.gen) {
TestResult { status: Pass, .. } => n_tests_passed += 1,
TestResult {
status: Discard, ..
} => continue,
r @ TestResult { status: Fail, .. } => return Err(r),
}
}
Ok(n_tests_passed)
}
pub fn sample_test<S, A>(&mut self, s: S, f: A)
where
A: Testable<S>,
S: Sample,
S::Output: Clone + Debug,
{
let _ = crate::env_logger_init();
let n_tests_passed = match self.sample_test_count(s, f) {
Ok(n_tests_passed) => n_tests_passed,
Err(result) => panic!("{}", result.failed_msg()),
};
if n_tests_passed >= self.min_tests_passed {
info!("(Passed {} SampleTest tests.)", n_tests_passed)
} else {
panic!(
"(Unable to generate enough tests, {} not discarded.)",
n_tests_passed
)
}
}
}
pub fn sample_test<S, A>(s: S, f: A)
where
A: Testable<S>,
S: Sample,
S::Output: Clone + Debug,
{
SampleTest::new().sample_test(s, f)
}
#[derive(Clone, Debug)]
pub struct TestResult {
status: Status,
arguments: String,
err: Option<String>,
}
#[derive(Clone, Debug)]
enum Status {
Pass,
Fail,
Discard,
}
impl TestResult {
pub fn passed() -> TestResult {
TestResult::from_bool(true)
}
pub fn failed() -> TestResult {
TestResult::from_bool(false)
}
pub fn error<S: Into<String>>(msg: S) -> TestResult {
let mut r = TestResult::from_bool(false);
r.err = Some(msg.into());
r
}
pub fn discard() -> TestResult {
TestResult {
status: Discard,
arguments: String::from(""),
err: None,
}
}
pub fn from_bool(b: bool) -> TestResult {
TestResult {
status: if b { Pass } else { Fail },
arguments: String::from(""),
err: None,
}
}
pub fn must_fail<T, F>(f: F) -> TestResult
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
let f = panic::AssertUnwindSafe(f);
TestResult::from_bool(panic::catch_unwind(f).is_err())
}
pub fn is_success(&self) -> bool {
match self.status {
Pass => true,
Fail | Discard => false,
}
}
pub fn is_failure(&self) -> bool {
match self.status {
Fail => true,
Pass | Discard => false,
}
}
pub fn is_error(&self) -> bool {
self.is_failure() && self.err.is_some()
}
pub fn arguments(&self) -> &str {
&self.arguments
}
fn failed_msg(&self) -> String {
match self.err {
None => format!("[sample_test] TEST FAILED. Arguments: ({})", self.arguments),
Some(ref err) => format!(
"[sample_test] TEST FAILED (runtime error). \
Arguments: ({})\nError: {}",
self.arguments, err
),
}
}
}
pub trait Testable<S>: 'static
where
S: Sample,
{
fn result(&self, v: S::Output) -> TestResult;
fn test_once(&self, s: &mut S, rng: &mut Random) -> TestResult
where
S::Output: Clone + Debug,
{
let v = Sample::generate(s, rng);
let r = self.result(v.clone());
match r.status {
Pass | Discard => r,
Fail => {
error!("{:?}", r);
self.shrink(s, r, v)
}
}
}
fn shrink(&self, s: &S, r: TestResult, v: S::Output) -> TestResult
where
S::Output: Clone + Debug,
{
trace!("shrinking {:?}", v);
let mut result = r;
let mut it = s.shrink(v);
let iterations = 10_000_000;
for _ in 0..iterations {
let sv = it.next();
if let Some(sv) = sv {
let r_new = self.result(sv.clone());
if r_new.is_failure() {
trace!("shrinking {:?}", sv);
result = r_new;
it = s.shrink(sv);
}
} else {
return result;
}
}
trace!(
"halting shrinkage after {} iterations with: {:?}",
iterations,
result
);
result
}
}
impl From<bool> for TestResult {
fn from(value: bool) -> TestResult {
TestResult::from_bool(value)
}
}
impl From<()> for TestResult {
fn from(_: ()) -> TestResult {
TestResult::passed()
}
}
impl<A, E> From<Result<A, E>> for TestResult
where
TestResult: From<A>,
E: Debug + 'static,
{
fn from(value: Result<A, E>) -> TestResult {
match value {
Ok(r) => r.into(),
Err(err) => TestResult::error(format!("{:?}", err)),
}
}
}
macro_rules! testable_fn {
($($name: ident),*) => {
impl<T: 'static, S, $($name),*> Testable<S> for fn($($name),*) -> T
where
TestResult: From<T>,
S: Sample<Output=($($name),*,)>,
($($name),*,): Clone,
$($name: Debug + 'static),*
{
#[allow(non_snake_case)]
fn result(&self, v: S::Output) -> TestResult {
let ( $($name,)* ) = v.clone();
let f: fn($($name),*) -> T = *self;
let mut r = <TestResult as From<Result<T, String>>>::from(safe(move || {f($($name),*)}));
{
let ( $(ref $name,)* ) = v;
r.arguments = format!("{:?}", &($($name),*));
}
r
}
}}}
testable_fn!(A);
testable_fn!(A, B);
testable_fn!(A, B, C);
testable_fn!(A, B, C, D);
testable_fn!(A, B, C, D, E);
testable_fn!(A, B, C, D, E, F);
testable_fn!(A, B, C, D, E, F, G);
testable_fn!(A, B, C, D, E, F, G, H);
fn safe<T, F>(fun: F) -> Result<T, String>
where
F: FnOnce() -> T,
F: 'static,
T: 'static,
{
panic::catch_unwind(panic::AssertUnwindSafe(fun)).map_err(|any_err| {
if let Some(&s) = any_err.downcast_ref::<&str>() {
s.to_owned()
} else if let Some(s) = any_err.downcast_ref::<String>() {
s.to_owned()
} else {
"UNABLE TO SHOW RESULT OF PANIC.".to_owned()
}
})
}