pub const BASE_CHUNK: usize = 128;
#[inline]
const fn left_split(len: usize) -> usize {
assert!(
len > BASE_CHUNK,
"left_split: caller must guarantee len > BASE_CHUNK"
);
let mut k = BASE_CHUNK;
while k.saturating_mul(2) < len {
k = k.saturating_mul(2);
}
k
}
#[inline]
fn reduce_block<T, F>(acc: T, items: &[T], combine: &F) -> T
where
T: Copy,
F: Fn(T, T) -> T,
{
let mut out = acc;
for &x in items {
out = combine(out, x);
}
out
}
pub fn pairwise_reduce<T, F>(items: &[T], combine: F, identity: T) -> T
where
T: Copy,
F: Fn(T, T) -> T,
{
reduce_range(items, &combine, identity)
}
fn reduce_range<T, F>(items: &[T], combine: &F, identity: T) -> T
where
T: Copy,
F: Fn(T, T) -> T,
{
let len = items.len();
if len == 0 {
return identity;
}
if len <= BASE_CHUNK {
return reduce_block(items[0], &items[1..], combine);
}
let mid = left_split(len);
let left = reduce_range(&items[..mid], combine, identity);
let right = reduce_range(&items[mid..], combine, identity);
combine(left, right)
}
pub fn pairwise_sum(xs: &[f64]) -> f64 {
pairwise_reduce(xs, |a, b| a + b, 0.0)
}
pub struct StreamingPairwise<T, F>
where
T: Copy,
F: Fn(T, T) -> T,
{
combine: F,
identity: T,
buf: Vec<T>,
forest: Vec<(usize, T)>,
}
impl<T, F> StreamingPairwise<T, F>
where
T: Copy,
F: Fn(T, T) -> T,
{
pub fn new(combine: F, identity: T) -> Self {
Self {
combine,
identity,
buf: Vec::with_capacity(BASE_CHUNK),
forest: Vec::new(),
}
}
pub fn push(&mut self, x: T) {
self.buf.push(x);
if self.buf.len() == BASE_CHUNK {
let block = reduce_block(self.buf[0], &self.buf[1..], &self.combine);
self.buf.clear();
self.absorb(BASE_CHUNK, block);
}
}
pub fn extend_from_slice(&mut self, chunk: &[T]) {
for &x in chunk {
self.push(x);
}
}
fn absorb(&mut self, weight: usize, value: T) {
let mut w = weight;
let mut v = value;
while let Some(&(top_w, top_v)) = self.forest.last() {
if top_w == w {
self.forest.pop();
v = (self.combine)(top_v, v);
w = w.saturating_mul(2);
} else {
break;
}
}
self.forest.push((w, v));
}
pub fn finish(mut self) -> T {
if !self.buf.is_empty() {
let tail = reduce_block(self.buf[0], &self.buf[1..], &self.combine);
let tail_w = self.buf.len();
self.buf.clear();
self.forest.push((tail_w, tail));
}
let mut iter = self.forest.into_iter().rev();
match iter.next() {
None => self.identity,
Some((_, mut acc)) => {
for (_, left) in iter {
acc = (self.combine)(left, acc);
}
acc
}
}
}
}
pub fn pairwise_reduce_chunked<'a, T, F, I>(chunks: I, combine: F, identity: T) -> T
where
T: Copy + 'a,
F: Fn(T, T) -> T,
I: IntoIterator<Item = &'a [T]>,
{
let mut acc = StreamingPairwise::new(combine, identity);
for chunk in chunks {
acc.extend_from_slice(chunk);
}
acc.finish()
}
pub fn pairwise_sum_chunked<'a, I>(chunks: I) -> f64
where
I: IntoIterator<Item = &'a [f64]>,
{
pairwise_reduce_chunked(chunks, |a, b| a + b, 0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn left_split_minimal_case() {
assert_eq!(super::left_split(BASE_CHUNK + 1), BASE_CHUNK);
}
#[test]
fn left_split_at_two_blocks() {
assert_eq!(super::left_split(2 * BASE_CHUNK), BASE_CHUNK);
}
#[test]
fn left_split_just_above_two_blocks() {
assert_eq!(super::left_split(2 * BASE_CHUNK + 1), 2 * BASE_CHUNK);
}
#[test]
fn left_split_at_four_blocks() {
assert_eq!(super::left_split(4 * BASE_CHUNK), 2 * BASE_CHUNK);
}
#[test]
fn pairwise_reduce_empty_returns_identity() {
let result = pairwise_reduce::<u64, _>(&[], |a, b| a + b, 99);
assert_eq!(result, 99);
}
#[test]
fn pairwise_reduce_single_element() {
assert_eq!(pairwise_reduce(&[42u64], |a, b| a + b, 0), 42);
}
#[test]
fn pairwise_reduce_small_sum() {
let xs = [1u64, 2, 3, 4, 5];
assert_eq!(pairwise_reduce(&xs, |a, b| a + b, 0), 15);
}
#[test]
fn pairwise_reduce_product() {
let xs = [2u64, 3, 4, 5];
assert_eq!(pairwise_reduce(&xs, |a, b| a * b, 1), 120);
}
#[test]
fn pairwise_sum_empty_is_zero() {
assert_eq!(pairwise_sum(&[]), 0.0);
}
#[test]
fn pairwise_sum_single_element() {
assert_eq!(pairwise_sum(&[3.5f64]), 3.5);
}
#[test]
fn pairwise_sum_small_slice_exact() {
assert_eq!(pairwise_sum(&[1.0f64, 2.0, 3.0, 4.0, 5.0]), 15.0);
}
#[test]
fn pairwise_sum_exactly_base_chunk_elements() {
let xs: Vec<f64> = (1..=BASE_CHUNK as u64).map(|x| x as f64).collect();
let naive: f64 = (1..=BASE_CHUNK as u64).map(|x| x as f64).sum();
assert_eq!(pairwise_sum(&xs), naive);
}
#[test]
fn pairwise_sum_one_above_base_chunk_triggers_split() {
let xs = vec![1.0f64; BASE_CHUNK + 1];
assert_eq!(pairwise_sum(&xs), (BASE_CHUNK + 1) as f64);
}
#[test]
fn pairwise_sum_two_base_chunks() {
let xs = vec![1.0f64; 2 * BASE_CHUNK];
assert_eq!(pairwise_sum(&xs), (2 * BASE_CHUNK) as f64);
}
#[test]
fn streaming_one_at_a_time_matches_whole_slice() {
let xs: Vec<f64> = (0..300).map(|i| i as f64 * 0.1).collect();
let expected = pairwise_sum(&xs);
let mut acc = StreamingPairwise::new(|a: f64, b: f64| a + b, 0.0);
for &x in &xs {
acc.push(x);
}
assert_eq!(acc.finish().to_bits(), expected.to_bits());
}
#[test]
fn chunked_matches_whole_slice_across_chunk_sizes() {
let xs: Vec<f64> = (0..500).map(|i| i as f64).collect();
let expected = pairwise_sum(&xs);
for chunk_size in [1usize, 7, 64, 128, 129, 200, 499, 500] {
let chunks: Vec<&[f64]> = xs.chunks(chunk_size).collect();
let result = pairwise_sum_chunked(chunks);
assert_eq!(
result.to_bits(),
expected.to_bits(),
"chunk_size={chunk_size}"
);
}
}
#[test]
fn pairwise_reduce_chunked_matches_whole_slice() {
let xs: Vec<u64> = (1..=300).collect();
let expected = pairwise_reduce(&xs, |a, b| a + b, 0u64);
let chunks: Vec<&[u64]> = xs.chunks(77).collect();
let result = pairwise_reduce_chunked(chunks, |a, b| a + b, 0u64);
assert_eq!(result, expected);
}
#[test]
fn pairwise_sum_chunked_basic() {
let a = [1.0f64, 2.0, 3.0];
let b = [4.0f64, 5.0];
assert_eq!(pairwise_sum_chunked([a.as_ref(), b.as_ref()]), 15.0);
}
}