nodedb_query/
simd_agg_i64.rs1pub struct I64SimdRuntime {
10 pub sum_i64: fn(&[i64]) -> i128,
12 pub min_i64: fn(&[i64]) -> i64,
13 pub max_i64: fn(&[i64]) -> i64,
14 pub name: &'static str,
15}
16
17impl I64SimdRuntime {
18 pub fn detect() -> Self {
19 #[cfg(target_arch = "x86_64")]
20 {
21 if std::is_x86_feature_detected!("avx512f") {
22 return Self {
23 sum_i64: avx512_sum_i64,
24 min_i64: avx512_min_i64,
25 max_i64: avx512_max_i64,
26 name: "avx512",
27 };
28 }
29 if std::is_x86_feature_detected!("avx2") {
30 return Self {
31 sum_i64: avx2_sum_i64,
32 min_i64: avx2_min_i64,
33 max_i64: avx2_max_i64,
34 name: "avx2",
35 };
36 }
37 }
38 #[cfg(target_arch = "aarch64")]
39 {
40 return Self {
41 sum_i64: neon_sum_i64,
42 min_i64: neon_min_i64,
43 max_i64: neon_max_i64,
44 name: "neon",
45 };
46 }
47 #[cfg(target_arch = "wasm32")]
48 {
49 return Self {
50 sum_i64: wasm_sum_i64,
51 min_i64: wasm_min_i64,
52 max_i64: wasm_max_i64,
53 name: "wasm-simd128",
54 };
55 }
56 #[allow(unreachable_code)]
57 Self {
58 sum_i64: scalar_sum_i64,
59 min_i64: scalar_min_i64,
60 max_i64: scalar_max_i64,
61 name: "scalar",
62 }
63 }
64}
65
66static I64_RUNTIME: std::sync::OnceLock<I64SimdRuntime> = std::sync::OnceLock::new();
67
68pub fn i64_runtime() -> &'static I64SimdRuntime {
70 I64_RUNTIME.get_or_init(I64SimdRuntime::detect)
71}
72
73fn scalar_sum_i64(values: &[i64]) -> i128 {
76 let mut sum: i128 = 0;
77 for &v in values {
78 sum += v as i128;
79 }
80 sum
81}
82
83fn scalar_min_i64(values: &[i64]) -> i64 {
84 let mut m = i64::MAX;
85 for &v in values {
86 if v < m {
87 m = v;
88 }
89 }
90 m
91}
92
93fn scalar_max_i64(values: &[i64]) -> i64 {
94 let mut m = i64::MIN;
95 for &v in values {
96 if v > m {
97 m = v;
98 }
99 }
100 m
101}
102
103#[cfg(target_arch = "x86_64")]
106#[target_feature(enable = "avx512f")]
107unsafe fn avx512_sum_i64_inner(values: &[i64]) -> i128 {
108 use std::arch::x86_64::*;
109 unsafe {
110 let mut acc = _mm512_setzero_si512();
111 let chunks = values.len() / 8;
112 let ptr = values.as_ptr();
113 for i in 0..chunks {
114 let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
115 acc = _mm512_add_epi64(acc, v);
116 }
117 let mut sum: i128 = 0;
119 let mut buf = [0i64; 8];
120 _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
121 for &v in &buf {
122 sum += v as i128;
123 }
124 for &v in &values[chunks * 8..] {
125 sum += v as i128;
126 }
127 sum
128 }
129}
130
131#[cfg(target_arch = "x86_64")]
132fn avx512_sum_i64(values: &[i64]) -> i128 {
133 if values.len() < 16 {
134 return scalar_sum_i64(values);
135 }
136 unsafe { avx512_sum_i64_inner(values) }
137}
138
139#[cfg(target_arch = "x86_64")]
140#[target_feature(enable = "avx512f")]
141unsafe fn avx512_min_i64_inner(values: &[i64]) -> i64 {
142 use std::arch::x86_64::*;
143 unsafe {
144 let mut acc = _mm512_set1_epi64(i64::MAX);
145 let chunks = values.len() / 8;
146 let ptr = values.as_ptr();
147 for i in 0..chunks {
148 let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
149 acc = _mm512_min_epi64(acc, v);
150 }
151 let mut buf = [0i64; 8];
152 _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
153 let mut m = buf[0];
154 for &v in &buf[1..] {
155 if v < m {
156 m = v;
157 }
158 }
159 for &v in &values[chunks * 8..] {
160 if v < m {
161 m = v;
162 }
163 }
164 m
165 }
166}
167
168#[cfg(target_arch = "x86_64")]
169fn avx512_min_i64(values: &[i64]) -> i64 {
170 if values.len() < 16 {
171 return scalar_min_i64(values);
172 }
173 unsafe { avx512_min_i64_inner(values) }
174}
175
176#[cfg(target_arch = "x86_64")]
177#[target_feature(enable = "avx512f")]
178unsafe fn avx512_max_i64_inner(values: &[i64]) -> i64 {
179 use std::arch::x86_64::*;
180 unsafe {
181 let mut acc = _mm512_set1_epi64(i64::MIN);
182 let chunks = values.len() / 8;
183 let ptr = values.as_ptr();
184 for i in 0..chunks {
185 let v = _mm512_loadu_si512(ptr.add(i * 8).cast());
186 acc = _mm512_max_epi64(acc, v);
187 }
188 let mut buf = [0i64; 8];
189 _mm512_storeu_si512(buf.as_mut_ptr().cast(), acc);
190 let mut m = buf[0];
191 for &v in &buf[1..] {
192 if v > m {
193 m = v;
194 }
195 }
196 for &v in &values[chunks * 8..] {
197 if v > m {
198 m = v;
199 }
200 }
201 m
202 }
203}
204
205#[cfg(target_arch = "x86_64")]
206fn avx512_max_i64(values: &[i64]) -> i64 {
207 if values.len() < 16 {
208 return scalar_max_i64(values);
209 }
210 unsafe { avx512_max_i64_inner(values) }
211}
212
213#[cfg(target_arch = "x86_64")]
216#[target_feature(enable = "avx2")]
217unsafe fn avx2_sum_i64_inner(values: &[i64]) -> i128 {
218 use std::arch::x86_64::*;
219 unsafe {
220 let mut acc = _mm256_setzero_si256();
221 let chunks = values.len() / 4;
222 let ptr = values.as_ptr();
223 for i in 0..chunks {
224 let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
225 acc = _mm256_add_epi64(acc, v);
226 }
227 let mut buf = [0i64; 4];
228 _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
229 let mut sum: i128 = 0;
230 for &v in &buf {
231 sum += v as i128;
232 }
233 for &v in &values[chunks * 4..] {
234 sum += v as i128;
235 }
236 sum
237 }
238}
239
240#[cfg(target_arch = "x86_64")]
241fn avx2_sum_i64(values: &[i64]) -> i128 {
242 if values.len() < 8 {
243 return scalar_sum_i64(values);
244 }
245 unsafe { avx2_sum_i64_inner(values) }
246}
247
248#[cfg(target_arch = "x86_64")]
249#[target_feature(enable = "avx2")]
250unsafe fn avx2_minmax_i64_inner(values: &[i64], is_min: bool) -> i64 {
251 use std::arch::x86_64::*;
254 unsafe {
255 let init_val = if is_min { i64::MAX } else { i64::MIN };
256 let mut acc = _mm256_set1_epi64x(init_val);
257 let chunks = values.len() / 4;
258 let ptr = values.as_ptr();
259 for i in 0..chunks {
260 let v = _mm256_loadu_si256(ptr.add(i * 4).cast());
261 let cmp = _mm256_cmpgt_epi64(acc, v);
262 if is_min {
265 acc = _mm256_blendv_epi8(acc, v, cmp);
266 } else {
267 acc = _mm256_blendv_epi8(v, acc, cmp);
268 }
269 }
270 let mut buf = [0i64; 4];
271 _mm256_storeu_si256(buf.as_mut_ptr().cast(), acc);
272 let mut m = buf[0];
273 for &v in &buf[1..] {
274 if is_min {
275 if v < m {
276 m = v;
277 }
278 } else if v > m {
279 m = v;
280 }
281 }
282 for &v in &values[chunks * 4..] {
283 if is_min {
284 if v < m {
285 m = v;
286 }
287 } else if v > m {
288 m = v;
289 }
290 }
291 m
292 }
293}
294
295#[cfg(target_arch = "x86_64")]
296fn avx2_min_i64(values: &[i64]) -> i64 {
297 if values.len() < 8 {
298 return scalar_min_i64(values);
299 }
300 unsafe { avx2_minmax_i64_inner(values, true) }
301}
302
303#[cfg(target_arch = "x86_64")]
304fn avx2_max_i64(values: &[i64]) -> i64 {
305 if values.len() < 8 {
306 return scalar_max_i64(values);
307 }
308 unsafe { avx2_minmax_i64_inner(values, false) }
309}
310
311#[cfg(target_arch = "aarch64")]
314fn neon_sum_i64(values: &[i64]) -> i128 {
315 use std::arch::aarch64::*;
316 let chunks = values.len() / 2;
317 let ptr = values.as_ptr();
318 let mut acc = unsafe { vdupq_n_s64(0) };
319 for i in 0..chunks {
320 let v = unsafe { vld1q_s64(ptr.add(i * 2)) };
321 acc = unsafe { vaddq_s64(acc, v) };
322 }
323 let mut buf = [0i64; 2];
324 unsafe { vst1q_s64(buf.as_mut_ptr(), acc) };
325 let mut sum: i128 = buf[0] as i128 + buf[1] as i128;
326 for &v in &values[chunks * 2..] {
327 sum += v as i128;
328 }
329 sum
330}
331
332#[cfg(target_arch = "aarch64")]
333fn neon_min_i64(values: &[i64]) -> i64 {
334 scalar_min_i64(values)
336}
337
338#[cfg(target_arch = "aarch64")]
339fn neon_max_i64(values: &[i64]) -> i64 {
340 scalar_max_i64(values)
341}
342
343#[cfg(target_arch = "wasm32")]
346#[cfg(target_feature = "simd128")]
347fn wasm_sum_i64(values: &[i64]) -> i128 {
348 use std::arch::wasm32::*;
349 let chunks = values.len() / 2;
350 let ptr = values.as_ptr() as *const v128;
351 let mut acc = i64x2_splat(0);
352 for i in 0..chunks {
353 let v = unsafe { v128_load(ptr.add(i)) };
354 acc = i64x2_add(acc, v);
355 }
356 let lo = i64x2_extract_lane::<0>(acc) as i128;
357 let hi = i64x2_extract_lane::<1>(acc) as i128;
358 let mut sum = lo + hi;
359 for &v in &values[chunks * 2..] {
360 sum += v as i128;
361 }
362 sum
363}
364
365#[cfg(target_arch = "wasm32")]
366#[cfg(not(target_feature = "simd128"))]
367fn wasm_sum_i64(values: &[i64]) -> i128 {
368 scalar_sum_i64(values)
369}
370
371#[cfg(target_arch = "wasm32")]
372fn wasm_min_i64(values: &[i64]) -> i64 {
373 scalar_min_i64(values)
374}
375
376#[cfg(target_arch = "wasm32")]
377fn wasm_max_i64(values: &[i64]) -> i64 {
378 scalar_max_i64(values)
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn runtime_detects() {
387 let rt = i64_runtime();
388 assert!(!rt.name.is_empty());
389 }
390
391 #[test]
392 fn sum_correctness() {
393 let rt = i64_runtime();
394 let values: Vec<i64> = (0..1000).collect();
395 let expected: i128 = 999 * 1000 / 2;
396 assert_eq!((rt.sum_i64)(&values), expected);
397 }
398
399 #[test]
400 fn sum_overflow_safe() {
401 let rt = i64_runtime();
402 let values = vec![i64::MAX, i64::MAX, i64::MAX];
403 let result = (rt.sum_i64)(&values);
404 assert_eq!(result, 3 * i64::MAX as i128);
405 }
406
407 #[test]
408 fn min_max_correctness() {
409 let rt = i64_runtime();
410 let values: Vec<i64> = (-500..500).collect();
411 assert_eq!((rt.min_i64)(&values), -500);
412 assert_eq!((rt.max_i64)(&values), 499);
413 }
414
415 #[test]
416 fn empty_input() {
417 let rt = i64_runtime();
418 assert_eq!((rt.sum_i64)(&[]), 0);
419 assert_eq!((rt.min_i64)(&[]), i64::MAX);
420 assert_eq!((rt.max_i64)(&[]), i64::MIN);
421 }
422
423 #[test]
424 fn small_input() {
425 let rt = i64_runtime();
426 assert_eq!((rt.sum_i64)(&[1, 2, 3]), 6);
427 assert_eq!((rt.min_i64)(&[3, 1, 2]), 1);
428 assert_eq!((rt.max_i64)(&[1, 3, 2]), 3);
429 }
430}