1pub struct TsSimdRuntime {
15 pub sum_f64: fn(&[f64]) -> f64,
16 pub min_f64: fn(&[f64]) -> f64,
17 pub max_f64: fn(&[f64]) -> f64,
18 pub range_filter_i64: fn(&[i64], i64, i64) -> Vec<u32>,
20 pub name: &'static str,
21}
22
23impl TsSimdRuntime {
24 pub fn detect() -> Self {
25 #[cfg(target_arch = "x86_64")]
26 {
27 if std::is_x86_feature_detected!("avx512f") {
28 return Self {
29 sum_f64: avx512_sum_f64,
30 min_f64: avx512_min_f64,
31 max_f64: avx512_max_f64,
32 range_filter_i64: avx512_range_filter_i64,
33 name: "avx512",
34 };
35 }
36 if std::is_x86_feature_detected!("avx2") {
37 return Self {
38 sum_f64: avx2_sum_f64,
39 min_f64: avx2_min_f64,
40 max_f64: avx2_max_f64,
41 range_filter_i64: avx2_range_filter_i64,
42 name: "avx2",
43 };
44 }
45 }
46 #[cfg(target_arch = "aarch64")]
47 {
48 return Self {
49 sum_f64: neon_sum_f64,
50 min_f64: neon_min_f64,
51 max_f64: neon_max_f64,
52 range_filter_i64: scalar_range_filter_i64,
53 name: "neon",
54 };
55 }
56 #[cfg(target_arch = "wasm32")]
57 {
58 return Self {
59 sum_f64: wasm_sum_f64,
60 min_f64: wasm_min_f64,
61 max_f64: wasm_max_f64,
62 range_filter_i64: scalar_range_filter_i64,
63 name: "wasm-simd128",
64 };
65 }
66 #[allow(unreachable_code)]
67 Self {
68 sum_f64: scalar_sum_f64,
69 min_f64: scalar_min_f64,
70 max_f64: scalar_max_f64,
71 range_filter_i64: scalar_range_filter_i64,
72 name: "scalar",
73 }
74 }
75}
76
77static TS_RUNTIME: std::sync::OnceLock<TsSimdRuntime> = std::sync::OnceLock::new();
78
79pub fn ts_runtime() -> &'static TsSimdRuntime {
81 TS_RUNTIME.get_or_init(TsSimdRuntime::detect)
82}
83
84fn scalar_sum_f64(values: &[f64]) -> f64 {
87 let mut sum = 0.0f64;
89 let mut comp = 0.0f64;
90 for &v in values {
91 let y = v - comp;
92 let t = sum + y;
93 comp = (t - sum) - y;
94 sum = t;
95 }
96 sum
97}
98
99fn scalar_min_f64(values: &[f64]) -> f64 {
100 let mut m = f64::INFINITY;
101 for &v in values {
102 if v < m {
103 m = v;
104 }
105 }
106 m
107}
108
109fn scalar_max_f64(values: &[f64]) -> f64 {
110 let mut m = f64::NEG_INFINITY;
111 for &v in values {
112 if v > m {
113 m = v;
114 }
115 }
116 m
117}
118
119fn scalar_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
120 let mut out = Vec::new();
121 for (i, &v) in values.iter().enumerate() {
122 if v >= min && v <= max {
123 out.push(i as u32);
124 }
125 }
126 out
127}
128
129#[cfg(target_arch = "x86_64")]
132#[target_feature(enable = "avx512f")]
133unsafe fn avx512_sum_f64_inner(values: &[f64]) -> f64 {
134 use std::arch::x86_64::*;
135 unsafe {
136 let mut acc = _mm512_setzero_pd();
137 let chunks = values.len() / 8;
138 let ptr = values.as_ptr();
139 for i in 0..chunks {
140 let v = _mm512_loadu_pd(ptr.add(i * 8));
141 acc = _mm512_add_pd(acc, v);
142 }
143 let mut sum = _mm512_reduce_add_pd(acc);
144 for &v in &values[chunks * 8..] {
145 sum += v;
146 }
147 sum
148 }
149}
150
151#[cfg(target_arch = "x86_64")]
152fn avx512_sum_f64(values: &[f64]) -> f64 {
153 if values.len() < 16 {
154 return scalar_sum_f64(values);
155 }
156 unsafe { avx512_sum_f64_inner(values) }
157}
158
159#[cfg(target_arch = "x86_64")]
160#[target_feature(enable = "avx512f")]
161unsafe fn avx512_min_f64_inner(values: &[f64]) -> f64 {
162 use std::arch::x86_64::*;
163 unsafe {
164 let mut acc = _mm512_set1_pd(f64::INFINITY);
165 let chunks = values.len() / 8;
166 let ptr = values.as_ptr();
167 for i in 0..chunks {
168 let v = _mm512_loadu_pd(ptr.add(i * 8));
169 acc = _mm512_min_pd(acc, v);
170 }
171 let mut m = _mm512_reduce_min_pd(acc);
172 for &v in &values[chunks * 8..] {
173 if v < m {
174 m = v;
175 }
176 }
177 m
178 }
179}
180
181#[cfg(target_arch = "x86_64")]
182fn avx512_min_f64(values: &[f64]) -> f64 {
183 if values.len() < 16 {
184 return scalar_min_f64(values);
185 }
186 unsafe { avx512_min_f64_inner(values) }
187}
188
189#[cfg(target_arch = "x86_64")]
190#[target_feature(enable = "avx512f")]
191unsafe fn avx512_max_f64_inner(values: &[f64]) -> f64 {
192 use std::arch::x86_64::*;
193 unsafe {
194 let mut acc = _mm512_set1_pd(f64::NEG_INFINITY);
195 let chunks = values.len() / 8;
196 let ptr = values.as_ptr();
197 for i in 0..chunks {
198 let v = _mm512_loadu_pd(ptr.add(i * 8));
199 acc = _mm512_max_pd(acc, v);
200 }
201 let mut m = _mm512_reduce_max_pd(acc);
202 for &v in &values[chunks * 8..] {
203 if v > m {
204 m = v;
205 }
206 }
207 m
208 }
209}
210
211#[cfg(target_arch = "x86_64")]
212fn avx512_max_f64(values: &[f64]) -> f64 {
213 if values.len() < 16 {
214 return scalar_max_f64(values);
215 }
216 unsafe { avx512_max_f64_inner(values) }
217}
218
219#[cfg(target_arch = "x86_64")]
220fn avx512_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
221 scalar_range_filter_i64(values, min, max)
224}
225
226#[cfg(target_arch = "x86_64")]
229#[target_feature(enable = "avx2")]
230unsafe fn avx2_sum_f64_inner(values: &[f64]) -> f64 {
231 use std::arch::x86_64::*;
232 unsafe {
233 let mut acc = _mm256_setzero_pd();
234 let chunks = values.len() / 4;
235 let ptr = values.as_ptr();
236 for i in 0..chunks {
237 let v = _mm256_loadu_pd(ptr.add(i * 4));
238 acc = _mm256_add_pd(acc, v);
239 }
240 let hi = _mm256_extractf128_pd(acc, 1);
241 let lo = _mm256_castpd256_pd128(acc);
242 let sum2 = _mm_add_pd(lo, hi);
243 let hi2 = _mm_unpackhi_pd(sum2, sum2);
244 let mut sum = _mm_cvtsd_f64(_mm_add_sd(sum2, hi2));
245 for &v in &values[chunks * 4..] {
246 sum += v;
247 }
248 sum
249 }
250}
251
252#[cfg(target_arch = "x86_64")]
253fn avx2_sum_f64(values: &[f64]) -> f64 {
254 if values.len() < 8 {
255 return scalar_sum_f64(values);
256 }
257 unsafe { avx2_sum_f64_inner(values) }
258}
259
260#[cfg(target_arch = "x86_64")]
261#[target_feature(enable = "avx2")]
262unsafe fn avx2_min_f64_inner(values: &[f64]) -> f64 {
263 use std::arch::x86_64::*;
264 unsafe {
265 let mut acc = _mm256_set1_pd(f64::INFINITY);
266 let chunks = values.len() / 4;
267 let ptr = values.as_ptr();
268 for i in 0..chunks {
269 let v = _mm256_loadu_pd(ptr.add(i * 4));
270 acc = _mm256_min_pd(acc, v);
271 }
272 let hi = _mm256_extractf128_pd(acc, 1);
273 let lo = _mm256_castpd256_pd128(acc);
274 let min2 = _mm_min_pd(lo, hi);
275 let hi2 = _mm_unpackhi_pd(min2, min2);
276 let mut m = _mm_cvtsd_f64(_mm_min_sd(min2, hi2));
277 for &v in &values[chunks * 4..] {
278 if v < m {
279 m = v;
280 }
281 }
282 m
283 }
284}
285
286#[cfg(target_arch = "x86_64")]
287fn avx2_min_f64(values: &[f64]) -> f64 {
288 if values.len() < 8 {
289 return scalar_min_f64(values);
290 }
291 unsafe { avx2_min_f64_inner(values) }
292}
293
294#[cfg(target_arch = "x86_64")]
295#[target_feature(enable = "avx2")]
296unsafe fn avx2_max_f64_inner(values: &[f64]) -> f64 {
297 use std::arch::x86_64::*;
298 unsafe {
299 let mut acc = _mm256_set1_pd(f64::NEG_INFINITY);
300 let chunks = values.len() / 4;
301 let ptr = values.as_ptr();
302 for i in 0..chunks {
303 let v = _mm256_loadu_pd(ptr.add(i * 4));
304 acc = _mm256_max_pd(acc, v);
305 }
306 let hi = _mm256_extractf128_pd(acc, 1);
307 let lo = _mm256_castpd256_pd128(acc);
308 let max2 = _mm_max_pd(lo, hi);
309 let hi2 = _mm_unpackhi_pd(max2, max2);
310 let mut m = _mm_cvtsd_f64(_mm_max_sd(max2, hi2));
311 for &v in &values[chunks * 4..] {
312 if v > m {
313 m = v;
314 }
315 }
316 m
317 }
318}
319
320#[cfg(target_arch = "x86_64")]
321fn avx2_max_f64(values: &[f64]) -> f64 {
322 if values.len() < 8 {
323 return scalar_max_f64(values);
324 }
325 unsafe { avx2_max_f64_inner(values) }
326}
327
328#[cfg(target_arch = "x86_64")]
329fn avx2_range_filter_i64(values: &[i64], min: i64, max: i64) -> Vec<u32> {
330 scalar_range_filter_i64(values, min, max)
331}
332
333#[cfg(target_arch = "aarch64")]
336fn neon_sum_f64(values: &[f64]) -> f64 {
337 use std::arch::aarch64::*;
338 if values.len() < 4 {
339 return scalar_sum_f64(values);
340 }
341 unsafe {
342 let mut acc = vdupq_n_f64(0.0);
343 let chunks = values.len() / 2;
344 let ptr = values.as_ptr();
345 for i in 0..chunks {
346 let v = vld1q_f64(ptr.add(i * 2));
347 acc = vaddq_f64(acc, v);
348 }
349 let mut sum = vgetq_lane_f64(acc, 0) + vgetq_lane_f64(acc, 1);
350 for &v in &values[chunks * 2..] {
351 sum += v;
352 }
353 sum
354 }
355}
356
357#[cfg(target_arch = "aarch64")]
358fn neon_min_f64(values: &[f64]) -> f64 {
359 use std::arch::aarch64::*;
360 if values.len() < 4 {
361 return scalar_min_f64(values);
362 }
363 unsafe {
364 let mut acc = vdupq_n_f64(f64::INFINITY);
365 let chunks = values.len() / 2;
366 let ptr = values.as_ptr();
367 for i in 0..chunks {
368 let v = vld1q_f64(ptr.add(i * 2));
369 acc = vminq_f64(acc, v);
370 }
371 let mut m = vgetq_lane_f64(acc, 0).min(vgetq_lane_f64(acc, 1));
372 for &v in &values[chunks * 2..] {
373 if v < m {
374 m = v;
375 }
376 }
377 m
378 }
379}
380
381#[cfg(target_arch = "aarch64")]
382fn neon_max_f64(values: &[f64]) -> f64 {
383 use std::arch::aarch64::*;
384 if values.len() < 4 {
385 return scalar_max_f64(values);
386 }
387 unsafe {
388 let mut acc = vdupq_n_f64(f64::NEG_INFINITY);
389 let chunks = values.len() / 2;
390 let ptr = values.as_ptr();
391 for i in 0..chunks {
392 let v = vld1q_f64(ptr.add(i * 2));
393 acc = vmaxq_f64(acc, v);
394 }
395 let mut m = vgetq_lane_f64(acc, 0).max(vgetq_lane_f64(acc, 1));
396 for &v in &values[chunks * 2..] {
397 if v > m {
398 m = v;
399 }
400 }
401 m
402 }
403}
404
405#[cfg(target_arch = "wasm32")]
408#[target_feature(enable = "simd128")]
409unsafe fn wasm_sum_f64_inner(values: &[f64]) -> f64 {
410 use core::arch::wasm32::*;
411 let mut acc = f64x2_splat(0.0);
412 let chunks = values.len() / 2;
413 let ptr = values.as_ptr();
414 for i in 0..chunks {
415 let v = v128_load(ptr.add(i * 2) as *const v128);
416 acc = f64x2_add(acc, v);
417 }
418 let mut sum = f64x2_extract_lane::<0>(acc) + f64x2_extract_lane::<1>(acc);
419 for &v in &values[chunks * 2..] {
420 sum += v;
421 }
422 sum
423}
424
425#[cfg(target_arch = "wasm32")]
426fn wasm_sum_f64(values: &[f64]) -> f64 {
427 if values.len() < 4 {
428 return scalar_sum_f64(values);
429 }
430 unsafe { wasm_sum_f64_inner(values) }
431}
432
433#[cfg(target_arch = "wasm32")]
434#[target_feature(enable = "simd128")]
435unsafe fn wasm_min_f64_inner(values: &[f64]) -> f64 {
436 use core::arch::wasm32::*;
437 let mut acc = f64x2_splat(f64::INFINITY);
438 let chunks = values.len() / 2;
439 let ptr = values.as_ptr();
440 for i in 0..chunks {
441 let v = v128_load(ptr.add(i * 2) as *const v128);
442 acc = f64x2_min(acc, v);
443 }
444 let mut m = f64x2_extract_lane::<0>(acc).min(f64x2_extract_lane::<1>(acc));
445 for &v in &values[chunks * 2..] {
446 if v < m {
447 m = v;
448 }
449 }
450 m
451}
452
453#[cfg(target_arch = "wasm32")]
454fn wasm_min_f64(values: &[f64]) -> f64 {
455 if values.len() < 4 {
456 return scalar_min_f64(values);
457 }
458 unsafe { wasm_min_f64_inner(values) }
459}
460
461#[cfg(target_arch = "wasm32")]
462#[target_feature(enable = "simd128")]
463unsafe fn wasm_max_f64_inner(values: &[f64]) -> f64 {
464 use core::arch::wasm32::*;
465 let mut acc = f64x2_splat(f64::NEG_INFINITY);
466 let chunks = values.len() / 2;
467 let ptr = values.as_ptr();
468 for i in 0..chunks {
469 let v = v128_load(ptr.add(i * 2) as *const v128);
470 acc = f64x2_max(acc, v);
471 }
472 let mut m = f64x2_extract_lane::<0>(acc).max(f64x2_extract_lane::<1>(acc));
473 for &v in &values[chunks * 2..] {
474 if v > m {
475 m = v;
476 }
477 }
478 m
479}
480
481#[cfg(target_arch = "wasm32")]
482fn wasm_max_f64(values: &[f64]) -> f64 {
483 if values.len() < 4 {
484 return scalar_max_f64(values);
485 }
486 unsafe { wasm_max_f64_inner(values) }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn runtime_detects() {
495 let rt = ts_runtime();
496 assert!(!rt.name.is_empty());
497 assert!(
499 ["avx512", "avx2", "neon", "wasm-simd128", "scalar"].contains(&rt.name),
500 "unexpected runtime: {}",
501 rt.name
502 );
503 }
504
505 #[test]
506 fn sum_correctness() {
507 let rt = ts_runtime();
508 let values: Vec<f64> = (0..1000).map(|i| i as f64).collect();
509 let expected = 999.0 * 1000.0 / 2.0; let result = (rt.sum_f64)(&values);
511 assert!(
512 (result - expected).abs() < 1e-6,
513 "sum: got {result}, expected {expected}"
514 );
515 }
516
517 #[test]
518 fn min_max_correctness() {
519 let rt = ts_runtime();
520 let values: Vec<f64> = (0..1000).map(|i| (i as f64) - 500.0).collect();
521 assert!(((rt.min_f64)(&values) - (-500.0)).abs() < f64::EPSILON);
522 assert!(((rt.max_f64)(&values) - 499.0).abs() < f64::EPSILON);
523 }
524
525 #[test]
526 fn range_filter_correctness() {
527 let rt = ts_runtime();
528 let values: Vec<i64> = (0..100).collect();
529 let indices = (rt.range_filter_i64)(&values, 25, 75);
530 assert_eq!(indices.len(), 51); assert_eq!(indices[0], 25);
532 assert_eq!(*indices.last().unwrap(), 75);
533 }
534
535 #[test]
536 fn empty_input() {
537 let rt = ts_runtime();
538 assert_eq!((rt.sum_f64)(&[]), 0.0);
539 assert!((rt.min_f64)(&[]).is_infinite());
540 assert!((rt.max_f64)(&[]).is_infinite());
541 assert!((rt.range_filter_i64)(&[], 0, 100).is_empty());
542 }
543
544 #[test]
545 fn small_input() {
546 let rt = ts_runtime();
547 assert!(((rt.sum_f64)(&[1.0, 2.0, 3.0]) - 6.0).abs() < f64::EPSILON);
548 assert!(((rt.min_f64)(&[3.0, 1.0, 2.0]) - 1.0).abs() < f64::EPSILON);
549 assert!(((rt.max_f64)(&[1.0, 3.0, 2.0]) - 3.0).abs() < f64::EPSILON);
550 }
551}