Skip to main content

cjc_runtime/
dispatch.rs

1//! Hybrid Summation Dispatch — Routes reductions to the appropriate accumulator.
2//!
3//! # Dispatch Rules
4//!
5//! | Condition                          | Strategy              |
6//! |------------------------------------|-----------------------|
7//! | `ExecMode::Parallel`               | BinnedAccumulator     |
8//! | `@nogc` context                    | BinnedAccumulator     |
9//! | `ReproMode::Strict`                | BinnedAccumulator     |
10//! | Reduction inside `LinalgOp`        | BinnedAccumulator     |
11//! | Serial + `ReproMode::On`           | Kahan Summation       |
12//! | Serial + no vectorization          | Kahan Summation       |
13//! | Not forced strict                  | Kahan Summation       |
14//!
15//! The dispatch path is deterministic and unit-tested.
16
17use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
18use cjc_repro::kahan_sum_f64;
19
20// ---------------------------------------------------------------------------
21// Execution context for dispatch decisions
22// ---------------------------------------------------------------------------
23
24/// Execution mode for the current context.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ExecMode {
27    /// Serial, single-threaded execution.
28    Serial,
29    /// Parallel / multi-threaded execution.
30    Parallel,
31}
32
33/// Reproducibility mode.
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum ReproMode {
36    /// Reproducibility disabled — fastest path.
37    Off,
38    /// Reproducibility enabled — Kahan for serial, Binned for parallel.
39    On,
40    /// Strict reproducibility — always Binned, regardless of exec mode.
41    Strict,
42}
43
44/// Reduction context passed to the dispatch logic.
45#[derive(Debug, Clone, Copy)]
46pub struct ReductionContext {
47    /// Current execution mode.
48    pub exec_mode: ExecMode,
49    /// Reproducibility mode.
50    pub repro_mode: ReproMode,
51    /// Whether we are inside a @nogc function.
52    pub in_nogc: bool,
53    /// Whether this is a linalg operation (matmul, etc.).
54    pub is_linalg: bool,
55}
56
57impl ReductionContext {
58    /// Default context: serial, repro on, not in nogc, not linalg.
59    pub fn default_serial() -> Self {
60        ReductionContext {
61            exec_mode: ExecMode::Serial,
62            repro_mode: ReproMode::On,
63            in_nogc: false,
64            is_linalg: false,
65        }
66    }
67
68    /// Context for @nogc zones.
69    pub fn nogc() -> Self {
70        ReductionContext {
71            exec_mode: ExecMode::Serial,
72            repro_mode: ReproMode::Strict,
73            in_nogc: true,
74            is_linalg: false,
75        }
76    }
77
78    /// Context for linalg operations.
79    pub fn linalg() -> Self {
80        ReductionContext {
81            exec_mode: ExecMode::Serial,
82            repro_mode: ReproMode::On,
83            in_nogc: false,
84            is_linalg: true,
85        }
86    }
87
88    /// Context for strict parallel.
89    pub fn strict_parallel() -> Self {
90        ReductionContext {
91            exec_mode: ExecMode::Parallel,
92            repro_mode: ReproMode::Strict,
93            in_nogc: false,
94            is_linalg: false,
95        }
96    }
97}
98
99// ---------------------------------------------------------------------------
100// Strategy selection
101// ---------------------------------------------------------------------------
102
103/// Which summation strategy to use.
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum SumStrategy {
106    /// Kahan compensated summation (serial, order-dependent).
107    Kahan,
108    /// Binned superaccumulator (order-invariant, deterministic).
109    Binned,
110}
111
112/// Determine the appropriate summation strategy for the given context.
113///
114/// # Rules (in priority order)
115///
116/// 1. `ExecMode::Parallel` → Binned
117/// 2. `@nogc` context → Binned
118/// 3. `ReproMode::Strict` → Binned
119/// 4. Linalg operation → Binned
120/// 5. Otherwise → Kahan
121pub fn select_strategy(ctx: &ReductionContext) -> SumStrategy {
122    if ctx.exec_mode == ExecMode::Parallel {
123        return SumStrategy::Binned;
124    }
125    if ctx.in_nogc {
126        return SumStrategy::Binned;
127    }
128    if ctx.repro_mode == ReproMode::Strict {
129        return SumStrategy::Binned;
130    }
131    if ctx.is_linalg {
132        return SumStrategy::Binned;
133    }
134    SumStrategy::Kahan
135}
136
137// ---------------------------------------------------------------------------
138// Dispatched summation functions
139// ---------------------------------------------------------------------------
140
141/// Sum f64 values using the strategy appropriate for the given context.
142///
143/// This is the primary entry point for all reductions in the CJC runtime.
144#[inline]
145pub fn dispatch_sum_f64(values: &[f64], ctx: &ReductionContext) -> f64 {
146    match select_strategy(ctx) {
147        SumStrategy::Kahan => kahan_sum_f64(values),
148        SumStrategy::Binned => binned_sum_f64(values),
149    }
150}
151
152/// Dot product of two equal-length f64 slices using dispatched summation.
153///
154/// Computes element-wise products, then sums with the selected strategy.
155/// For Binned strategy, the products Vec is collected on the stack (via Vec)
156/// before passing to the accumulator. This is acceptable because the Vec
157/// is for the products array, not the accumulator itself.
158#[inline]
159pub fn dispatch_dot_f64(a: &[f64], b: &[f64], ctx: &ReductionContext) -> f64 {
160    debug_assert_eq!(a.len(), b.len());
161    match select_strategy(ctx) {
162        SumStrategy::Kahan => {
163            let products: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
164            kahan_sum_f64(&products)
165        }
166        SumStrategy::Binned => {
167            let mut acc = BinnedAccumulatorF64::new();
168            for (&x, &y) in a.iter().zip(b.iter()) {
169                acc.add(x * y);
170            }
171            acc.finalize()
172        }
173    }
174}
175
176// ---------------------------------------------------------------------------
177// Inline tests
178// ---------------------------------------------------------------------------
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_serial_on_uses_kahan() {
186        let ctx = ReductionContext::default_serial();
187        assert_eq!(select_strategy(&ctx), SumStrategy::Kahan);
188    }
189
190    #[test]
191    fn test_parallel_uses_binned() {
192        let ctx = ReductionContext {
193            exec_mode: ExecMode::Parallel,
194            repro_mode: ReproMode::On,
195            in_nogc: false,
196            is_linalg: false,
197        };
198        assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
199    }
200
201    #[test]
202    fn test_nogc_uses_binned() {
203        let ctx = ReductionContext::nogc();
204        assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
205    }
206
207    #[test]
208    fn test_strict_uses_binned() {
209        let ctx = ReductionContext {
210            exec_mode: ExecMode::Serial,
211            repro_mode: ReproMode::Strict,
212            in_nogc: false,
213            is_linalg: false,
214        };
215        assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
216    }
217
218    #[test]
219    fn test_linalg_uses_binned() {
220        let ctx = ReductionContext::linalg();
221        assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
222    }
223
224    #[test]
225    fn test_off_serial_uses_kahan() {
226        let ctx = ReductionContext {
227            exec_mode: ExecMode::Serial,
228            repro_mode: ReproMode::Off,
229            in_nogc: false,
230            is_linalg: false,
231        };
232        assert_eq!(select_strategy(&ctx), SumStrategy::Kahan);
233    }
234
235    #[test]
236    fn test_dispatch_sum_kahan() {
237        let ctx = ReductionContext::default_serial();
238        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
239        assert_eq!(dispatch_sum_f64(&values, &ctx), 15.0);
240    }
241
242    #[test]
243    fn test_dispatch_sum_binned() {
244        let ctx = ReductionContext::strict_parallel();
245        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
246        assert_eq!(dispatch_sum_f64(&values, &ctx), 15.0);
247    }
248
249    #[test]
250    fn test_dispatch_dot_kahan() {
251        let ctx = ReductionContext::default_serial();
252        let a = vec![1.0, 2.0, 3.0];
253        let b = vec![4.0, 5.0, 6.0];
254        assert_eq!(dispatch_dot_f64(&a, &b, &ctx), 32.0);
255    }
256
257    #[test]
258    fn test_dispatch_dot_binned() {
259        let ctx = ReductionContext::strict_parallel();
260        let a = vec![1.0, 2.0, 3.0];
261        let b = vec![4.0, 5.0, 6.0];
262        assert_eq!(dispatch_dot_f64(&a, &b, &ctx), 32.0);
263    }
264
265    #[test]
266    fn test_dispatch_strategies_agree_on_simple() {
267        let values: Vec<f64> = (1..=100).map(|i| i as f64).collect();
268        let kahan_ctx = ReductionContext::default_serial();
269        let binned_ctx = ReductionContext::strict_parallel();
270
271        let kahan_result = dispatch_sum_f64(&values, &kahan_ctx);
272        let binned_result = dispatch_sum_f64(&values, &binned_ctx);
273
274        assert_eq!(kahan_result, 5050.0);
275        assert_eq!(binned_result, 5050.0);
276    }
277}