use std::sync::mpsc;
use std::thread;
pub const MIN_PARALLEL_ROWS: usize = 4096;
pub fn parallel_scan<T, U, F>(input: &[T], chunk_count: usize, worker: F) -> Vec<U>
where
T: Send + Sync + Clone + 'static,
U: Send + 'static,
F: Fn(&[T]) -> Vec<U> + Send + Sync + 'static,
{
if chunk_count <= 1 || input.len() < MIN_PARALLEL_ROWS {
return worker(input);
}
let chunk_size = input.len().div_ceil(chunk_count);
let mut chunks: Vec<Vec<T>> = Vec::with_capacity(chunk_count);
let mut idx = 0;
while idx < input.len() {
let end = (idx + chunk_size).min(input.len());
chunks.push(input[idx..end].to_vec());
idx = end;
}
let (tx, rx) = mpsc::channel();
let worker = std::sync::Arc::new(worker);
let mut handles = Vec::with_capacity(chunks.len());
for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
let tx = tx.clone();
let worker = worker.clone();
let handle = thread::spawn(move || {
let result = worker(&chunk);
let _ = tx.send((chunk_idx, result));
});
handles.push(handle);
}
drop(tx);
let mut indexed: Vec<Option<Vec<U>>> = (0..handles.len()).map(|_| None).collect();
while let Ok((idx, result)) = rx.recv() {
indexed[idx] = Some(result);
}
for handle in handles {
let _ = handle.join();
}
let mut out: Vec<U> = Vec::new();
for chunk_result in indexed.into_iter().flatten() {
out.extend(chunk_result);
}
out
}
pub fn parallel_count<T, F>(input: &[T], chunk_count: usize, counter: F) -> u64
where
T: Send + Sync + Clone + 'static,
F: Fn(&[T]) -> u64 + Send + Sync + 'static,
{
if chunk_count <= 1 || input.len() < MIN_PARALLEL_ROWS {
return counter(input);
}
let chunk_size = input.len().div_ceil(chunk_count);
let mut chunks: Vec<Vec<T>> = Vec::with_capacity(chunk_count);
let mut idx = 0;
while idx < input.len() {
let end = (idx + chunk_size).min(input.len());
chunks.push(input[idx..end].to_vec());
idx = end;
}
let (tx, rx) = mpsc::channel();
let counter = std::sync::Arc::new(counter);
let mut handles = Vec::with_capacity(chunks.len());
for chunk in chunks {
let tx = tx.clone();
let counter = counter.clone();
let handle = thread::spawn(move || {
let n = counter(&chunk);
let _ = tx.send(n);
});
handles.push(handle);
}
drop(tx);
let mut total = 0u64;
while let Ok(n) = rx.recv() {
total += n;
}
for handle in handles {
let _ = handle.join();
}
total
}
pub fn default_parallelism() -> usize {
std::thread::available_parallelism()
.map(|n| n.get().min(8))
.unwrap_or(1)
}
pub fn parallel_scan_default<T, U, F>(input: &[T], worker: F) -> Vec<U>
where
T: Send + Sync + Clone + 'static,
U: Send + 'static,
F: Fn(&[T]) -> Vec<U> + Send + Sync + 'static,
{
parallel_scan(input, default_parallelism(), worker)
}
pub fn parallel_count_default<T, F>(input: &[T], counter: F) -> u64
where
T: Send + Sync + Clone + 'static,
F: Fn(&[T]) -> u64 + Send + Sync + 'static,
{
parallel_count(input, default_parallelism(), counter)
}