use rayon::prelude::*;
pub const BLOCK: usize = 1 << 12;
#[must_use]
pub fn tree_sum(xs: &[f64]) -> f64 {
if xs.is_empty() {
return 0.0;
}
if xs.len() == 1 {
return xs[0];
}
let mut buf: Vec<f64> = xs.to_vec();
while buf.len() > 1 {
let pairs = buf.len() / 2;
let trailing = buf.len() % 2 == 1;
for i in 0..pairs {
buf[i] = buf[2 * i] + buf[2 * i + 1];
}
if trailing {
buf[pairs] = buf[2 * pairs];
buf.truncate(pairs + 1);
} else {
buf.truncate(pairs);
}
}
buf[0]
}
#[must_use]
pub fn par_tree_sum(xs: &[f64]) -> f64 {
if xs.len() <= BLOCK {
return tree_sum(xs);
}
let block_sums: Vec<f64> = xs.par_chunks(BLOCK).map(tree_sum).collect();
tree_sum(&block_sums)
}
#[must_use]
pub fn tree_dot(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "tree_dot: length mismatch");
if a.is_empty() {
return 0.0;
}
let products: Vec<f64> = a.iter().zip(b.iter()).map(|(x, y)| x * y).collect();
tree_sum(&products)
}
#[must_use]
pub fn par_tree_dot(a: &[f64], b: &[f64]) -> f64 {
assert_eq!(a.len(), b.len(), "par_tree_dot: length mismatch");
if a.len() <= BLOCK {
return tree_dot(a, b);
}
let block_sums: Vec<f64> = a
.par_chunks(BLOCK)
.zip(b.par_chunks(BLOCK))
.map(|(ac, bc)| tree_dot(ac, bc))
.collect();
tree_sum(&block_sums)
}
#[must_use]
pub fn tree_var(xs: &[f64]) -> f64 {
let n = xs.len();
if n < 2 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let n_f = n as f64;
let mean = tree_sum(xs) / n_f;
let centered_sq: Vec<f64> = xs
.iter()
.map(|&x| {
let d = x - mean;
d * d
})
.collect();
#[allow(clippy::cast_precision_loss)]
let denom = (n - 1) as f64;
tree_sum(¢ered_sq) / denom
}
#[must_use]
pub fn par_tree_var(xs: &[f64]) -> f64 {
let n = xs.len();
if n < 2 {
return 0.0;
}
#[allow(clippy::cast_precision_loss)]
let n_f = n as f64;
let mean = par_tree_sum(xs) / n_f;
let centered_sq: Vec<f64> = xs
.par_iter()
.map(|&x| {
let d = x - mean;
d * d
})
.collect();
#[allow(clippy::cast_precision_loss)]
let denom = (n - 1) as f64;
par_tree_sum(¢ered_sq) / denom
}
#[cfg(test)]
#[allow(
clippy::float_cmp,
clippy::approx_constant,
clippy::cast_precision_loss
)]
mod tests {
use super::*;
fn linear_vec(n: usize) -> Vec<f64> {
(0..n).map(|i| (i as f64) * 0.5).collect()
}
#[test]
fn tree_sum_empty_is_zero() {
assert_eq!(tree_sum(&[]), 0.0);
}
#[test]
fn tree_sum_single_is_passthrough() {
assert_eq!(tree_sum(&[3.14]), 3.14);
}
#[test]
fn tree_sum_pair_is_sum() {
assert_eq!(tree_sum(&[1.0, 2.0]), 3.0);
}
#[test]
fn tree_sum_odd_length_handles_trailing_element() {
assert_eq!(tree_sum(&[1.0, 2.0, 3.0, 4.0, 5.0]), 15.0);
}
#[test]
fn tree_sum_matches_arithmetic_progression_closed_form() {
let n = 1000usize;
let xs = linear_vec(n);
#[allow(clippy::cast_precision_loss)]
let expected = 0.5 * (n as f64) * ((n - 1) as f64) / 2.0;
let got = tree_sum(&xs);
assert!(
(got - expected).abs() < 1e-9,
"tree_sum {got} vs closed form {expected}"
);
}
#[test]
fn par_tree_sum_short_input_falls_through_to_tree_sum() {
let xs = linear_vec(100);
assert_eq!(par_tree_sum(&xs), tree_sum(&xs));
}
#[test]
fn par_tree_sum_long_input_matches_tree_sum_bitwise() {
let xs = linear_vec(BLOCK * 16 + 137);
let p = par_tree_sum(&xs);
let s = tree_sum(&xs);
assert!(
p.to_bits() == s.to_bits(),
"par_tree_sum != tree_sum bitwise"
);
}
#[test]
fn par_tree_sum_is_self_consistent_across_reruns() {
let xs = linear_vec(BLOCK * 8);
let a = par_tree_sum(&xs);
let b = par_tree_sum(&xs);
assert!(a.to_bits() == b.to_bits());
}
#[test]
fn tree_dot_matches_naive_dot_for_short_inputs() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 3.0, 4.0, 5.0];
assert_eq!(tree_dot(&a, &b), 40.0);
}
#[test]
#[should_panic(expected = "length mismatch")]
fn tree_dot_length_mismatch_panics() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let _ = tree_dot(&a, &b);
}
#[test]
fn par_tree_dot_long_input_matches_tree_dot_bitwise() {
let a = linear_vec(BLOCK * 4 + 11);
let b: Vec<f64> = a.iter().rev().copied().collect();
let p = par_tree_dot(&a, &b);
let s = tree_dot(&a, &b);
assert!(p.to_bits() == s.to_bits());
}
#[test]
fn tree_var_constant_input_is_zero() {
let xs = vec![3.0; 1000];
let v = tree_var(&xs);
assert!(v.abs() < 1e-12, "constant input variance {v} not ~0");
}
#[test]
fn tree_var_matches_unbiased_formula_for_simple_input() {
let xs = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let v = tree_var(&xs);
assert!((v - 2.5).abs() < 1e-12, "got {v}");
}
#[test]
fn par_tree_var_matches_tree_var_bitwise_long_input() {
let xs = linear_vec(BLOCK * 4 + 11);
let p = par_tree_var(&xs);
let s = tree_var(&xs);
assert!(p.to_bits() == s.to_bits());
}
#[test]
fn tree_var_short_input_returns_zero_below_n_two() {
assert_eq!(tree_var(&[]), 0.0);
assert_eq!(tree_var(&[42.0]), 0.0);
}
#[test]
fn tree_sum_two_elements_uses_single_addition() {
let a = 1.234_567_8_f64;
let b = 9.876_543_2_f64;
assert_eq!(tree_sum(&[a, b]).to_bits(), (a + b).to_bits());
}
#[test]
fn tree_sum_four_elements_pairs_then_pairs() {
let a = 1.0;
let b = 2.0;
let c = 3.0;
let d = 4.0;
let expected = (a + b) + (c + d);
assert_eq!(tree_sum(&[a, b, c, d]).to_bits(), expected.to_bits());
}
#[test]
fn tree_sum_distinguishes_naive_left_fold_under_catastrophic_cancellation() {
let mut xs = Vec::with_capacity(1 + (1 << 20));
xs.push(1e16);
xs.extend(std::iter::repeat_n(1.0_f64, 1 << 20));
let naive: f64 = xs.iter().copied().sum();
let tree = tree_sum(&xs);
assert!(
tree > naive,
"tree_sum {tree} should exceed naive sum {naive} on this input"
);
let exact = 1e16 + ((1u64 << 20) as f64);
assert_eq!(tree.to_bits(), exact.to_bits());
}
#[test]
fn tree_sum_zero_vector_is_zero() {
let xs = vec![0.0_f64; 1024];
assert_eq!(tree_sum(&xs), 0.0);
}
#[test]
fn tree_sum_negative_values_sum_correctly() {
let xs = vec![-1.0, -2.0, -3.0, -4.0];
assert_eq!(tree_sum(&xs), -10.0);
}
#[test]
fn tree_sum_alternating_signs_cancels() {
let xs = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
assert_eq!(tree_sum(&xs), 0.0);
}
#[test]
fn tree_sum_handles_lengths_through_block_boundary() {
for n in [
1, 2, 3, 7, 16, 100, 1023, 1024, 4095, 4096, 4097, 8191, 8192, 8193,
] {
let xs = linear_vec(n);
assert_eq!(
par_tree_sum(&xs).to_bits(),
tree_sum(&xs).to_bits(),
"par_tree_sum != tree_sum at N={n}"
);
}
}
#[test]
fn tree_sum_scaled_input_scales_output_for_finite_factor() {
let xs = linear_vec(2048);
let c = 7.5_f64;
let scaled: Vec<f64> = xs.iter().map(|x| c * x).collect();
let lhs = tree_sum(&scaled);
let rhs = c * tree_sum(&xs);
assert!(
(lhs - rhs).abs() <= 1e-9 * lhs.abs().max(rhs.abs()),
"tree_sum scaling: lhs={lhs} rhs={rhs}"
);
}
#[test]
fn tree_sum_concat_equals_sum_of_subsums_for_powers_of_two() {
for k in [4, 8, 12] {
let n = 1usize << k;
let xs = linear_vec(n);
let half = n / 2;
let lhs = tree_sum(&xs[..half]) + tree_sum(&xs[half..]);
let rhs = tree_sum(&xs);
assert_eq!(
lhs.to_bits(),
rhs.to_bits(),
"split-half identity at N=2^{k}"
);
}
}
#[test]
fn tree_dot_with_zero_vector_is_zero() {
let a = linear_vec(1024);
let zero = vec![0.0_f64; 1024];
assert_eq!(tree_dot(&a, &zero), 0.0);
}
#[test]
fn tree_dot_is_commutative_bitwise() {
let a = linear_vec(4096);
let b: Vec<f64> = a.iter().rev().copied().collect();
assert_eq!(tree_dot(&a, &b).to_bits(), tree_dot(&b, &a).to_bits());
}
#[test]
fn tree_dot_scales_with_either_argument() {
let a = linear_vec(2048);
let b = linear_vec(2048);
let c = 3.0_f64;
let scaled_a: Vec<f64> = a.iter().map(|x| c * x).collect();
let lhs = tree_dot(&scaled_a, &b);
let rhs = c * tree_dot(&a, &b);
assert!((lhs - rhs).abs() <= 1e-9 * lhs.abs().max(rhs.abs()));
}
#[test]
fn tree_dot_of_empty_pair_is_zero() {
assert_eq!(tree_dot(&[], &[]), 0.0);
}
#[test]
fn tree_dot_of_unit_vectors_is_inner_product() {
let a = vec![3.0];
let b = vec![4.0];
assert_eq!(tree_dot(&a, &b), 12.0);
}
#[test]
fn tree_dot_squares_match_tree_sum_of_squares() {
let a = linear_vec(1024);
let squares: Vec<f64> = a.iter().map(|x| x * x).collect();
let lhs = tree_dot(&a, &a);
let rhs = tree_sum(&squares);
assert_eq!(lhs.to_bits(), rhs.to_bits());
}
#[test]
fn tree_var_is_translation_invariant() {
let xs = linear_vec(2048);
let c = 100.0_f64;
let shifted: Vec<f64> = xs.iter().map(|x| x + c).collect();
let v_xs = tree_var(&xs);
let v_shift = tree_var(&shifted);
assert!(
(v_xs - v_shift).abs() <= 1e-7 * v_xs.abs().max(v_shift.abs()),
"translation invariance: var(xs)={v_xs} var(xs+c)={v_shift}"
);
}
#[test]
fn tree_var_scales_quadratically() {
let xs = linear_vec(2048);
let c = 4.0_f64;
let scaled: Vec<f64> = xs.iter().map(|x| c * x).collect();
let lhs = tree_var(&scaled);
let rhs = c * c * tree_var(&xs);
assert!(
(lhs - rhs).abs() <= 1e-9 * lhs.abs().max(rhs.abs()),
"var quadratic scaling: lhs={lhs} rhs={rhs}"
);
}
#[test]
fn tree_var_uses_bessel_correction_n_minus_one() {
let xs = vec![1.0, 2.0];
let v = tree_var(&xs);
assert!((v - 0.5).abs() < 1e-12, "Bessel-corrected var; got {v}");
}
#[test]
fn tree_var_of_two_equal_values_is_zero() {
let xs = vec![5.0, 5.0];
assert_eq!(tree_var(&xs), 0.0);
}
#[test]
fn tree_var_of_centered_dataset_matches_naive_unbiased() {
let xs = vec![-1.0, 0.0, 1.0];
let v = tree_var(&xs);
assert!((v - 1.0).abs() < 1e-12, "got {v}");
}
#[test]
fn tree_var_matches_naive_var_for_uniform_grid() {
let n = 100usize;
let xs: Vec<f64> = (0..n).map(|i| i as f64).collect();
let mean: f64 = xs.iter().copied().sum::<f64>() / n as f64;
let naive_var: f64 = xs.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
let v = tree_var(&xs);
assert!(
(v - naive_var).abs() <= 1e-9 * v.abs().max(naive_var.abs()),
"tree_var {v} vs naive {naive_var}"
);
}
#[test]
fn par_tree_sum_matches_tree_sum_for_zero_vector() {
let xs = vec![0.0_f64; BLOCK * 4];
assert_eq!(par_tree_sum(&xs).to_bits(), tree_sum(&xs).to_bits());
}
#[test]
fn par_tree_sum_matches_tree_sum_for_negative_values() {
let xs: Vec<f64> = (0..BLOCK * 3).map(|i| -(i as f64) * 0.25).collect();
assert_eq!(par_tree_sum(&xs).to_bits(), tree_sum(&xs).to_bits());
}
#[test]
fn par_tree_dot_handles_lengths_through_block_boundary() {
for n in [1, 2, 1023, 1024, 4095, 4096, 4097, 8193] {
let a = linear_vec(n);
let b: Vec<f64> = a.iter().map(|x| x + 1.0).collect();
assert_eq!(
par_tree_dot(&a, &b).to_bits(),
tree_dot(&a, &b).to_bits(),
"par_tree_dot != tree_dot at N={n}"
);
}
}
#[test]
fn par_tree_var_handles_lengths_through_block_boundary() {
for n in [2, 100, 1024, 4096, 4097, 8192] {
let xs = linear_vec(n);
assert_eq!(
par_tree_var(&xs).to_bits(),
tree_var(&xs).to_bits(),
"par_tree_var != tree_var at N={n}"
);
}
}
#[test]
fn tree_sum_of_three_elements_pairs_first_two_then_adds_third() {
let a = 1e10_f64;
let b = 1e-10_f64;
let c = -1e10_f64;
let expected = (a + b) + c;
assert_eq!(tree_sum(&[a, b, c]).to_bits(), expected.to_bits());
}
#[test]
fn tree_sum_handles_subnormals() {
let small = f64::MIN_POSITIVE / 2.0; let xs = vec![small; 8];
let s = tree_sum(&xs);
assert!(s > 0.0);
assert_eq!(s.to_bits(), (8.0 * small).to_bits());
}
#[test]
fn par_tree_sum_one_block_exact() {
let xs = linear_vec(BLOCK);
assert_eq!(par_tree_sum(&xs).to_bits(), tree_sum(&xs).to_bits());
}
#[test]
fn par_tree_sum_many_blocks_with_tail() {
let xs = linear_vec(BLOCK * 8 + 1);
assert_eq!(par_tree_sum(&xs).to_bits(), tree_sum(&xs).to_bits());
}
#[test]
fn par_tree_dot_with_self_equals_sum_of_squares() {
let a = linear_vec(BLOCK * 3 + 7);
let p = par_tree_dot(&a, &a);
let s_seq = tree_dot(&a, &a);
assert_eq!(p.to_bits(), s_seq.to_bits());
}
#[test]
fn tree_dot_distributes_over_addition_within_fp() {
let a = linear_vec(1024);
let b = linear_vec(1024);
let c: Vec<f64> = (0..1024_i32).map(|i| f64::from(i) * 0.1).collect();
let bc: Vec<f64> = b.iter().zip(c.iter()).map(|(x, y)| x + y).collect();
let lhs = tree_dot(&a, &bc);
let rhs = tree_dot(&a, &b) + tree_dot(&a, &c);
assert!(
(lhs - rhs).abs() <= 1e-9 * lhs.abs().max(rhs.abs()),
"lhs={lhs} rhs={rhs}"
);
}
}