use crate::error::{CoreError, CoreResult};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
const PAR_THRESHOLD: usize = 1024;
pub fn parallel_scan_sum(input: &[u64]) -> Vec<u64> {
exclusive_scan_u64(input)
}
pub fn parallel_scan_generic<T, F>(input: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
blelloch_exclusive_scan(input, identity, op)
}
pub fn segmented_scan<T, F>(
input: &[T],
segments: &[bool],
identity: T,
op: F,
) -> CoreResult<Vec<T>>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync,
{
if input.len() != segments.len() {
return Err(CoreError::InvalidInput(
crate::error::ErrorContext::new(format!(
"segmented_scan: input length {} != segments length {} (input and segments slices must have the same length)",
input.len(),
segments.len()
)),
));
}
let n = input.len();
let mut output = vec![identity.clone(); n];
let mut acc = identity.clone();
for i in 0..n {
if segments[i] {
acc = identity.clone();
}
output[i] = acc.clone();
acc = op(acc, input[i].clone());
}
Ok(output)
}
pub fn exclusive_scan<T, F>(input: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
blelloch_exclusive_scan(input, identity, op)
}
pub fn inclusive_scan<T, F>(input: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
let n = input.len();
if n == 0 {
return Vec::new();
}
let mut excl = blelloch_exclusive_scan(input, identity, op.clone());
for i in 0..n {
let prev = excl[i].clone();
excl[i] = op(prev, input[i].clone());
}
excl
}
fn exclusive_scan_u64(input: &[u64]) -> Vec<u64> {
let n = input.len();
if n == 0 {
return Vec::new();
}
let mut out = vec![0u64; n];
if n < PAR_THRESHOLD {
let mut acc = 0u64;
for (i, &v) in input.iter().enumerate() {
out[i] = acc;
acc = acc.wrapping_add(v);
}
return out;
}
blelloch_scan_u64_parallel(input, &mut out);
out
}
fn blelloch_exclusive_scan<T, F>(input: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
let n = input.len();
if n == 0 {
return Vec::new();
}
if n < PAR_THRESHOLD {
return sequential_exclusive_scan(input, identity, &op);
}
parallel_chunked_exclusive_scan(input, identity, op)
}
fn sequential_exclusive_scan<T, F>(input: &[T], identity: T, op: &F) -> Vec<T>
where
T: Clone,
F: Fn(T, T) -> T,
{
let n = input.len();
let mut out = Vec::with_capacity(n);
let mut acc = identity;
for item in input {
out.push(acc.clone());
acc = op(acc, item.clone());
}
out
}
#[cfg(feature = "parallel")]
fn parallel_chunked_exclusive_scan<T, F>(input: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
use std::sync::Arc;
let n = input.len();
let num_threads = rayon::current_num_threads().max(1);
let chunk_size = ((n + num_threads - 1) / num_threads).max(1);
let op = Arc::new(op);
let op_phase1 = Arc::clone(&op);
let chunks: Vec<(Vec<T>, T)> = input
.par_chunks(chunk_size)
.map(|chunk| {
let mut local = Vec::with_capacity(chunk.len());
let mut acc = identity.clone();
for item in chunk {
local.push(acc.clone());
acc = op_phase1(acc, item.clone());
}
(local, acc)
})
.collect();
let mut chunk_offsets: Vec<T> = Vec::with_capacity(chunks.len());
let mut running = identity.clone();
for (_, total) in &chunks {
chunk_offsets.push(running.clone());
running = op(running, total.clone());
}
let local_scans: Vec<Vec<T>> = chunks.into_iter().map(|(v, _)| v).collect();
let op_phase3 = Arc::clone(&op);
let result: Vec<T> = local_scans
.into_par_iter()
.zip(chunk_offsets.into_par_iter())
.flat_map(move |(chunk_scan, offset)| {
let op_inner = Arc::clone(&op_phase3);
chunk_scan
.into_iter()
.map(move |v| op_inner(offset.clone(), v))
.collect::<Vec<_>>()
})
.collect();
result
}
#[cfg(not(feature = "parallel"))]
fn parallel_chunked_exclusive_scan<T, F>(input: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
sequential_exclusive_scan(input, identity, &op)
}
#[cfg(feature = "parallel")]
fn blelloch_scan_u64_parallel(input: &[u64], out: &mut [u64]) {
let n = input.len();
out[..n].copy_from_slice(input);
let tree_size = n.next_power_of_two();
let mut tree = vec![0u64; tree_size];
tree[..n].copy_from_slice(input);
let mut stride = 1usize;
while stride < tree_size {
let step = stride * 2;
let indices: Vec<usize> = (step - 1..tree_size).step_by(step).collect();
let updates: Vec<u64> = indices
.par_iter()
.map(|&i| tree[i].wrapping_add(tree[i - stride]))
.collect();
for (&i, v) in indices.iter().zip(updates) {
tree[i] = v;
}
stride *= 2;
}
tree[tree_size - 1] = 0;
let mut stride = tree_size / 2;
while stride >= 1 {
let step = stride * 2;
let indices: Vec<usize> = (step - 1..tree_size).step_by(step).collect();
let swaps: Vec<(u64, u64)> = indices
.par_iter()
.map(|&i| {
let left_child = tree[i - stride];
let parent = tree[i];
(parent, parent.wrapping_add(left_child))
})
.collect();
for (&i, (new_left, new_right)) in indices.iter().zip(swaps) {
tree[i - stride] = new_left;
tree[i] = new_right;
}
stride /= 2;
}
out[..n].copy_from_slice(&tree[..n]);
}
#[cfg(not(feature = "parallel"))]
fn blelloch_scan_u64_parallel(input: &[u64], out: &mut [u64]) {
let mut acc = 0u64;
for (i, &v) in input.iter().enumerate() {
out[i] = acc;
acc = acc.wrapping_add(v);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_scan_sum_empty() {
let v: Vec<u64> = vec![];
assert_eq!(parallel_scan_sum(&v), Vec::<u64>::new());
}
#[test]
fn test_parallel_scan_sum_single() {
assert_eq!(parallel_scan_sum(&[42u64]), vec![0u64]);
}
#[test]
fn test_parallel_scan_sum_basic() {
let v = vec![1u64, 2, 3, 4, 5];
assert_eq!(parallel_scan_sum(&v), vec![0, 1, 3, 6, 10]);
}
#[test]
fn test_parallel_scan_sum_large() {
let n = 10_000usize;
let input: Vec<u64> = (0..n as u64).collect();
let result = parallel_scan_sum(&input);
for i in 0..n {
let expected: u64 = (0..i as u64).sum();
assert_eq!(result[i], expected, "mismatch at index {i}");
}
}
#[test]
fn test_exclusive_scan_multiplication() {
let v = vec![2i64, 3, 4, 5];
let r = exclusive_scan(&v, 1i64, |a, b| a * b);
assert_eq!(r, vec![1, 2, 6, 24]);
}
#[test]
fn test_inclusive_scan_sum() {
let v = vec![1i32, 2, 3, 4, 5];
let r = inclusive_scan(&v, 0i32, |a, b| a + b);
assert_eq!(r, vec![1, 3, 6, 10, 15]);
}
#[test]
fn test_segmented_scan_basic() {
let v = vec![1i32, 2, 3, 4, 5];
let segs = vec![true, false, false, true, false];
let r = segmented_scan(&v, &segs, 0i32, |a, b| a + b).expect("segmented scan failed");
assert_eq!(r, vec![0, 1, 3, 0, 4]);
}
#[test]
fn test_segmented_scan_length_mismatch() {
let v = vec![1i32, 2, 3];
let segs = vec![true, false];
assert!(segmented_scan(&v, &segs, 0i32, |a, b| a + b).is_err());
}
#[test]
fn test_parallel_scan_generic_sum() {
let v = vec![1i32, 2, 3, 4];
let r = parallel_scan_generic(&v, 0i32, |a, b| a + b);
assert_eq!(r, vec![0, 1, 3, 6]);
}
#[test]
fn test_scan_consistency_small_large() {
let v: Vec<u64> = (1..=100).collect();
let result = parallel_scan_sum(&v);
let mut expected = 0u64;
for (i, r) in result.iter().enumerate() {
assert_eq!(*r, expected, "index {i}");
expected += v[i];
}
}
}