use crate::error::{CoreError, CoreResult};
#[inline]
pub fn is_neon_available() -> bool {
#[cfg(target_arch = "aarch64")]
{
true
}
#[cfg(all(target_arch = "arm", target_feature = "neon"))]
{
true
}
#[cfg(not(any(
target_arch = "aarch64",
all(target_arch = "arm", target_feature = "neon")
)))]
{
false
}
}
#[inline]
pub fn neon_dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return scalar_dot_product(a, b);
}
#[cfg(target_arch = "aarch64")]
{
unsafe { neon_dot_product_impl(a, b) }
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_dot_product(a, b)
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn neon_dot_product_impl(a: &[f32], b: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = a.len();
let chunks = len / 4;
let _remainder = len % 4;
let mut sum = vdupq_n_f32(0.0);
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
sum = vfmaq_f32(sum, va, vb); }
let mut result = vaddvq_f32(sum);
for i in (chunks * 4)..len {
result += a[i] * b[i];
}
result
}
#[inline]
fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
pub fn neon_vec_add(a: &[f32], b: &[f32], c: &mut [f32]) -> CoreResult<()> {
if a.len() != b.len() || a.len() != c.len() {
return Err(CoreError::DimensionMismatch {
expected: a.len(),
got: b.len(),
});
}
#[cfg(target_arch = "aarch64")]
unsafe {
neon_vec_add_impl(a, b, c);
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..a.len() {
c[i] = a[i] + b[i];
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn neon_vec_add_impl(a: &[f32], b: &[f32], c: &mut [f32]) {
use std::arch::aarch64::*;
let len = a.len();
let chunks = len / 4;
let _remainder = len % 4;
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
let vc = vaddq_f32(va, vb);
vst1q_f32(c.as_mut_ptr().add(idx), vc);
}
for i in (chunks * 4)..len {
c[i] = a[i] + b[i];
}
}
pub fn neon_vec_mul(a: &[f32], b: &[f32], c: &mut [f32]) -> CoreResult<()> {
if a.len() != b.len() || a.len() != c.len() {
return Err(CoreError::DimensionMismatch {
expected: a.len(),
got: b.len(),
});
}
#[cfg(target_arch = "aarch64")]
unsafe {
neon_vec_mul_impl(a, b, c);
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..a.len() {
c[i] = a[i] * b[i];
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn neon_vec_mul_impl(a: &[f32], b: &[f32], c: &mut [f32]) {
use std::arch::aarch64::*;
let len = a.len();
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
let vc = vmulq_f32(va, vb);
vst1q_f32(c.as_mut_ptr().add(idx), vc);
}
for i in (chunks * 4)..len {
c[i] = a[i] * b[i];
}
}
pub fn neon_vec_fma(a: &[f32], b: &[f32], c: &mut [f32]) -> CoreResult<()> {
if a.len() != b.len() || a.len() != c.len() {
return Err(CoreError::DimensionMismatch {
expected: a.len(),
got: b.len(),
});
}
#[cfg(target_arch = "aarch64")]
unsafe {
neon_vec_fma_impl(a, b, c);
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..a.len() {
c[i] += a[i] * b[i];
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn neon_vec_fma_impl(a: &[f32], b: &[f32], c: &mut [f32]) {
use std::arch::aarch64::*;
let len = a.len();
let chunks = len / 4;
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a.as_ptr().add(idx));
let vb = vld1q_f32(b.as_ptr().add(idx));
let vc = vld1q_f32(c.as_ptr().add(idx));
let vresult = vfmaq_f32(vc, va, vb); vst1q_f32(c.as_mut_ptr().add(idx), vresult);
}
for i in (chunks * 4)..len {
c[i] += a[i] * b[i];
}
}
pub fn neon_matvec(
matrix: &[f32],
x: &[f32],
y: &mut [f32],
rows: usize,
cols: usize,
) -> CoreResult<()> {
if matrix.len() != rows * cols {
return Err(CoreError::DimensionMismatch {
expected: rows * cols,
got: matrix.len(),
});
}
if x.len() != cols || y.len() != rows {
return Err(CoreError::DimensionMismatch {
expected: cols,
got: x.len(),
});
}
for i in 0..rows {
let row = &matrix[i * cols..(i + 1) * cols];
y[i] = neon_dot_product(row, x);
}
Ok(())
}
pub fn neon_relu(x: &[f32], y: &mut [f32]) -> CoreResult<()> {
if x.len() != y.len() {
return Err(CoreError::DimensionMismatch {
expected: x.len(),
got: y.len(),
});
}
#[cfg(target_arch = "aarch64")]
unsafe {
neon_relu_impl(x, y);
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..x.len() {
y[i] = x[i].max(0.0);
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn neon_relu_impl(x: &[f32], y: &mut [f32]) {
use std::arch::aarch64::*;
let len = x.len();
let chunks = len / 4;
let zeros = vdupq_n_f32(0.0);
for i in 0..chunks {
let idx = i * 4;
let vx = vld1q_f32(x.as_ptr().add(idx));
let vy = vmaxq_f32(vx, zeros); vst1q_f32(y.as_mut_ptr().add(idx), vy);
}
for i in (chunks * 4)..len {
y[i] = x[i].max(0.0);
}
}
pub fn neon_layer_norm(x: &[f32], y: &mut [f32], eps: f32) -> CoreResult<()> {
if x.len() != y.len() {
return Err(CoreError::DimensionMismatch {
expected: x.len(),
got: y.len(),
});
}
let n = x.len() as f32;
let sum: f32 = x.iter().sum();
let mean = sum / n;
let var_sum: f32 = x.iter().map(|&xi| (xi - mean).powi(2)).sum();
let variance = var_sum / n;
let std_inv = 1.0 / (variance + eps).sqrt();
#[cfg(target_arch = "aarch64")]
unsafe {
neon_layer_norm_impl(x, y, mean, std_inv);
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..x.len() {
y[i] = (x[i] - mean) * std_inv;
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[inline]
unsafe fn neon_layer_norm_impl(x: &[f32], y: &mut [f32], mean: f32, std_inv: f32) {
use std::arch::aarch64::*;
let len = x.len();
let chunks = len / 4;
let vmean = vdupq_n_f32(mean);
let vstd_inv = vdupq_n_f32(std_inv);
for i in 0..chunks {
let idx = i * 4;
let vx = vld1q_f32(x.as_ptr().add(idx));
let centered = vsubq_f32(vx, vmean);
let normalized = vmulq_f32(centered, vstd_inv);
vst1q_f32(y.as_mut_ptr().add(idx), normalized);
}
for i in (chunks * 4)..len {
y[i] = (x[i] - mean) * std_inv;
}
}
pub fn neon_softmax(x: &mut [f32]) {
let n = x.len();
if n == 0 {
return;
}
let mut max_val = x[0];
let mut sum = 1.0f32;
for &xi in x.iter().skip(1) {
if xi > max_val {
sum *= (max_val - xi).exp();
sum += 1.0;
max_val = xi;
} else {
sum += (xi - max_val).exp();
}
}
let inv_sum = 1.0 / sum;
#[cfg(target_arch = "aarch64")]
unsafe {
neon_softmax_normalize_impl(x, max_val, inv_sum);
}
#[cfg(not(target_arch = "aarch64"))]
{
for val in x.iter_mut() {
*val = (*val - max_val).exp() * inv_sum;
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_softmax_normalize_impl(x: &mut [f32], max_val: f32, inv_sum: f32) {
use std::arch::aarch64::*;
let len = x.len();
let chunks = len / 4;
let vmax = vdupq_n_f32(max_val);
let vinv = vdupq_n_f32(inv_sum);
for i in 0..chunks {
let idx = i * 4;
let vx = vld1q_f32(x.as_ptr().add(idx));
let vshifted = vsubq_f32(vx, vmax);
let e0 = vgetq_lane_f32(vshifted, 0).exp();
let e1 = vgetq_lane_f32(vshifted, 1).exp();
let e2 = vgetq_lane_f32(vshifted, 2).exp();
let e3 = vgetq_lane_f32(vshifted, 3).exp();
let vexp = {
let mut tmp = vdupq_n_f32(0.0);
tmp = vsetq_lane_f32(e0, tmp, 0);
tmp = vsetq_lane_f32(e1, tmp, 1);
tmp = vsetq_lane_f32(e2, tmp, 2);
tmp = vsetq_lane_f32(e3, tmp, 3);
tmp
};
let vnorm = vmulq_f32(vexp, vinv);
vst1q_f32(x.as_mut_ptr().add(idx), vnorm);
}
for val in x.iter_mut().skip(chunks * 4) {
*val = (*val - max_val).exp() * inv_sum;
}
}
pub fn neon_rms_norm(x: &[f32], output: &mut [f32], eps: f32) -> CoreResult<()> {
if x.len() != output.len() {
return Err(CoreError::DimensionMismatch {
expected: x.len(),
got: output.len(),
});
}
let n = x.len();
if n == 0 {
return Ok(());
}
let sum_sq: f32 = {
#[cfg(target_arch = "aarch64")]
{
unsafe { neon_sum_squares(x) }
}
#[cfg(not(target_arch = "aarch64"))]
{
x.iter().map(|&v| v * v).sum()
}
};
let rms = (sum_sq / n as f32 + eps).sqrt();
let inv_rms = 1.0 / rms;
#[cfg(target_arch = "aarch64")]
unsafe {
neon_scale_impl(x, inv_rms, output);
}
#[cfg(not(target_arch = "aarch64"))]
{
for (o, &xi) in output.iter_mut().zip(x.iter()) {
*o = xi * inv_rms;
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_sum_squares(x: &[f32]) -> f32 {
use std::arch::aarch64::*;
let len = x.len();
let chunks = len / 4;
let mut accum = vdupq_n_f32(0.0);
for i in 0..chunks {
let idx = i * 4;
let vx = vld1q_f32(x.as_ptr().add(idx));
accum = vfmaq_f32(accum, vx, vx);
}
let mut result = vaddvq_f32(accum);
for &v in x.iter().skip(chunks * 4) {
result += v * v;
}
result
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_scale_impl(input: &[f32], scale: f32, output: &mut [f32]) {
use std::arch::aarch64::*;
let len = input.len();
let chunks = len / 4;
let vscale = vdupq_n_f32(scale);
for i in 0..chunks {
let idx = i * 4;
let vx = vld1q_f32(input.as_ptr().add(idx));
let vr = vmulq_f32(vx, vscale);
vst1q_f32(output.as_mut_ptr().add(idx), vr);
}
for i in (chunks * 4)..len {
output[i] = input[i] * scale;
}
}
pub fn neon_ssm_update(a_bar: &[f32], h: &mut [f32], b_bar: &[f32], x_val: f32) -> CoreResult<()> {
if a_bar.len() != h.len() || b_bar.len() != h.len() {
return Err(CoreError::DimensionMismatch {
expected: h.len(),
got: if a_bar.len() != h.len() {
a_bar.len()
} else {
b_bar.len()
},
});
}
#[cfg(target_arch = "aarch64")]
unsafe {
neon_ssm_update_impl(a_bar, h, b_bar, x_val);
}
#[cfg(not(target_arch = "aarch64"))]
{
for i in 0..h.len() {
h[i] = a_bar[i] * h[i] + b_bar[i] * x_val;
}
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn neon_ssm_update_impl(a_bar: &[f32], h: &mut [f32], b_bar: &[f32], x_val: f32) {
use std::arch::aarch64::*;
let len = h.len();
let chunks = len / 4;
let vx = vdupq_n_f32(x_val);
for i in 0..chunks {
let idx = i * 4;
let va = vld1q_f32(a_bar.as_ptr().add(idx));
let vh = vld1q_f32(h.as_ptr().add(idx));
let vb = vld1q_f32(b_bar.as_ptr().add(idx));
let vah = vmulq_f32(va, vh);
let vresult = vfmaq_f32(vah, vb, vx);
vst1q_f32(h.as_mut_ptr().add(idx), vresult);
}
for i in (chunks * 4)..len {
h[i] = a_bar[i] * h[i] + b_bar[i] * x_val;
}
}
pub fn neon_fast_exp(x: &[f32], y: &mut [f32]) -> CoreResult<()> {
if x.len() != y.len() {
return Err(CoreError::DimensionMismatch {
expected: x.len(),
got: y.len(),
});
}
for i in 0..x.len() {
y[i] = fast_exp_scalar(x[i]);
}
Ok(())
}
#[inline]
fn fast_exp_scalar(x: f32) -> f32 {
let x_clamped = x.clamp(-88.0, 88.0);
let x2 = x_clamped * x_clamped;
let x3 = x2 * x_clamped;
let x4 = x2 * x2;
let x5 = x2 * x3;
1.0 + x_clamped + 0.5 * x2 + 0.16666667 * x3 + 0.04166667 * x4 + 0.00833333 * x5
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neon_availability() {
let _available = is_neon_available();
#[cfg(target_arch = "aarch64")]
assert!(_available);
}
#[test]
fn test_neon_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![2.0, 3.0, 4.0, 5.0, 6.0];
let result = neon_dot_product(&a, &b);
let expected = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0 + 5.0 * 6.0;
assert!((result - expected).abs() < 1e-5);
}
#[test]
fn test_neon_vec_add() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 4];
neon_vec_add(&a, &b, &mut c).expect("neon_vec_add failed");
assert_eq!(c, vec![6.0, 8.0, 10.0, 12.0]);
}
#[test]
fn test_neon_vec_mul() {
let a = vec![2.0, 3.0, 4.0, 5.0];
let b = vec![3.0, 4.0, 5.0, 6.0];
let mut c = vec![0.0; 4];
neon_vec_mul(&a, &b, &mut c).expect("neon_vec_mul failed");
assert_eq!(c, vec![6.0, 12.0, 20.0, 30.0]);
}
#[test]
fn test_neon_vec_fma() {
let a = vec![2.0, 3.0, 4.0, 5.0];
let b = vec![3.0, 4.0, 5.0, 6.0];
let mut c = vec![1.0, 1.0, 1.0, 1.0];
neon_vec_fma(&a, &b, &mut c).expect("neon_vec_fma failed");
assert_eq!(c, vec![7.0, 13.0, 21.0, 31.0]); }
#[test]
fn test_neon_matvec() {
let matrix = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![0.0; 2];
neon_matvec(&matrix, &x, &mut y, 2, 3).expect("neon_matvec failed");
assert_eq!(y[0], 1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0); assert_eq!(y[1], 4.0 * 1.0 + 5.0 * 2.0 + 6.0 * 3.0); }
#[test]
fn test_neon_relu() {
let x = vec![-2.0, -1.0, 0.0, 1.0, 2.0];
let mut y = vec![0.0; 5];
neon_relu(&x, &mut y).expect("neon_relu failed");
assert_eq!(y, vec![0.0, 0.0, 0.0, 1.0, 2.0]);
}
#[test]
fn test_neon_layer_norm() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let mut y = vec![0.0; 4];
neon_layer_norm(&x, &mut y, 1e-5).expect("neon_layer_norm failed");
let mean: f32 = y.iter().sum::<f32>() / y.len() as f32;
let variance: f32 = y.iter().map(|&yi| yi.powi(2)).sum::<f32>() / y.len() as f32;
assert!(mean.abs() < 1e-5);
assert!((variance - 1.0).abs() < 1e-3);
}
#[test]
fn test_neon_fast_exp() {
let x = vec![-1.0, 0.0, 1.0, 2.0];
let mut y = vec![0.0; 4];
neon_fast_exp(&x, &mut y).expect("neon_fast_exp failed");
assert!(y.iter().all(|&val| val > 0.0));
assert!((y[1] - 1.0).abs() < 0.01); }
#[test]
fn test_dimension_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let mut c = vec![0.0; 2];
assert!(neon_vec_add(&a, &b, &mut c).is_err());
}
#[test]
fn test_neon_softmax_sums_to_one() {
let mut x = vec![1.0f32, 2.0, 3.0, 4.0];
neon_softmax(&mut x);
let sum: f32 = x.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "softmax sum = {sum}");
}
#[test]
fn test_neon_softmax_monotone() {
let mut x = vec![1.0f32, 2.0, 3.0];
neon_softmax(&mut x);
assert!(x[0] < x[1] && x[1] < x[2]);
}
#[test]
fn test_neon_softmax_numerical_stability() {
let mut x = vec![100.0f32, 101.0, 102.0];
neon_softmax(&mut x);
for &v in &x {
assert!(v.is_finite(), "neon_softmax produced non-finite: {v}");
}
let sum: f32 = x.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_neon_rms_norm_basic() {
let x = vec![3.0f32, 4.0];
let mut out = vec![0.0f32; 2];
neon_rms_norm(&x, &mut out, 0.0).expect("rms_norm failed");
let rms = (12.5f32).sqrt();
assert!((out[0] - 3.0 / rms).abs() < 1e-5);
assert!((out[1] - 4.0 / rms).abs() < 1e-5);
}
#[test]
fn test_neon_rms_norm_dimension_mismatch() {
let x = vec![1.0f32, 2.0];
let mut out = vec![0.0f32; 3];
assert!(neon_rms_norm(&x, &mut out, 1e-5).is_err());
}
#[test]
fn test_neon_ssm_update_basic() {
let a_bar = vec![0.5f32, 0.5, 0.5, 0.5];
let mut h = vec![2.0f32, 4.0, 6.0, 8.0];
let b_bar = vec![1.0f32, 1.0, 1.0, 1.0];
neon_ssm_update(&a_bar, &mut h, &b_bar, 1.0).expect("ssm_update failed");
assert!((h[0] - 2.0).abs() < 1e-5); assert!((h[1] - 3.0).abs() < 1e-5); assert!((h[2] - 4.0).abs() < 1e-5); assert!((h[3] - 5.0).abs() < 1e-5); }
#[test]
fn test_neon_ssm_update_dimension_mismatch() {
let a_bar = vec![1.0f32, 2.0];
let mut h = vec![1.0f32, 2.0, 3.0];
let b_bar = vec![1.0f32, 2.0, 3.0];
assert!(neon_ssm_update(&a_bar, &mut h, &b_bar, 1.0).is_err());
}
}