use std::sync::Arc;
#[derive(Clone, Default)]
pub struct WorkStealingContext {
pool: Option<Arc<rayon::ThreadPool>>,
}
impl WorkStealingContext {
#[must_use]
pub fn new() -> Self {
Self { pool: None }
}
#[must_use]
pub fn with_pool(pool: Arc<rayon::ThreadPool>) -> Self {
Self { pool: Some(pool) }
}
#[must_use]
pub fn with_rayon_pool(mut self, pool: Arc<rayon::ThreadPool>) -> Self {
self.pool = Some(pool);
self
}
#[must_use]
pub fn num_threads(&self) -> usize {
match &self.pool {
Some(p) => p.current_num_threads(),
None => rayon::current_num_threads(),
}
}
pub fn install<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
match &self.pool {
Some(pool) => pool.install(f),
None => f(),
}
}
pub fn par_map_slices_mut<T, F>(&self, data: &mut [T], chunk_size: usize, f: F)
where
T: Send,
F: Fn(&mut [T]) + Send + Sync,
{
use rayon::prelude::*;
if chunk_size == 0 || data.is_empty() {
return;
}
match &self.pool {
Some(pool) => {
pool.install(|| {
data.par_chunks_mut(chunk_size).for_each(&f);
});
}
None => {
data.par_chunks_mut(chunk_size).for_each(f);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_context_increments_all_chunks() {
let ctx = WorkStealingContext::new();
let mut data = vec![0u32; 100];
ctx.par_map_slices_mut(&mut data, 10, |chunk| {
for v in chunk.iter_mut() {
*v = 1;
}
});
assert!(data.iter().all(|&v| v == 1));
}
#[test]
fn test_with_custom_pool() {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(2)
.build()
.expect("failed to build pool");
let ctx = WorkStealingContext::with_pool(Arc::new(pool));
assert_eq!(ctx.num_threads(), 2);
let mut data = vec![0i64; 64];
ctx.par_map_slices_mut(&mut data, 16, |chunk| {
for v in chunk.iter_mut() {
*v += 7;
}
});
assert!(data.iter().all(|&v| v == 7));
}
#[test]
fn test_with_rayon_pool_builder() {
let pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(2)
.build()
.expect("pool build failed"),
);
let ctx = WorkStealingContext::new().with_rayon_pool(pool);
let mut data = vec![1u8; 32];
ctx.par_map_slices_mut(&mut data, 8, |chunk| {
for v in chunk.iter_mut() {
*v = v.wrapping_mul(2);
}
});
assert!(data.iter().all(|&v| v == 2));
}
#[test]
fn test_empty_data_is_noop() {
let ctx = WorkStealingContext::new();
let mut data: Vec<u32> = vec![];
ctx.par_map_slices_mut(&mut data, 8, |_chunk| panic!("should not be called"));
}
#[test]
fn test_chunk_size_zero_is_noop() {
let ctx = WorkStealingContext::new();
let mut data = vec![42u32; 8];
ctx.par_map_slices_mut(&mut data, 0, |_chunk| panic!("should not be called"));
assert!(data.iter().all(|&v| v == 42));
}
}