use std::ops::Range;
use std::thread;
pub(crate) fn resolve_parallelism(
total_work_items: usize,
requested_threads: Option<usize>,
min_chunk_size: usize,
) -> usize {
if total_work_items == 0 {
return 1;
}
let requested = match requested_threads {
Some(v) => v.max(1),
None => thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1),
};
let max_by_work = total_work_items;
let max_by_chunk = total_work_items.div_ceil(min_chunk_size.max(1)).max(1);
requested.min(max_by_work).min(max_by_chunk).max(1)
}
pub(crate) fn split_ranges(total: usize, workers: usize) -> Vec<Range<usize>> {
let mut ranges = Vec::with_capacity(workers);
for worker_id in 0..workers {
let start = worker_id * total / workers;
let end = (worker_id + 1) * total / workers;
ranges.push(start..end);
}
ranges
}
pub(crate) fn parallel_map_indexed<T, R, F>(
input: &[T],
requested_threads: Option<usize>,
min_chunk_size: usize,
mapper: F,
) -> Vec<R>
where
T: Sync,
R: Send,
F: Fn(usize, &T) -> R + Sync,
{
let total = input.len();
if total == 0 {
return Vec::new();
}
let worker_count = resolve_parallelism(total, requested_threads, min_chunk_size);
if worker_count <= 1 {
return input
.iter()
.enumerate()
.map(|(idx, item)| mapper(idx, item))
.collect();
}
let ranges = split_ranges(total, worker_count);
let mut ordered_chunks: Vec<Vec<R>> = Vec::with_capacity(worker_count);
thread::scope(|scope| {
let mut handles = Vec::with_capacity(worker_count);
for range in ranges {
let mapper_ref = &mapper;
let input_ref = input;
handles.push(scope.spawn(move || {
let mut local = Vec::with_capacity(range.len());
for idx in range {
local.push(mapper_ref(idx, &input_ref[idx]));
}
local
}));
}
for handle in handles {
ordered_chunks.push(
handle
.join()
.expect("parallel worker panicked while mapping elements"),
);
}
});
let mut result = Vec::with_capacity(total);
for mut chunk in ordered_chunks {
result.append(&mut chunk);
}
result
}
pub fn resolve_num_threads(num_threads: Option<usize>) -> usize {
match num_threads {
Some(n) => n.max(1),
None => std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1),
}
}