1#[cfg(target_arch = "x86_64")]
19pub mod x86_64;
20
21#[cfg(target_arch = "aarch64")]
22pub mod aarch64;
23
24#[cfg(target_arch = "wasm32")]
25pub mod wasm32;
26
27pub mod complex;
28pub mod dispatch;
29pub mod scalar;
30
31use crate::scalar::{Field, Real, Scalar};
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
35pub enum SimdLevel {
36 Scalar,
38 Simd128,
40 Simd256,
42 Simd512,
44}
45
46impl SimdLevel {
47 #[inline]
49 pub const fn lanes<T: Scalar>(self) -> usize {
50 match self {
51 SimdLevel::Scalar => 1,
52 SimdLevel::Simd128 => 16 / core::mem::size_of::<T>(),
53 SimdLevel::Simd256 => 32 / core::mem::size_of::<T>(),
54 SimdLevel::Simd512 => 64 / core::mem::size_of::<T>(),
55 }
56 }
57
58 #[inline]
60 pub const fn width_bytes(self) -> usize {
61 match self {
62 SimdLevel::Scalar => 8, SimdLevel::Simd128 => 16,
64 SimdLevel::Simd256 => 32,
65 SimdLevel::Simd512 => 64,
66 }
67 }
68}
69
70#[inline]
77pub fn detect_simd_level() -> SimdLevel {
78 #[cfg(feature = "force-scalar")]
80 {
81 SimdLevel::Scalar
82 }
83
84 #[cfg(not(feature = "force-scalar"))]
85 {
86 let detected = detect_simd_level_raw();
87
88 #[cfg(feature = "max-simd-128")]
90 {
91 return if detected > SimdLevel::Simd128 {
92 SimdLevel::Simd128
93 } else {
94 detected
95 };
96 }
97
98 #[cfg(feature = "max-simd-256")]
99 #[cfg(not(feature = "max-simd-128"))]
100 {
101 return if detected > SimdLevel::Simd256 {
102 SimdLevel::Simd256
103 } else {
104 detected
105 };
106 }
107
108 #[cfg(not(any(feature = "max-simd-128", feature = "max-simd-256")))]
109 {
110 detected
111 }
112 }
113}
114
115#[inline]
120pub fn detect_simd_level_raw() -> SimdLevel {
121 #[cfg(target_arch = "x86_64")]
122 {
123 if is_x86_feature_detected!("avx512f") {
124 SimdLevel::Simd512
125 } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
126 SimdLevel::Simd256
127 } else if is_x86_feature_detected!("sse2") {
128 SimdLevel::Simd128
129 } else {
130 SimdLevel::Scalar
131 }
132 }
133
134 #[cfg(target_arch = "aarch64")]
135 {
136 SimdLevel::Simd128
138 }
139
140 #[cfg(target_arch = "wasm32")]
141 {
142 #[cfg(target_feature = "simd128")]
144 {
145 SimdLevel::Simd128
146 }
147 #[cfg(not(target_feature = "simd128"))]
148 {
149 SimdLevel::Scalar
150 }
151 }
152
153 #[cfg(not(any(
154 target_arch = "x86_64",
155 target_arch = "aarch64",
156 target_arch = "wasm32"
157 )))]
158 {
159 SimdLevel::Scalar
160 }
161}
162
163pub trait SimdScalar: Field {
168 type Simd256: SimdRegister<Scalar = Self>;
170 type Simd512: SimdRegister<Scalar = Self>;
172
173 const LANES_256: usize = 32 / core::mem::size_of::<Self>();
175
176 const LANES_512: usize = 64 / core::mem::size_of::<Self>();
178}
179
180pub trait SimdRegister: Copy + Clone + Send + Sync {
185 type Scalar: SimdScalar;
187
188 const LANES: usize;
190
191 fn zero() -> Self;
193
194 fn splat(value: Self::Scalar) -> Self;
196
197 unsafe fn load_aligned(ptr: *const Self::Scalar) -> Self;
203
204 unsafe fn load_unaligned(ptr: *const Self::Scalar) -> Self;
209
210 unsafe fn store_aligned(self, ptr: *mut Self::Scalar);
216
217 unsafe fn store_unaligned(self, ptr: *mut Self::Scalar);
222
223 fn add(self, other: Self) -> Self;
225
226 fn sub(self, other: Self) -> Self;
228
229 fn mul(self, other: Self) -> Self;
231
232 fn div(self, other: Self) -> Self;
234
235 fn mul_add(self, a: Self, b: Self) -> Self;
237
238 fn mul_sub(self, a: Self, b: Self) -> Self;
240
241 fn neg_mul_add(self, a: Self, b: Self) -> Self;
243
244 fn reduce_sum(self) -> Self::Scalar;
246
247 fn reduce_max(self) -> Self::Scalar
249 where
250 Self::Scalar: Real;
251
252 fn reduce_min(self) -> Self::Scalar
254 where
255 Self::Scalar: Real;
256
257 fn extract(self, index: usize) -> Self::Scalar;
259
260 fn insert(self, index: usize, value: Self::Scalar) -> Self;
262}
263
264pub trait SimdMask: SimdRegister {
266 type Mask: Copy + Clone;
268
269 fn mask_from_bools(bools: &[bool]) -> Self::Mask;
271
272 unsafe fn load_masked(ptr: *const Self::Scalar, mask: Self::Mask, default: Self) -> Self;
277
278 unsafe fn store_masked(self, ptr: *mut Self::Scalar, mask: Self::Mask);
283
284 fn blend(mask: Self::Mask, a: Self, b: Self) -> Self;
286}
287
288#[derive(Debug, Clone, Copy)]
290pub struct SimdChunks {
291 pub len: usize,
293 pub lanes: usize,
295 pub head_end: usize,
297 pub body_end: usize,
299}
300
301impl SimdChunks {
302 #[inline]
304 pub fn new<T: Scalar>(ptr: *const T, len: usize, level: SimdLevel) -> Self {
305 let lanes = level.lanes::<T>();
306 let align = level.width_bytes();
307
308 if lanes <= 1 || len < lanes * 2 {
309 return SimdChunks {
311 len,
312 lanes,
313 head_end: len,
314 body_end: len,
315 };
316 }
317
318 let addr = ptr as usize;
319 let misalign = addr % align;
320
321 let head_end = if misalign == 0 {
322 0
323 } else {
324 let elements_to_align = (align - misalign) / core::mem::size_of::<T>();
325 elements_to_align.min(len)
326 };
327
328 let remaining = len - head_end;
329 let full_vectors = remaining / lanes;
330 let body_end = head_end + full_vectors * lanes;
331
332 SimdChunks {
333 len,
334 lanes,
335 head_end,
336 body_end,
337 }
338 }
339
340 #[inline]
342 pub fn head_len(&self) -> usize {
343 self.head_end
344 }
345
346 #[inline]
348 pub fn body_len(&self) -> usize {
349 self.body_end - self.head_end
350 }
351
352 #[inline]
354 pub fn tail_len(&self) -> usize {
355 self.len - self.body_end
356 }
357
358 #[inline]
360 pub fn body_vectors(&self) -> usize {
361 self.body_len() / self.lanes
362 }
363}
364
365#[macro_export]
370macro_rules! simd_dispatch {
371 ($level:expr, $scalar:ty, |$reg:ident| $body:expr) => {{
372 match $level {
373 $crate::simd::SimdLevel::Simd512 => {
374 #[cfg(target_arch = "x86_64")]
375 {
376 type $reg = <$scalar as $crate::simd::SimdScalar>::Simd512;
377 $body
378 }
379 #[cfg(not(target_arch = "x86_64"))]
380 {
381 type $reg = <$scalar as $crate::simd::SimdScalar>::Simd256;
383 $body
384 }
385 }
386 $crate::simd::SimdLevel::Simd256 | $crate::simd::SimdLevel::Simd128 => {
387 type $reg = <$scalar as $crate::simd::SimdScalar>::Simd256;
388 $body
389 }
390 $crate::simd::SimdLevel::Scalar => {
391 type $reg = $scalar;
393 $body
394 }
395 }
396 }};
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_detect_simd_level() {
405 let level = detect_simd_level();
406 println!("Detected SIMD level: {:?}", level);
407
408 #[cfg(feature = "force-scalar")]
410 {
411 assert_eq!(level, SimdLevel::Scalar);
412 let raw = detect_simd_level_raw();
414 println!("Raw hardware SIMD level: {:?}", raw);
415 }
416
417 #[cfg(not(feature = "force-scalar"))]
419 {
420 #[cfg(target_arch = "x86_64")]
421 assert!(level >= SimdLevel::Simd128);
422
423 #[cfg(target_arch = "aarch64")]
424 assert_eq!(level, SimdLevel::Simd128);
425 }
426 }
427
428 #[test]
429 fn test_simd_level_lanes() {
430 assert_eq!(SimdLevel::Simd256.lanes::<f64>(), 4);
431 assert_eq!(SimdLevel::Simd256.lanes::<f32>(), 8);
432 assert_eq!(SimdLevel::Simd512.lanes::<f64>(), 8);
433 assert_eq!(SimdLevel::Simd512.lanes::<f32>(), 16);
434 }
435
436 #[test]
437 fn test_simd_chunks() {
438 let data: Vec<f64> = vec![0.0; 100];
440 let ptr = data.as_ptr();
441
442 let chunks = SimdChunks::new(ptr, 100, SimdLevel::Simd256);
443 println!(
444 "Chunks: head_end={}, body_end={}",
445 chunks.head_end, chunks.body_end
446 );
447
448 assert_eq!(
450 chunks.head_len() + chunks.body_len() + chunks.tail_len(),
451 100
452 );
453 }
454
455 #[test]
461 fn test_scalar_fma_accuracy() {
462 use crate::simd::scalar::ScalarF64;
463
464 let a = ScalarF64::splat(1.0 + 1e-15);
465 let b = ScalarF64::splat(1.0 + 1e-15);
466 let c = ScalarF64::splat(-(1.0 + 2e-15));
467
468 let fma_result = a.mul_add(b, c);
470 let mul_add_result = a.mul(b).add(c);
471
472 assert!(fma_result.0.abs() < 1e-14);
474 assert!(mul_add_result.0.abs() < 1e-14);
475 }
476
477 #[test]
479 fn test_load_store_roundtrip() {
480 use crate::simd::scalar::ScalarF64;
481
482 let values = [42.0f64, 1.5, -3.5, 1000.0];
483
484 for &val in &values {
485 let v = ScalarF64::splat(val);
486 assert_eq!(v.reduce_sum(), val);
487 assert_eq!(v.extract(0), val);
488 }
489 }
490
491 #[test]
493 fn test_arithmetic_identities() {
494 use crate::simd::scalar::{ScalarF32, ScalarF64};
495
496 let a = ScalarF64::splat(5.0);
498 let zero = ScalarF64::zero();
499 let one = ScalarF64::splat(1.0);
500
501 assert_eq!(a.add(zero).0, 5.0);
503 assert_eq!(a.sub(zero).0, 5.0);
505 assert_eq!(a.mul(one).0, 5.0);
507 assert_eq!(a.div(one).0, 5.0);
509 assert_eq!(a.mul(zero).0, 0.0);
511
512 let a32 = ScalarF32::splat(5.0);
514 let zero32 = ScalarF32::zero();
515 let one32 = ScalarF32::splat(1.0);
516
517 assert_eq!(a32.add(zero32).0, 5.0);
518 assert_eq!(a32.mul(one32).0, 5.0);
519 }
520
521 #[test]
523 fn test_reductions() {
524 use crate::simd::scalar::{ScalarF32, ScalarF64};
525
526 let a = ScalarF64::splat(42.0);
528 assert_eq!(a.reduce_sum(), 42.0);
529 assert_eq!(a.reduce_max(), 42.0);
530 assert_eq!(a.reduce_min(), 42.0);
531
532 let b = ScalarF32::splat(-3.5);
533 assert_eq!(b.reduce_sum(), -3.5);
534 assert_eq!(b.reduce_max(), -3.5);
535 assert_eq!(b.reduce_min(), -3.5);
536 }
537
538 #[test]
540 fn test_negative_values() {
541 use crate::simd::scalar::ScalarF64;
542
543 let neg = ScalarF64::splat(-5.0);
544 let pos = ScalarF64::splat(3.0);
545
546 assert_eq!(neg.add(pos).0, -2.0);
548 assert_eq!(neg.mul(pos).0, -15.0);
550 assert_eq!(neg.sub(pos).0, -8.0);
552 }
553
554 #[test]
556 fn test_fma_variants() {
557 use crate::simd::scalar::ScalarF64;
558
559 let a = ScalarF64::splat(2.0);
560 let b = ScalarF64::splat(3.0);
561 let c = ScalarF64::splat(4.0);
562
563 assert_eq!(a.mul_add(b, c).0, 10.0);
565
566 assert_eq!(a.mul_sub(b, c).0, 2.0);
568
569 assert_eq!(a.neg_mul_add(b, c).0, -2.0);
571 }
572
573 #[test]
575 fn test_insert_extract() {
576 use crate::simd::scalar::ScalarF64;
577
578 let a = ScalarF64::splat(1.0);
579 let b = a.insert(0, 42.0);
580 assert_eq!(b.extract(0), 42.0);
581 }
582
583 #[cfg(target_arch = "aarch64")]
585 #[test]
586 fn test_aarch64_simd_correctness() {
587 use crate::simd::aarch64::{F32x4, F64x2, F64x4};
588
589 let a = F64x2::splat(2.0);
591 let b = F64x2::splat(3.0);
592
593 let sum = a.add(b);
594 assert_eq!(sum.extract(0), 5.0);
595 assert_eq!(sum.extract(1), 5.0);
596
597 let fma = a.mul_add(b, F64x2::splat(1.0));
598 assert_eq!(fma.extract(0), 7.0); let c = F64x4::splat(2.0);
602 let d = F64x4::splat(3.0);
603
604 assert_eq!(c.add(d).reduce_sum(), 20.0); let e = F32x4::splat(2.0);
608 let f = F32x4::splat(3.0);
609
610 assert_eq!(e.add(f).reduce_sum(), 20.0); }
612}