1use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
18use cjc_repro::kahan_sum_f64;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum ExecMode {
27 Serial,
29 Parallel,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum ReproMode {
36 Off,
38 On,
40 Strict,
42}
43
44#[derive(Debug, Clone, Copy)]
46pub struct ReductionContext {
47 pub exec_mode: ExecMode,
49 pub repro_mode: ReproMode,
51 pub in_nogc: bool,
53 pub is_linalg: bool,
55}
56
57impl ReductionContext {
58 pub fn default_serial() -> Self {
60 ReductionContext {
61 exec_mode: ExecMode::Serial,
62 repro_mode: ReproMode::On,
63 in_nogc: false,
64 is_linalg: false,
65 }
66 }
67
68 pub fn nogc() -> Self {
70 ReductionContext {
71 exec_mode: ExecMode::Serial,
72 repro_mode: ReproMode::Strict,
73 in_nogc: true,
74 is_linalg: false,
75 }
76 }
77
78 pub fn linalg() -> Self {
80 ReductionContext {
81 exec_mode: ExecMode::Serial,
82 repro_mode: ReproMode::On,
83 in_nogc: false,
84 is_linalg: true,
85 }
86 }
87
88 pub fn strict_parallel() -> Self {
90 ReductionContext {
91 exec_mode: ExecMode::Parallel,
92 repro_mode: ReproMode::Strict,
93 in_nogc: false,
94 is_linalg: false,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum SumStrategy {
106 Kahan,
108 Binned,
110}
111
112pub fn select_strategy(ctx: &ReductionContext) -> SumStrategy {
122 if ctx.exec_mode == ExecMode::Parallel {
123 return SumStrategy::Binned;
124 }
125 if ctx.in_nogc {
126 return SumStrategy::Binned;
127 }
128 if ctx.repro_mode == ReproMode::Strict {
129 return SumStrategy::Binned;
130 }
131 if ctx.is_linalg {
132 return SumStrategy::Binned;
133 }
134 SumStrategy::Kahan
135}
136
137#[inline]
145pub fn dispatch_sum_f64(values: &[f64], ctx: &ReductionContext) -> f64 {
146 match select_strategy(ctx) {
147 SumStrategy::Kahan => kahan_sum_f64(values),
148 SumStrategy::Binned => binned_sum_f64(values),
149 }
150}
151
152#[inline]
159pub fn dispatch_dot_f64(a: &[f64], b: &[f64], ctx: &ReductionContext) -> f64 {
160 debug_assert_eq!(a.len(), b.len());
161 match select_strategy(ctx) {
162 SumStrategy::Kahan => {
163 let products: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
164 kahan_sum_f64(&products)
165 }
166 SumStrategy::Binned => {
167 let mut acc = BinnedAccumulatorF64::new();
168 for (&x, &y) in a.iter().zip(b.iter()) {
169 acc.add(x * y);
170 }
171 acc.finalize()
172 }
173 }
174}
175
176#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_serial_on_uses_kahan() {
186 let ctx = ReductionContext::default_serial();
187 assert_eq!(select_strategy(&ctx), SumStrategy::Kahan);
188 }
189
190 #[test]
191 fn test_parallel_uses_binned() {
192 let ctx = ReductionContext {
193 exec_mode: ExecMode::Parallel,
194 repro_mode: ReproMode::On,
195 in_nogc: false,
196 is_linalg: false,
197 };
198 assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
199 }
200
201 #[test]
202 fn test_nogc_uses_binned() {
203 let ctx = ReductionContext::nogc();
204 assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
205 }
206
207 #[test]
208 fn test_strict_uses_binned() {
209 let ctx = ReductionContext {
210 exec_mode: ExecMode::Serial,
211 repro_mode: ReproMode::Strict,
212 in_nogc: false,
213 is_linalg: false,
214 };
215 assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
216 }
217
218 #[test]
219 fn test_linalg_uses_binned() {
220 let ctx = ReductionContext::linalg();
221 assert_eq!(select_strategy(&ctx), SumStrategy::Binned);
222 }
223
224 #[test]
225 fn test_off_serial_uses_kahan() {
226 let ctx = ReductionContext {
227 exec_mode: ExecMode::Serial,
228 repro_mode: ReproMode::Off,
229 in_nogc: false,
230 is_linalg: false,
231 };
232 assert_eq!(select_strategy(&ctx), SumStrategy::Kahan);
233 }
234
235 #[test]
236 fn test_dispatch_sum_kahan() {
237 let ctx = ReductionContext::default_serial();
238 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
239 assert_eq!(dispatch_sum_f64(&values, &ctx), 15.0);
240 }
241
242 #[test]
243 fn test_dispatch_sum_binned() {
244 let ctx = ReductionContext::strict_parallel();
245 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
246 assert_eq!(dispatch_sum_f64(&values, &ctx), 15.0);
247 }
248
249 #[test]
250 fn test_dispatch_dot_kahan() {
251 let ctx = ReductionContext::default_serial();
252 let a = vec![1.0, 2.0, 3.0];
253 let b = vec![4.0, 5.0, 6.0];
254 assert_eq!(dispatch_dot_f64(&a, &b, &ctx), 32.0);
255 }
256
257 #[test]
258 fn test_dispatch_dot_binned() {
259 let ctx = ReductionContext::strict_parallel();
260 let a = vec![1.0, 2.0, 3.0];
261 let b = vec![4.0, 5.0, 6.0];
262 assert_eq!(dispatch_dot_f64(&a, &b, &ctx), 32.0);
263 }
264
265 #[test]
266 fn test_dispatch_strategies_agree_on_simple() {
267 let values: Vec<f64> = (1..=100).map(|i| i as f64).collect();
268 let kahan_ctx = ReductionContext::default_serial();
269 let binned_ctx = ReductionContext::strict_parallel();
270
271 let kahan_result = dispatch_sum_f64(&values, &kahan_ctx);
272 let binned_result = dispatch_sum_f64(&values, &binned_ctx);
273
274 assert_eq!(kahan_result, 5050.0);
275 assert_eq!(binned_result, 5050.0);
276 }
277}