use super::feature_detect::{detect_cpu_features, CpuFeatures};
use super::simd_select::{select_simd_implementation, SimdImplementation};
use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use std::sync::OnceLock;
#[cfg(target_arch = "x86_64")]
use super::avx2_enhanced::EnhancedSimdOps;
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
use super::avx512_enhanced::Avx2EnhancedOps;
#[cfg(target_arch = "aarch64")]
use super::neon_enhanced::NeonEnhancedOps;
#[repr(align(64))]
pub struct UnifiedSimdDispatcher {
features: CpuFeatures,
implementation: SimdImplementation,
}
impl UnifiedSimdDispatcher {
pub fn new() -> Self {
let features = detect_cpu_features();
let implementation = select_simd_implementation(&features);
Self {
features,
implementation,
}
}
pub fn implementation_info(&self) -> SimdImplementationInfo {
SimdImplementationInfo {
name: self.implementation.name(),
vector_width: self.implementation.vector_width(),
supports_fma: self.features.fma,
supports_avx512: matches!(self.implementation, SimdImplementation::AVX512),
architecture: std::env::consts::ARCH,
}
}
pub fn optimized_matmul_f32(&self, a: &Array<f32>, b: &Array<f32>) -> Result<Array<f32>> {
let [m, k] = a.shape()[..] else {
return Err(NumRs2Error::DimensionMismatch(
"Matrix A must be 2D".to_string(),
));
};
let [k2, n] = b.shape()[..] else {
return Err(NumRs2Error::DimensionMismatch(
"Matrix B must be 2D".to_string(),
));
};
if k != k2 {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![k],
actual: vec![k2],
});
}
let mut result = Array::zeros(&[m, n]);
match self.implementation {
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
SimdImplementation::AVX512 => {
let tile_size = 64; Avx2EnhancedOps::avx2_matmul_f32(a, b, &mut result, tile_size)?;
}
#[cfg(target_arch = "x86_64")]
SimdImplementation::AVX2 => {
let block_size = 32; EnhancedSimdOps::cache_aware_matmul_f32(a, b, &mut result, block_size)?;
}
#[cfg(target_arch = "aarch64")]
SimdImplementation::NEON => {
let block_size = 32; NeonEnhancedOps::neon_matmul_f32(a, b, &mut result, block_size)?;
}
_ => {
result = a.matmul(b)?;
}
}
Ok(result)
}
pub fn optimized_exp_f32(&self, input: &Array<f32>) -> Array<f32> {
match self.implementation {
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
SimdImplementation::AVX512 => EnhancedSimdOps::vectorized_exp_f32(input),
#[cfg(target_arch = "x86_64")]
SimdImplementation::AVX2 => EnhancedSimdOps::vectorized_exp_f32(input),
#[cfg(target_arch = "aarch64")]
SimdImplementation::NEON => NeonEnhancedOps::neon_exp_f32(input),
_ => input.map(|x| x.exp()),
}
}
pub fn optimized_log_f32(&self, input: &Array<f32>) -> Array<f32> {
match self.implementation {
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
SimdImplementation::AVX512 => input.map(|x| x.ln()),
#[cfg(target_arch = "x86_64")]
SimdImplementation::AVX2 => EnhancedSimdOps::vectorized_log_f32(input),
#[cfg(target_arch = "aarch64")]
SimdImplementation::NEON => NeonEnhancedOps::neon_log_f32(input),
_ => input.map(|x| x.ln()),
}
}
pub fn optimized_sin_cos_f32(&self, input: &Array<f32>) -> (Array<f32>, Array<f32>) {
match self.implementation {
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
SimdImplementation::AVX512 => (input.map(|x| x.sin()), input.map(|x| x.cos())),
#[cfg(target_arch = "x86_64")]
SimdImplementation::AVX2 => {
let sin_result = EnhancedSimdOps::vectorized_sin_f32_simd(input);
let cos_result = input.map(|x| x.cos());
(sin_result, cos_result)
}
#[cfg(target_arch = "aarch64")]
SimdImplementation::NEON => NeonEnhancedOps::neon_sin_cos_f32(input),
_ => (input.map(|x| x.sin()), input.map(|x| x.cos())),
}
}
pub fn optimized_sum_f32(&self, input: &Array<f32>) -> f32 {
match self.implementation {
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
SimdImplementation::AVX512 => input.sum(),
#[cfg(target_arch = "x86_64")]
SimdImplementation::AVX2 => EnhancedSimdOps::simd_kahan_sum_f32(input),
#[cfg(target_arch = "aarch64")]
SimdImplementation::NEON => NeonEnhancedOps::neon_sum_f32(input),
_ => input.sum(),
}
}
pub fn optimized_dot_f32(&self, a: &Array<f32>, b: &Array<f32>) -> Result<f32> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
});
}
let result = match self.implementation {
#[cfg(target_arch = "aarch64")]
SimdImplementation::NEON => NeonEnhancedOps::neon_dot_f32(a, b)?,
_ => a.dot(b)?,
};
Ok(result)
}
pub fn optimized_complex_multiply_f32(
&self,
a_real: &Array<f32>,
a_imag: &Array<f32>,
b_real: &Array<f32>,
b_imag: &Array<f32>,
) -> Result<(Array<f32>, Array<f32>)> {
match self.implementation {
#[cfg(target_arch = "x86_64")]
SimdImplementation::AVX2 => {
EnhancedSimdOps::complex_multiply_f32(a_real, a_imag, b_real, b_imag)
}
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
SimdImplementation::AVX512 => {
EnhancedSimdOps::complex_multiply_f32(a_real, a_imag, b_real, b_imag)
}
_ => {
if a_real.shape() != a_imag.shape()
|| b_real.shape() != b_imag.shape()
|| a_real.shape() != b_real.shape()
{
return Err(NumRs2Error::ShapeMismatch {
expected: a_real.shape(),
actual: b_real.shape(),
});
}
let a_r = a_real.to_vec();
let a_i = a_imag.to_vec();
let b_r = b_real.to_vec();
let b_i = b_imag.to_vec();
let c_r: Vec<f32> = a_r
.iter()
.zip(a_i.iter())
.zip(b_r.iter().zip(b_i.iter()))
.map(|((&ar, &ai), (&br, &bi))| ar * br - ai * bi)
.collect();
let c_i: Vec<f32> = a_r
.iter()
.zip(a_i.iter())
.zip(b_r.iter().zip(b_i.iter()))
.map(|((&ar, &ai), (&br, &bi))| ar * bi + ai * br)
.collect();
Ok((
Array::from_vec(c_r).reshape(&a_real.shape()),
Array::from_vec(c_i).reshape(&a_real.shape()),
))
}
}
}
pub fn optimized_copy_f32(&self, src: &Array<f32>) -> Result<Array<f32>> {
let dst = match self.implementation {
#[cfg(target_arch = "x86_64")]
SimdImplementation::AVX2 => EnhancedSimdOps::simd_copy_f32(src),
#[cfg(all(target_arch = "x86_64", feature = "unstable"))]
SimdImplementation::AVX512 => EnhancedSimdOps::simd_copy_f32(src),
#[cfg(target_arch = "aarch64")]
SimdImplementation::NEON => {
let mut dst = Array::zeros(&src.shape());
NeonEnhancedOps::neon_copy_f32(src, &mut dst)?;
dst
}
_ => src.clone(),
};
Ok(dst)
}
pub fn benchmark_operations(&self, size: usize, iterations: usize) -> SimdBenchmarkResults {
use std::time::Instant;
let data = Array::from_vec((0..size).map(|i| i as f32).collect::<Vec<_>>());
let start = Instant::now();
for _ in 0..iterations {
let _result = self.optimized_exp_f32(&data);
}
let exp_time = start.elapsed().as_nanos() as f64 / iterations as f64;
let start = Instant::now();
for _ in 0..iterations {
let _result = self.optimized_sum_f32(&data);
}
let sum_time = start.elapsed().as_nanos() as f64 / iterations as f64;
let start = Instant::now();
for _ in 0..iterations {
let _result = self.optimized_copy_f32(&data);
}
let copy_time = start.elapsed().as_nanos() as f64 / iterations as f64;
SimdBenchmarkResults {
implementation: self.implementation.name(),
elements: size,
exp_time_ns: exp_time,
sum_time_ns: sum_time,
copy_time_ns: copy_time,
exp_throughput: size as f64 / exp_time * 1e9,
sum_throughput: size as f64 / sum_time * 1e9,
copy_throughput: size as f64 / copy_time * 1e9,
}
}
}
impl Default for UnifiedSimdDispatcher {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SimdImplementationInfo {
pub name: &'static str,
pub vector_width: usize,
pub supports_fma: bool,
pub supports_avx512: bool,
pub architecture: &'static str,
}
impl SimdImplementationInfo {
pub fn print_info(&self) {
println!("SIMD Implementation Info:");
println!(" Name: {}", self.name);
println!(" Architecture: {}", self.architecture);
println!(" Vector width: {} bits", self.vector_width);
println!(" FMA support: {}", self.supports_fma);
println!(" AVX-512 support: {}", self.supports_avx512);
}
}
#[derive(Debug, Clone)]
pub struct SimdBenchmarkResults {
pub implementation: &'static str,
pub elements: usize,
pub exp_time_ns: f64,
pub sum_time_ns: f64,
pub copy_time_ns: f64,
pub exp_throughput: f64,
pub sum_throughput: f64,
pub copy_throughput: f64,
}
impl SimdBenchmarkResults {
pub fn print_summary(&self) {
println!("SIMD Benchmark Results ({}):", self.implementation);
println!(" Elements: {}", self.elements);
println!(" Exponential:");
println!(" Time: {:.2} ns", self.exp_time_ns);
println!(" Throughput: {:.2} elements/sec", self.exp_throughput);
println!(" Sum reduction:");
println!(" Time: {:.2} ns", self.sum_time_ns);
println!(" Throughput: {:.2} elements/sec", self.sum_throughput);
println!(" Memory copy:");
println!(" Time: {:.2} ns", self.copy_time_ns);
println!(" Throughput: {:.2} elements/sec", self.copy_throughput);
}
}
static GLOBAL_DISPATCHER: OnceLock<UnifiedSimdDispatcher> = OnceLock::new();
pub fn global_dispatcher() -> &'static UnifiedSimdDispatcher {
GLOBAL_DISPATCHER.get_or_init(UnifiedSimdDispatcher::new)
}
pub mod optimized {
use super::*;
pub fn matmul_f32(a: &Array<f32>, b: &Array<f32>) -> Result<Array<f32>> {
global_dispatcher().optimized_matmul_f32(a, b)
}
pub fn exp_f32(input: &Array<f32>) -> Array<f32> {
global_dispatcher().optimized_exp_f32(input)
}
pub fn log_f32(input: &Array<f32>) -> Array<f32> {
global_dispatcher().optimized_log_f32(input)
}
pub fn sin_cos_f32(input: &Array<f32>) -> (Array<f32>, Array<f32>) {
global_dispatcher().optimized_sin_cos_f32(input)
}
pub fn sum_f32(input: &Array<f32>) -> f32 {
global_dispatcher().optimized_sum_f32(input)
}
pub fn dot_f32(a: &Array<f32>, b: &Array<f32>) -> Result<f32> {
global_dispatcher().optimized_dot_f32(a, b)
}
pub fn complex_multiply_f32(
a_real: &Array<f32>,
a_imag: &Array<f32>,
b_real: &Array<f32>,
b_imag: &Array<f32>,
) -> Result<(Array<f32>, Array<f32>)> {
global_dispatcher().optimized_complex_multiply_f32(a_real, a_imag, b_real, b_imag)
}
pub fn copy_f32(src: &Array<f32>) -> Result<Array<f32>> {
global_dispatcher().optimized_copy_f32(src)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dispatcher_creation() {
let dispatcher = UnifiedSimdDispatcher::new();
let info = dispatcher.implementation_info();
println!("Dispatcher created with: {:?}", info);
assert!(!info.name.is_empty());
}
#[test]
fn test_optimized_operations() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
let sum = optimized::sum_f32(&a);
assert_relative_eq!(sum, 10.0, epsilon = 1e-6);
let dot =
optimized::dot_f32(&a, &b).expect("dot_f32 should succeed with equal-length vectors");
assert_relative_eq!(dot, 70.0, epsilon = 1e-6);
let exp_input = Array::from_vec(vec![0.0, 1.0]);
let exp_result = optimized::exp_f32(&exp_input);
let result_vec = exp_result.to_vec();
println!("exp_result values: {:?}", result_vec);
println!("Expected: [1.0, {}]", std::f32::consts::E);
#[cfg(target_arch = "x86_64")]
{
let direct_result =
crate::simd_optimize::avx2_enhanced::EnhancedSimdOps::vectorized_exp_f32(
&exp_input,
);
let direct_vec = direct_result.to_vec();
println!("Direct AVX2 result: {:?}", direct_vec);
assert_relative_eq!(direct_vec[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(direct_vec[1], std::f32::consts::E, epsilon = 1e-5);
}
#[cfg(not(target_arch = "x86_64"))]
{
let fallback_result = exp_input.map(|x| x.exp());
let fallback_vec = fallback_result.to_vec();
assert_relative_eq!(fallback_vec[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(fallback_vec[1], std::f32::consts::E, epsilon = 1e-5);
}
}
#[test]
fn test_global_dispatcher() {
let dispatcher = global_dispatcher();
let info = dispatcher.implementation_info();
println!("Global dispatcher: {:?}", info);
assert!(!info.name.is_empty());
}
#[test]
fn test_benchmarking() {
let dispatcher = UnifiedSimdDispatcher::new();
let results = dispatcher.benchmark_operations(1000, 100);
results.print_summary();
assert!(results.exp_throughput > 0.0);
assert!(results.sum_throughput > 0.0);
assert!(results.copy_throughput > 0.0);
}
#[test]
fn test_complex_multiply() {
let a_r = Array::from_vec(vec![1.0, 2.0]);
let a_i = Array::from_vec(vec![3.0, 4.0]);
let b_r = Array::from_vec(vec![5.0, 6.0]);
let b_i = Array::from_vec(vec![7.0, 8.0]);
let (c_r, c_i) = optimized::complex_multiply_f32(&a_r, &a_i, &b_r, &b_i)
.expect("complex_multiply_f32 should succeed with equal-sized complex vectors");
assert_relative_eq!(c_r.to_vec()[0], -16.0, epsilon = 1e-6);
assert_relative_eq!(c_i.to_vec()[0], 22.0, epsilon = 1e-6);
}
}