use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
pub fn parallel_prefix_sum<T>(data: &[T]) -> Vec<T>
where
T: Clone + Send + Sync + Default + std::ops::Add<Output = T>,
{
parallel_scan(data, T::default(), |a, b| a + b)
}
pub fn parallel_prefix_sum_exclusive<T>(data: &[T]) -> Vec<T>
where
T: Clone + Send + Sync + Default + std::ops::Add<Output = T>,
{
parallel_scan_exclusive(data, T::default(), |a, b| a + b)
}
pub fn parallel_scan<T, F>(data: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
if data.is_empty() {
return Vec::new();
}
let n = data.len();
if n < SEQUENTIAL_THRESHOLD {
return sequential_inclusive_scan(data, &identity, &op);
}
blelloch_inclusive_scan(data, &identity, &op)
}
pub fn parallel_scan_exclusive<T, F>(data: &[T], identity: T, op: F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
if data.is_empty() {
return Vec::new();
}
let n = data.len();
if n < SEQUENTIAL_THRESHOLD {
return sequential_exclusive_scan(data, &identity, &op);
}
blelloch_exclusive_scan(data, &identity, &op)
}
pub fn try_parallel_prefix_sum<T>(data: &[T]) -> CoreResult<Vec<T>>
where
T: Clone + Send + Sync + Default + std::ops::Add<Output = T>,
{
Ok(parallel_prefix_sum(data))
}
pub fn try_parallel_scan<T, F>(data: &[T], identity: T, op: F) -> CoreResult<Vec<T>>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
if data.is_empty() {
return Err(CoreError::ValueError(
ErrorContext::new("parallel_scan requires non-empty input".to_string())
.with_location(ErrorLocation::new(file!(), line!())),
));
}
Ok(parallel_scan(data, identity, op))
}
const SEQUENTIAL_THRESHOLD: usize = 1024;
const TILE_SIZE: usize = 256;
fn sequential_inclusive_scan<T, F>(data: &[T], identity: &T, op: &F) -> Vec<T>
where
T: Clone,
F: Fn(T, T) -> T,
{
let mut result = Vec::with_capacity(data.len());
let mut acc = identity.clone();
for item in data {
acc = op(acc, item.clone());
result.push(acc.clone());
}
result
}
fn sequential_exclusive_scan<T, F>(data: &[T], identity: &T, op: &F) -> Vec<T>
where
T: Clone,
F: Fn(T, T) -> T,
{
let mut result = Vec::with_capacity(data.len());
let mut acc = identity.clone();
for item in data {
result.push(acc.clone());
acc = op(acc, item.clone());
}
result
}
#[cfg(feature = "parallel")]
fn blelloch_inclusive_scan<T, F>(data: &[T], identity: &T, op: &F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
use rayon::prelude::*;
let n = data.len();
let num_tiles = (n + TILE_SIZE - 1) / TILE_SIZE;
let tile_reductions: Vec<T> = (0..num_tiles)
.into_par_iter()
.map(|tile_idx| {
let start = tile_idx * TILE_SIZE;
let end = (start + TILE_SIZE).min(n);
let mut acc = identity.clone();
for item in &data[start..end] {
acc = op(acc, item.clone());
}
acc
})
.collect();
let mut tile_prefixes: Vec<T> = Vec::with_capacity(num_tiles);
{
let mut acc = identity.clone();
for red in &tile_reductions {
tile_prefixes.push(acc.clone());
acc = op(acc, red.clone());
}
}
let mut result: Vec<T> = vec![identity.clone(); n];
let result_chunks: Vec<&mut [T]> = result.chunks_mut(TILE_SIZE).collect();
let result_ptr = result.as_mut_ptr();
let data_ref = data;
std::thread::scope(|s| {
let mut handles = Vec::with_capacity(num_tiles);
for tile_idx in 0..num_tiles {
let start = tile_idx * TILE_SIZE;
let end = (start + TILE_SIZE).min(n);
let tile_prefix = tile_prefixes[tile_idx].clone();
let op_clone = op.clone();
let tile_data = &data_ref[start..end];
let handle = s.spawn(move || {
let mut acc = tile_prefix;
let mut local_result = Vec::with_capacity(end - start);
for item in tile_data {
acc = op_clone(acc, item.clone());
local_result.push(acc.clone());
}
(start, local_result)
});
handles.push(handle);
}
for handle in handles {
if let Ok((start, local_result)) = handle.join() {
for (i, val) in local_result.into_iter().enumerate() {
unsafe {
std::ptr::write(result_ptr.add(start + i), val);
}
}
}
}
});
result
}
#[cfg(feature = "parallel")]
fn blelloch_exclusive_scan<T, F>(data: &[T], identity: &T, op: &F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
use rayon::prelude::*;
let n = data.len();
let num_tiles = (n + TILE_SIZE - 1) / TILE_SIZE;
let tile_reductions: Vec<T> = (0..num_tiles)
.into_par_iter()
.map(|tile_idx| {
let start = tile_idx * TILE_SIZE;
let end = (start + TILE_SIZE).min(n);
let mut acc = identity.clone();
for item in &data[start..end] {
acc = op(acc, item.clone());
}
acc
})
.collect();
let mut tile_prefixes: Vec<T> = Vec::with_capacity(num_tiles);
{
let mut acc = identity.clone();
for red in &tile_reductions {
tile_prefixes.push(acc.clone());
acc = op(acc, red.clone());
}
}
let mut result: Vec<T> = vec![identity.clone(); n];
let result_ptr = result.as_mut_ptr();
let data_ref = data;
std::thread::scope(|s| {
let mut handles = Vec::with_capacity(num_tiles);
for tile_idx in 0..num_tiles {
let start = tile_idx * TILE_SIZE;
let end = (start + TILE_SIZE).min(n);
let tile_prefix = tile_prefixes[tile_idx].clone();
let op_clone = op.clone();
let tile_data = &data_ref[start..end];
let handle = s.spawn(move || {
let mut acc = tile_prefix;
let mut local_result = Vec::with_capacity(end - start);
for item in tile_data {
local_result.push(acc.clone());
acc = op_clone(acc, item.clone());
}
(start, local_result)
});
handles.push(handle);
}
for handle in handles {
if let Ok((start, local_result)) = handle.join() {
for (i, val) in local_result.into_iter().enumerate() {
unsafe {
std::ptr::write(result_ptr.add(start + i), val);
}
}
}
}
});
result
}
#[cfg(not(feature = "parallel"))]
fn blelloch_inclusive_scan<T, F>(data: &[T], identity: &T, op: &F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
sequential_inclusive_scan(data, identity, op)
}
#[cfg(not(feature = "parallel"))]
fn blelloch_exclusive_scan<T, F>(data: &[T], identity: &T, op: &F) -> Vec<T>
where
T: Clone + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Clone,
{
sequential_exclusive_scan(data, identity, op)
}
pub fn parallel_prefix_sum_f64(data: &[f64]) -> Vec<f64> {
parallel_prefix_sum(data)
}
pub fn parallel_prefix_sum_i64(data: &[i64]) -> Vec<i64> {
parallel_prefix_sum(data)
}
pub fn parallel_prefix_min<T>(data: &[T]) -> Vec<T>
where
T: Clone + Send + Sync + Ord,
{
if data.is_empty() {
return Vec::new();
}
let identity = data[0].clone();
parallel_scan(data, identity, |a, b| if a <= b { a } else { b })
}
pub fn parallel_prefix_max<T>(data: &[T]) -> Vec<T>
where
T: Clone + Send + Sync + Ord,
{
if data.is_empty() {
return Vec::new();
}
let identity = data[0].clone();
parallel_scan(data, identity, |a, b| if a >= b { a } else { b })
}
pub fn segmented_prefix_sum<T>(data: &[T], flags: &[bool]) -> Vec<T>
where
T: Clone + Send + Sync + Default + std::ops::Add<Output = T>,
{
let n = data.len().min(flags.len());
if n == 0 {
return Vec::new();
}
let mut result = Vec::with_capacity(n);
let mut acc = T::default();
for i in 0..n {
if flags[i] {
acc = T::default();
}
acc = acc + data[i].clone();
result.push(acc.clone());
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_prefix_sum() {
let data: Vec<i32> = Vec::new();
assert!(parallel_prefix_sum(&data).is_empty());
assert!(parallel_prefix_sum_exclusive(&data).is_empty());
}
#[test]
fn test_single_element() {
assert_eq!(parallel_prefix_sum(&[42]), vec![42]);
assert_eq!(parallel_prefix_sum_exclusive(&[42]), vec![0]);
}
#[test]
fn test_small_inclusive_sum() {
let data = vec![1, 2, 3, 4, 5];
let result = parallel_prefix_sum(&data);
assert_eq!(result, vec![1, 3, 6, 10, 15]);
}
#[test]
fn test_small_exclusive_sum() {
let data = vec![1, 2, 3, 4, 5];
let result = parallel_prefix_sum_exclusive(&data);
assert_eq!(result, vec![0, 1, 3, 6, 10]);
}
#[test]
fn test_generic_scan_multiplication() {
let data = vec![1, 2, 3, 4, 5];
let result = parallel_scan(&data, 1, |a, b| a * b);
assert_eq!(result, vec![1, 2, 6, 24, 120]);
}
#[test]
fn test_generic_scan_max() {
let data = vec![3, 1, 4, 1, 5, 9, 2, 6];
let result = parallel_prefix_max(&data);
assert_eq!(result, vec![3, 3, 4, 4, 5, 9, 9, 9]);
}
#[test]
fn test_generic_scan_min() {
let data = vec![5, 3, 7, 1, 4, 2, 8, 6];
let result = parallel_prefix_min(&data);
assert_eq!(result, vec![5, 3, 3, 1, 1, 1, 1, 1]);
}
#[test]
fn test_large_prefix_sum() {
let n = 5000;
let data: Vec<i64> = (1..=n).collect();
let result = parallel_prefix_sum(&data);
assert_eq!(result[0], 1);
assert_eq!(result[n as usize - 1], n * (n + 1) / 2);
for i in 1..result.len() {
assert!(result[i] > result[i - 1]);
}
}
#[test]
fn test_large_exclusive_sum() {
let n = 5000;
let data: Vec<i64> = (1..=n).collect();
let result = parallel_prefix_sum_exclusive(&data);
assert_eq!(result[0], 0);
assert_eq!(result[1], 1);
assert_eq!(result[n as usize - 1], n * (n - 1) / 2);
}
#[test]
fn test_segmented_prefix_sum() {
let data = vec![1, 2, 3, 1, 2, 3];
let flags = vec![true, false, false, true, false, false];
let result = segmented_prefix_sum(&data, &flags);
assert_eq!(result, vec![1, 3, 6, 1, 3, 6]);
}
#[test]
fn test_segmented_prefix_sum_single_segment() {
let data = vec![1, 2, 3, 4];
let flags = vec![true, false, false, false];
let result = segmented_prefix_sum(&data, &flags);
assert_eq!(result, vec![1, 3, 6, 10]);
}
#[test]
fn test_segmented_prefix_sum_all_segments() {
let data = vec![10, 20, 30];
let flags = vec![true, true, true];
let result = segmented_prefix_sum(&data, &flags);
assert_eq!(result, vec![10, 20, 30]);
}
#[test]
fn test_f64_prefix_sum() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let result = parallel_prefix_sum_f64(&data);
assert!((result[0] - 1.0).abs() < 1e-10);
assert!((result[1] - 3.0).abs() < 1e-10);
assert!((result[2] - 6.0).abs() < 1e-10);
assert!((result[3] - 10.0).abs() < 1e-10);
}
#[test]
fn test_try_parallel_scan_empty_error() {
let data: Vec<i32> = Vec::new();
let result = try_parallel_scan(&data, 0, |a, b| a + b);
assert!(result.is_err());
}
#[test]
fn test_try_parallel_prefix_sum() {
let data = vec![1, 2, 3];
let result = try_parallel_prefix_sum(&data).expect("should succeed");
assert_eq!(result, vec![1, 3, 6]);
}
#[test]
fn test_consistency_inclusive_vs_exclusive() {
let data: Vec<i32> = (1..=100).collect();
let inclusive = parallel_prefix_sum(&data);
let exclusive = parallel_prefix_sum_exclusive(&data);
for i in 0..data.len() {
assert_eq!(exclusive[i] + data[i], inclusive[i]);
}
}
#[test]
fn test_string_concat_scan() {
let data = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let result = parallel_scan(&data, String::new(), |mut a, b| {
a.push_str(&b);
a
});
assert_eq!(result, vec!["a", "ab", "abc"]);
}
#[test]
fn test_large_parallel_correctness() {
let n = 10_000;
let data: Vec<i64> = (0..n).collect();
let par_result = parallel_prefix_sum(&data);
let seq_result = sequential_inclusive_scan(&data, &0i64, &|a, b| a + b);
assert_eq!(par_result, seq_result);
}
}