use std::future::Future;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use tokio::task::{JoinError, JoinSet};
pub(crate) const DEFAULT_WORKERS: usize = 16;
pub(crate) const MIN_WORKERS: usize = 1;
pub(crate) const MAX_WORKERS: usize = 2_000;
pub(crate) const DEFAULT_ROUNDS: usize = 2;
pub(crate) const MIN_ROUNDS: usize = 1;
pub(crate) const MAX_ROUNDS: usize = 1_000_000;
type SyncTestBlock<E> = Arc<dyn Fn() -> Result<(), E> + Send + Sync + 'static>;
type AsyncTestBlock<E> =
Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<(), E>> + Send>> + Send + Sync + 'static>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub struct ConcurrentConfig {
pub workers: usize,
pub iterations_per_worker: usize,
}
impl ConcurrentConfig {
#[must_use]
pub const fn new(workers: usize, iterations_per_worker: usize) -> Self {
Self {
workers,
iterations_per_worker,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum ConcurrentAssertError<E> {
NoBlocks,
InvalidWorkers {
workers: usize,
},
InvalidRounds {
rounds: usize,
},
ZeroWorkers,
ZeroIterations,
TooFewWorkers {
workers: usize,
blocks: usize,
},
TestBlockFailed {
worker: usize,
round: usize,
block: usize,
error: E,
},
TestBlockPanicked {
worker: usize,
round: usize,
block: usize,
},
OperationFailed {
worker: usize,
iteration: usize,
error: E,
},
WorkerJoinFailed {
worker: Option<usize>,
source: JoinError,
},
WorkerThreadSpawnFailed {
worker: usize,
source: std::io::Error,
},
WorkerThreadPanicked {
worker: usize,
},
}
impl<E> std::fmt::Display for ConcurrentAssertError<E>
where
E: std::fmt::Display,
{
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoBlocks => formatter.write_str("no test blocks were registered"),
Self::InvalidWorkers { workers } => write!(
formatter,
"workers must be in {MIN_WORKERS}..={MAX_WORKERS}, got {workers}"
),
Self::InvalidRounds { rounds } => write!(
formatter,
"rounds must be in {MIN_ROUNDS}..={MAX_ROUNDS}, got {rounds}"
),
Self::ZeroWorkers => formatter.write_str("workers must be greater than zero"),
Self::ZeroIterations => {
formatter.write_str("iterations_per_worker must be greater than zero")
}
Self::TooFewWorkers { workers, blocks } => write!(
formatter,
"workers ({workers}) must be greater than or equal to registered blocks ({blocks})"
),
Self::TestBlockFailed {
worker,
round,
block,
error,
} => write!(
formatter,
"test block {block} failed at worker {worker}, round {round}: {error}"
),
Self::TestBlockPanicked {
worker,
round,
block,
} => write!(
formatter,
"test block {block} panicked at worker {worker}, round {round}"
),
Self::OperationFailed {
worker,
iteration,
error,
} => write!(
formatter,
"concurrent operation failed at worker {worker}, iteration {iteration}: {error}"
),
Self::WorkerJoinFailed {
worker: Some(worker),
source,
} => write!(formatter, "worker {worker} failed to join: {source}"),
Self::WorkerJoinFailed {
worker: None,
source,
} => write!(formatter, "worker failed to join: {source}"),
Self::WorkerThreadSpawnFailed { worker, source } => {
write!(
formatter,
"worker thread {worker} failed to spawn: {source}"
)
}
Self::WorkerThreadPanicked { worker } => {
write!(formatter, "worker thread {worker} panicked")
}
}
}
}
impl<E> std::error::Error for ConcurrentAssertError<E>
where
E: std::error::Error + 'static,
{
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::TestBlockFailed { error, .. } => Some(error),
Self::OperationFailed { error, .. } => Some(error),
Self::WorkerJoinFailed { source, .. } => Some(source),
Self::WorkerThreadSpawnFailed { source, .. } => Some(source),
Self::NoBlocks
| Self::InvalidWorkers { .. }
| Self::InvalidRounds { .. }
| Self::ZeroWorkers
| Self::ZeroIterations
| Self::TooFewWorkers { .. }
| Self::TestBlockPanicked { .. }
| Self::WorkerThreadPanicked { .. } => None,
}
}
}
#[derive(Clone)]
pub struct MultithreadingTester<E> {
workers: usize,
rounds: usize,
blocks: Vec<SyncTestBlock<E>>,
}
impl<E> Default for MultithreadingTester<E> {
fn default() -> Self {
Self {
workers: DEFAULT_WORKERS,
rounds: DEFAULT_ROUNDS,
blocks: Vec::new(),
}
}
}
impl<E> MultithreadingTester<E> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn workers(mut self, workers: usize) -> Self {
self.workers = workers;
self
}
#[must_use]
pub fn rounds(mut self, rounds: usize) -> Self {
self.rounds = rounds;
self
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn add<F>(mut self, block: F) -> Self
where
F: Fn() -> Result<(), E> + Send + Sync + 'static,
{
self.blocks.push(Arc::new(block));
self
}
pub fn run(self) -> Result<(), ConcurrentAssertError<E>>
where
E: Send + 'static,
{
validate_tester_config(self.workers, self.rounds, self.blocks.len(), true)?;
let workers = self.workers;
let rounds = self.rounds;
let blocks = Arc::new(self.blocks);
let total_runs = workers * rounds;
let next = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let stop = Arc::new(AtomicBool::new(false));
let mut handles = Vec::with_capacity(workers);
let mut spawn_error = None;
for worker in 0..workers {
let blocks = Arc::clone(&blocks);
let next = Arc::clone(&next);
let stop = Arc::clone(&stop);
let stop_on_spawn_error = Arc::clone(&stop);
match thread::Builder::new().spawn(move || {
loop {
if stop.load(Ordering::SeqCst) {
return Ok(());
}
let run_index = next.fetch_add(1, Ordering::SeqCst);
if run_index >= total_runs {
return Ok(());
}
let block = run_index % blocks.len();
let round = run_index / workers;
let result = catch_unwind(AssertUnwindSafe(|| blocks[block]()));
match result {
Ok(Ok(())) => {}
Ok(Err(error)) => {
stop.store(true, Ordering::SeqCst);
return Err(ConcurrentAssertError::TestBlockFailed {
worker,
round,
block,
error,
});
}
Err(_) => {
stop.store(true, Ordering::SeqCst);
return Err(ConcurrentAssertError::TestBlockPanicked {
worker,
round,
block,
});
}
}
}
}) {
Ok(handle) => handles.push((worker, handle)),
Err(source) => {
stop_on_spawn_error.store(true, Ordering::SeqCst);
spawn_error =
Some(ConcurrentAssertError::WorkerThreadSpawnFailed { worker, source });
break;
}
}
}
let mut first_error = spawn_error;
for (worker, handle) in handles {
match handle.join() {
Ok(Ok(())) => {}
Ok(Err(error)) => {
if first_error.is_none() {
first_error = Some(error);
}
}
Err(_) => {
if first_error.is_none() {
first_error = Some(ConcurrentAssertError::WorkerThreadPanicked { worker });
}
}
}
}
match first_error {
Some(error) => Err(error),
None => Ok(()),
}
}
}
#[derive(Clone)]
pub struct SuspendedJobTester<E> {
workers: usize,
rounds: usize,
blocks: Vec<AsyncTestBlock<E>>,
}
impl<E> Default for SuspendedJobTester<E> {
fn default() -> Self {
Self {
workers: DEFAULT_WORKERS,
rounds: DEFAULT_ROUNDS,
blocks: Vec::new(),
}
}
}
impl<E> SuspendedJobTester<E> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn workers(mut self, workers: usize) -> Self {
self.workers = workers;
self
}
#[must_use]
pub fn rounds(mut self, rounds: usize) -> Self {
self.rounds = rounds;
self
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn add<F, Fut>(mut self, block: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
{
self.blocks.push(Arc::new(move || Box::pin(block())));
self
}
pub async fn run(self) -> Result<(), ConcurrentAssertError<E>>
where
E: Send + 'static,
{
validate_tester_config(self.workers, self.rounds, self.blocks.len(), false)?;
let workers = self.workers;
let rounds = self.rounds;
let block_count = self.blocks.len();
let total_runs = workers * rounds;
let blocks = Arc::new(self.blocks);
let next = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let mut tasks = JoinSet::new();
for worker in 0..workers {
let blocks = Arc::clone(&blocks);
let next = Arc::clone(&next);
tasks.spawn(async move {
loop {
let run_index = next.fetch_add(1, Ordering::SeqCst);
if run_index >= total_runs {
return (worker, Ok(()));
}
let block = run_index % block_count;
let round = run_index / workers;
if let Err(error) = blocks[block]().await {
return (
worker,
Err(ConcurrentAssertError::TestBlockFailed {
worker,
round,
block,
error,
}),
);
}
}
});
}
while let Some(result) = tasks.join_next().await {
match result {
Ok((_worker, Ok(()))) => {}
Ok((_worker, Err(error))) => {
abort_and_drain(&mut tasks).await;
return Err(error);
}
Err(source) => {
abort_and_drain(&mut tasks).await;
return Err(ConcurrentAssertError::WorkerJoinFailed {
worker: None,
source,
});
}
}
}
Ok(())
}
}
async fn abort_and_drain<E>(tasks: &mut JoinSet<(usize, Result<(), ConcurrentAssertError<E>>)>)
where
E: Send + 'static,
{
tasks.abort_all();
while tasks.join_next().await.is_some() {}
}
fn validate_concurrent_config<E>(config: ConcurrentConfig) -> Result<(), ConcurrentAssertError<E>> {
if config.workers == 0 {
return Err(ConcurrentAssertError::ZeroWorkers);
}
if config.iterations_per_worker == 0 {
return Err(ConcurrentAssertError::ZeroIterations);
}
if config.workers > MAX_WORKERS {
return Err(ConcurrentAssertError::InvalidWorkers {
workers: config.workers,
});
}
if config.iterations_per_worker > MAX_ROUNDS {
return Err(ConcurrentAssertError::InvalidRounds {
rounds: config.iterations_per_worker,
});
}
Ok(())
}
fn validate_tester_config<E>(
workers: usize,
rounds: usize,
blocks: usize,
require_workers_at_least_blocks: bool,
) -> Result<(), ConcurrentAssertError<E>> {
if !(MIN_WORKERS..=MAX_WORKERS).contains(&workers) {
return Err(ConcurrentAssertError::InvalidWorkers { workers });
}
if !(MIN_ROUNDS..=MAX_ROUNDS).contains(&rounds) {
return Err(ConcurrentAssertError::InvalidRounds { rounds });
}
if blocks == 0 {
return Err(ConcurrentAssertError::NoBlocks);
}
if require_workers_at_least_blocks && workers < blocks {
return Err(ConcurrentAssertError::TooFewWorkers { workers, blocks });
}
Ok(())
}
pub async fn run_concurrently<F, Fut, E>(
config: ConcurrentConfig,
operation: F,
) -> Result<(), ConcurrentAssertError<E>>
where
F: Fn(usize, usize) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), E>> + Send + 'static,
E: Send + 'static,
{
validate_concurrent_config(config)?;
let operation = Arc::new(operation);
let mut tasks = JoinSet::new();
for worker in 0..config.workers {
let operation = Arc::clone(&operation);
tasks.spawn(async move {
for iteration in 0..config.iterations_per_worker {
if let Err(error) = operation(worker, iteration).await {
return (
worker,
Err(ConcurrentAssertError::OperationFailed {
worker,
iteration,
error,
}),
);
}
}
(worker, Ok(()))
});
}
while let Some(result) = tasks.join_next().await {
match result {
Ok((_worker, Ok(()))) => {}
Ok((_worker, Err(error))) => {
abort_and_drain(&mut tasks).await;
return Err(error);
}
Err(source) => {
abort_and_drain(&mut tasks).await;
return Err(ConcurrentAssertError::WorkerJoinFailed {
worker: None,
source,
});
}
}
}
Ok(())
}