use crate::error::Result;
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::PlatformCapabilities;
#[cfg(target_arch = "aarch64")]
pub fn neon_add_f32(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(crate::error::NumRs2Error::ShapeMismatch {
expected: vec![a.len()],
actual: vec![b.len()],
});
}
result
.par_chunks_mut(4)
.zip(a.par_chunks(4))
.zip(b.par_chunks(4))
.for_each(|((r_chunk, a_chunk), b_chunk)| {
for i in 0..r_chunk.len() {
r_chunk[i] = a_chunk[i] + b_chunk[i];
}
});
Ok(())
}
#[cfg(target_arch = "aarch64")]
pub fn neon_mul_f32(a: &[f32], b: &[f32], result: &mut [f32]) -> Result<()> {
if a.len() != b.len() || a.len() != result.len() {
return Err(crate::error::NumRs2Error::ShapeMismatch {
expected: vec![a.len()],
actual: vec![b.len()],
});
}
result
.par_chunks_mut(4)
.zip(a.par_chunks(4))
.zip(b.par_chunks(4))
.for_each(|((r_chunk, a_chunk), b_chunk)| {
for i in 0..r_chunk.len() {
r_chunk[i] = a_chunk[i] * b_chunk[i];
}
});
Ok(())
}
#[cfg(target_arch = "aarch64")]
pub fn neon_dot_f32(a: &[f32], b: &[f32]) -> Result<f32> {
if a.len() != b.len() {
return Err(crate::error::NumRs2Error::ShapeMismatch {
expected: vec![a.len()],
actual: vec![b.len()],
});
}
let sum: f32 = a
.par_chunks(4)
.zip(b.par_chunks(4))
.map(|(a_chunk, b_chunk)| {
a_chunk
.iter()
.zip(b_chunk.iter())
.map(|(&a, &b)| a * b)
.sum::<f32>()
})
.sum();
Ok(sum)
}
#[cfg(target_arch = "aarch64")]
pub fn neon_sum_f32(data: &[f32]) -> f32 {
data.par_chunks(4)
.map(|chunk| chunk.iter().sum::<f32>())
.sum()
}
#[cfg(target_arch = "aarch64")]
pub fn neon_max_f32(data: &[f32]) -> Option<f32> {
data.par_chunks(4)
.map(|chunk| chunk.iter().cloned().fold(f32::NEG_INFINITY, f32::max))
.reduce(|| f32::NEG_INFINITY, f32::max)
.into()
}
#[cfg(target_arch = "aarch64")]
pub fn neon_min_f32(data: &[f32]) -> Option<f32> {
data.par_chunks(4)
.map(|chunk| chunk.iter().cloned().fold(f32::INFINITY, f32::min))
.reduce(|| f32::INFINITY, f32::min)
.into()
}
#[cfg(target_arch = "aarch64")]
pub fn neon_exp_f32(data: &[f32], result: &mut [f32]) -> Result<()> {
if data.len() != result.len() {
return Err(crate::error::NumRs2Error::ShapeMismatch {
expected: vec![data.len()],
actual: vec![result.len()],
});
}
result
.par_chunks_mut(4)
.zip(data.par_chunks(4))
.for_each(|(r_chunk, d_chunk)| {
for i in 0..r_chunk.len() {
r_chunk[i] = d_chunk[i].exp();
}
});
Ok(())
}
#[cfg(target_arch = "aarch64")]
pub fn neon_sqrt_f32(data: &[f32], result: &mut [f32]) -> Result<()> {
if data.len() != result.len() {
return Err(crate::error::NumRs2Error::ShapeMismatch {
expected: vec![data.len()],
actual: vec![result.len()],
});
}
result
.par_chunks_mut(4)
.zip(data.par_chunks(4))
.for_each(|(r_chunk, d_chunk)| {
for i in 0..r_chunk.len() {
r_chunk[i] = d_chunk[i].sqrt();
}
});
Ok(())
}
#[cfg(target_arch = "aarch64")]
pub fn is_neon_available() -> bool {
let caps = PlatformCapabilities::detect();
caps.neon_available
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_add_f32(_a: &[f32], _b: &[f32], _result: &mut [f32]) -> Result<()> {
Err(crate::error::NumRs2Error::FeatureNotEnabled(
"NEON is only available on ARM64 architectures".to_string(),
))
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_mul_f32(_a: &[f32], _b: &[f32], _result: &mut [f32]) -> Result<()> {
Err(crate::error::NumRs2Error::FeatureNotEnabled(
"NEON is only available on ARM64 architectures".to_string(),
))
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_dot_f32(_a: &[f32], _b: &[f32]) -> Result<f32> {
Err(crate::error::NumRs2Error::FeatureNotEnabled(
"NEON is only available on ARM64 architectures".to_string(),
))
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_sum_f32(_data: &[f32]) -> f32 {
panic!("NEON is only available on ARM64 architectures")
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_max_f32(_data: &[f32]) -> Option<f32> {
None
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_min_f32(_data: &[f32]) -> Option<f32> {
None
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_exp_f32(_data: &[f32], _result: &mut [f32]) -> Result<()> {
Err(crate::error::NumRs2Error::FeatureNotEnabled(
"NEON is only available on ARM64 architectures".to_string(),
))
}
#[cfg(not(target_arch = "aarch64"))]
pub fn neon_sqrt_f32(_data: &[f32], _result: &mut [f32]) -> Result<()> {
Err(crate::error::NumRs2Error::FeatureNotEnabled(
"NEON is only available on ARM64 architectures".to_string(),
))
}
#[cfg(not(target_arch = "aarch64"))]
pub fn is_neon_available() -> bool {
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neon_availability() {
let available = is_neon_available();
#[cfg(target_arch = "aarch64")]
{
println!("NEON is available: {}", available);
}
#[cfg(not(target_arch = "aarch64"))]
{
assert!(
!available,
"NEON should not be available on non-ARM architectures"
);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_neon_add_f32() {
if !is_neon_available() {
return; }
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let mut result = vec![0.0; 8];
neon_add_f32(&a, &b, &mut result)
.expect("neon_add_f32 should succeed for equal length arrays");
let expected = vec![9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0];
assert_eq!(result, expected);
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_neon_mul_f32() {
if !is_neon_available() {
return;
}
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 2.0, 2.0, 2.0];
let mut result = vec![0.0; 4];
neon_mul_f32(&a, &b, &mut result)
.expect("neon_mul_f32 should succeed for equal length arrays");
let expected = vec![2.0, 4.0, 6.0, 8.0];
assert_eq!(result, expected);
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_neon_dot_f32() {
if !is_neon_available() {
return;
}
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 1.0, 1.0, 1.0];
let result =
neon_dot_f32(&a, &b).expect("neon_dot_f32 should succeed for equal length arrays");
assert_eq!(result, 10.0);
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_neon_sum_f32() {
if !is_neon_available() {
return;
}
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = neon_sum_f32(&data);
assert_eq!(result, 15.0);
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_neon_max_min_f32() {
if !is_neon_available() {
return;
}
let data = vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
let max =
neon_max_f32(&data).expect("neon_max_f32 should return a value for non-empty array");
let min =
neon_min_f32(&data).expect("neon_min_f32 should return a value for non-empty array");
assert_eq!(max, 9.0);
assert_eq!(min, 1.0);
}
}