Skip to main content

nodedb_query/
simd_agg_i64.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SIMD-accelerated aggregation kernels for i64 columns.
4//!
5//! Mirrors the f64 dispatch in `simd_agg.rs`. Uses i128 accumulator for
6//! overflow-safe sum. Same runtime detection: AVX-512 → AVX2 → NEON → scalar.
7
8/// SIMD runtime for i64 aggregation.
9pub struct I64SimdRuntime {
10    /// Sum with i128 accumulator (overflow-safe).
11    pub sum_i64: fn(&[i64]) -> i128,
12    pub min_i64: fn(&[i64]) -> i64,
13    pub max_i64: fn(&[i64]) -> i64,
14    pub name: &'static str,
15}
16
17impl I64SimdRuntime {
18    pub fn detect() -> Self {
19        #[cfg(target_arch = "x86_64")]
20        {
21            if std::is_x86_feature_detected!("avx512f") {
22                return Self {
23                    sum_i64: avx512_sum_i64,
24                    min_i64: avx512_min_i64,
25                    max_i64: avx512_max_i64,
26                    name: "avx512",
27                };
28            }
29            if std::is_x86_feature_detected!("avx2") {
30                return Self {
31                    sum_i64: avx2_sum_i64,
32                    min_i64: avx2_min_i64,
33                    max_i64: avx2_max_i64,
34                    name: "avx2",
35                };
36            }
37        }
38        #[cfg(target_arch = "aarch64")]
39        {
40            return Self {
41                sum_i64: neon_sum_i64,
42                min_i64: neon_min_i64,
43                max_i64: neon_max_i64,
44                name: "neon",
45            };
46        }
47        #[cfg(target_arch = "wasm32")]
48        {
49            return Self {
50                sum_i64: wasm_sum_i64,
51                min_i64: wasm_min_i64,
52                max_i64: wasm_max_i64,
53                name: "wasm-simd128",
54            };
55        }
56        #[allow(unreachable_code)]
57        Self {
58            sum_i64: scalar_sum_i64,
59            min_i64: scalar_min_i64,
60            max_i64: scalar_max_i64,
61            name: "scalar",
62        }
63    }
64}
65
66static I64_RUNTIME: std::sync::OnceLock<I64SimdRuntime> = std::sync::OnceLock::new();
67
68/// Get the global i64 SIMD runtime.
69pub fn i64_runtime() -> &'static I64SimdRuntime {
70    I64_RUNTIME.get_or_init(I64SimdRuntime::detect)
71}
72
73// ── Scalar fallback ────────────────────────────────────────────────
74
75fn scalar_sum_i64(values: &[i64]) -> i128 {
76    let mut sum: i128 = 0;
77    for &v in values {
78        sum += v as i128;
79    }
80    sum
81}
82
83fn scalar_min_i64(values: &[i64]) -> i64 {
84    let mut m = i64::MAX;
85    for &v in values {
86        if v < m {
87            m = v;
88        }
89    }
90    m
91}
92
93fn scalar_max_i64(values: &[i64]) -> i64 {
94    let mut m = i64::MIN;
95    for &v in values {
96        if v > m {
97            m = v;
98        }
99    }
100    m
101}
102
103// ── AVX-512 (x86_64) ──────────────────────────────────────────────
104
105#[cfg(target_arch = "x86_64")]
106#[target_feature(enable = "avx512f")]
107unsafe fn avx512_sum_i64_inner(values: &[i64]) -> i128 {
108    use std::arch::x86_64::*;
109    unsafe {
110        let mut acc = _mm512_setzero_si512();
111        let chunks = values.len() / 8;
112        let ptr = values.as_ptr();
113        for i in 0..chunks {
114            let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
115            acc = _mm512_add_epi64(acc, v);
116        }
117        // Horizontal sum: extract 8 i64 lanes into scalar.
118        let mut sum: i128 = 0;
119        let mut buf = [0i64; 8];
120        _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
121        for &v in &buf {
122            sum += v as i128;
123        }
124        for &v in &values[chunks * 8..] {
125            sum += v as i128;
126        }
127        sum
128    }
129}
130
131#[cfg(target_arch = "x86_64")]
132fn avx512_sum_i64(values: &[i64]) -> i128 {
133    if values.len() < 16 {
134        return scalar_sum_i64(values);
135    }
136    unsafe { avx512_sum_i64_inner(values) }
137}
138
139#[cfg(target_arch = "x86_64")]
140#[target_feature(enable = "avx512f")]
141unsafe fn avx512_min_i64_inner(values: &[i64]) -> i64 {
142    use std::arch::x86_64::*;
143    unsafe {
144        let mut acc = _mm512_set1_epi64(i64::MAX);
145        let chunks = values.len() / 8;
146        let ptr = values.as_ptr();
147        for i in 0..chunks {
148            let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
149            acc = _mm512_min_epi64(acc, v);
150        }
151        let mut buf = [0i64; 8];
152        _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
153        let mut m = buf[0];
154        for &v in &buf[1..] {
155            if v < m {
156                m = v;
157            }
158        }
159        for &v in &values[chunks * 8..] {
160            if v < m {
161                m = v;
162            }
163        }
164        m
165    }
166}
167
168#[cfg(target_arch = "x86_64")]
169fn avx512_min_i64(values: &[i64]) -> i64 {
170    if values.len() < 16 {
171        return scalar_min_i64(values);
172    }
173    unsafe { avx512_min_i64_inner(values) }
174}
175
176#[cfg(target_arch = "x86_64")]
177#[target_feature(enable = "avx512f")]
178unsafe fn avx512_max_i64_inner(values: &[i64]) -> i64 {
179    use std::arch::x86_64::*;
180    unsafe {
181        let mut acc = _mm512_set1_epi64(i64::MIN);
182        let chunks = values.len() / 8;
183        let ptr = values.as_ptr();
184        for i in 0..chunks {
185            let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
186            acc = _mm512_max_epi64(acc, v);
187        }
188        let mut buf = [0i64; 8];
189        _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
190        let mut m = buf[0];
191        for &v in &buf[1..] {
192            if v > m {
193                m = v;
194            }
195        }
196        for &v in &values[chunks * 8..] {
197            if v > m {
198                m = v;
199            }
200        }
201        m
202    }
203}
204
205#[cfg(target_arch = "x86_64")]
206fn avx512_max_i64(values: &[i64]) -> i64 {
207    if values.len() < 16 {
208        return scalar_max_i64(values);
209    }
210    unsafe { avx512_max_i64_inner(values) }
211}
212
213// ── AVX2 (x86_64) ─────────────────────────────────────────────────
214
215#[cfg(target_arch = "x86_64")]
216#[target_feature(enable = "avx2")]
217unsafe fn avx2_sum_i64_inner(values: &[i64]) -> i128 {
218    use std::arch::x86_64::*;
219    unsafe {
220        let mut acc = _mm256_setzero_si256();
221        let chunks = values.len() / 4;
222        let ptr = values.as_ptr();
223        for i in 0..chunks {
224            let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
225            acc = _mm256_add_epi64(acc, v);
226        }
227        let mut buf = [0i64; 4];
228        _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
229        let mut sum: i128 = 0;
230        for &v in &buf {
231            sum += v as i128;
232        }
233        for &v in &values[chunks * 4..] {
234            sum += v as i128;
235        }
236        sum
237    }
238}
239
240#[cfg(target_arch = "x86_64")]
241fn avx2_sum_i64(values: &[i64]) -> i128 {
242    if values.len() < 8 {
243        return scalar_sum_i64(values);
244    }
245    unsafe { avx2_sum_i64_inner(values) }
246}
247
248#[cfg(target_arch = "x86_64")]
249#[target_feature(enable = "avx2")]
250unsafe fn avx2_minmax_i64_inner(values: &[i64], is_min: bool) -> i64 {
251    // AVX2 doesn't have native _mm256_min_epi64, so process 4 at a time
252    // and compare manually with blend.
253    use std::arch::x86_64::*;
254    unsafe {
255        let init_val = if is_min { i64::MAX } else { i64::MIN };
256        let mut acc = _mm256_set1_epi64x(init_val);
257        let chunks = values.len() / 4;
258        let ptr = values.as_ptr();
259        for i in 0..chunks {
260            let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
261            let cmp = _mm256_cmpgt_epi64(acc, v);
262            // If is_min: pick smaller (where acc > v, pick v)
263            // If is_max: pick larger (where acc > v, keep acc)
264            if is_min {
265                acc = _mm256_blendv_epi8(acc, v, cmp);
266            } else {
267                acc = _mm256_blendv_epi8(v, acc, cmp);
268            }
269        }
270        let mut buf = [0i64; 4];
271        _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
272        let mut m = buf[0];
273        for &v in &buf[1..] {
274            if is_min {
275                if v < m {
276                    m = v;
277                }
278            } else if v > m {
279                m = v;
280            }
281        }
282        for &v in &values[chunks * 4..] {
283            if is_min {
284                if v < m {
285                    m = v;
286                }
287            } else if v > m {
288                m = v;
289            }
290        }
291        m
292    }
293}
294
295#[cfg(target_arch = "x86_64")]
296fn avx2_min_i64(values: &[i64]) -> i64 {
297    if values.len() < 8 {
298        return scalar_min_i64(values);
299    }
300    unsafe { avx2_minmax_i64_inner(values, true) }
301}
302
303#[cfg(target_arch = "x86_64")]
304fn avx2_max_i64(values: &[i64]) -> i64 {
305    if values.len() < 8 {
306        return scalar_max_i64(values);
307    }
308    unsafe { avx2_minmax_i64_inner(values, false) }
309}
310
311// ── NEON (AArch64) ─────────────────────────────────────────────────
312
313#[cfg(target_arch = "aarch64")]
314fn neon_sum_i64(values: &[i64]) -> i128 {
315    use std::arch::aarch64::*;
316    let chunks = values.len() / 2;
317    let ptr = values.as_ptr();
318    let mut acc = unsafe { vdupq_n_s64(0) };
319    for i in 0..chunks {
320        let v = unsafe { vld1q_s64(ptr.add(i * 2)) };
321        acc = unsafe { vaddq_s64(acc, v) };
322    }
323    let mut buf = [0i64; 2];
324    unsafe { vst1q_s64(buf.as_mut_ptr(), acc) };
325    let mut sum: i128 = buf[0] as i128 + buf[1] as i128;
326    for &v in &values[chunks * 2..] {
327        sum += v as i128;
328    }
329    sum
330}
331
332#[cfg(target_arch = "aarch64")]
333fn neon_min_i64(values: &[i64]) -> i64 {
334    // NEON doesn't have vminq_s64 on all targets; use scalar for correctness.
335    scalar_min_i64(values)
336}
337
338#[cfg(target_arch = "aarch64")]
339fn neon_max_i64(values: &[i64]) -> i64 {
340    scalar_max_i64(values)
341}
342
343// ── WASM SIMD128 ───────────────────────────────────────────────────
344
345#[cfg(target_arch = "wasm32")]
346#[cfg(target_feature = "simd128")]
347fn wasm_sum_i64(values: &[i64]) -> i128 {
348    use std::arch::wasm32::*;
349    let chunks = values.len() / 2;
350    let ptr = values.as_ptr() as *const v128;
351    let mut acc = i64x2_splat(0);
352    for i in 0..chunks {
353        let v = unsafe { v128_load(ptr.add(i)) };
354        acc = i64x2_add(acc, v);
355    }
356    let lo = i64x2_extract_lane::<0>(acc) as i128;
357    let hi = i64x2_extract_lane::<1>(acc) as i128;
358    let mut sum = lo + hi;
359    for &v in &values[chunks * 2..] {
360        sum += v as i128;
361    }
362    sum
363}
364
365#[cfg(target_arch = "wasm32")]
366#[cfg(not(target_feature = "simd128"))]
367fn wasm_sum_i64(values: &[i64]) -> i128 {
368    scalar_sum_i64(values)
369}
370
371#[cfg(target_arch = "wasm32")]
372fn wasm_min_i64(values: &[i64]) -> i64 {
373    scalar_min_i64(values)
374}
375
376#[cfg(target_arch = "wasm32")]
377fn wasm_max_i64(values: &[i64]) -> i64 {
378    scalar_max_i64(values)
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn runtime_detects() {
387        let rt = i64_runtime();
388        assert!(!rt.name.is_empty());
389    }
390
391    #[test]
392    fn sum_correctness() {
393        let rt = i64_runtime();
394        let values: Vec<i64> = (0..1000).collect();
395        let expected: i128 = 999 * 1000 / 2;
396        assert_eq!((rt.sum_i64)(&values), expected);
397    }
398
399    #[test]
400    fn sum_overflow_safe() {
401        let rt = i64_runtime();
402        let values = vec![i64::MAX, i64::MAX, i64::MAX];
403        let result = (rt.sum_i64)(&values);
404        assert_eq!(result, 3 * i64::MAX as i128);
405    }
406
407    #[test]
408    fn min_max_correctness() {
409        let rt = i64_runtime();
410        let values: Vec<i64> = (-500..500).collect();
411        assert_eq!((rt.min_i64)(&values), -500);
412        assert_eq!((rt.max_i64)(&values), 499);
413    }
414
415    #[test]
416    fn empty_input() {
417        let rt = i64_runtime();
418        assert_eq!((rt.sum_i64)(&[]), 0);
419        assert_eq!((rt.min_i64)(&[]), i64::MAX);
420        assert_eq!((rt.max_i64)(&[]), i64::MIN);
421    }
422
423    #[test]
424    fn small_input() {
425        let rt = i64_runtime();
426        assert_eq!((rt.sum_i64)(&[1, 2, 3]), 6);
427        assert_eq!((rt.min_i64)(&[3, 1, 2]), 1);
428        assert_eq!((rt.max_i64)(&[1, 3, 2]), 3);
429    }
430}