1#[cfg(target_arch = "x86_64")]
19pub mod x86_64;
20
21#[cfg(target_arch = "aarch64")]
22pub mod aarch64;
23
24#[cfg(target_arch = "wasm32")]
25pub mod wasm32;
26
27pub mod complex;
28pub mod dispatch;
29pub mod multiver;
30pub mod scalar;
31
32use crate::scalar::{Field, Real, Scalar};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
36pub enum SimdLevel {
37 Scalar,
39 Simd128,
41 Simd256,
43 Simd512,
45}
46
47impl SimdLevel {
48 #[inline]
50 pub const fn lanes<T: Scalar>(self) -> usize {
51 match self {
52 SimdLevel::Scalar => 1,
53 SimdLevel::Simd128 => 16 / core::mem::size_of::<T>(),
54 SimdLevel::Simd256 => 32 / core::mem::size_of::<T>(),
55 SimdLevel::Simd512 => 64 / core::mem::size_of::<T>(),
56 }
57 }
58
59 #[inline]
61 pub const fn width_bytes(self) -> usize {
62 match self {
63 SimdLevel::Scalar => 8, SimdLevel::Simd128 => 16,
65 SimdLevel::Simd256 => 32,
66 SimdLevel::Simd512 => 64,
67 }
68 }
69}
70
71#[inline]
78pub fn detect_simd_level() -> SimdLevel {
79 #[cfg(feature = "force-scalar")]
81 {
82 SimdLevel::Scalar
83 }
84
85 #[cfg(not(feature = "force-scalar"))]
86 {
87 let detected = detect_simd_level_raw();
88
89 #[cfg(feature = "max-simd-128")]
91 {
92 return if detected > SimdLevel::Simd128 {
93 SimdLevel::Simd128
94 } else {
95 detected
96 };
97 }
98
99 #[cfg(feature = "max-simd-256")]
100 #[cfg(not(feature = "max-simd-128"))]
101 {
102 return if detected > SimdLevel::Simd256 {
103 SimdLevel::Simd256
104 } else {
105 detected
106 };
107 }
108
109 #[cfg(not(any(feature = "max-simd-128", feature = "max-simd-256")))]
110 {
111 detected
112 }
113 }
114}
115
116#[inline]
121pub fn detect_simd_level_raw() -> SimdLevel {
122 #[cfg(all(target_arch = "x86_64", feature = "std"))]
124 {
125 if is_x86_feature_detected!("avx512f") {
126 SimdLevel::Simd512
127 } else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
128 SimdLevel::Simd256
129 } else if is_x86_feature_detected!("sse2") {
130 SimdLevel::Simd128
131 } else {
132 SimdLevel::Scalar
133 }
134 }
135
136 #[cfg(all(target_arch = "x86_64", not(feature = "std")))]
138 {
139 #[cfg(target_feature = "avx512f")]
140 {
141 SimdLevel::Simd512
142 }
143 #[cfg(all(
144 target_feature = "avx2",
145 target_feature = "fma",
146 not(target_feature = "avx512f")
147 ))]
148 {
149 SimdLevel::Simd256
150 }
151 #[cfg(all(
152 target_feature = "sse2",
153 not(target_feature = "avx2"),
154 not(target_feature = "avx512f")
155 ))]
156 {
157 SimdLevel::Simd128
158 }
159 #[cfg(not(any(
160 target_feature = "sse2",
161 target_feature = "avx2",
162 target_feature = "avx512f"
163 )))]
164 {
165 SimdLevel::Scalar
166 }
167 }
168
169 #[cfg(target_arch = "aarch64")]
170 {
171 SimdLevel::Simd128
173 }
174
175 #[cfg(target_arch = "wasm32")]
176 {
177 #[cfg(target_feature = "simd128")]
179 {
180 SimdLevel::Simd128
181 }
182 #[cfg(not(target_feature = "simd128"))]
183 {
184 SimdLevel::Scalar
185 }
186 }
187
188 #[cfg(not(any(
189 target_arch = "x86_64",
190 target_arch = "aarch64",
191 target_arch = "wasm32"
192 )))]
193 {
194 SimdLevel::Scalar
195 }
196}
197
198pub trait SimdScalar: Field {
203 type Simd256: SimdRegister<Scalar = Self>;
205 type Simd512: SimdRegister<Scalar = Self>;
207
208 const LANES_256: usize = 32 / core::mem::size_of::<Self>();
210
211 const LANES_512: usize = 64 / core::mem::size_of::<Self>();
213}
214
215pub trait SimdRegister: Copy + Clone + Send + Sync {
220 type Scalar: SimdScalar;
222
223 const LANES: usize;
225
226 fn zero() -> Self;
228
229 fn splat(value: Self::Scalar) -> Self;
231
232 unsafe fn load_aligned(ptr: *const Self::Scalar) -> Self;
238
239 unsafe fn load_unaligned(ptr: *const Self::Scalar) -> Self;
244
245 unsafe fn store_aligned(self, ptr: *mut Self::Scalar);
251
252 unsafe fn store_unaligned(self, ptr: *mut Self::Scalar);
257
258 fn add(self, other: Self) -> Self;
260
261 fn sub(self, other: Self) -> Self;
263
264 fn mul(self, other: Self) -> Self;
266
267 fn div(self, other: Self) -> Self;
269
270 fn mul_add(self, a: Self, b: Self) -> Self;
272
273 fn mul_sub(self, a: Self, b: Self) -> Self;
275
276 fn neg_mul_add(self, a: Self, b: Self) -> Self;
278
279 fn reduce_sum(self) -> Self::Scalar;
281
282 fn reduce_max(self) -> Self::Scalar
284 where
285 Self::Scalar: Real;
286
287 fn reduce_min(self) -> Self::Scalar
289 where
290 Self::Scalar: Real;
291
292 fn extract(self, index: usize) -> Self::Scalar;
294
295 fn insert(self, index: usize, value: Self::Scalar) -> Self;
297}
298
299pub trait SimdMask: SimdRegister {
301 type Mask: Copy + Clone;
303
304 fn mask_from_bools(bools: &[bool]) -> Self::Mask;
306
307 unsafe fn load_masked(ptr: *const Self::Scalar, mask: Self::Mask, default: Self) -> Self;
312
313 unsafe fn store_masked(self, ptr: *mut Self::Scalar, mask: Self::Mask);
318
319 fn blend(mask: Self::Mask, a: Self, b: Self) -> Self;
321}
322
323#[derive(Debug, Clone, Copy)]
325pub struct SimdChunks {
326 pub len: usize,
328 pub lanes: usize,
330 pub head_end: usize,
332 pub body_end: usize,
334}
335
336impl SimdChunks {
337 #[inline]
339 pub fn new<T: Scalar>(ptr: *const T, len: usize, level: SimdLevel) -> Self {
340 let lanes = level.lanes::<T>();
341 let align = level.width_bytes();
342
343 if lanes <= 1 || len < lanes * 2 {
344 return SimdChunks {
346 len,
347 lanes,
348 head_end: len,
349 body_end: len,
350 };
351 }
352
353 let addr = ptr as usize;
354 let misalign = addr % align;
355
356 let head_end = if misalign == 0 {
357 0
358 } else {
359 let elements_to_align = (align - misalign) / core::mem::size_of::<T>();
360 elements_to_align.min(len)
361 };
362
363 let remaining = len - head_end;
364 let full_vectors = remaining / lanes;
365 let body_end = head_end + full_vectors * lanes;
366
367 SimdChunks {
368 len,
369 lanes,
370 head_end,
371 body_end,
372 }
373 }
374
375 #[inline]
377 pub fn head_len(&self) -> usize {
378 self.head_end
379 }
380
381 #[inline]
383 pub fn body_len(&self) -> usize {
384 self.body_end - self.head_end
385 }
386
387 #[inline]
389 pub fn tail_len(&self) -> usize {
390 self.len - self.body_end
391 }
392
393 #[inline]
395 pub fn body_vectors(&self) -> usize {
396 self.body_len() / self.lanes
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_detect_simd_level() {
406 let level = detect_simd_level();
407 println!("Detected SIMD level: {:?}", level);
408
409 #[cfg(feature = "force-scalar")]
411 {
412 assert_eq!(level, SimdLevel::Scalar);
413 let raw = detect_simd_level_raw();
415 println!("Raw hardware SIMD level: {:?}", raw);
416 }
417
418 #[cfg(not(feature = "force-scalar"))]
420 {
421 #[cfg(target_arch = "x86_64")]
422 assert!(level >= SimdLevel::Simd128);
423
424 #[cfg(target_arch = "aarch64")]
425 assert_eq!(level, SimdLevel::Simd128);
426 }
427 }
428
429 #[test]
430 fn test_simd_level_lanes() {
431 assert_eq!(SimdLevel::Simd256.lanes::<f64>(), 4);
432 assert_eq!(SimdLevel::Simd256.lanes::<f32>(), 8);
433 assert_eq!(SimdLevel::Simd512.lanes::<f64>(), 8);
434 assert_eq!(SimdLevel::Simd512.lanes::<f32>(), 16);
435 }
436
437 #[test]
438 fn test_simd_chunks() {
439 let data: Vec<f64> = vec![0.0; 100];
441 let ptr = data.as_ptr();
442
443 let chunks = SimdChunks::new(ptr, 100, SimdLevel::Simd256);
444 println!(
445 "Chunks: head_end={}, body_end={}",
446 chunks.head_end, chunks.body_end
447 );
448
449 assert_eq!(
451 chunks.head_len() + chunks.body_len() + chunks.tail_len(),
452 100
453 );
454 }
455
456 #[test]
462 fn test_scalar_fma_accuracy() {
463 use crate::simd::scalar::ScalarF64;
464
465 let a = ScalarF64::splat(1.0 + 1e-15);
466 let b = ScalarF64::splat(1.0 + 1e-15);
467 let c = ScalarF64::splat(-(1.0 + 2e-15));
468
469 let fma_result = a.mul_add(b, c);
471 let mul_add_result = a.mul(b).add(c);
472
473 assert!(fma_result.0.abs() < 1e-14);
475 assert!(mul_add_result.0.abs() < 1e-14);
476 }
477
478 #[test]
480 fn test_load_store_roundtrip() {
481 use crate::simd::scalar::ScalarF64;
482
483 let values = [42.0f64, 1.5, -3.5, 1000.0];
484
485 for &val in &values {
486 let v = ScalarF64::splat(val);
487 assert_eq!(v.reduce_sum(), val);
488 assert_eq!(v.extract(0), val);
489 }
490 }
491
492 #[test]
494 fn test_arithmetic_identities() {
495 use crate::simd::scalar::{ScalarF32, ScalarF64};
496
497 let a = ScalarF64::splat(5.0);
499 let zero = ScalarF64::zero();
500 let one = ScalarF64::splat(1.0);
501
502 assert_eq!(a.add(zero).0, 5.0);
504 assert_eq!(a.sub(zero).0, 5.0);
506 assert_eq!(a.mul(one).0, 5.0);
508 assert_eq!(a.div(one).0, 5.0);
510 assert_eq!(a.mul(zero).0, 0.0);
512
513 let a32 = ScalarF32::splat(5.0);
515 let zero32 = ScalarF32::zero();
516 let one32 = ScalarF32::splat(1.0);
517
518 assert_eq!(a32.add(zero32).0, 5.0);
519 assert_eq!(a32.mul(one32).0, 5.0);
520 }
521
522 #[test]
524 fn test_reductions() {
525 use crate::simd::scalar::{ScalarF32, ScalarF64};
526
527 let a = ScalarF64::splat(42.0);
529 assert_eq!(a.reduce_sum(), 42.0);
530 assert_eq!(a.reduce_max(), 42.0);
531 assert_eq!(a.reduce_min(), 42.0);
532
533 let b = ScalarF32::splat(-3.5);
534 assert_eq!(b.reduce_sum(), -3.5);
535 assert_eq!(b.reduce_max(), -3.5);
536 assert_eq!(b.reduce_min(), -3.5);
537 }
538
539 #[test]
541 fn test_negative_values() {
542 use crate::simd::scalar::ScalarF64;
543
544 let neg = ScalarF64::splat(-5.0);
545 let pos = ScalarF64::splat(3.0);
546
547 assert_eq!(neg.add(pos).0, -2.0);
549 assert_eq!(neg.mul(pos).0, -15.0);
551 assert_eq!(neg.sub(pos).0, -8.0);
553 }
554
555 #[test]
557 fn test_fma_variants() {
558 use crate::simd::scalar::ScalarF64;
559
560 let a = ScalarF64::splat(2.0);
561 let b = ScalarF64::splat(3.0);
562 let c = ScalarF64::splat(4.0);
563
564 assert_eq!(a.mul_add(b, c).0, 10.0);
566
567 assert_eq!(a.mul_sub(b, c).0, 2.0);
569
570 assert_eq!(a.neg_mul_add(b, c).0, -2.0);
572 }
573
574 #[test]
576 fn test_insert_extract() {
577 use crate::simd::scalar::ScalarF64;
578
579 let a = ScalarF64::splat(1.0);
580 let b = a.insert(0, 42.0);
581 assert_eq!(b.extract(0), 42.0);
582 }
583
584 #[cfg(target_arch = "aarch64")]
586 #[test]
587 fn test_aarch64_simd_correctness() {
588 use crate::simd::aarch64::{F32x4, F64x2, F64x4};
589
590 let a = F64x2::splat(2.0);
592 let b = F64x2::splat(3.0);
593
594 let sum = a.add(b);
595 assert_eq!(sum.extract(0), 5.0);
596 assert_eq!(sum.extract(1), 5.0);
597
598 let fma = a.mul_add(b, F64x2::splat(1.0));
599 assert_eq!(fma.extract(0), 7.0); let c = F64x4::splat(2.0);
603 let d = F64x4::splat(3.0);
604
605 assert_eq!(c.add(d).reduce_sum(), 20.0); let e = F32x4::splat(2.0);
609 let f = F32x4::splat(3.0);
610
611 assert_eq!(e.add(f).reduce_sum(), 20.0); }
613}