use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
use cjc_repro::kahan_sum_f64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecMode {
Serial,
Parallel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReproMode {
Off,
On,
Strict,
}
#[derive(Debug, Clone, Copy)]
pub struct ReductionContext {
pub exec_mode: ExecMode,
pub repro_mode: ReproMode,
pub in_nogc: bool,
pub is_linalg: bool,
}
impl ReductionContext {
pub fn default_serial() -> Self {
ReductionContext {
exec_mode: ExecMode::Serial,
repro_mode: ReproMode::On,
in_nogc: false,
is_linalg: false,
}
}
pub fn nogc() -> Self {
ReductionContext {
exec_mode: ExecMode::Serial,
repro_mode: ReproMode::Strict,
in_nogc: true,
is_linalg: false,
}
}
pub fn linalg() -> Self {
ReductionContext {
exec_mode: ExecMode::Serial,
repro_mode: ReproMode::On,
in_nogc: false,
is_linalg: true,
}
}
pub fn strict_parallel() -> Self {
ReductionContext {
exec_mode: ExecMode::Parallel,
repro_mode: ReproMode::Strict,
in_nogc: false,
is_linalg: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SumStrategy {
Kahan,
Binned,
}
pub fn select_strategy(ctx: &ReductionContext) -> SumStrategy {
if ctx.exec_mode == ExecMode::Parallel {
return SumStrategy::Binned;
}
if ctx.in_nogc {
return SumStrategy::Binned;
}
if ctx.repro_mode == ReproMode::Strict {
return SumStrategy::Binned;
}
if ctx.is_linalg {
return SumStrategy::Binned;
}
SumStrategy::Kahan
}
#[inline]
pub fn dispatch_sum_f64(values: &[f64], ctx: &ReductionContext) -> f64 {
match select_strategy(ctx) {
SumStrategy::Kahan => kahan_sum_f64(values),
SumStrategy::Binned => binned_sum_f64(values),
}
}
#[inline]
pub fn dispatch_dot_f64(a: &[f64], b: &[f64], ctx: &ReductionContext) -> f64 {
debug_assert_eq!(a.len(), b.len());
match select_strategy(ctx) {
SumStrategy::Kahan => {
let products: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
kahan_sum_f64(&products)
}
SumStrategy::Binned => {
let mut acc = BinnedAccumulatorF64::new();
for (&x, &y) in a.iter().zip(b.iter()) {
acc.add(x * y);
}
acc.finalize()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serial_on_uses_kahan() {
let ctx = ReductionContext::default_serial();
assert_eq!(select_strategy(&ctx), SumStrategy::Kahan);
}
#[test]
fn test_parallel_uses_binned() {
let ctx = ReductionContext {
exec_mode: ExecMode::Parallel,
repro_mode: ReproMode::On,
in_nogc: false,
is_linalg: false,
};
assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
}
#[test]
fn test_nogc_uses_binned() {
let ctx = ReductionContext::nogc();
assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
}
#[test]
fn test_strict_uses_binned() {
let ctx = ReductionContext {
exec_mode: ExecMode::Serial,
repro_mode: ReproMode::Strict,
in_nogc: false,
is_linalg: false,
};
assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
}
#[test]
fn test_linalg_uses_binned() {
let ctx = ReductionContext::linalg();
assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
}
#[test]
fn test_off_serial_uses_kahan() {
let ctx = ReductionContext {
exec_mode: ExecMode::Serial,
repro_mode: ReproMode::Off,
in_nogc: false,
is_linalg: false,
};
assert_eq!(select_strategy(&ctx), SumStrategy::Kahan);
}
#[test]
fn test_dispatch_sum_kahan() {
let ctx = ReductionContext::default_serial();
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(dispatch_sum_f64(&values, &ctx), 15.0);
}
#[test]
fn test_dispatch_sum_binned() {
let ctx = ReductionContext::strict_parallel();
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(dispatch_sum_f64(&values, &ctx), 15.0);
}
#[test]
fn test_dispatch_dot_kahan() {
let ctx = ReductionContext::default_serial();
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert_eq!(dispatch_dot_f64(&a, &b, &ctx), 32.0);
}
#[test]
fn test_dispatch_dot_binned() {
let ctx = ReductionContext::strict_parallel();
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert_eq!(dispatch_dot_f64(&a, &b, &ctx), 32.0);
}
#[test]
fn test_dispatch_strategies_agree_on_simple() {
let values: Vec<f64> = (1..=100).map(|i| i as f64).collect();
let kahan_ctx = ReductionContext::default_serial();
let binned_ctx = ReductionContext::strict_parallel();
let kahan_result = dispatch_sum_f64(&values, &kahan_ctx);
let binned_result = dispatch_sum_f64(&values, &binned_ctx);
assert_eq!(kahan_result, 5050.0);
assert_eq!(binned_result, 5050.0);
}
}