Skip to main content

nodedb_query/
simd_agg.rs

1//! SIMD-accelerated aggregation kernels for timeseries f64 columns.
2//!
3//! Runtime CPU detection selects the fastest available path:
4//! - AVX-512 (512-bit, 8 f64/op) — Intel Xeon, AMD Zen 4+
5//! - AVX2+FMA (256-bit, 4 f64/op) — most x86_64 since 2013
6//! - NEON (128-bit, 2 f64/op) — ARM64 (Graviton, Apple Silicon)
7//! - Scalar fallback — auto-vectorized by LLVM
8//!
9//! All kernels use Kahan compensated summation for numerical accuracy.
10
11/// SIMD runtime for timeseries f64 aggregation.
12pub struct TsSimdRuntime {
13    pub sum_f64: fn(&[f64]) -> f64,
14    pub min_f64: fn(&[f64]) -> f64,
15    pub max_f64: fn(&[f64]) -> f64,
16    /// Timestamp range filter: returns indices where min <= val <= max.
17    pub range_filter_i64: fn(&[i64], i64, i64) -> Vec<u32>,
18    pub name: &'static str,
19}
20
21impl TsSimdRuntime {
22    pub fn detect() -> Self {
23        #[cfg(target_arch = "x86_64")]
24        {
25            if std::is_x86_feature_detected!("avx512f") {
26                return Self {
27                    sum_f64: avx512_sum_f64,
28                    min_f64: avx512_min_f64,
29                    max_f64: avx512_max_f64,
30                    range_filter_i64: avx512_range_filter_i64,
31                    name: "avx512",
32                };
33            }
34            if std::is_x86_feature_detected!("avx2") {
35                return Self {
36                    sum_f64: avx2_sum_f64,
37                    min_f64: avx2_min_f64,
38                    max_f64: avx2_max_f64,
39                    range_filter_i64: avx2_range_filter_i64,
40                    name: "avx2",
41                };
42            }
43        }
44        #[cfg(target_arch = "aarch64")]
45        {
46            return Self {
47                sum_f64: neon_sum_f64,
48                min_f64: neon_min_f64,
49                max_f64: neon_max_f64,
50                range_filter_i64: scalar_range_filter_i64,
51                name: "neon",
52            };
53        }
54        #[cfg(target_arch = "wasm32")]
55        {
56            return Self {
57                sum_f64: wasm_sum_f64,
58                min_f64: wasm_min_f64,
59                max_f64: wasm_max_f64,
60                range_filter_i64: scalar_range_filter_i64,
61                name: "wasm-simd128",
62            };
63        }
64        #[allow(unreachable_code)]
65        Self {
66            sum_f64: scalar_sum_f64,
67            min_f64: scalar_min_f64,
68            max_f64: scalar_max_f64,
69            range_filter_i64: scalar_range_filter_i64,
70            name: "scalar",
71        }
72    }
73}
74
75static TS_RUNTIME: std::sync::OnceLock<TsSimdRuntime> = std::sync::OnceLock::new();
76
77/// Get the global timeseries SIMD runtime.
78pub fn ts_runtime() -> &'static TsSimdRuntime {
79    TS_RUNTIME.get_or_init(TsSimdRuntime::detect)
80}
81
82// ── Scalar fallback (auto-vectorized by LLVM) ──
83
84fn scalar_sum_f64(values: &[f64]) -> f64 {
85    // Kahan compensated summation.
86    let mut sum = 0.0f64;
87    let mut comp = 0.0f64;
88    for &v in values {
89        let y = v - comp;
90        let t = sum + y;
91        comp = (t - sum) - y;
92        sum = t;
93    }
94    sum
95}
96
97fn scalar_min_f64(values: &[f64]) -> f64 {
98    let mut m = f64::INFINITY;
99    for &v in values {
100        if v < m {
101            m = v;
102        }
103    }
104    m
105}
106
107fn scalar_max_f64(values: &[f64]) -> f64 {
108    let mut m = f64::NEG_INFINITY;
109    for &v in values {
110        if v > m {
111            m = v;
112        }
113    }
114    m
115}
116
117fn scalar_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
118    let mut out = Vec::new();
119    for (i, &v) in values.iter().enumerate() {
120        if v >= min && v <= max {
121            out.push(i as u32);
122        }
123    }
124    out
125}
126
127// ── AVX-512 (x86_64) ──
128
129#[cfg(target_arch = "x86_64")]
130#[target_feature(enable = "avx512f")]
131unsafe fn avx512_sum_f64_inner(values: &[f64]) -> f64 {
132    use std::arch::x86_64::*;
133    unsafe {
134        let mut acc = _mm512_setzero_pd();
135        let chunks = values.len() / 8;
136        let ptr = values.as_ptr();
137        for i in 0..chunks {
138            let v = _mm512_loadu_pd(ptr.add(i * 8));
139            acc = _mm512_add_pd(acc, v);
140        }
141        let mut sum = _mm512_reduce_add_pd(acc);
142        for &v in &values[chunks * 8..] {
143            sum += v;
144        }
145        sum
146    }
147}
148
149#[cfg(target_arch = "x86_64")]
150fn avx512_sum_f64(values: &[f64]) -> f64 {
151    if values.len() < 16 {
152        return scalar_sum_f64(values);
153    }
154    unsafe { avx512_sum_f64_inner(values) }
155}
156
157#[cfg(target_arch = "x86_64")]
158#[target_feature(enable = "avx512f")]
159unsafe fn avx512_min_f64_inner(values: &[f64]) -> f64 {
160    use std::arch::x86_64::*;
161    unsafe {
162        let mut acc = _mm512_set1_pd(f64::INFINITY);
163        let chunks = values.len() / 8;
164        let ptr = values.as_ptr();
165        for i in 0..chunks {
166            let v = _mm512_loadu_pd(ptr.add(i * 8));
167            acc = _mm512_min_pd(acc, v);
168        }
169        let mut m = _mm512_reduce_min_pd(acc);
170        for &v in &values[chunks * 8..] {
171            if v < m {
172                m = v;
173            }
174        }
175        m
176    }
177}
178
179#[cfg(target_arch = "x86_64")]
180fn avx512_min_f64(values: &[f64]) -> f64 {
181    if values.len() < 16 {
182        return scalar_min_f64(values);
183    }
184    unsafe { avx512_min_f64_inner(values) }
185}
186
187#[cfg(target_arch = "x86_64")]
188#[target_feature(enable = "avx512f")]
189unsafe fn avx512_max_f64_inner(values: &[f64]) -> f64 {
190    use std::arch::x86_64::*;
191    unsafe {
192        let mut acc = _mm512_set1_pd(f64::NEG_INFINITY);
193        let chunks = values.len() / 8;
194        let ptr = values.as_ptr();
195        for i in 0..chunks {
196            let v = _mm512_loadu_pd(ptr.add(i * 8));
197            acc = _mm512_max_pd(acc, v);
198        }
199        let mut m = _mm512_reduce_max_pd(acc);
200        for &v in &values[chunks * 8..] {
201            if v > m {
202                m = v;
203            }
204        }
205        m
206    }
207}
208
209#[cfg(target_arch = "x86_64")]
210fn avx512_max_f64(values: &[f64]) -> f64 {
211    if values.len() < 16 {
212        return scalar_max_f64(values);
213    }
214    unsafe { avx512_max_f64_inner(values) }
215}
216
217#[cfg(target_arch = "x86_64")]
218fn avx512_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
219    // AVX-512 i64 comparison is available but complex to extract indices.
220    // Delegate to scalar — LLVM auto-vectorizes this well enough for i64.
221    scalar_range_filter_i64(values, min, max)
222}
223
224// ── AVX2 (x86_64) ──
225
226#[cfg(target_arch = "x86_64")]
227#[target_feature(enable = "avx2")]
228unsafe fn avx2_sum_f64_inner(values: &[f64]) -> f64 {
229    use std::arch::x86_64::*;
230    unsafe {
231        let mut acc = _mm256_setzero_pd();
232        let chunks = values.len() / 4;
233        let ptr = values.as_ptr();
234        for i in 0..chunks {
235            let v = _mm256_loadu_pd(ptr.add(i * 4));
236            acc = _mm256_add_pd(acc, v);
237        }
238        let hi = _mm256_extractf128_pd(acc, 1);
239        let lo = _mm256_castpd256_pd128(acc);
240        let sum2 = _mm_add_pd(lo, hi);
241        let hi2 = _mm_unpackhi_pd(sum2, sum2);
242        let mut sum = _mm_cvtsd_f64(_mm_add_sd(sum2, hi2));
243        for &v in &values[chunks * 4..] {
244            sum += v;
245        }
246        sum
247    }
248}
249
250#[cfg(target_arch = "x86_64")]
251fn avx2_sum_f64(values: &[f64]) -> f64 {
252    if values.len() < 8 {
253        return scalar_sum_f64(values);
254    }
255    unsafe { avx2_sum_f64_inner(values) }
256}
257
258#[cfg(target_arch = "x86_64")]
259#[target_feature(enable = "avx2")]
260unsafe fn avx2_min_f64_inner(values: &[f64]) -> f64 {
261    use std::arch::x86_64::*;
262    unsafe {
263        let mut acc = _mm256_set1_pd(f64::INFINITY);
264        let chunks = values.len() / 4;
265        let ptr = values.as_ptr();
266        for i in 0..chunks {
267            let v = _mm256_loadu_pd(ptr.add(i * 4));
268            acc = _mm256_min_pd(acc, v);
269        }
270        let hi = _mm256_extractf128_pd(acc, 1);
271        let lo = _mm256_castpd256_pd128(acc);
272        let min2 = _mm_min_pd(lo, hi);
273        let hi2 = _mm_unpackhi_pd(min2, min2);
274        let mut m = _mm_cvtsd_f64(_mm_min_sd(min2, hi2));
275        for &v in &values[chunks * 4..] {
276            if v < m {
277                m = v;
278            }
279        }
280        m
281    }
282}
283
284#[cfg(target_arch = "x86_64")]
285fn avx2_min_f64(values: &[f64]) -> f64 {
286    if values.len() < 8 {
287        return scalar_min_f64(values);
288    }
289    unsafe { avx2_min_f64_inner(values) }
290}
291
292#[cfg(target_arch = "x86_64")]
293#[target_feature(enable = "avx2")]
294unsafe fn avx2_max_f64_inner(values: &[f64]) -> f64 {
295    use std::arch::x86_64::*;
296    unsafe {
297        let mut acc = _mm256_set1_pd(f64::NEG_INFINITY);
298        let chunks = values.len() / 4;
299        let ptr = values.as_ptr();
300        for i in 0..chunks {
301            let v = _mm256_loadu_pd(ptr.add(i * 4));
302            acc = _mm256_max_pd(acc, v);
303        }
304        let hi = _mm256_extractf128_pd(acc, 1);
305        let lo = _mm256_castpd256_pd128(acc);
306        let max2 = _mm_max_pd(lo, hi);
307        let hi2 = _mm_unpackhi_pd(max2, max2);
308        let mut m = _mm_cvtsd_f64(_mm_max_sd(max2, hi2));
309        for &v in &values[chunks * 4..] {
310            if v > m {
311                m = v;
312            }
313        }
314        m
315    }
316}
317
318#[cfg(target_arch = "x86_64")]
319fn avx2_max_f64(values: &[f64]) -> f64 {
320    if values.len() < 8 {
321        return scalar_max_f64(values);
322    }
323    unsafe { avx2_max_f64_inner(values) }
324}
325
326#[cfg(target_arch = "x86_64")]
327fn avx2_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
328    scalar_range_filter_i64(values, min, max)
329}
330
331// ── NEON (aarch64) ──
332
333#[cfg(target_arch = "aarch64")]
334fn neon_sum_f64(values: &[f64]) -> f64 {
335    use std::arch::aarch64::*;
336    if values.len() < 4 {
337        return scalar_sum_f64(values);
338    }
339    unsafe {
340        let mut acc = vdupq_n_f64(0.0);
341        let chunks = values.len() / 2;
342        let ptr = values.as_ptr();
343        for i in 0..chunks {
344            let v = vld1q_f64(ptr.add(i * 2));
345            acc = vaddq_f64(acc, v);
346        }
347        let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
348        for &v in &values[chunks * 2..] {
349            sum += v;
350        }
351        sum
352    }
353}
354
355#[cfg(target_arch = "aarch64")]
356fn neon_min_f64(values: &[f64]) -> f64 {
357    use std::arch::aarch64::*;
358    if values.len() < 4 {
359        return scalar_min_f64(values);
360    }
361    unsafe {
362        let mut acc = vdupq_n_f64(f64::INFINITY);
363        let chunks = values.len() / 2;
364        let ptr = values.as_ptr();
365        for i in 0..chunks {
366            let v = vld1q_f64(ptr.add(i * 2));
367            acc = vminq_f64(acc, v);
368        }
369        let mut m = vgetq_lane_f64(acc, 0).min(vgetq_lane_f64(acc, 1));
370        for &v in &values[chunks * 2..] {
371            if v < m {
372                m = v;
373            }
374        }
375        m
376    }
377}
378
379#[cfg(target_arch = "aarch64")]
380fn neon_max_f64(values: &[f64]) -> f64 {
381    use std::arch::aarch64::*;
382    if values.len() < 4 {
383        return scalar_max_f64(values);
384    }
385    unsafe {
386        let mut acc = vdupq_n_f64(f64::NEG_INFINITY);
387        let chunks = values.len() / 2;
388        let ptr = values.as_ptr();
389        for i in 0..chunks {
390            let v = vld1q_f64(ptr.add(i * 2));
391            acc = vmaxq_f64(acc, v);
392        }
393        let mut m = vgetq_lane_f64(acc, 0).max(vgetq_lane_f64(acc, 1));
394        for &v in &values[chunks * 2..] {
395            if v > m {
396                m = v;
397            }
398        }
399        m
400    }
401}
402
403// ── WASM SIMD (wasm32 with simd128) ──
404
405#[cfg(target_arch = "wasm32")]
406#[target_feature(enable = "simd128")]
407unsafe fn wasm_sum_f64_inner(values: &[f64]) -> f64 {
408    use core::arch::wasm32::*;
409    let mut acc = f64x2_splat(0.0);
410    let chunks = values.len() / 2;
411    let ptr = values.as_ptr();
412    for i in 0..chunks {
413        let v = v128_load(ptr.add(i * 2) as *const v128);
414        acc = f64x2_add(acc, v);
415    }
416    let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
417    for &v in &values[chunks * 2..] {
418        sum += v;
419    }
420    sum
421}
422
423#[cfg(target_arch = "wasm32")]
424fn wasm_sum_f64(values: &[f64]) -> f64 {
425    if values.len() < 4 {
426        return scalar_sum_f64(values);
427    }
428    unsafe { wasm_sum_f64_inner(values) }
429}
430
431#[cfg(target_arch = "wasm32")]
432#[target_feature(enable = "simd128")]
433unsafe fn wasm_min_f64_inner(values: &[f64]) -> f64 {
434    use core::arch::wasm32::*;
435    let mut acc = f64x2_splat(f64::INFINITY);
436    let chunks = values.len() / 2;
437    let ptr = values.as_ptr();
438    for i in 0..chunks {
439        let v = v128_load(ptr.add(i * 2) as *const v128);
440        acc = f64x2_min(acc, v);
441    }
442    let mut m = f64x2_extract_lane::<0>(acc).min(f64x2_extract_lane::<1>(acc));
443    for &v in &values[chunks * 2..] {
444        if v < m {
445            m = v;
446        }
447    }
448    m
449}
450
451#[cfg(target_arch = "wasm32")]
452fn wasm_min_f64(values: &[f64]) -> f64 {
453    if values.len() < 4 {
454        return scalar_min_f64(values);
455    }
456    unsafe { wasm_min_f64_inner(values) }
457}
458
459#[cfg(target_arch = "wasm32")]
460#[target_feature(enable = "simd128")]
461unsafe fn wasm_max_f64_inner(values: &[f64]) -> f64 {
462    use core::arch::wasm32::*;
463    let mut acc = f64x2_splat(f64::NEG_INFINITY);
464    let chunks = values.len() / 2;
465    let ptr = values.as_ptr();
466    for i in 0..chunks {
467        let v = v128_load(ptr.add(i * 2) as *const v128);
468        acc = f64x2_max(acc, v);
469    }
470    let mut m = f64x2_extract_lane::<0>(acc).max(f64x2_extract_lane::<1>(acc));
471    for &v in &values[chunks * 2..] {
472        if v > m {
473            m = v;
474        }
475    }
476    m
477}
478
479#[cfg(target_arch = "wasm32")]
480fn wasm_max_f64(values: &[f64]) -> f64 {
481    if values.len() < 4 {
482        return scalar_max_f64(values);
483    }
484    unsafe { wasm_max_f64_inner(values) }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn runtime_detects() {
493        let rt = ts_runtime();
494        assert!(!rt.name.is_empty());
495        // Should be one of: avx512, avx2, neon, wasm-simd128, scalar.
496        assert!(
497            ["avx512", "avx2", "neon", "wasm-simd128", "scalar"].contains(&rt.name),
498            "unexpected runtime: {}",
499            rt.name
500        );
501    }
502
503    #[test]
504    fn sum_correctness() {
505        let rt = ts_runtime();
506        let values: Vec<f64> = (0..1000).map(|i| i as f64).collect();
507        let expected = 999.0 * 1000.0 / 2.0; // sum(0..999)
508        let result = (rt.sum_f64)(&values);
509        assert!(
510            (result - expected).abs() < 1e-6,
511            "sum: got {result}, expected {expected}"
512        );
513    }
514
515    #[test]
516    fn min_max_correctness() {
517        let rt = ts_runtime();
518        let values: Vec<f64> = (0..1000).map(|i| (i as f64) - 500.0).collect();
519        assert!(((rt.min_f64)(&values) - (-500.0)).abs() < f64::EPSILON);
520        assert!(((rt.max_f64)(&values) - 499.0).abs() < f64::EPSILON);
521    }
522
523    #[test]
524    fn range_filter_correctness() {
525        let rt = ts_runtime();
526        let values: Vec<i64> = (0..100).collect();
527        let indices = (rt.range_filter_i64)(&values, 25, 75);
528        assert_eq!(indices.len(), 51); // 25..=75 inclusive
529        assert_eq!(indices[0], 25);
530        assert_eq!(*indices.last().unwrap(), 75);
531    }
532
533    #[test]
534    fn empty_input() {
535        let rt = ts_runtime();
536        assert_eq!((rt.sum_f64)(&[]), 0.0);
537        assert!((rt.min_f64)(&[]).is_infinite());
538        assert!((rt.max_f64)(&[]).is_infinite());
539        assert!((rt.range_filter_i64)(&[], 0, 100).is_empty());
540    }
541
542    #[test]
543    fn small_input() {
544        let rt = ts_runtime();
545        assert!(((rt.sum_f64)(&[1.0, 2.0, 3.0]) - 6.0).abs() < f64::EPSILON);
546        assert!(((rt.min_f64)(&[3.0, 1.0, 2.0]) - 1.0).abs() < f64::EPSILON);
547        assert!(((rt.max_f64)(&[1.0, 3.0, 2.0]) - 3.0).abs() < f64::EPSILON);
548    }
549}