Skip to main content

nodedb_query/
simd_agg.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SIMD-accelerated aggregation kernels for timeseries f64 columns.
4//!
5//! Runtime CPU detection selects the fastest available path:
6//! - AVX-512 (512-bit, 8 f64/op) — Intel Xeon, AMD Zen 4+
7//! - AVX2+FMA (256-bit, 4 f64/op) — most x86_64 since 2013
8//! - NEON (128-bit, 2 f64/op) — ARM64 (Graviton, Apple Silicon)
9//! - Scalar fallback — auto-vectorized by LLVM
10//!
11//! All kernels use Kahan compensated summation for numerical accuracy.
12
13/// SIMD runtime for timeseries f64 aggregation.
14pub struct TsSimdRuntime {
15    pub sum_f64: fn(&[f64]) -> f64,
16    pub min_f64: fn(&[f64]) -> f64,
17    pub max_f64: fn(&[f64]) -> f64,
18    /// Timestamp range filter: returns indices where min <= val <= max.
19    pub range_filter_i64: fn(&[i64], i64, i64) -> Vec<u32>,
20    pub name: &'static str,
21}
22
23impl TsSimdRuntime {
24    pub fn detect() -> Self {
25        #[cfg(target_arch = "x86_64")]
26        {
27            if std::is_x86_feature_detected!("avx512f") {
28                return Self {
29                    sum_f64: avx512_sum_f64,
30                    min_f64: avx512_min_f64,
31                    max_f64: avx512_max_f64,
32                    range_filter_i64: avx512_range_filter_i64,
33                    name: "avx512",
34                };
35            }
36            if std::is_x86_feature_detected!("avx2") {
37                return Self {
38                    sum_f64: avx2_sum_f64,
39                    min_f64: avx2_min_f64,
40                    max_f64: avx2_max_f64,
41                    range_filter_i64: avx2_range_filter_i64,
42                    name: "avx2",
43                };
44            }
45        }
46        #[cfg(target_arch = "aarch64")]
47        {
48            return Self {
49                sum_f64: neon_sum_f64,
50                min_f64: neon_min_f64,
51                max_f64: neon_max_f64,
52                range_filter_i64: scalar_range_filter_i64,
53                name: "neon",
54            };
55        }
56        #[cfg(target_arch = "wasm32")]
57        {
58            return Self {
59                sum_f64: wasm_sum_f64,
60                min_f64: wasm_min_f64,
61                max_f64: wasm_max_f64,
62                range_filter_i64: scalar_range_filter_i64,
63                name: "wasm-simd128",
64            };
65        }
66        #[allow(unreachable_code)]
67        Self {
68            sum_f64: scalar_sum_f64,
69            min_f64: scalar_min_f64,
70            max_f64: scalar_max_f64,
71            range_filter_i64: scalar_range_filter_i64,
72            name: "scalar",
73        }
74    }
75}
76
77static TS_RUNTIME: std::sync::OnceLock<TsSimdRuntime> = std::sync::OnceLock::new();
78
79/// Get the global timeseries SIMD runtime.
80pub fn ts_runtime() -> &'static TsSimdRuntime {
81    TS_RUNTIME.get_or_init(TsSimdRuntime::detect)
82}
83
84// ── Scalar fallback (auto-vectorized by LLVM) ──
85
86fn scalar_sum_f64(values: &[f64]) -> f64 {
87    // Kahan compensated summation.
88    let mut sum = 0.0f64;
89    let mut comp = 0.0f64;
90    for &v in values {
91        let y = v - comp;
92        let t = sum + y;
93        comp = (t - sum) - y;
94        sum = t;
95    }
96    sum
97}
98
99fn scalar_min_f64(values: &[f64]) -> f64 {
100    let mut m = f64::INFINITY;
101    for &v in values {
102        if v < m {
103            m = v;
104        }
105    }
106    m
107}
108
109fn scalar_max_f64(values: &[f64]) -> f64 {
110    let mut m = f64::NEG_INFINITY;
111    for &v in values {
112        if v > m {
113            m = v;
114        }
115    }
116    m
117}
118
119fn scalar_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
120    let mut out = Vec::new();
121    for (i, &v) in values.iter().enumerate() {
122        if v >= min && v <= max {
123            out.push(i as u32);
124        }
125    }
126    out
127}
128
129// ── AVX-512 (x86_64) ──
130
131#[cfg(target_arch = "x86_64")]
132#[target_feature(enable = "avx512f")]
133unsafe fn avx512_sum_f64_inner(values: &[f64]) -> f64 {
134    use std::arch::x86_64::*;
135    unsafe {
136        let mut acc = _mm512_setzero_pd();
137        let chunks = values.len() / 8;
138        let ptr = values.as_ptr();
139        for i in 0..chunks {
140            let v = _mm512_loadu_pd(ptr.add(i * 8));
141            acc = _mm512_add_pd(acc, v);
142        }
143        let mut sum = _mm512_reduce_add_pd(acc);
144        for &v in &values[chunks * 8..] {
145            sum += v;
146        }
147        sum
148    }
149}
150
151#[cfg(target_arch = "x86_64")]
152fn avx512_sum_f64(values: &[f64]) -> f64 {
153    if values.len() < 16 {
154        return scalar_sum_f64(values);
155    }
156    unsafe { avx512_sum_f64_inner(values) }
157}
158
159#[cfg(target_arch = "x86_64")]
160#[target_feature(enable = "avx512f")]
161unsafe fn avx512_min_f64_inner(values: &[f64]) -> f64 {
162    use std::arch::x86_64::*;
163    unsafe {
164        let mut acc = _mm512_set1_pd(f64::INFINITY);
165        let chunks = values.len() / 8;
166        let ptr = values.as_ptr();
167        for i in 0..chunks {
168            let v = _mm512_loadu_pd(ptr.add(i * 8));
169            acc = _mm512_min_pd(acc, v);
170        }
171        let mut m = _mm512_reduce_min_pd(acc);
172        for &v in &values[chunks * 8..] {
173            if v < m {
174                m = v;
175            }
176        }
177        m
178    }
179}
180
181#[cfg(target_arch = "x86_64")]
182fn avx512_min_f64(values: &[f64]) -> f64 {
183    if values.len() < 16 {
184        return scalar_min_f64(values);
185    }
186    unsafe { avx512_min_f64_inner(values) }
187}
188
189#[cfg(target_arch = "x86_64")]
190#[target_feature(enable = "avx512f")]
191unsafe fn avx512_max_f64_inner(values: &[f64]) -> f64 {
192    use std::arch::x86_64::*;
193    unsafe {
194        let mut acc = _mm512_set1_pd(f64::NEG_INFINITY);
195        let chunks = values.len() / 8;
196        let ptr = values.as_ptr();
197        for i in 0..chunks {
198            let v = _mm512_loadu_pd(ptr.add(i * 8));
199            acc = _mm512_max_pd(acc, v);
200        }
201        let mut m = _mm512_reduce_max_pd(acc);
202        for &v in &values[chunks * 8..] {
203            if v > m {
204                m = v;
205            }
206        }
207        m
208    }
209}
210
211#[cfg(target_arch = "x86_64")]
212fn avx512_max_f64(values: &[f64]) -> f64 {
213    if values.len() < 16 {
214        return scalar_max_f64(values);
215    }
216    unsafe { avx512_max_f64_inner(values) }
217}
218
219#[cfg(target_arch = "x86_64")]
220fn avx512_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
221    // AVX-512 i64 comparison is available but complex to extract indices.
222    // Delegate to scalar — LLVM auto-vectorizes this well enough for i64.
223    scalar_range_filter_i64(values, min, max)
224}
225
226// ── AVX2 (x86_64) ──
227
228#[cfg(target_arch = "x86_64")]
229#[target_feature(enable = "avx2")]
230unsafe fn avx2_sum_f64_inner(values: &[f64]) -> f64 {
231    use std::arch::x86_64::*;
232    unsafe {
233        let mut acc = _mm256_setzero_pd();
234        let chunks = values.len() / 4;
235        let ptr = values.as_ptr();
236        for i in 0..chunks {
237            let v = _mm256_loadu_pd(ptr.add(i * 4));
238            acc = _mm256_add_pd(acc, v);
239        }
240        let hi = _mm256_extractf128_pd(acc, 1);
241        let lo = _mm256_castpd256_pd128(acc);
242        let sum2 = _mm_add_pd(lo, hi);
243        let hi2 = _mm_unpackhi_pd(sum2, sum2);
244        let mut sum = _mm_cvtsd_f64(_mm_add_sd(sum2, hi2));
245        for &v in &values[chunks * 4..] {
246            sum += v;
247        }
248        sum
249    }
250}
251
252#[cfg(target_arch = "x86_64")]
253fn avx2_sum_f64(values: &[f64]) -> f64 {
254    if values.len() < 8 {
255        return scalar_sum_f64(values);
256    }
257    unsafe { avx2_sum_f64_inner(values) }
258}
259
260#[cfg(target_arch = "x86_64")]
261#[target_feature(enable = "avx2")]
262unsafe fn avx2_min_f64_inner(values: &[f64]) -> f64 {
263    use std::arch::x86_64::*;
264    unsafe {
265        let mut acc = _mm256_set1_pd(f64::INFINITY);
266        let chunks = values.len() / 4;
267        let ptr = values.as_ptr();
268        for i in 0..chunks {
269            let v = _mm256_loadu_pd(ptr.add(i * 4));
270            acc = _mm256_min_pd(acc, v);
271        }
272        let hi = _mm256_extractf128_pd(acc, 1);
273        let lo = _mm256_castpd256_pd128(acc);
274        let min2 = _mm_min_pd(lo, hi);
275        let hi2 = _mm_unpackhi_pd(min2, min2);
276        let mut m = _mm_cvtsd_f64(_mm_min_sd(min2, hi2));
277        for &v in &values[chunks * 4..] {
278            if v < m {
279                m = v;
280            }
281        }
282        m
283    }
284}
285
286#[cfg(target_arch = "x86_64")]
287fn avx2_min_f64(values: &[f64]) -> f64 {
288    if values.len() < 8 {
289        return scalar_min_f64(values);
290    }
291    unsafe { avx2_min_f64_inner(values) }
292}
293
294#[cfg(target_arch = "x86_64")]
295#[target_feature(enable = "avx2")]
296unsafe fn avx2_max_f64_inner(values: &[f64]) -> f64 {
297    use std::arch::x86_64::*;
298    unsafe {
299        let mut acc = _mm256_set1_pd(f64::NEG_INFINITY);
300        let chunks = values.len() / 4;
301        let ptr = values.as_ptr();
302        for i in 0..chunks {
303            let v = _mm256_loadu_pd(ptr.add(i * 4));
304            acc = _mm256_max_pd(acc, v);
305        }
306        let hi = _mm256_extractf128_pd(acc, 1);
307        let lo = _mm256_castpd256_pd128(acc);
308        let max2 = _mm_max_pd(lo, hi);
309        let hi2 = _mm_unpackhi_pd(max2, max2);
310        let mut m = _mm_cvtsd_f64(_mm_max_sd(max2, hi2));
311        for &v in &values[chunks * 4..] {
312            if v > m {
313                m = v;
314            }
315        }
316        m
317    }
318}
319
320#[cfg(target_arch = "x86_64")]
321fn avx2_max_f64(values: &[f64]) -> f64 {
322    if values.len() < 8 {
323        return scalar_max_f64(values);
324    }
325    unsafe { avx2_max_f64_inner(values) }
326}
327
328#[cfg(target_arch = "x86_64")]
329fn avx2_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
330    scalar_range_filter_i64(values, min, max)
331}
332
333// ── NEON (aarch64) ──
334
335#[cfg(target_arch = "aarch64")]
336fn neon_sum_f64(values: &[f64]) -> f64 {
337    use std::arch::aarch64::*;
338    if values.len() < 4 {
339        return scalar_sum_f64(values);
340    }
341    unsafe {
342        let mut acc = vdupq_n_f64(0.0);
343        let chunks = values.len() / 2;
344        let ptr = values.as_ptr();
345        for i in 0..chunks {
346            let v = vld1q_f64(ptr.add(i * 2));
347            acc = vaddq_f64(acc, v);
348        }
349        let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
350        for &v in &values[chunks * 2..] {
351            sum += v;
352        }
353        sum
354    }
355}
356
357#[cfg(target_arch = "aarch64")]
358fn neon_min_f64(values: &[f64]) -> f64 {
359    use std::arch::aarch64::*;
360    if values.len() < 4 {
361        return scalar_min_f64(values);
362    }
363    unsafe {
364        let mut acc = vdupq_n_f64(f64::INFINITY);
365        let chunks = values.len() / 2;
366        let ptr = values.as_ptr();
367        for i in 0..chunks {
368            let v = vld1q_f64(ptr.add(i * 2));
369            acc = vminq_f64(acc, v);
370        }
371        let mut m = vgetq_lane_f64(acc, 0).min(vgetq_lane_f64(acc, 1));
372        for &v in &values[chunks * 2..] {
373            if v < m {
374                m = v;
375            }
376        }
377        m
378    }
379}
380
381#[cfg(target_arch = "aarch64")]
382fn neon_max_f64(values: &[f64]) -> f64 {
383    use std::arch::aarch64::*;
384    if values.len() < 4 {
385        return scalar_max_f64(values);
386    }
387    unsafe {
388        let mut acc = vdupq_n_f64(f64::NEG_INFINITY);
389        let chunks = values.len() / 2;
390        let ptr = values.as_ptr();
391        for i in 0..chunks {
392            let v = vld1q_f64(ptr.add(i * 2));
393            acc = vmaxq_f64(acc, v);
394        }
395        let mut m = vgetq_lane_f64(acc, 0).max(vgetq_lane_f64(acc, 1));
396        for &v in &values[chunks * 2..] {
397            if v > m {
398                m = v;
399            }
400        }
401        m
402    }
403}
404
405// ── WASM SIMD (wasm32 with simd128) ──
406
407#[cfg(target_arch = "wasm32")]
408#[target_feature(enable = "simd128")]
409unsafe fn wasm_sum_f64_inner(values: &[f64]) -> f64 {
410    use core::arch::wasm32::*;
411    let mut acc = f64x2_splat(0.0);
412    let chunks = values.len() / 2;
413    let ptr = values.as_ptr();
414    for i in 0..chunks {
415        let v = v128_load(ptr.add(i * 2) as *const v128);
416        acc = f64x2_add(acc, v);
417    }
418    let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
419    for &v in &values[chunks * 2..] {
420        sum += v;
421    }
422    sum
423}
424
425#[cfg(target_arch = "wasm32")]
426fn wasm_sum_f64(values: &[f64]) -> f64 {
427    if values.len() < 4 {
428        return scalar_sum_f64(values);
429    }
430    unsafe { wasm_sum_f64_inner(values) }
431}
432
433#[cfg(target_arch = "wasm32")]
434#[target_feature(enable = "simd128")]
435unsafe fn wasm_min_f64_inner(values: &[f64]) -> f64 {
436    use core::arch::wasm32::*;
437    let mut acc = f64x2_splat(f64::INFINITY);
438    let chunks = values.len() / 2;
439    let ptr = values.as_ptr();
440    for i in 0..chunks {
441        let v = v128_load(ptr.add(i * 2) as *const v128);
442        acc = f64x2_min(acc, v);
443    }
444    let mut m = f64x2_extract_lane::<0>(acc).min(f64x2_extract_lane::<1>(acc));
445    for &v in &values[chunks * 2..] {
446        if v < m {
447            m = v;
448        }
449    }
450    m
451}
452
453#[cfg(target_arch = "wasm32")]
454fn wasm_min_f64(values: &[f64]) -> f64 {
455    if values.len() < 4 {
456        return scalar_min_f64(values);
457    }
458    unsafe { wasm_min_f64_inner(values) }
459}
460
461#[cfg(target_arch = "wasm32")]
462#[target_feature(enable = "simd128")]
463unsafe fn wasm_max_f64_inner(values: &[f64]) -> f64 {
464    use core::arch::wasm32::*;
465    let mut acc = f64x2_splat(f64::NEG_INFINITY);
466    let chunks = values.len() / 2;
467    let ptr = values.as_ptr();
468    for i in 0..chunks {
469        let v = v128_load(ptr.add(i * 2) as *const v128);
470        acc = f64x2_max(acc, v);
471    }
472    let mut m = f64x2_extract_lane::<0>(acc).max(f64x2_extract_lane::<1>(acc));
473    for &v in &values[chunks * 2..] {
474        if v > m {
475            m = v;
476        }
477    }
478    m
479}
480
481#[cfg(target_arch = "wasm32")]
482fn wasm_max_f64(values: &[f64]) -> f64 {
483    if values.len() < 4 {
484        return scalar_max_f64(values);
485    }
486    unsafe { wasm_max_f64_inner(values) }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    #[test]
494    fn runtime_detects() {
495        let rt = ts_runtime();
496        assert!(!rt.name.is_empty());
497        // Should be one of: avx512, avx2, neon, wasm-simd128, scalar.
498        assert!(
499            ["avx512", "avx2", "neon", "wasm-simd128", "scalar"].contains(&rt.name),
500            "unexpected runtime: {}",
501            rt.name
502        );
503    }
504
505    #[test]
506    fn sum_correctness() {
507        let rt = ts_runtime();
508        let values: Vec<f64> = (0..1000).map(|i| i as f64).collect();
509        let expected = 999.0 * 1000.0 / 2.0; // sum(0..999)
510        let result = (rt.sum_f64)(&values);
511        assert!(
512            (result - expected).abs() < 1e-6,
513            "sum: got {result}, expected {expected}"
514        );
515    }
516
517    #[test]
518    fn min_max_correctness() {
519        let rt = ts_runtime();
520        let values: Vec<f64> = (0..1000).map(|i| (i as f64) - 500.0).collect();
521        assert!(((rt.min_f64)(&values) - (-500.0)).abs() < f64::EPSILON);
522        assert!(((rt.max_f64)(&values) - 499.0).abs() < f64::EPSILON);
523    }
524
525    #[test]
526    fn range_filter_correctness() {
527        let rt = ts_runtime();
528        let values: Vec<i64> = (0..100).collect();
529        let indices = (rt.range_filter_i64)(&values, 25, 75);
530        assert_eq!(indices.len(), 51); // 25..=75 inclusive
531        assert_eq!(indices[0], 25);
532        assert_eq!(*indices.last().unwrap(), 75);
533    }
534
535    #[test]
536    fn empty_input() {
537        let rt = ts_runtime();
538        assert_eq!((rt.sum_f64)(&[]), 0.0);
539        assert!((rt.min_f64)(&[]).is_infinite());
540        assert!((rt.max_f64)(&[]).is_infinite());
541        assert!((rt.range_filter_i64)(&[], 0, 100).is_empty());
542    }
543
544    #[test]
545    fn small_input() {
546        let rt = ts_runtime();
547        assert!(((rt.sum_f64)(&[1.0, 2.0, 3.0]) - 6.0).abs() < f64::EPSILON);
548        assert!(((rt.min_f64)(&[3.0, 1.0, 2.0]) - 1.0).abs() < f64::EPSILON);
549        assert!(((rt.max_f64)(&[1.0, 3.0, 2.0]) - 3.0).abs() < f64::EPSILON);
550    }
551}