pub mod vectors;
pub mod matrix;
pub mod energy;
pub use vectors::{dot_product_simd, norm_squared_simd, subtract_simd, scale_simd};
pub use matrix::{matmul_simd, matvec_simd};
pub use energy::{
batch_residuals_simd, weighted_energy_sum_simd, batch_lane_assignment_simd,
batch_residual_norms_simd,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum SimdWidth {
Scalar = 0,
Sse42 = 1,
Avx2 = 2,
Avx512 = 3,
Neon = 4,
}
impl SimdWidth {
#[inline]
pub const fn lanes_f32(self) -> usize {
match self {
SimdWidth::Scalar => 1,
SimdWidth::Sse42 | SimdWidth::Neon => 4,
SimdWidth::Avx2 => 8,
SimdWidth::Avx512 => 16,
}
}
#[inline]
pub const fn lanes_f64(self) -> usize {
match self {
SimdWidth::Scalar => 1,
SimdWidth::Sse42 | SimdWidth::Neon => 2,
SimdWidth::Avx2 => 4,
SimdWidth::Avx512 => 8,
}
}
#[inline]
pub fn is_supported(self) -> bool {
match self {
SimdWidth::Scalar => true,
SimdWidth::Sse42 => cfg!(target_arch = "x86_64") && is_sse42_supported(),
SimdWidth::Avx2 => cfg!(target_arch = "x86_64") && is_avx2_supported(),
SimdWidth::Avx512 => cfg!(target_arch = "x86_64") && is_avx512_supported(),
SimdWidth::Neon => cfg!(target_arch = "aarch64") && is_neon_supported(),
}
}
pub const fn name(self) -> &'static str {
match self {
SimdWidth::Scalar => "Scalar",
SimdWidth::Sse42 => "SSE4.2",
SimdWidth::Avx2 => "AVX2",
SimdWidth::Avx512 => "AVX-512",
SimdWidth::Neon => "NEON",
}
}
}
impl Default for SimdWidth {
fn default() -> Self {
best_simd_width()
}
}
#[inline]
pub fn best_simd_width() -> SimdWidth {
#[cfg(target_arch = "x86_64")]
{
if is_avx512_supported() {
return SimdWidth::Avx512;
}
if is_avx2_supported() {
return SimdWidth::Avx2;
}
if is_sse42_supported() {
return SimdWidth::Sse42;
}
}
#[cfg(target_arch = "aarch64")]
{
if is_neon_supported() {
return SimdWidth::Neon;
}
}
SimdWidth::Scalar
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn is_sse42_supported() -> bool {
#[cfg(target_feature = "sse4.2")]
{
true
}
#[cfg(not(target_feature = "sse4.2"))]
{
std::arch::is_x86_feature_detected!("sse4.2")
}
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
fn is_sse42_supported() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn is_avx2_supported() -> bool {
#[cfg(target_feature = "avx2")]
{
true
}
#[cfg(not(target_feature = "avx2"))]
{
std::arch::is_x86_feature_detected!("avx2")
}
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
fn is_avx2_supported() -> bool {
false
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn is_avx512_supported() -> bool {
#[cfg(target_feature = "avx512f")]
{
true
}
#[cfg(not(target_feature = "avx512f"))]
{
std::arch::is_x86_feature_detected!("avx512f")
}
}
#[cfg(not(target_arch = "x86_64"))]
#[inline]
fn is_avx512_supported() -> bool {
false
}
#[cfg(target_arch = "aarch64")]
#[inline]
fn is_neon_supported() -> bool {
true
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
fn is_neon_supported() -> bool {
false
}
#[derive(Debug, Clone)]
pub struct SimdContext {
pub width: SimdWidth,
pub f32_lanes: usize,
pub f64_lanes: usize,
}
impl SimdContext {
pub fn new() -> Self {
let width = best_simd_width();
Self {
width,
f32_lanes: width.lanes_f32(),
f64_lanes: width.lanes_f64(),
}
}
pub fn with_width(width: SimdWidth) -> Self {
assert!(width.is_supported(), "SIMD width {:?} not supported", width);
Self {
width,
f32_lanes: width.lanes_f32(),
f64_lanes: width.lanes_f64(),
}
}
pub fn global() -> &'static SimdContext {
use once_cell::sync::Lazy;
static CONTEXT: Lazy<SimdContext> = Lazy::new(SimdContext::new);
&CONTEXT
}
}
impl Default for SimdContext {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_width_detection() {
let width = best_simd_width();
println!("Detected SIMD width: {:?}", width);
assert!(width.is_supported());
}
#[test]
fn test_simd_lanes() {
assert_eq!(SimdWidth::Scalar.lanes_f32(), 1);
assert_eq!(SimdWidth::Sse42.lanes_f32(), 4);
assert_eq!(SimdWidth::Avx2.lanes_f32(), 8);
assert_eq!(SimdWidth::Avx512.lanes_f32(), 16);
assert_eq!(SimdWidth::Neon.lanes_f32(), 4);
}
#[test]
fn test_simd_context() {
let ctx = SimdContext::new();
assert!(ctx.width.is_supported());
assert_eq!(ctx.f32_lanes, ctx.width.lanes_f32());
}
#[test]
fn test_global_context() {
let ctx1 = SimdContext::global();
let ctx2 = SimdContext::global();
assert_eq!(ctx1.width, ctx2.width);
}
}