1pub struct TsSimdRuntime {
13 pub sum_f64: fn(&[f64]) -> f64,
14 pub min_f64: fn(&[f64]) -> f64,
15 pub max_f64: fn(&[f64]) -> f64,
16 pub range_filter_i64: fn(&[i64], i64, i64) -> Vec<u32>,
18 pub name: &'static str,
19}
20
21impl TsSimdRuntime {
22 pub fn detect() -> Self {
23 #[cfg(target_arch = "x86_64")]
24 {
25 if std::is_x86_feature_detected!("avx512f") {
26 return Self {
27 sum_f64: avx512_sum_f64,
28 min_f64: avx512_min_f64,
29 max_f64: avx512_max_f64,
30 range_filter_i64: avx512_range_filter_i64,
31 name: "avx512",
32 };
33 }
34 if std::is_x86_feature_detected!("avx2") {
35 return Self {
36 sum_f64: avx2_sum_f64,
37 min_f64: avx2_min_f64,
38 max_f64: avx2_max_f64,
39 range_filter_i64: avx2_range_filter_i64,
40 name: "avx2",
41 };
42 }
43 }
44 #[cfg(target_arch = "aarch64")]
45 {
46 return Self {
47 sum_f64: neon_sum_f64,
48 min_f64: neon_min_f64,
49 max_f64: neon_max_f64,
50 range_filter_i64: scalar_range_filter_i64,
51 name: "neon",
52 };
53 }
54 #[cfg(target_arch = "wasm32")]
55 {
56 return Self {
57 sum_f64: wasm_sum_f64,
58 min_f64: wasm_min_f64,
59 max_f64: wasm_max_f64,
60 range_filter_i64: scalar_range_filter_i64,
61 name: "wasm-simd128",
62 };
63 }
64 #[allow(unreachable_code)]
65 Self {
66 sum_f64: scalar_sum_f64,
67 min_f64: scalar_min_f64,
68 max_f64: scalar_max_f64,
69 range_filter_i64: scalar_range_filter_i64,
70 name: "scalar",
71 }
72 }
73}
74
75static TS_RUNTIME: std::sync::OnceLock<TsSimdRuntime> = std::sync::OnceLock::new();
76
77pub fn ts_runtime() -> &'static TsSimdRuntime {
79 TS_RUNTIME.get_or_init(TsSimdRuntime::detect)
80}
81
82fn scalar_sum_f64(values: &[f64]) -> f64 {
85 let mut sum = 0.0f64;
87 let mut comp = 0.0f64;
88 for &v in values {
89 let y = v - comp;
90 let t = sum + y;
91 comp = (t - sum) - y;
92 sum = t;
93 }
94 sum
95}
96
97fn scalar_min_f64(values: &[f64]) -> f64 {
98 let mut m = f64::INFINITY;
99 for &v in values {
100 if v < m {
101 m = v;
102 }
103 }
104 m
105}
106
107fn scalar_max_f64(values: &[f64]) -> f64 {
108 let mut m = f64::NEG_INFINITY;
109 for &v in values {
110 if v > m {
111 m = v;
112 }
113 }
114 m
115}
116
117fn scalar_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
118 let mut out = Vec::new();
119 for (i, &v) in values.iter().enumerate() {
120 if v >= min && v <= max {
121 out.push(i as u32);
122 }
123 }
124 out
125}
126
127#[cfg(target_arch = "x86_64")]
130#[target_feature(enable = "avx512f")]
131unsafe fn avx512_sum_f64_inner(values: &[f64]) -> f64 {
132 use std::arch::x86_64::*;
133 unsafe {
134 let mut acc = _mm512_setzero_pd();
135 let chunks = values.len() / 8;
136 let ptr = values.as_ptr();
137 for i in 0..chunks {
138 let v = _mm512_loadu_pd(ptr.add(i * 8));
139 acc = _mm512_add_pd(acc, v);
140 }
141 let mut sum = _mm512_reduce_add_pd(acc);
142 for &v in &values[chunks * 8..] {
143 sum += v;
144 }
145 sum
146 }
147}
148
149#[cfg(target_arch = "x86_64")]
150fn avx512_sum_f64(values: &[f64]) -> f64 {
151 if values.len() < 16 {
152 return scalar_sum_f64(values);
153 }
154 unsafe { avx512_sum_f64_inner(values) }
155}
156
157#[cfg(target_arch = "x86_64")]
158#[target_feature(enable = "avx512f")]
159unsafe fn avx512_min_f64_inner(values: &[f64]) -> f64 {
160 use std::arch::x86_64::*;
161 unsafe {
162 let mut acc = _mm512_set1_pd(f64::INFINITY);
163 let chunks = values.len() / 8;
164 let ptr = values.as_ptr();
165 for i in 0..chunks {
166 let v = _mm512_loadu_pd(ptr.add(i * 8));
167 acc = _mm512_min_pd(acc, v);
168 }
169 let mut m = _mm512_reduce_min_pd(acc);
170 for &v in &values[chunks * 8..] {
171 if v < m {
172 m = v;
173 }
174 }
175 m
176 }
177}
178
179#[cfg(target_arch = "x86_64")]
180fn avx512_min_f64(values: &[f64]) -> f64 {
181 if values.len() < 16 {
182 return scalar_min_f64(values);
183 }
184 unsafe { avx512_min_f64_inner(values) }
185}
186
187#[cfg(target_arch = "x86_64")]
188#[target_feature(enable = "avx512f")]
189unsafe fn avx512_max_f64_inner(values: &[f64]) -> f64 {
190 use std::arch::x86_64::*;
191 unsafe {
192 let mut acc = _mm512_set1_pd(f64::NEG_INFINITY);
193 let chunks = values.len() / 8;
194 let ptr = values.as_ptr();
195 for i in 0..chunks {
196 let v = _mm512_loadu_pd(ptr.add(i * 8));
197 acc = _mm512_max_pd(acc, v);
198 }
199 let mut m = _mm512_reduce_max_pd(acc);
200 for &v in &values[chunks * 8..] {
201 if v > m {
202 m = v;
203 }
204 }
205 m
206 }
207}
208
209#[cfg(target_arch = "x86_64")]
210fn avx512_max_f64(values: &[f64]) -> f64 {
211 if values.len() < 16 {
212 return scalar_max_f64(values);
213 }
214 unsafe { avx512_max_f64_inner(values) }
215}
216
217#[cfg(target_arch = "x86_64")]
218fn avx512_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
219 scalar_range_filter_i64(values, min, max)
222}
223
224#[cfg(target_arch = "x86_64")]
227#[target_feature(enable = "avx2")]
228unsafe fn avx2_sum_f64_inner(values: &[f64]) -> f64 {
229 use std::arch::x86_64::*;
230 unsafe {
231 let mut acc = _mm256_setzero_pd();
232 let chunks = values.len() / 4;
233 let ptr = values.as_ptr();
234 for i in 0..chunks {
235 let v = _mm256_loadu_pd(ptr.add(i * 4));
236 acc = _mm256_add_pd(acc, v);
237 }
238 let hi = _mm256_extractf128_pd(acc, 1);
239 let lo = _mm256_castpd256_pd128(acc);
240 let sum2 = _mm_add_pd(lo, hi);
241 let hi2 = _mm_unpackhi_pd(sum2, sum2);
242 let mut sum = _mm_cvtsd_f64(_mm_add_sd(sum2, hi2));
243 for &v in &values[chunks * 4..] {
244 sum += v;
245 }
246 sum
247 }
248}
249
250#[cfg(target_arch = "x86_64")]
251fn avx2_sum_f64(values: &[f64]) -> f64 {
252 if values.len() < 8 {
253 return scalar_sum_f64(values);
254 }
255 unsafe { avx2_sum_f64_inner(values) }
256}
257
258#[cfg(target_arch = "x86_64")]
259#[target_feature(enable = "avx2")]
260unsafe fn avx2_min_f64_inner(values: &[f64]) -> f64 {
261 use std::arch::x86_64::*;
262 unsafe {
263 let mut acc = _mm256_set1_pd(f64::INFINITY);
264 let chunks = values.len() / 4;
265 let ptr = values.as_ptr();
266 for i in 0..chunks {
267 let v = _mm256_loadu_pd(ptr.add(i * 4));
268 acc = _mm256_min_pd(acc, v);
269 }
270 let hi = _mm256_extractf128_pd(acc, 1);
271 let lo = _mm256_castpd256_pd128(acc);
272 let min2 = _mm_min_pd(lo, hi);
273 let hi2 = _mm_unpackhi_pd(min2, min2);
274 let mut m = _mm_cvtsd_f64(_mm_min_sd(min2, hi2));
275 for &v in &values[chunks * 4..] {
276 if v < m {
277 m = v;
278 }
279 }
280 m
281 }
282}
283
284#[cfg(target_arch = "x86_64")]
285fn avx2_min_f64(values: &[f64]) -> f64 {
286 if values.len() < 8 {
287 return scalar_min_f64(values);
288 }
289 unsafe { avx2_min_f64_inner(values) }
290}
291
292#[cfg(target_arch = "x86_64")]
293#[target_feature(enable = "avx2")]
294unsafe fn avx2_max_f64_inner(values: &[f64]) -> f64 {
295 use std::arch::x86_64::*;
296 unsafe {
297 let mut acc = _mm256_set1_pd(f64::NEG_INFINITY);
298 let chunks = values.len() / 4;
299 let ptr = values.as_ptr();
300 for i in 0..chunks {
301 let v = _mm256_loadu_pd(ptr.add(i * 4));
302 acc = _mm256_max_pd(acc, v);
303 }
304 let hi = _mm256_extractf128_pd(acc, 1);
305 let lo = _mm256_castpd256_pd128(acc);
306 let max2 = _mm_max_pd(lo, hi);
307 let hi2 = _mm_unpackhi_pd(max2, max2);
308 let mut m = _mm_cvtsd_f64(_mm_max_sd(max2, hi2));
309 for &v in &values[chunks * 4..] {
310 if v > m {
311 m = v;
312 }
313 }
314 m
315 }
316}
317
318#[cfg(target_arch = "x86_64")]
319fn avx2_max_f64(values: &[f64]) -> f64 {
320 if values.len() < 8 {
321 return scalar_max_f64(values);
322 }
323 unsafe { avx2_max_f64_inner(values) }
324}
325
326#[cfg(target_arch = "x86_64")]
327fn avx2_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
328 scalar_range_filter_i64(values, min, max)
329}
330
331#[cfg(target_arch = "aarch64")]
334fn neon_sum_f64(values: &[f64]) -> f64 {
335 use std::arch::aarch64::*;
336 if values.len() < 4 {
337 return scalar_sum_f64(values);
338 }
339 unsafe {
340 let mut acc = vdupq_n_f64(0.0);
341 let chunks = values.len() / 2;
342 let ptr = values.as_ptr();
343 for i in 0..chunks {
344 let v = vld1q_f64(ptr.add(i * 2));
345 acc = vaddq_f64(acc, v);
346 }
347 let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
348 for &v in &values[chunks * 2..] {
349 sum += v;
350 }
351 sum
352 }
353}
354
355#[cfg(target_arch = "aarch64")]
356fn neon_min_f64(values: &[f64]) -> f64 {
357 use std::arch::aarch64::*;
358 if values.len() < 4 {
359 return scalar_min_f64(values);
360 }
361 unsafe {
362 let mut acc = vdupq_n_f64(f64::INFINITY);
363 let chunks = values.len() / 2;
364 let ptr = values.as_ptr();
365 for i in 0..chunks {
366 let v = vld1q_f64(ptr.add(i * 2));
367 acc = vminq_f64(acc, v);
368 }
369 let mut m = vgetq_lane_f64(acc, 0).min(vgetq_lane_f64(acc, 1));
370 for &v in &values[chunks * 2..] {
371 if v < m {
372 m = v;
373 }
374 }
375 m
376 }
377}
378
379#[cfg(target_arch = "aarch64")]
380fn neon_max_f64(values: &[f64]) -> f64 {
381 use std::arch::aarch64::*;
382 if values.len() < 4 {
383 return scalar_max_f64(values);
384 }
385 unsafe {
386 let mut acc = vdupq_n_f64(f64::NEG_INFINITY);
387 let chunks = values.len() / 2;
388 let ptr = values.as_ptr();
389 for i in 0..chunks {
390 let v = vld1q_f64(ptr.add(i * 2));
391 acc = vmaxq_f64(acc, v);
392 }
393 let mut m = vgetq_lane_f64(acc, 0).max(vgetq_lane_f64(acc, 1));
394 for &v in &values[chunks * 2..] {
395 if v > m {
396 m = v;
397 }
398 }
399 m
400 }
401}
402
403#[cfg(target_arch = "wasm32")]
406#[target_feature(enable = "simd128")]
407unsafe fn wasm_sum_f64_inner(values: &[f64]) -> f64 {
408 use core::arch::wasm32::*;
409 let mut acc = f64x2_splat(0.0);
410 let chunks = values.len() / 2;
411 let ptr = values.as_ptr();
412 for i in 0..chunks {
413 let v = v128_load(ptr.add(i * 2) as *const v128);
414 acc = f64x2_add(acc, v);
415 }
416 let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
417 for &v in &values[chunks * 2..] {
418 sum += v;
419 }
420 sum
421}
422
423#[cfg(target_arch = "wasm32")]
424fn wasm_sum_f64(values: &[f64]) -> f64 {
425 if values.len() < 4 {
426 return scalar_sum_f64(values);
427 }
428 unsafe { wasm_sum_f64_inner(values) }
429}
430
431#[cfg(target_arch = "wasm32")]
432#[target_feature(enable = "simd128")]
433unsafe fn wasm_min_f64_inner(values: &[f64]) -> f64 {
434 use core::arch::wasm32::*;
435 let mut acc = f64x2_splat(f64::INFINITY);
436 let chunks = values.len() / 2;
437 let ptr = values.as_ptr();
438 for i in 0..chunks {
439 let v = v128_load(ptr.add(i * 2) as *const v128);
440 acc = f64x2_min(acc, v);
441 }
442 let mut m = f64x2_extract_lane::<0>(acc).min(f64x2_extract_lane::<1>(acc));
443 for &v in &values[chunks * 2..] {
444 if v < m {
445 m = v;
446 }
447 }
448 m
449}
450
451#[cfg(target_arch = "wasm32")]
452fn wasm_min_f64(values: &[f64]) -> f64 {
453 if values.len() < 4 {
454 return scalar_min_f64(values);
455 }
456 unsafe { wasm_min_f64_inner(values) }
457}
458
459#[cfg(target_arch = "wasm32")]
460#[target_feature(enable = "simd128")]
461unsafe fn wasm_max_f64_inner(values: &[f64]) -> f64 {
462 use core::arch::wasm32::*;
463 let mut acc = f64x2_splat(f64::NEG_INFINITY);
464 let chunks = values.len() / 2;
465 let ptr = values.as_ptr();
466 for i in 0..chunks {
467 let v = v128_load(ptr.add(i * 2) as *const v128);
468 acc = f64x2_max(acc, v);
469 }
470 let mut m = f64x2_extract_lane::<0>(acc).max(f64x2_extract_lane::<1>(acc));
471 for &v in &values[chunks * 2..] {
472 if v > m {
473 m = v;
474 }
475 }
476 m
477}
478
479#[cfg(target_arch = "wasm32")]
480fn wasm_max_f64(values: &[f64]) -> f64 {
481 if values.len() < 4 {
482 return scalar_max_f64(values);
483 }
484 unsafe { wasm_max_f64_inner(values) }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn runtime_detects() {
493 let rt = ts_runtime();
494 assert!(!rt.name.is_empty());
495 assert!(
497 ["avx512", "avx2", "neon", "wasm-simd128", "scalar"].contains(&rt.name),
498 "unexpected runtime: {}",
499 rt.name
500 );
501 }
502
503 #[test]
504 fn sum_correctness() {
505 let rt = ts_runtime();
506 let values: Vec<f64> = (0..1000).map(|i| i as f64).collect();
507 let expected = 999.0 * 1000.0 / 2.0; let result = (rt.sum_f64)(&values);
509 assert!(
510 (result - expected).abs() < 1e-6,
511 "sum: got {result}, expected {expected}"
512 );
513 }
514
515 #[test]
516 fn min_max_correctness() {
517 let rt = ts_runtime();
518 let values: Vec<f64> = (0..1000).map(|i| (i as f64) - 500.0).collect();
519 assert!(((rt.min_f64)(&values) - (-500.0)).abs() < f64::EPSILON);
520 assert!(((rt.max_f64)(&values) - 499.0).abs() < f64::EPSILON);
521 }
522
523 #[test]
524 fn range_filter_correctness() {
525 let rt = ts_runtime();
526 let values: Vec<i64> = (0..100).collect();
527 let indices = (rt.range_filter_i64)(&values, 25, 75);
528 assert_eq!(indices.len(), 51); assert_eq!(indices[0], 25);
530 assert_eq!(*indices.last().unwrap(), 75);
531 }
532
533 #[test]
534 fn empty_input() {
535 let rt = ts_runtime();
536 assert_eq!((rt.sum_f64)(&[]), 0.0);
537 assert!((rt.min_f64)(&[]).is_infinite());
538 assert!((rt.max_f64)(&[]).is_infinite());
539 assert!((rt.range_filter_i64)(&[], 0, 100).is_empty());
540 }
541
542 #[test]
543 fn small_input() {
544 let rt = ts_runtime();
545 assert!(((rt.sum_f64)(&[1.0, 2.0, 3.0]) - 6.0).abs() < f64::EPSILON);
546 assert!(((rt.min_f64)(&[3.0, 1.0, 2.0]) - 1.0).abs() < f64::EPSILON);
547 assert!(((rt.max_f64)(&[1.0, 3.0, 2.0]) - 3.0).abs() < f64::EPSILON);
548 }
549}