Skip to main content

nodedb_query/
simd_agg_i64.rs

1//! SIMD-accelerated aggregation kernels for i64 columns.
2//!
3//! Mirrors the f64 dispatch in `simd_agg.rs`. Uses i128 accumulator for
4//! overflow-safe sum. Same runtime detection: AVX-512 → AVX2 → NEON → scalar.
5
6/// SIMD runtime for i64 aggregation.
7pub struct I64SimdRuntime {
8    /// Sum with i128 accumulator (overflow-safe).
9    pub sum_i64: fn(&[i64]) -> i128,
10    pub min_i64: fn(&[i64]) -> i64,
11    pub max_i64: fn(&[i64]) -> i64,
12    pub name: &'static str,
13}
14
15impl I64SimdRuntime {
16    pub fn detect() -> Self {
17        #[cfg(target_arch = "x86_64")]
18        {
19            if std::is_x86_feature_detected!("avx512f") {
20                return Self {
21                    sum_i64: avx512_sum_i64,
22                    min_i64: avx512_min_i64,
23                    max_i64: avx512_max_i64,
24                    name: "avx512",
25                };
26            }
27            if std::is_x86_feature_detected!("avx2") {
28                return Self {
29                    sum_i64: avx2_sum_i64,
30                    min_i64: avx2_min_i64,
31                    max_i64: avx2_max_i64,
32                    name: "avx2",
33                };
34            }
35        }
36        #[cfg(target_arch = "aarch64")]
37        {
38            return Self {
39                sum_i64: neon_sum_i64,
40                min_i64: neon_min_i64,
41                max_i64: neon_max_i64,
42                name: "neon",
43            };
44        }
45        #[cfg(target_arch = "wasm32")]
46        {
47            return Self {
48                sum_i64: wasm_sum_i64,
49                min_i64: wasm_min_i64,
50                max_i64: wasm_max_i64,
51                name: "wasm-simd128",
52            };
53        }
54        #[allow(unreachable_code)]
55        Self {
56            sum_i64: scalar_sum_i64,
57            min_i64: scalar_min_i64,
58            max_i64: scalar_max_i64,
59            name: "scalar",
60        }
61    }
62}
63
64static I64_RUNTIME: std::sync::OnceLock<I64SimdRuntime> = std::sync::OnceLock::new();
65
66/// Get the global i64 SIMD runtime.
67pub fn i64_runtime() -> &'static I64SimdRuntime {
68    I64_RUNTIME.get_or_init(I64SimdRuntime::detect)
69}
70
71// ── Scalar fallback ────────────────────────────────────────────────
72
73fn scalar_sum_i64(values: &[i64]) -> i128 {
74    let mut sum: i128 = 0;
75    for &v in values {
76        sum += v as i128;
77    }
78    sum
79}
80
81fn scalar_min_i64(values: &[i64]) -> i64 {
82    let mut m = i64::MAX;
83    for &v in values {
84        if v < m {
85            m = v;
86        }
87    }
88    m
89}
90
91fn scalar_max_i64(values: &[i64]) -> i64 {
92    let mut m = i64::MIN;
93    for &v in values {
94        if v > m {
95            m = v;
96        }
97    }
98    m
99}
100
101// ── AVX-512 (x86_64) ──────────────────────────────────────────────
102
103#[cfg(target_arch = "x86_64")]
104#[target_feature(enable = "avx512f")]
105unsafe fn avx512_sum_i64_inner(values: &[i64]) -> i128 {
106    use std::arch::x86_64::*;
107    unsafe {
108        let mut acc = _mm512_setzero_si512();
109        let chunks = values.len() / 8;
110        let ptr = values.as_ptr();
111        for i in 0..chunks {
112            let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
113            acc = _mm512_add_epi64(acc, v);
114        }
115        // Horizontal sum: extract 8 i64 lanes into scalar.
116        let mut sum: i128 = 0;
117        let mut buf = [0i64; 8];
118        _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
119        for &v in &buf {
120            sum += v as i128;
121        }
122        for &v in &values[chunks * 8..] {
123            sum += v as i128;
124        }
125        sum
126    }
127}
128
129#[cfg(target_arch = "x86_64")]
130fn avx512_sum_i64(values: &[i64]) -> i128 {
131    if values.len() < 16 {
132        return scalar_sum_i64(values);
133    }
134    unsafe { avx512_sum_i64_inner(values) }
135}
136
137#[cfg(target_arch = "x86_64")]
138#[target_feature(enable = "avx512f")]
139unsafe fn avx512_min_i64_inner(values: &[i64]) -> i64 {
140    use std::arch::x86_64::*;
141    unsafe {
142        let mut acc = _mm512_set1_epi64(i64::MAX);
143        let chunks = values.len() / 8;
144        let ptr = values.as_ptr();
145        for i in 0..chunks {
146            let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
147            acc = _mm512_min_epi64(acc, v);
148        }
149        let mut buf = [0i64; 8];
150        _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
151        let mut m = buf[0];
152        for &v in &buf[1..] {
153            if v < m {
154                m = v;
155            }
156        }
157        for &v in &values[chunks * 8..] {
158            if v < m {
159                m = v;
160            }
161        }
162        m
163    }
164}
165
166#[cfg(target_arch = "x86_64")]
167fn avx512_min_i64(values: &[i64]) -> i64 {
168    if values.len() < 16 {
169        return scalar_min_i64(values);
170    }
171    unsafe { avx512_min_i64_inner(values) }
172}
173
174#[cfg(target_arch = "x86_64")]
175#[target_feature(enable = "avx512f")]
176unsafe fn avx512_max_i64_inner(values: &[i64]) -> i64 {
177    use std::arch::x86_64::*;
178    unsafe {
179        let mut acc = _mm512_set1_epi64(i64::MIN);
180        let chunks = values.len() / 8;
181        let ptr = values.as_ptr();
182        for i in 0..chunks {
183            let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
184            acc = _mm512_max_epi64(acc, v);
185        }
186        let mut buf = [0i64; 8];
187        _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
188        let mut m = buf[0];
189        for &v in &buf[1..] {
190            if v > m {
191                m = v;
192            }
193        }
194        for &v in &values[chunks * 8..] {
195            if v > m {
196                m = v;
197            }
198        }
199        m
200    }
201}
202
203#[cfg(target_arch = "x86_64")]
204fn avx512_max_i64(values: &[i64]) -> i64 {
205    if values.len() < 16 {
206        return scalar_max_i64(values);
207    }
208    unsafe { avx512_max_i64_inner(values) }
209}
210
211// ── AVX2 (x86_64) ─────────────────────────────────────────────────
212
213#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "avx2")]
215unsafe fn avx2_sum_i64_inner(values: &[i64]) -> i128 {
216    use std::arch::x86_64::*;
217    unsafe {
218        let mut acc = _mm256_setzero_si256();
219        let chunks = values.len() / 4;
220        let ptr = values.as_ptr();
221        for i in 0..chunks {
222            let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
223            acc = _mm256_add_epi64(acc, v);
224        }
225        let mut buf = [0i64; 4];
226        _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
227        let mut sum: i128 = 0;
228        for &v in &buf {
229            sum += v as i128;
230        }
231        for &v in &values[chunks * 4..] {
232            sum += v as i128;
233        }
234        sum
235    }
236}
237
238#[cfg(target_arch = "x86_64")]
239fn avx2_sum_i64(values: &[i64]) -> i128 {
240    if values.len() < 8 {
241        return scalar_sum_i64(values);
242    }
243    unsafe { avx2_sum_i64_inner(values) }
244}
245
246#[cfg(target_arch = "x86_64")]
247#[target_feature(enable = "avx2")]
248unsafe fn avx2_minmax_i64_inner(values: &[i64], is_min: bool) -> i64 {
249    // AVX2 doesn't have native _mm256_min_epi64, so process 4 at a time
250    // and compare manually with blend.
251    use std::arch::x86_64::*;
252    unsafe {
253        let init_val = if is_min { i64::MAX } else { i64::MIN };
254        let mut acc = _mm256_set1_epi64x(init_val);
255        let chunks = values.len() / 4;
256        let ptr = values.as_ptr();
257        for i in 0..chunks {
258            let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
259            let cmp = _mm256_cmpgt_epi64(acc, v);
260            // If is_min: pick smaller (where acc > v, pick v)
261            // If is_max: pick larger (where acc > v, keep acc)
262            if is_min {
263                acc = _mm256_blendv_epi8(acc, v, cmp);
264            } else {
265                acc = _mm256_blendv_epi8(v, acc, cmp);
266            }
267        }
268        let mut buf = [0i64; 4];
269        _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
270        let mut m = buf[0];
271        for &v in &buf[1..] {
272            if is_min {
273                if v < m {
274                    m = v;
275                }
276            } else if v > m {
277                m = v;
278            }
279        }
280        for &v in &values[chunks * 4..] {
281            if is_min {
282                if v < m {
283                    m = v;
284                }
285            } else if v > m {
286                m = v;
287            }
288        }
289        m
290    }
291}
292
293#[cfg(target_arch = "x86_64")]
294fn avx2_min_i64(values: &[i64]) -> i64 {
295    if values.len() < 8 {
296        return scalar_min_i64(values);
297    }
298    unsafe { avx2_minmax_i64_inner(values, true) }
299}
300
301#[cfg(target_arch = "x86_64")]
302fn avx2_max_i64(values: &[i64]) -> i64 {
303    if values.len() < 8 {
304        return scalar_max_i64(values);
305    }
306    unsafe { avx2_minmax_i64_inner(values, false) }
307}
308
309// ── NEON (AArch64) ─────────────────────────────────────────────────
310
311#[cfg(target_arch = "aarch64")]
312fn neon_sum_i64(values: &[i64]) -> i128 {
313    use std::arch::aarch64::*;
314    let chunks = values.len() / 2;
315    let ptr = values.as_ptr();
316    let mut acc = unsafe { vdupq_n_s64(0) };
317    for i in 0..chunks {
318        let v = unsafe { vld1q_s64(ptr.add(i * 2)) };
319        acc = unsafe { vaddq_s64(acc, v) };
320    }
321    let mut buf = [0i64; 2];
322    unsafe { vst1q_s64(buf.as_mut_ptr(), acc) };
323    let mut sum: i128 = buf[0] as i128 + buf[1] as i128;
324    for &v in &values[chunks * 2..] {
325        sum += v as i128;
326    }
327    sum
328}
329
330#[cfg(target_arch = "aarch64")]
331fn neon_min_i64(values: &[i64]) -> i64 {
332    // NEON doesn't have vminq_s64 on all targets; use scalar for correctness.
333    scalar_min_i64(values)
334}
335
336#[cfg(target_arch = "aarch64")]
337fn neon_max_i64(values: &[i64]) -> i64 {
338    scalar_max_i64(values)
339}
340
341// ── WASM SIMD128 ───────────────────────────────────────────────────
342
343#[cfg(target_arch = "wasm32")]
344#[cfg(target_feature = "simd128")]
345fn wasm_sum_i64(values: &[i64]) -> i128 {
346    use std::arch::wasm32::*;
347    let chunks = values.len() / 2;
348    let ptr = values.as_ptr() as *const v128;
349    let mut acc = i64x2_splat(0);
350    for i in 0..chunks {
351        let v = unsafe { v128_load(ptr.add(i)) };
352        acc = i64x2_add(acc, v);
353    }
354    let lo = i64x2_extract_lane::<0>(acc) as i128;
355    let hi = i64x2_extract_lane::<1>(acc) as i128;
356    let mut sum = lo + hi;
357    for &v in &values[chunks * 2..] {
358        sum += v as i128;
359    }
360    sum
361}
362
363#[cfg(target_arch = "wasm32")]
364#[cfg(not(target_feature = "simd128"))]
365fn wasm_sum_i64(values: &[i64]) -> i128 {
366    scalar_sum_i64(values)
367}
368
369#[cfg(target_arch = "wasm32")]
370fn wasm_min_i64(values: &[i64]) -> i64 {
371    scalar_min_i64(values)
372}
373
374#[cfg(target_arch = "wasm32")]
375fn wasm_max_i64(values: &[i64]) -> i64 {
376    scalar_max_i64(values)
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    #[test]
384    fn runtime_detects() {
385        let rt = i64_runtime();
386        assert!(!rt.name.is_empty());
387    }
388
389    #[test]
390    fn sum_correctness() {
391        let rt = i64_runtime();
392        let values: Vec<i64> = (0..1000).collect();
393        let expected: i128 = 999 * 1000 / 2;
394        assert_eq!((rt.sum_i64)(&values), expected);
395    }
396
397    #[test]
398    fn sum_overflow_safe() {
399        let rt = i64_runtime();
400        let values = vec![i64::MAX, i64::MAX, i64::MAX];
401        let result = (rt.sum_i64)(&values);
402        assert_eq!(result, 3 * i64::MAX as i128);
403    }
404
405    #[test]
406    fn min_max_correctness() {
407        let rt = i64_runtime();
408        let values: Vec<i64> = (-500..500).collect();
409        assert_eq!((rt.min_i64)(&values), -500);
410        assert_eq!((rt.max_i64)(&values), 499);
411    }
412
413    #[test]
414    fn empty_input() {
415        let rt = i64_runtime();
416        assert_eq!((rt.sum_i64)(&[]), 0);
417        assert_eq!((rt.min_i64)(&[]), i64::MAX);
418        assert_eq!((rt.max_i64)(&[]), i64::MIN);
419    }
420
421    #[test]
422    fn small_input() {
423        let rt = i64_runtime();
424        assert_eq!((rt.sum_i64)(&[1, 2, 3]), 6);
425        assert_eq!((rt.min_i64)(&[3, 1, 2]), 1);
426        assert_eq!((rt.max_i64)(&[1, 3, 2]), 3);
427    }
428}