#[cfg(target_arch = "x86_64")]
pub mod x86_64;
#[cfg(target_arch = "aarch64")]
pub mod aarch64;
#[cfg(target_arch = "wasm32")]
pub mod wasm32;
pub mod complex;
pub mod dispatch;
pub mod multiver;
pub mod scalar;
use crate::scalar::{Field, Real, Scalar};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SimdLevel {
Scalar,
Simd128,
Simd256,
Simd512,
}
impl SimdLevel {
#[inline]
pub const fn lanes<T: Scalar>(self) -> usize {
match self {
SimdLevel::Scalar => 1,
SimdLevel::Simd128 => 16 / core::mem::size_of::<T>(),
SimdLevel::Simd256 => 32 / core::mem::size_of::<T>(),
SimdLevel::Simd512 => 64 / core::mem::size_of::<T>(),
}
}
#[inline]
pub const fn width_bytes(self) -> usize {
match self {
SimdLevel::Scalar => 8, SimdLevel::Simd128 => 16,
SimdLevel::Simd256 => 32,
SimdLevel::Simd512 => 64,
}
}
}
#[inline]
pub fn detect_simd_level() -> SimdLevel {
#[cfg(feature = "force-scalar")]
{
SimdLevel::Scalar
}
#[cfg(not(feature = "force-scalar"))]
{
let detected = detect_simd_level_raw();
#[cfg(feature = "max-simd-128")]
{
return if detected > SimdLevel::Simd128 {
SimdLevel::Simd128
} else {
detected
};
}
#[cfg(feature = "max-simd-256")]
#[cfg(not(feature = "max-simd-128"))]
{
return if detected > SimdLevel::Simd256 {
SimdLevel::Simd256
} else {
detected
};
}
#[cfg(not(any(feature = "max-simd-128", feature = "max-simd-256")))]
{
detected
}
}
}
#[inline]
pub fn detect_simd_level_raw() -> SimdLevel {
#[cfg(all(target_arch = "x86_64", feature = "std"))]
{
if is_x86_feature_detected!("avx512f") {
SimdLevel::Simd512
} else if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
SimdLevel::Simd256
} else if is_x86_feature_detected!("sse2") {
SimdLevel::Simd128
} else {
SimdLevel::Scalar
}
}
#[cfg(all(target_arch = "x86_64", not(feature = "std")))]
{
#[cfg(target_feature = "avx512f")]
{
SimdLevel::Simd512
}
#[cfg(all(
target_feature = "avx2",
target_feature = "fma",
not(target_feature = "avx512f")
))]
{
SimdLevel::Simd256
}
#[cfg(all(
target_feature = "sse2",
not(target_feature = "avx2"),
not(target_feature = "avx512f")
))]
{
SimdLevel::Simd128
}
#[cfg(not(any(
target_feature = "sse2",
target_feature = "avx2",
target_feature = "avx512f"
)))]
{
SimdLevel::Scalar
}
}
#[cfg(target_arch = "aarch64")]
{
SimdLevel::Simd128
}
#[cfg(target_arch = "wasm32")]
{
#[cfg(target_feature = "simd128")]
{
SimdLevel::Simd128
}
#[cfg(not(target_feature = "simd128"))]
{
SimdLevel::Scalar
}
}
#[cfg(not(any(
target_arch = "x86_64",
target_arch = "aarch64",
target_arch = "wasm32"
)))]
{
SimdLevel::Scalar
}
}
pub trait SimdScalar: Field {
type Simd256: SimdRegister<Scalar = Self>;
type Simd512: SimdRegister<Scalar = Self>;
const LANES_256: usize = 32 / core::mem::size_of::<Self>();
const LANES_512: usize = 64 / core::mem::size_of::<Self>();
}
pub trait SimdRegister: Copy + Clone + Send + Sync {
type Scalar: SimdScalar;
const LANES: usize;
fn zero() -> Self;
fn splat(value: Self::Scalar) -> Self;
unsafe fn load_aligned(ptr: *const Self::Scalar) -> Self;
unsafe fn load_unaligned(ptr: *const Self::Scalar) -> Self;
unsafe fn store_aligned(self, ptr: *mut Self::Scalar);
unsafe fn store_unaligned(self, ptr: *mut Self::Scalar);
fn add(self, other: Self) -> Self;
fn sub(self, other: Self) -> Self;
fn mul(self, other: Self) -> Self;
fn div(self, other: Self) -> Self;
fn mul_add(self, a: Self, b: Self) -> Self;
fn mul_sub(self, a: Self, b: Self) -> Self;
fn neg_mul_add(self, a: Self, b: Self) -> Self;
fn reduce_sum(self) -> Self::Scalar;
fn reduce_max(self) -> Self::Scalar
where
Self::Scalar: Real;
fn reduce_min(self) -> Self::Scalar
where
Self::Scalar: Real;
fn extract(self, index: usize) -> Self::Scalar;
fn insert(self, index: usize, value: Self::Scalar) -> Self;
}
pub trait SimdMask: SimdRegister {
type Mask: Copy + Clone;
fn mask_from_bools(bools: &[bool]) -> Self::Mask;
unsafe fn load_masked(ptr: *const Self::Scalar, mask: Self::Mask, default: Self) -> Self;
unsafe fn store_masked(self, ptr: *mut Self::Scalar, mask: Self::Mask);
fn blend(mask: Self::Mask, a: Self, b: Self) -> Self;
}
#[derive(Debug, Clone, Copy)]
pub struct SimdChunks {
pub len: usize,
pub lanes: usize,
pub head_end: usize,
pub body_end: usize,
}
impl SimdChunks {
#[inline]
pub fn new<T: Scalar>(ptr: *const T, len: usize, level: SimdLevel) -> Self {
let lanes = level.lanes::<T>();
let align = level.width_bytes();
if lanes <= 1 || len < lanes * 2 {
return SimdChunks {
len,
lanes,
head_end: len,
body_end: len,
};
}
let addr = ptr as usize;
let misalign = addr % align;
let head_end = if misalign == 0 {
0
} else {
let elements_to_align = (align - misalign) / core::mem::size_of::<T>();
elements_to_align.min(len)
};
let remaining = len - head_end;
let full_vectors = remaining / lanes;
let body_end = head_end + full_vectors * lanes;
SimdChunks {
len,
lanes,
head_end,
body_end,
}
}
#[inline]
pub fn head_len(&self) -> usize {
self.head_end
}
#[inline]
pub fn body_len(&self) -> usize {
self.body_end - self.head_end
}
#[inline]
pub fn tail_len(&self) -> usize {
self.len - self.body_end
}
#[inline]
pub fn body_vectors(&self) -> usize {
self.body_len() / self.lanes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_simd_level() {
let level = detect_simd_level();
println!("Detected SIMD level: {:?}", level);
#[cfg(feature = "force-scalar")]
{
assert_eq!(level, SimdLevel::Scalar);
let raw = detect_simd_level_raw();
println!("Raw hardware SIMD level: {:?}", raw);
}
#[cfg(not(feature = "force-scalar"))]
{
#[cfg(target_arch = "x86_64")]
assert!(level >= SimdLevel::Simd128);
#[cfg(target_arch = "aarch64")]
assert_eq!(level, SimdLevel::Simd128);
}
}
#[test]
fn test_simd_level_lanes() {
assert_eq!(SimdLevel::Simd256.lanes::<f64>(), 4);
assert_eq!(SimdLevel::Simd256.lanes::<f32>(), 8);
assert_eq!(SimdLevel::Simd512.lanes::<f64>(), 8);
assert_eq!(SimdLevel::Simd512.lanes::<f32>(), 16);
}
#[test]
fn test_simd_chunks() {
let data: Vec<f64> = vec![0.0; 100];
let ptr = data.as_ptr();
let chunks = SimdChunks::new(ptr, 100, SimdLevel::Simd256);
println!(
"Chunks: head_end={}, body_end={}",
chunks.head_end, chunks.body_end
);
assert_eq!(
chunks.head_len() + chunks.body_len() + chunks.tail_len(),
100
);
}
#[test]
fn test_scalar_fma_accuracy() {
use crate::simd::scalar::ScalarF64;
let a = ScalarF64::splat(1.0 + 1e-15);
let b = ScalarF64::splat(1.0 + 1e-15);
let c = ScalarF64::splat(-(1.0 + 2e-15));
let fma_result = a.mul_add(b, c);
let mul_add_result = a.mul(b).add(c);
assert!(fma_result.0.abs() < 1e-14);
assert!(mul_add_result.0.abs() < 1e-14);
}
#[test]
fn test_load_store_roundtrip() {
use crate::simd::scalar::ScalarF64;
let values = [42.0f64, 1.5, -3.5, 1000.0];
for &val in &values {
let v = ScalarF64::splat(val);
assert_eq!(v.reduce_sum(), val);
assert_eq!(v.extract(0), val);
}
}
#[test]
fn test_arithmetic_identities() {
use crate::simd::scalar::{ScalarF32, ScalarF64};
let a = ScalarF64::splat(5.0);
let zero = ScalarF64::zero();
let one = ScalarF64::splat(1.0);
assert_eq!(a.add(zero).0, 5.0);
assert_eq!(a.sub(zero).0, 5.0);
assert_eq!(a.mul(one).0, 5.0);
assert_eq!(a.div(one).0, 5.0);
assert_eq!(a.mul(zero).0, 0.0);
let a32 = ScalarF32::splat(5.0);
let zero32 = ScalarF32::zero();
let one32 = ScalarF32::splat(1.0);
assert_eq!(a32.add(zero32).0, 5.0);
assert_eq!(a32.mul(one32).0, 5.0);
}
#[test]
fn test_reductions() {
use crate::simd::scalar::{ScalarF32, ScalarF64};
let a = ScalarF64::splat(42.0);
assert_eq!(a.reduce_sum(), 42.0);
assert_eq!(a.reduce_max(), 42.0);
assert_eq!(a.reduce_min(), 42.0);
let b = ScalarF32::splat(-3.5);
assert_eq!(b.reduce_sum(), -3.5);
assert_eq!(b.reduce_max(), -3.5);
assert_eq!(b.reduce_min(), -3.5);
}
#[test]
fn test_negative_values() {
use crate::simd::scalar::ScalarF64;
let neg = ScalarF64::splat(-5.0);
let pos = ScalarF64::splat(3.0);
assert_eq!(neg.add(pos).0, -2.0);
assert_eq!(neg.mul(pos).0, -15.0);
assert_eq!(neg.sub(pos).0, -8.0);
}
#[test]
fn test_fma_variants() {
use crate::simd::scalar::ScalarF64;
let a = ScalarF64::splat(2.0);
let b = ScalarF64::splat(3.0);
let c = ScalarF64::splat(4.0);
assert_eq!(a.mul_add(b, c).0, 10.0);
assert_eq!(a.mul_sub(b, c).0, 2.0);
assert_eq!(a.neg_mul_add(b, c).0, -2.0);
}
#[test]
fn test_insert_extract() {
use crate::simd::scalar::ScalarF64;
let a = ScalarF64::splat(1.0);
let b = a.insert(0, 42.0);
assert_eq!(b.extract(0), 42.0);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn test_aarch64_simd_correctness() {
use crate::simd::aarch64::{F32x4, F64x2, F64x4};
let a = F64x2::splat(2.0);
let b = F64x2::splat(3.0);
let sum = a.add(b);
assert_eq!(sum.extract(0), 5.0);
assert_eq!(sum.extract(1), 5.0);
let fma = a.mul_add(b, F64x2::splat(1.0));
assert_eq!(fma.extract(0), 7.0);
let c = F64x4::splat(2.0);
let d = F64x4::splat(3.0);
assert_eq!(c.add(d).reduce_sum(), 20.0);
let e = F32x4::splat(2.0);
let f = F32x4::splat(3.0);
assert_eq!(e.add(f).reduce_sum(), 20.0); }
}