use std::sync::mpsc;
use std::thread;
pub const MIN_PARALLEL_ROWS: usize = 4096;
pub fn parallel_scan<T, U, F>(input: &[T], chunk_count: usize, worker: F) -> Vec<U>
where
T: Send + Sync + Clone + 'static,
U: Send + 'static,
F: Fn(&[T]) -> Vec<U> + Send + Sync + 'static,
{
if chunk_count <= 1 || input.len() < MIN_PARALLEL_ROWS {
return worker(input);
}
let chunk_size = input.len().div_ceil(chunk_count);
let mut chunks: Vec<Vec<T>> = Vec::with_capacity(chunk_count);
let mut idx = 0;
while idx < input.len() {
let end = (idx + chunk_size).min(input.len());
chunks.push(input[idx..end].to_vec());
idx = end;
}
let (tx, rx) = mpsc::channel();
let worker = std::sync::Arc::new(worker);
let mut handles = Vec::with_capacity(chunks.len());
for (chunk_idx, chunk) in chunks.into_iter().enumerate() {
let tx = tx.clone();
let worker = worker.clone();
let handle = thread::spawn(move || {
let result = worker(&chunk);
let _ = tx.send((chunk_idx, result));
});
handles.push(handle);
}
drop(tx);
let mut indexed: Vec<Option<Vec<U>>> = (0..handles.len()).map(|_| None).collect();
while let Ok((idx, result)) = rx.recv() {
indexed[idx] = Some(result);
}
for handle in handles {
let _ = handle.join();
}
let mut out: Vec<U> = Vec::new();
for chunk_result in indexed.into_iter().flatten() {
out.extend(chunk_result);
}
out
}
pub fn parallel_count<T, F>(input: &[T], chunk_count: usize, counter: F) -> u64
where
T: Send + Sync + Clone + 'static,
F: Fn(&[T]) -> u64 + Send + Sync + 'static,
{
if chunk_count <= 1 || input.len() < MIN_PARALLEL_ROWS {
return counter(input);
}
let chunk_size = input.len().div_ceil(chunk_count);
let mut chunks: Vec<Vec<T>> = Vec::with_capacity(chunk_count);
let mut idx = 0;
while idx < input.len() {
let end = (idx + chunk_size).min(input.len());
chunks.push(input[idx..end].to_vec());
idx = end;
}
let (tx, rx) = mpsc::channel();
let counter = std::sync::Arc::new(counter);
let mut handles = Vec::with_capacity(chunks.len());
for chunk in chunks {
let tx = tx.clone();
let counter = counter.clone();
let handle = thread::spawn(move || {
let n = counter(&chunk);
let _ = tx.send(n);
});
handles.push(handle);
}
drop(tx);
let mut total = 0u64;
while let Ok(n) = rx.recv() {
total += n;
}
for handle in handles {
let _ = handle.join();
}
total
}
pub fn default_parallelism() -> usize {
std::thread::available_parallelism()
.map(|n| n.get().min(8))
.unwrap_or(1)
}
pub fn parallel_scan_default<T, U, F>(input: &[T], worker: F) -> Vec<U>
where
T: Send + Sync + Clone + 'static,
U: Send + 'static,
F: Fn(&[T]) -> Vec<U> + Send + Sync + 'static,
{
parallel_scan(input, default_parallelism(), worker)
}
pub fn parallel_count_default<T, F>(input: &[T], counter: F) -> u64
where
T: Send + Sync + Clone + 'static,
F: Fn(&[T]) -> u64 + Send + Sync + 'static,
{
parallel_count(input, default_parallelism(), counter)
}
pub const DEFAULT_SCAN_BATCH_ROWS: usize = 256;
pub struct ScanBatches<'a, T, U, F>
where
F: Fn(&[T]) -> Vec<U>,
{
input: &'a [T],
cursor: usize,
batch_rows: usize,
worker: F,
_marker: std::marker::PhantomData<fn() -> U>,
}
impl<'a, T, U, F> Iterator for ScanBatches<'a, T, U, F>
where
F: Fn(&[T]) -> Vec<U>,
{
type Item = Vec<U>;
fn next(&mut self) -> Option<Vec<U>> {
if self.cursor >= self.input.len() {
return None;
}
let end = (self.cursor + self.batch_rows).min(self.input.len());
let batch = &self.input[self.cursor..end];
self.cursor = end;
Some((self.worker)(batch))
}
}
pub fn parallel_scan_stream<T, U, F>(
input: &[T],
batch_rows: usize,
worker: F,
) -> ScanBatches<'_, T, U, F>
where
F: Fn(&[T]) -> Vec<U>,
{
ScanBatches {
input,
cursor: 0,
batch_rows: batch_rows.max(1),
worker,
_marker: std::marker::PhantomData,
}
}
pub fn parallel_scan_rows<'a, T, U, F>(
input: &'a [T],
batch_rows: usize,
worker: F,
) -> impl Iterator<Item = U> + 'a
where
T: 'a,
U: 'a,
F: Fn(&[T]) -> Vec<U> + 'a,
{
parallel_scan_stream(input, batch_rows, worker).flat_map(|batch| batch.into_iter())
}
#[cfg(test)]
mod tests {
use super::*;
fn copy_worker(chunk: &[u64]) -> Vec<u64> {
chunk.to_vec()
}
#[test]
fn scan_stream_yields_batches_in_order_and_matches_eager_collect() {
let input: Vec<u64> = (0..1000).collect();
let eager = parallel_scan(&input, default_parallelism(), copy_worker);
let streamed: Vec<u64> = parallel_scan_rows(&input, 64, copy_worker).collect();
assert_eq!(eager, streamed);
assert_eq!(streamed, input);
}
#[test]
fn scan_stream_applies_filter_worker_with_parity() {
let input: Vec<u64> = (0..500).collect();
let even =
|chunk: &[u64]| -> Vec<u64> { chunk.iter().copied().filter(|n| n % 2 == 0).collect() };
let eager = parallel_scan(&input, default_parallelism(), even);
let streamed: Vec<u64> = parallel_scan_rows(&input, 16, even).collect();
assert_eq!(eager, streamed);
assert!(streamed.iter().all(|n| n % 2 == 0));
}
#[test]
fn scan_stream_batch_rows_zero_is_clamped_to_one() {
let input: Vec<u64> = (0..5).collect();
let batches: Vec<Vec<u64>> = parallel_scan_stream(&input, 0, copy_worker).collect();
assert_eq!(batches.len(), 5, "batch_rows 0 must clamp to 1 row/batch");
assert_eq!(batches.concat(), input);
}
#[test]
fn scan_stream_materialises_at_most_one_batch_at_a_time() {
let input: Vec<u64> = (0..10_000).collect();
const BATCH: usize = 128;
let bounded = |chunk: &[u64]| -> Vec<u64> {
assert!(
chunk.len() <= BATCH,
"worker saw {} rows, exceeding batch cap {BATCH}",
chunk.len()
);
chunk.to_vec()
};
let total: usize = parallel_scan_rows(&input, BATCH, bounded).count();
assert_eq!(total, input.len());
}
}