nodedb_query/
simd_agg_i64.rs1pub struct I64SimdRuntime {
8 pub sum_i64: fn(&[i64]) -> i128,
10 pub min_i64: fn(&[i64]) -> i64,
11 pub max_i64: fn(&[i64]) -> i64,
12 pub name: &'static str,
13}
14
15impl I64SimdRuntime {
16 pub fn detect() -> Self {
17 #[cfg(target_arch = "x86_64")]
18 {
19 if std::is_x86_feature_detected!("avx512f") {
20 return Self {
21 sum_i64: avx512_sum_i64,
22 min_i64: avx512_min_i64,
23 max_i64: avx512_max_i64,
24 name: "avx512",
25 };
26 }
27 if std::is_x86_feature_detected!("avx2") {
28 return Self {
29 sum_i64: avx2_sum_i64,
30 min_i64: avx2_min_i64,
31 max_i64: avx2_max_i64,
32 name: "avx2",
33 };
34 }
35 }
36 #[cfg(target_arch = "aarch64")]
37 {
38 return Self {
39 sum_i64: neon_sum_i64,
40 min_i64: neon_min_i64,
41 max_i64: neon_max_i64,
42 name: "neon",
43 };
44 }
45 #[cfg(target_arch = "wasm32")]
46 {
47 return Self {
48 sum_i64: wasm_sum_i64,
49 min_i64: wasm_min_i64,
50 max_i64: wasm_max_i64,
51 name: "wasm-simd128",
52 };
53 }
54 #[allow(unreachable_code)]
55 Self {
56 sum_i64: scalar_sum_i64,
57 min_i64: scalar_min_i64,
58 max_i64: scalar_max_i64,
59 name: "scalar",
60 }
61 }
62}
63
64static I64_RUNTIME: std::sync::OnceLock<I64SimdRuntime> = std::sync::OnceLock::new();
65
66pub fn i64_runtime() -> &'static I64SimdRuntime {
68 I64_RUNTIME.get_or_init(I64SimdRuntime::detect)
69}
70
71fn scalar_sum_i64(values: &[i64]) -> i128 {
74 let mut sum: i128 = 0;
75 for &v in values {
76 sum += v as i128;
77 }
78 sum
79}
80
81fn scalar_min_i64(values: &[i64]) -> i64 {
82 let mut m = i64::MAX;
83 for &v in values {
84 if v < m {
85 m = v;
86 }
87 }
88 m
89}
90
91fn scalar_max_i64(values: &[i64]) -> i64 {
92 let mut m = i64::MIN;
93 for &v in values {
94 if v > m {
95 m = v;
96 }
97 }
98 m
99}
100
101#[cfg(target_arch = "x86_64")]
104#[target_feature(enable = "avx512f")]
105unsafe fn avx512_sum_i64_inner(values: &[i64]) -> i128 {
106 use std::arch::x86_64::*;
107 unsafe {
108 let mut acc = _mm512_setzero_si512();
109 let chunks = values.len() / 8;
110 let ptr = values.as_ptr();
111 for i in 0..chunks {
112 let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
113 acc = _mm512_add_epi64(acc, v);
114 }
115 let mut sum: i128 = 0;
117 let mut buf = [0i64; 8];
118 _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
119 for &v in &buf {
120 sum += v as i128;
121 }
122 for &v in &values[chunks * 8..] {
123 sum += v as i128;
124 }
125 sum
126 }
127}
128
129#[cfg(target_arch = "x86_64")]
130fn avx512_sum_i64(values: &[i64]) -> i128 {
131 if values.len() < 16 {
132 return scalar_sum_i64(values);
133 }
134 unsafe { avx512_sum_i64_inner(values) }
135}
136
137#[cfg(target_arch = "x86_64")]
138#[target_feature(enable = "avx512f")]
139unsafe fn avx512_min_i64_inner(values: &[i64]) -> i64 {
140 use std::arch::x86_64::*;
141 unsafe {
142 let mut acc = _mm512_set1_epi64(i64::MAX);
143 let chunks = values.len() / 8;
144 let ptr = values.as_ptr();
145 for i in 0..chunks {
146 let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
147 acc = _mm512_min_epi64(acc, v);
148 }
149 let mut buf = [0i64; 8];
150 _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
151 let mut m = buf[0];
152 for &v in &buf[1..] {
153 if v < m {
154 m = v;
155 }
156 }
157 for &v in &values[chunks * 8..] {
158 if v < m {
159 m = v;
160 }
161 }
162 m
163 }
164}
165
166#[cfg(target_arch = "x86_64")]
167fn avx512_min_i64(values: &[i64]) -> i64 {
168 if values.len() < 16 {
169 return scalar_min_i64(values);
170 }
171 unsafe { avx512_min_i64_inner(values) }
172}
173
174#[cfg(target_arch = "x86_64")]
175#[target_feature(enable = "avx512f")]
176unsafe fn avx512_max_i64_inner(values: &[i64]) -> i64 {
177 use std::arch::x86_64::*;
178 unsafe {
179 let mut acc = _mm512_set1_epi64(i64::MIN);
180 let chunks = values.len() / 8;
181 let ptr = values.as_ptr();
182 for i in 0..chunks {
183 let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
184 acc = _mm512_max_epi64(acc, v);
185 }
186 let mut buf = [0i64; 8];
187 _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
188 let mut m = buf[0];
189 for &v in &buf[1..] {
190 if v > m {
191 m = v;
192 }
193 }
194 for &v in &values[chunks * 8..] {
195 if v > m {
196 m = v;
197 }
198 }
199 m
200 }
201}
202
203#[cfg(target_arch = "x86_64")]
204fn avx512_max_i64(values: &[i64]) -> i64 {
205 if values.len() < 16 {
206 return scalar_max_i64(values);
207 }
208 unsafe { avx512_max_i64_inner(values) }
209}
210
211#[cfg(target_arch = "x86_64")]
214#[target_feature(enable = "avx2")]
215unsafe fn avx2_sum_i64_inner(values: &[i64]) -> i128 {
216 use std::arch::x86_64::*;
217 unsafe {
218 let mut acc = _mm256_setzero_si256();
219 let chunks = values.len() / 4;
220 let ptr = values.as_ptr();
221 for i in 0..chunks {
222 let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
223 acc = _mm256_add_epi64(acc, v);
224 }
225 let mut buf = [0i64; 4];
226 _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
227 let mut sum: i128 = 0;
228 for &v in &buf {
229 sum += v as i128;
230 }
231 for &v in &values[chunks * 4..] {
232 sum += v as i128;
233 }
234 sum
235 }
236}
237
238#[cfg(target_arch = "x86_64")]
239fn avx2_sum_i64(values: &[i64]) -> i128 {
240 if values.len() < 8 {
241 return scalar_sum_i64(values);
242 }
243 unsafe { avx2_sum_i64_inner(values) }
244}
245
246#[cfg(target_arch = "x86_64")]
247#[target_feature(enable = "avx2")]
248unsafe fn avx2_minmax_i64_inner(values: &[i64], is_min: bool) -> i64 {
249 use std::arch::x86_64::*;
252 unsafe {
253 let init_val = if is_min { i64::MAX } else { i64::MIN };
254 let mut acc = _mm256_set1_epi64x(init_val);
255 let chunks = values.len() / 4;
256 let ptr = values.as_ptr();
257 for i in 0..chunks {
258 let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
259 let cmp = _mm256_cmpgt_epi64(acc, v);
260 if is_min {
263 acc = _mm256_blendv_epi8(acc, v, cmp);
264 } else {
265 acc = _mm256_blendv_epi8(v, acc, cmp);
266 }
267 }
268 let mut buf = [0i64; 4];
269 _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
270 let mut m = buf[0];
271 for &v in &buf[1..] {
272 if is_min {
273 if v < m {
274 m = v;
275 }
276 } else if v > m {
277 m = v;
278 }
279 }
280 for &v in &values[chunks * 4..] {
281 if is_min {
282 if v < m {
283 m = v;
284 }
285 } else if v > m {
286 m = v;
287 }
288 }
289 m
290 }
291}
292
293#[cfg(target_arch = "x86_64")]
294fn avx2_min_i64(values: &[i64]) -> i64 {
295 if values.len() < 8 {
296 return scalar_min_i64(values);
297 }
298 unsafe { avx2_minmax_i64_inner(values, true) }
299}
300
301#[cfg(target_arch = "x86_64")]
302fn avx2_max_i64(values: &[i64]) -> i64 {
303 if values.len() < 8 {
304 return scalar_max_i64(values);
305 }
306 unsafe { avx2_minmax_i64_inner(values, false) }
307}
308
309#[cfg(target_arch = "aarch64")]
312fn neon_sum_i64(values: &[i64]) -> i128 {
313 use std::arch::aarch64::*;
314 let chunks = values.len() / 2;
315 let ptr = values.as_ptr();
316 let mut acc = unsafe { vdupq_n_s64(0) };
317 for i in 0..chunks {
318 let v = unsafe { vld1q_s64(ptr.add(i * 2)) };
319 acc = unsafe { vaddq_s64(acc, v) };
320 }
321 let mut buf = [0i64; 2];
322 unsafe { vst1q_s64(buf.as_mut_ptr(), acc) };
323 let mut sum: i128 = buf[0] as i128 + buf[1] as i128;
324 for &v in &values[chunks * 2..] {
325 sum += v as i128;
326 }
327 sum
328}
329
330#[cfg(target_arch = "aarch64")]
331fn neon_min_i64(values: &[i64]) -> i64 {
332 scalar_min_i64(values)
334}
335
336#[cfg(target_arch = "aarch64")]
337fn neon_max_i64(values: &[i64]) -> i64 {
338 scalar_max_i64(values)
339}
340
341#[cfg(target_arch = "wasm32")]
344#[cfg(target_feature = "simd128")]
345fn wasm_sum_i64(values: &[i64]) -> i128 {
346 use std::arch::wasm32::*;
347 let chunks = values.len() / 2;
348 let ptr = values.as_ptr() as *const v128;
349 let mut acc = i64x2_splat(0);
350 for i in 0..chunks {
351 let v = unsafe { v128_load(ptr.add(i)) };
352 acc = i64x2_add(acc, v);
353 }
354 let lo = i64x2_extract_lane::<0>(acc) as i128;
355 let hi = i64x2_extract_lane::<1>(acc) as i128;
356 let mut sum = lo + hi;
357 for &v in &values[chunks * 2..] {
358 sum += v as i128;
359 }
360 sum
361}
362
363#[cfg(target_arch = "wasm32")]
364#[cfg(not(target_feature = "simd128"))]
365fn wasm_sum_i64(values: &[i64]) -> i128 {
366 scalar_sum_i64(values)
367}
368
369#[cfg(target_arch = "wasm32")]
370fn wasm_min_i64(values: &[i64]) -> i64 {
371 scalar_min_i64(values)
372}
373
374#[cfg(target_arch = "wasm32")]
375fn wasm_max_i64(values: &[i64]) -> i64 {
376 scalar_max_i64(values)
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn runtime_detects() {
385 let rt = i64_runtime();
386 assert!(!rt.name.is_empty());
387 }
388
389 #[test]
390 fn sum_correctness() {
391 let rt = i64_runtime();
392 let values: Vec<i64> = (0..1000).collect();
393 let expected: i128 = 999 * 1000 / 2;
394 assert_eq!((rt.sum_i64)(&values), expected);
395 }
396
397 #[test]
398 fn sum_overflow_safe() {
399 let rt = i64_runtime();
400 let values = vec![i64::MAX, i64::MAX, i64::MAX];
401 let result = (rt.sum_i64)(&values);
402 assert_eq!(result, 3 * i64::MAX as i128);
403 }
404
405 #[test]
406 fn min_max_correctness() {
407 let rt = i64_runtime();
408 let values: Vec<i64> = (-500..500).collect();
409 assert_eq!((rt.min_i64)(&values), -500);
410 assert_eq!((rt.max_i64)(&values), 499);
411 }
412
413 #[test]
414 fn empty_input() {
415 let rt = i64_runtime();
416 assert_eq!((rt.sum_i64)(&[]), 0);
417 assert_eq!((rt.min_i64)(&[]), i64::MAX);
418 assert_eq!((rt.max_i64)(&[]), i64::MIN);
419 }
420
421 #[test]
422 fn small_input() {
423 let rt = i64_runtime();
424 assert_eq!((rt.sum_i64)(&[1, 2, 3]), 6);
425 assert_eq!((rt.min_i64)(&[3, 1, 2]), 1);
426 assert_eq!((rt.max_i64)(&[1, 3, 2]), 3);
427 }
428}