use async_trait::async_trait;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ExecutionError {
#[error("Execution failed: {0}")]
Failed(String),
#[error("Thread pool error: {0}")]
ThreadPool(String),
#[error("Task join error: {0}")]
Join(String),
}
#[async_trait]
pub trait ConcurrencyExecutor: Send + Sync {
async fn execute_batch<F, T>(
&self,
items: Vec<T>,
op: F,
) -> Result<Vec<Result<(), ExecutionError>>, ExecutionError>
where
F: Fn(T) -> Result<(), ExecutionError> + Send + Sync + 'static,
T: Send + 'static;
fn name(&self) -> &str;
}
#[cfg(feature = "parallel")]
#[derive(Debug)]
pub struct RayonExecutor {
thread_pool: rayon::ThreadPool,
}
#[cfg(feature = "parallel")]
impl RayonExecutor {
pub fn new(num_threads: Option<usize>) -> Result<Self, ExecutionError> {
let mut builder = rayon::ThreadPoolBuilder::new();
if let Some(threads) = num_threads {
if threads == 0 {
return Err(ExecutionError::ThreadPool(
"Thread count must be > 0".to_string(),
));
}
builder = builder.num_threads(threads);
}
let thread_pool = builder.build().map_err(|e| {
ExecutionError::ThreadPool(format!("Failed to create thread pool: {}", e))
})?;
Ok(Self { thread_pool })
}
}
#[cfg(feature = "parallel")]
#[async_trait]
impl ConcurrencyExecutor for RayonExecutor {
async fn execute_batch<F, T>(
&self,
items: Vec<T>,
op: F,
) -> Result<Vec<Result<(), ExecutionError>>, ExecutionError>
where
F: Fn(T) -> Result<(), ExecutionError> + Send + Sync + 'static,
T: Send + 'static,
{
let op = Arc::new(op);
let results = self.thread_pool.install(|| {
use rayon::prelude::*;
items
.into_par_iter()
.map(|item| op(item))
.collect::<Vec<_>>()
});
Ok(results)
}
fn name(&self) -> &str {
"rayon"
}
}
#[derive(Debug)]
pub struct TokioExecutor {
max_concurrent: usize,
}
impl TokioExecutor {
pub fn new(max_concurrent: usize) -> Self {
Self { max_concurrent }
}
}
#[async_trait]
impl ConcurrencyExecutor for TokioExecutor {
async fn execute_batch<F, T>(
&self,
items: Vec<T>,
op: F,
) -> Result<Vec<Result<(), ExecutionError>>, ExecutionError>
where
F: Fn(T) -> Result<(), ExecutionError> + Send + Sync + 'static,
T: Send + 'static,
{
use tokio::sync::Semaphore;
use tokio::task;
let semaphore = Arc::new(Semaphore::new(self.max_concurrent));
let op = Arc::new(op);
let mut handles = Vec::with_capacity(items.len());
for item in items {
let permit = semaphore.clone().acquire_owned().await.map_err(|e| {
ExecutionError::Join(format!("Semaphore acquisition failed: {}", e))
})?;
let op = Arc::clone(&op);
let handle = task::spawn_blocking(move || {
let result = op(item);
drop(permit); result
});
handles.push(handle);
}
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
let result = handle
.await
.map_err(|e| ExecutionError::Join(format!("Task join failed: {}", e)))?;
results.push(result);
}
Ok(results)
}
fn name(&self) -> &str {
"tokio"
}
}
#[derive(Debug)]
pub struct SequentialExecutor;
#[async_trait]
impl ConcurrencyExecutor for SequentialExecutor {
async fn execute_batch<F, T>(
&self,
items: Vec<T>,
op: F,
) -> Result<Vec<Result<(), ExecutionError>>, ExecutionError>
where
F: Fn(T) -> Result<(), ExecutionError> + Send + Sync + 'static,
T: Send + 'static,
{
let results = items.into_iter().map(op).collect();
Ok(results)
}
fn name(&self) -> &str {
"sequential"
}
}
#[derive(Debug)]
pub enum Executor {
Sequential(SequentialExecutor),
Tokio(TokioExecutor),
#[cfg(feature = "parallel")]
Rayon(RayonExecutor),
}
impl Executor {
pub fn sequential() -> Self {
Self::Sequential(SequentialExecutor)
}
pub fn tokio(max_concurrent: usize) -> Self {
Self::Tokio(TokioExecutor::new(max_concurrent))
}
#[cfg(feature = "parallel")]
pub fn rayon(num_threads: Option<usize>) -> Result<Self, ExecutionError> {
RayonExecutor::new(num_threads).map(Self::Rayon)
}
pub fn name(&self) -> &str {
match self {
Self::Sequential(_) => "sequential",
Self::Tokio(_) => "tokio",
#[cfg(feature = "parallel")]
Self::Rayon(_) => "rayon",
}
}
pub async fn execute_batch<F, T>(
&self,
items: Vec<T>,
op: F,
) -> Result<Vec<Result<(), ExecutionError>>, ExecutionError>
where
F: Fn(T) -> Result<(), ExecutionError> + Send + Sync + 'static,
T: Send + 'static,
{
match self {
Self::Sequential(exec) => exec.execute_batch(items, op).await,
Self::Tokio(exec) => exec.execute_batch(items, op).await,
#[cfg(feature = "parallel")]
Self::Rayon(exec) => exec.execute_batch(items, op).await,
}
}
}
#[derive(Debug, Clone)]
pub enum ConcurrencyMode {
Rayon { num_threads: Option<usize> },
Tokio { max_concurrent: usize },
Sequential,
}
pub fn create_executor(mode: ConcurrencyMode) -> Executor {
match mode {
#[cfg(feature = "parallel")]
ConcurrencyMode::Rayon { num_threads } => {
match Executor::rayon(num_threads) {
Ok(executor) => executor,
Err(_) => {
Executor::sequential()
}
}
}
#[cfg(not(feature = "parallel"))]
ConcurrencyMode::Rayon { .. } => {
Executor::sequential()
}
ConcurrencyMode::Tokio { max_concurrent } => Executor::tokio(max_concurrent),
ConcurrencyMode::Sequential => Executor::sequential(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_sequential_basic() {
let executor = SequentialExecutor;
let items = vec![1, 2, 3];
let results = executor.execute_batch(items, |_| Ok(())).await.unwrap();
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[tokio::test]
async fn test_tokio_basic() {
let executor = TokioExecutor::new(2);
let items = vec![1, 2, 3];
let results = executor.execute_batch(items, |_| Ok(())).await.unwrap();
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[cfg(feature = "parallel")]
#[tokio::test]
async fn test_rayon_basic() {
let executor = RayonExecutor::new(None).unwrap();
let items = vec![1, 2, 3];
let results = executor.execute_batch(items, |_| Ok(())).await.unwrap();
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[test]
fn test_factory_sequential() {
let executor = create_executor(ConcurrencyMode::Sequential);
assert_eq!(executor.name(), "sequential");
}
#[test]
fn test_factory_tokio() {
let executor = create_executor(ConcurrencyMode::Tokio { max_concurrent: 5 });
assert_eq!(executor.name(), "tokio");
}
#[cfg(feature = "parallel")]
#[test]
fn test_factory_rayon() {
let executor = create_executor(ConcurrencyMode::Rayon { num_threads: None });
assert_eq!(executor.name(), "rayon");
}
#[cfg(not(feature = "parallel"))]
#[test]
fn test_factory_rayon_fallback() {
let executor = create_executor(ConcurrencyMode::Rayon { num_threads: None });
assert_eq!(executor.name(), "sequential");
}
}