mod serial;
mod spawn;
#[cfg(feature = "threading")]
mod rayon_impl;
pub use serial::SerialPool;
pub use spawn::ThreadPool;
#[cfg(feature = "threading")]
pub use rayon_impl::RayonPool;
#[cfg(feature = "threading")]
#[must_use]
pub fn get_default_pool() -> RayonPool {
RayonPool::new()
}
#[cfg(not(feature = "threading"))]
#[must_use]
pub fn get_default_pool() -> SerialPool {
SerialPool::new()
}
#[cfg(feature = "threading")]
#[must_use]
pub fn pool_with_threads(num_threads: usize) -> RayonPool {
RayonPool::with_num_threads(num_threads)
}
#[cfg(not(feature = "threading"))]
#[must_use]
pub fn pool_with_threads(_num_threads: usize) -> SerialPool {
SerialPool::new()
}
#[derive(Clone, Debug)]
pub struct PoolConfig {
pub num_threads: usize,
}
impl Default for PoolConfig {
fn default() -> Self {
Self { num_threads: 0 }
}
}
impl PoolConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn threads(mut self, num_threads: usize) -> Self {
self.num_threads = num_threads;
self
}
#[cfg(feature = "threading")]
#[must_use]
pub fn build(self) -> RayonPool {
if self.num_threads == 0 {
RayonPool::new()
} else {
RayonPool::with_num_threads(self.num_threads)
}
}
#[cfg(not(feature = "threading"))]
#[must_use]
pub fn build(self) -> SerialPool {
SerialPool::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_default_pool() {
let pool = get_default_pool();
assert!(pool.num_threads() >= 1);
}
#[test]
fn test_pool_config() {
let config = PoolConfig::new().threads(2);
let pool = config.build();
assert!(pool.num_threads() >= 1);
}
#[test]
fn test_serial_pool() {
let pool = SerialPool::new();
assert_eq!(pool.num_threads(), 1);
pool.parallel_for(5, |i| {
let _ = i;
});
let (a, b) = pool.join(|| 1, || 2);
assert_eq!(a, 1);
assert_eq!(b, 2);
}
#[test]
fn test_parallel_for_chunks() {
use std::sync::atomic::{AtomicUsize, Ordering};
let pool = get_default_pool();
let counter = AtomicUsize::new(0);
let n = 100;
pool.parallel_for_chunks(n, 10, |_start, len| {
counter.fetch_add(len, Ordering::SeqCst);
});
assert_eq!(counter.load(Ordering::SeqCst), n);
}
#[test]
fn test_parallel_split() {
use std::sync::atomic::{AtomicUsize, Ordering};
let pool = get_default_pool();
let counter = AtomicUsize::new(0);
pool.parallel_split(0, 100, 10, &|start, count| {
for i in start..(start + count) {
counter.fetch_add(i, Ordering::SeqCst);
}
});
assert_eq!(counter.load(Ordering::SeqCst), 4950);
}
#[test]
fn test_join() {
let pool = get_default_pool();
let (a, b) = pool.join(|| 42, || 43);
assert_eq!(a, 42);
assert_eq!(b, 43);
}
#[test]
fn test_parallel_for_data_integrity() {
use std::sync::atomic::{AtomicU64, Ordering};
let pool = get_default_pool();
let sum = AtomicU64::new(0);
let n = 1000;
pool.parallel_for(n, |i| {
sum.fetch_add(i as u64, Ordering::SeqCst);
});
let expected = (n as u64 * (n as u64 - 1)) / 2;
assert_eq!(sum.load(Ordering::SeqCst), expected);
}
#[cfg(feature = "threading")]
#[test]
fn test_rayon_pool_threads() {
let pool = RayonPool::with_num_threads(4);
assert_eq!(pool.num_threads(), 4);
}
#[cfg(feature = "threading")]
#[test]
fn test_parallel_correctness_with_mutex() {
use std::sync::Mutex;
let pool = get_default_pool();
let results = Mutex::new(vec![0usize; 100]);
pool.parallel_for(100, |i| {
let mut r = results.lock().unwrap();
r[i] = i * 2;
});
let r = results.lock().unwrap();
for i in 0..100 {
assert_eq!(r[i], i * 2, "Element {i} has wrong value");
}
}
#[cfg(feature = "threading")]
#[test]
fn test_parallel_chunks_boundary() {
use std::sync::atomic::{AtomicUsize, Ordering};
let pool = get_default_pool();
let counter = AtomicUsize::new(0);
let n = 97;
pool.parallel_for_chunks(n, 10, |start, len| {
for i in start..(start + len) {
counter.fetch_add(i, Ordering::SeqCst);
}
});
let expected: usize = (0..n).sum();
assert_eq!(counter.load(Ordering::SeqCst), expected);
}
#[cfg(feature = "threading")]
#[test]
fn test_nested_parallel() {
use std::sync::atomic::{AtomicUsize, Ordering};
let pool = get_default_pool();
let counter = AtomicUsize::new(0);
pool.parallel_for(4, |_i| {
for j in 0..25 {
counter.fetch_add(j, Ordering::SeqCst);
}
});
let expected = 4 * (25 * 24 / 2);
assert_eq!(counter.load(Ordering::SeqCst), expected);
}
}