use std::panic::resume_unwind;
use std::sync::{
Arc,
Mutex,
mpsc,
};
use std::thread;
struct ScopedWorkItem<T> {
index: usize,
item: T,
}
pub(crate) fn run_scoped_parallel<I, T, O, F>(
items: I,
declared_count: usize,
worker_count: usize,
observe_item: O,
run_item: F,
) -> usize
where
I: IntoIterator<Item = T>,
T: Send,
O: Fn() -> usize,
F: Fn(usize, T) + Sync,
{
assert!(
worker_count > 0,
"scoped parallel worker count must be positive"
);
let mut observed_count = 0usize;
thread::scope(|scope| {
let (work_sender, work_receiver) = mpsc::sync_channel(worker_count);
let work_receiver = Arc::new(Mutex::new(work_receiver));
let mut worker_handles = Vec::with_capacity(worker_count);
for _ in 0..worker_count {
let worker_receiver = Arc::clone(&work_receiver);
let worker_run_item = &run_item;
worker_handles.push(scope.spawn(move || {
run_scoped_worker(worker_receiver, worker_run_item);
}));
}
drop(work_receiver);
for item in items {
observed_count = observe_item();
if observed_count > declared_count {
break;
}
if work_sender
.send(ScopedWorkItem {
index: observed_count - 1,
item,
})
.is_err()
{
break;
}
}
drop(work_sender);
for handle in worker_handles {
if let Err(payload) = handle.join() {
resume_unwind(payload);
}
}
});
observed_count
}
fn run_scoped_worker<T, F>(
work_receiver: Arc<Mutex<mpsc::Receiver<ScopedWorkItem<T>>>>,
run_item: &F,
) where
F: Fn(usize, T),
{
loop {
let received = work_receiver
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.recv();
let Ok(work_item) = received else {
break;
};
let ScopedWorkItem { index, item } = work_item;
run_item(index, item);
}
}