use num_traits::Float;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
fn compute_chunk<T: Float>(a: &[T], b: &[T], start: usize, end: usize) -> (Vec<T>, Vec<T>, T, T) {
let len = end - start;
let mut cum_a = Vec::with_capacity(len); let mut local_h = Vec::with_capacity(len);
let mut running_a = T::one(); let mut running_h = T::zero(); for i in start..end {
running_a = running_a * a[i]; running_h = a[i] * running_h + b[i]; cum_a.push(running_a);
local_h.push(running_h);
}
(cum_a, local_h, running_a, running_h)
}
pub fn chunked_scan<T: Float + Send + Sync>(a: &[T], b: &[T], h0: T, chunk_size: usize) -> Vec<T> {
assert_eq!(a.len(), b.len(), "a and b must be the same length");
assert!(chunk_size > 0, "chunk_size must be positive");
let n = a.len();
if n == 0 {
return Vec::new();
}
let num_chunks = n.div_ceil(chunk_size);
let chunk_bounds = |k: usize| (k * chunk_size, ((k + 1) * chunk_size).min(n));
#[cfg(feature = "parallel")]
let chunk_results: Vec<(Vec<T>, Vec<T>, T, T)> = (0..num_chunks)
.into_par_iter()
.map(|k| {
let (start, end) = chunk_bounds(k);
compute_chunk(a, b, start, end)
})
.collect();
#[cfg(not(feature = "parallel"))]
let chunk_results: Vec<(Vec<T>, Vec<T>, T, T)> = (0..num_chunks)
.map(|k| {
let (start, end) = chunk_bounds(k);
compute_chunk(a, b, start, end)
})
.collect();
let mut cum_a_local: Vec<T> = Vec::with_capacity(n); let mut local_h: Vec<T> = Vec::with_capacity(n); let mut chunk_decay: Vec<T> = Vec::with_capacity(num_chunks); let mut chunk_state: Vec<T> = Vec::with_capacity(num_chunks); for (cum_a_chunk, h_chunk, decay, state) in chunk_results {
cum_a_local.extend(cum_a_chunk);
local_h.extend(h_chunk);
chunk_decay.push(decay);
chunk_state.push(state);
}
let mut state_in: Vec<T> = Vec::with_capacity(num_chunks);
let mut carry = h0;
for k in 0..num_chunks {
state_in.push(carry); carry = chunk_decay[k] * carry + chunk_state[k]; }
#[cfg(feature = "parallel")]
let out: Vec<T> = (0..n)
.into_par_iter()
.map(|i| {
let k = i / chunk_size; local_h[i] + cum_a_local[i] * state_in[k] })
.collect();
#[cfg(not(feature = "parallel"))]
let out: Vec<T> = (0..n)
.map(|i| {
let k = i / chunk_size;
local_h[i] + cum_a_local[i] * state_in[k]
})
.collect();
out
}
#[cfg(all(test, feature = "parallel"))]
mod parallel_tests {
use super::*;
use crate::scan::sequential_scan;
#[test]
fn parallel_path_matches_sequential() {
let n = 500;
let a: Vec<f64> = (0..n).map(|i| 0.5 + 0.4 * (i as f64 * 0.013).sin()).collect();
let b: Vec<f64> = (0..n).map(|i| (i as f64 * 0.029).cos()).collect();
let expected = sequential_scan(&a, &b, 0.0);
for chunk_size in [1, 7, 32, 128, 500, 999] {
let actual = chunked_scan(&a, &b, 0.0, chunk_size);
for (e, got) in expected.iter().zip(actual.iter()) {
assert!((e - got).abs() < 1e-9, "chunk_size={chunk_size}: expected {e}, got {got}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::scan::sequential_scan;
#[test]
fn matches_sequential_cumsum() {
let a = [1.0, 1.0, 1.0, 1.0];
let b = [1.0, 2.0, 3.0, 4.0];
let expected = sequential_scan(&a, &b, 0.0);
for chunk_size in 1..=5 {
assert_eq!(chunked_scan(&a, &b, 0.0, chunk_size), expected);
}
}
#[test]
fn matches_sequential_with_decay() {
let a = [0.5, 0.5, 0.5, 0.5, 0.5];
let b = [1.0, 1.0, 1.0, 1.0, 1.0];
let h0 = 0.3;
let expected = sequential_scan(&a, &b, h0);
for chunk_size in 1..=6 {
assert_eq!(chunked_scan(&a, &b, h0, chunk_size), expected);
}
}
#[test]
fn handles_uneven_chunks() {
let a = [0.9; 7];
let b: [f64; 7] = [1.0, 0.5, 2.0, 0.0, 1.5, 3.0, 0.2];
let expected = sequential_scan(&a, &b, 1.0);
assert_eq!(chunked_scan(&a, &b, 1.0, 3), expected);
}
#[test]
fn empty_input() {
let a: [f64; 0] = [];
let b: [f64; 0] = [];
assert_eq!(chunked_scan(&a, &b, 1.0, 4), Vec::<f64>::new());
}
}