use std::{
collections::HashSet,
panic,
sync::{Mutex, OnceLock},
thread,
time::Duration,
};
mod macros;
fn ignore_threads() -> &'static Mutex<HashSet<String>> {
static INSTANCE: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
INSTANCE.get_or_init(|| {
let panic_hook = panic::take_hook();
panic::set_hook(Box::new(move |panic_info| {
let ignore_threads = ignore_threads().lock().expect("lock ignore threads");
if let Some(thread_name) = thread::current().name() {
if !ignore_threads.contains(thread_name) {
panic_hook(panic_info);
}
} else {
panic_hook(panic_info);
}
}));
Mutex::new(HashSet::new())
})
}
struct IgnoreGuard;
impl IgnoreGuard {
fn new() -> IgnoreGuard {
if let Some(thread_name) = thread::current().name() {
ignore_threads()
.lock()
.expect("lock ignore threads")
.insert(thread_name.to_string());
}
IgnoreGuard
}
}
impl Drop for IgnoreGuard {
fn drop(&mut self) {
if let Some(thread_name) = thread::current().name() {
ignore_threads()
.lock()
.expect("lock ignore threads")
.remove(thread_name);
}
}
}
pub fn that<A, R>(repetitions: usize, delay: Duration, assert: A) -> R
where
A: Fn() -> R,
{
let ignore_guard = IgnoreGuard::new();
for _ in 0..(repetitions - 1) {
let result = panic::catch_unwind(panic::AssertUnwindSafe(&assert));
if let Ok(value) = result {
return value;
}
thread::sleep(delay);
}
drop(ignore_guard);
assert()
}
#[cfg(feature = "async")]
pub async fn that_async<A, F, R>(repetitions: usize, delay: Duration, assert: A) -> R
where
A: Fn() -> F,
F: std::future::Future<Output = R>,
{
use futures::future::FutureExt;
let ignore_guard = IgnoreGuard::new();
for _ in 0..(repetitions - 1) {
let result = panic::AssertUnwindSafe(assert()).catch_unwind().await;
if let Ok(value) = result {
return value;
}
tokio::time::sleep(delay).await;
}
drop(ignore_guard);
assert().await
}
pub fn with_catch<A, C, R>(
repetitions: usize,
delay: Duration,
repetitions_catch: usize,
catch: C,
assert: A,
) -> R
where
A: Fn() -> R,
C: FnOnce(),
{
let ignore_guard = IgnoreGuard::new();
for _ in 0..repetitions_catch {
let result = panic::catch_unwind(panic::AssertUnwindSafe(&assert));
if let Ok(value) = result {
return value;
}
thread::sleep(delay);
}
let thread_name = thread::current()
.name()
.unwrap_or("<unnamed thread>")
.to_string();
println!("{}: executing repeated-assert catch block", thread_name);
catch();
for _ in repetitions_catch..(repetitions - 1) {
let result = panic::catch_unwind(panic::AssertUnwindSafe(&assert));
if let Ok(value) = result {
return value;
}
thread::sleep(delay);
}
drop(ignore_guard);
assert()
}
#[cfg(feature = "async")]
pub async fn with_catch_async<A, F, C, G, R>(
repetitions: usize,
delay: Duration,
repetitions_catch: usize,
catch: C,
assert: A,
) -> R
where
A: Fn() -> F,
F: std::future::Future<Output = R>,
C: FnOnce() -> G,
G: std::future::Future<Output = ()>,
{
use futures::future::FutureExt;
let ignore_guard = IgnoreGuard::new();
for _ in 0..repetitions_catch {
let result = panic::AssertUnwindSafe(assert()).catch_unwind().await;
if let Ok(value) = result {
return value;
}
tokio::time::sleep(delay).await;
}
let thread_name = thread::current()
.name()
.unwrap_or("<unnamed thread>")
.to_string();
println!("{}: executing repeated-assert catch block", thread_name);
catch().await;
for _ in repetitions_catch..(repetitions - 1) {
let result = panic::AssertUnwindSafe(assert()).catch_unwind().await;
if let Ok(value) = result {
return value;
}
tokio::time::sleep(delay).await;
}
drop(ignore_guard);
assert().await
}
#[cfg(test)]
mod tests {
use crate as repeated_assert;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
static STEP_MS: u64 = 100;
fn spawn_thread(x: Arc<Mutex<i32>>) {
thread::spawn(move || loop {
thread::sleep(Duration::from_millis(10 * STEP_MS));
if let Ok(mut x) = x.lock() {
*x += 1;
}
});
}
#[test]
fn single_success() {
let x = Arc::new(Mutex::new(0));
spawn_thread(x.clone());
repeated_assert::that(5, Duration::from_millis(5 * STEP_MS), || {
assert!(*x.lock().unwrap() > 0);
});
}
#[cfg(feature = "async")]
#[tokio::test]
async fn single_success_async() {
let x = Arc::new(Mutex::new(0));
spawn_thread(x.clone());
repeated_assert::that_async(5, Duration::from_millis(5 * STEP_MS), || async {
assert!(*x.lock().unwrap() > 0);
})
.await;
}
#[test]
#[should_panic(expected = "assertion failed: *x.lock().unwrap() > 0")]
fn single_failure() {
let x = Arc::new(Mutex::new(0));
spawn_thread(x.clone());
repeated_assert::that(3, Duration::from_millis(STEP_MS), || {
assert!(*x.lock().unwrap() > 0);
});
}
#[cfg(feature = "async")]
#[tokio::test]
#[should_panic(expected = "assertion failed: *x.lock().unwrap() > 0")]
async fn single_failure_async() {
let x = Arc::new(Mutex::new(0));
spawn_thread(x.clone());
repeated_assert::that_async(3, Duration::from_millis(STEP_MS), || async {
assert!(*x.lock().unwrap() > 0);
})
.await;
}
#[test]
fn multiple_success() {
let x = Arc::new(Mutex::new(0));
let a = 11;
let b = 11;
spawn_thread(x.clone());
repeated_assert::that(5, Duration::from_millis(5 * STEP_MS), || {
assert!(*x.lock().unwrap() > 0);
assert_eq!(a, b);
});
}
#[cfg(feature = "async")]
#[tokio::test]
async fn multiple_success_async() {
let x = Arc::new(Mutex::new(0));
let a = 11;
let b = 11;
spawn_thread(x.clone());
repeated_assert::that_async(5, Duration::from_millis(5 * STEP_MS), || async {
assert!(*x.lock().unwrap() > 0);
assert_eq!(a, b);
})
.await;
}
#[test]
#[should_panic(expected = "assertion failed: *x.lock().unwrap() > 0")]
fn multiple_failure_1() {
let x = Arc::new(Mutex::new(0));
let a = 11;
let b = 11;
spawn_thread(x.clone());
repeated_assert::that(3, Duration::from_millis(STEP_MS), || {
assert!(*x.lock().unwrap() > 0);
assert_eq!(a, b);
});
}
#[cfg(feature = "async")]
#[tokio::test]
#[should_panic(expected = "assertion failed: *x.lock().unwrap() > 0")]
async fn multiple_failure_1_async() {
let x = Arc::new(Mutex::new(0));
let a = 11;
let b = 11;
spawn_thread(x.clone());
repeated_assert::that_async(3, Duration::from_millis(STEP_MS), || async {
assert!(*x.lock().unwrap() > 0);
assert_eq!(a, b);
})
.await;
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn multiple_failure_2() {
let x = Arc::new(Mutex::new(0));
let a = 11;
let b = 12;
spawn_thread(x.clone());
repeated_assert::that(5, Duration::from_millis(5 * STEP_MS), || {
assert!(*x.lock().unwrap() > 0);
assert_eq!(a, b);
});
}
#[cfg(feature = "async")]
#[tokio::test]
#[should_panic(expected = "assertion `left == right` failed")]
async fn multiple_failure_2_async() {
let x = Arc::new(Mutex::new(0));
let a = 11;
let b = 12;
spawn_thread(x.clone());
repeated_assert::that_async(5, Duration::from_millis(5 * STEP_MS), || async {
assert!(*x.lock().unwrap() > 0);
assert_eq!(a, b);
})
.await;
}
#[test]
fn catch() {
let x = Arc::new(Mutex::new(-1_000));
spawn_thread(x.clone());
repeated_assert::with_catch(
10,
Duration::from_millis(5 * STEP_MS),
5,
|| {
*x.lock().unwrap() = 0;
},
|| {
assert!(*x.lock().unwrap() > 0);
},
);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn catch_async() {
let x = Arc::new(Mutex::new(-1_000));
spawn_thread(x.clone());
repeated_assert::with_catch_async(
10,
Duration::from_millis(5 * STEP_MS),
5,
|| async {
*x.lock().unwrap() = 0;
},
|| async {
assert!(*x.lock().unwrap() > 0);
},
)
.await;
}
}