#![cfg(target_arch = "aarch64")]
#![allow(dead_code)]
use crate::Quantization;
use std::arch::aarch64::*;
#[inline(always)]
fn affine_bias(q: Quantization) -> f32 {
-(q.zero_point as f32) * q.scale
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn dequant_i8_to_f32_neon(input: &[i8], q: Quantization, output: &mut [f32]) {
debug_assert_eq!(input.len(), output.len());
let scale = q.scale;
let bias_v = vdupq_n_f32(affine_bias(q));
let n = input.len();
let chunks_16 = n / 16;
let mut i = 0usize;
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for _ in 0..chunks_16 {
let v_i8 = vld1q_s8(in_ptr.add(i));
let lo_i16 = vmovl_s8(vget_low_s8(v_i8));
let hi_i16 = vmovl_high_s8(v_i8);
let q0 = vmovl_s16(vget_low_s16(lo_i16));
let q1 = vmovl_high_s16(lo_i16);
let q2 = vmovl_s16(vget_low_s16(hi_i16));
let q3 = vmovl_high_s16(hi_i16);
let f0 = vfmaq_n_f32(bias_v, vcvtq_f32_s32(q0), scale);
let f1 = vfmaq_n_f32(bias_v, vcvtq_f32_s32(q1), scale);
let f2 = vfmaq_n_f32(bias_v, vcvtq_f32_s32(q2), scale);
let f3 = vfmaq_n_f32(bias_v, vcvtq_f32_s32(q3), scale);
vst1q_f32(out_ptr.add(i), f0);
vst1q_f32(out_ptr.add(i + 4), f1);
vst1q_f32(out_ptr.add(i + 8), f2);
vst1q_f32(out_ptr.add(i + 12), f3);
i += 16;
}
let zp = q.zero_point as f32;
while i < n {
*out_ptr.add(i) = (*in_ptr.add(i) as f32 - zp) * scale;
i += 1;
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn dequant_u8_to_f32_neon(input: &[u8], q: Quantization, output: &mut [f32]) {
debug_assert_eq!(input.len(), output.len());
let scale = q.scale;
let bias_v = vdupq_n_f32(affine_bias(q));
let n = input.len();
let chunks_16 = n / 16;
let mut i = 0usize;
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for _ in 0..chunks_16 {
let v_u8 = vld1q_u8(in_ptr.add(i));
let lo_u16 = vmovl_u8(vget_low_u8(v_u8));
let hi_u16 = vmovl_high_u8(v_u8);
let q0 = vmovl_u16(vget_low_u16(lo_u16));
let q1 = vmovl_high_u16(lo_u16);
let q2 = vmovl_u16(vget_low_u16(hi_u16));
let q3 = vmovl_high_u16(hi_u16);
let f0 = vfmaq_n_f32(bias_v, vcvtq_f32_u32(q0), scale);
let f1 = vfmaq_n_f32(bias_v, vcvtq_f32_u32(q1), scale);
let f2 = vfmaq_n_f32(bias_v, vcvtq_f32_u32(q2), scale);
let f3 = vfmaq_n_f32(bias_v, vcvtq_f32_u32(q3), scale);
vst1q_f32(out_ptr.add(i), f0);
vst1q_f32(out_ptr.add(i + 4), f1);
vst1q_f32(out_ptr.add(i + 8), f2);
vst1q_f32(out_ptr.add(i + 12), f3);
i += 16;
}
let zp = q.zero_point as f32;
while i < n {
*out_ptr.add(i) = (*in_ptr.add(i) as f32 - zp) * scale;
i += 1;
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn dequant_i16_to_f32_neon(input: &[i16], q: Quantization, output: &mut [f32]) {
debug_assert_eq!(input.len(), output.len());
let scale = q.scale;
let bias_v = vdupq_n_f32(affine_bias(q));
let n = input.len();
let chunks_8 = n / 8;
let mut i = 0usize;
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for _ in 0..chunks_8 {
let v_i16 = vld1q_s16(in_ptr.add(i));
let lo_i32 = vmovl_s16(vget_low_s16(v_i16));
let hi_i32 = vmovl_high_s16(v_i16);
let f0 = vfmaq_n_f32(bias_v, vcvtq_f32_s32(lo_i32), scale);
let f1 = vfmaq_n_f32(bias_v, vcvtq_f32_s32(hi_i32), scale);
vst1q_f32(out_ptr.add(i), f0);
vst1q_f32(out_ptr.add(i + 4), f1);
i += 8;
}
let zp = q.zero_point as f32;
while i < n {
*out_ptr.add(i) = (*in_ptr.add(i) as f32 - zp) * scale;
i += 1;
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn dequant_u16_to_f32_neon(input: &[u16], q: Quantization, output: &mut [f32]) {
debug_assert_eq!(input.len(), output.len());
let scale = q.scale;
let bias_v = vdupq_n_f32(affine_bias(q));
let n = input.len();
let chunks_8 = n / 8;
let mut i = 0usize;
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for _ in 0..chunks_8 {
let v_u16 = vld1q_u16(in_ptr.add(i));
let lo_u32 = vmovl_u16(vget_low_u16(v_u16));
let hi_u32 = vmovl_high_u16(v_u16);
let f0 = vfmaq_n_f32(bias_v, vcvtq_f32_u32(lo_u32), scale);
let f1 = vfmaq_n_f32(bias_v, vcvtq_f32_u32(hi_u32), scale);
vst1q_f32(out_ptr.add(i), f0);
vst1q_f32(out_ptr.add(i + 4), f1);
i += 8;
}
let zp = q.zero_point as f32;
while i < n {
*out_ptr.add(i) = (*in_ptr.add(i) as f32 - zp) * scale;
i += 1;
}
}
#[allow(clippy::excessive_precision)]
#[inline]
#[target_feature(enable = "neon")]
pub(crate) unsafe fn expf_neon_f32x4(x: float32x4_t) -> float32x4_t {
let exp_hi = vdupq_n_f32(88.376_26);
let exp_lo = vdupq_n_f32(-88.376_26);
let log2_e = vdupq_n_f32(core::f32::consts::LOG2_E);
let ln2_hi = vdupq_n_f32(0.693_359_375);
let ln2_lo = vdupq_n_f32(-2.121_944_400e-4);
let one = vdupq_n_f32(1.0);
let p0 = vdupq_n_f32(1.987_569_150e-4);
let p1 = vdupq_n_f32(1.398_199_950e-3);
let p2 = vdupq_n_f32(8.333_451_900e-3);
let p3 = vdupq_n_f32(4.166_579_590e-2);
let p4 = vdupq_n_f32(1.666_666_550e-1);
let p5 = vdupq_n_f32(5.000_000_120e-1);
let x = vminq_f32(vmaxq_f32(x, exp_lo), exp_hi);
let fx_f = vmulq_f32(x, log2_e);
let fx_i = vcvtnq_s32_f32(fx_f);
let fx = vcvtq_f32_s32(fx_i);
let z = vfmsq_f32(x, fx, ln2_hi);
let r = vfmsq_f32(z, fx, ln2_lo);
let mut y = vfmaq_f32(p1, p0, r);
y = vfmaq_f32(p2, y, r);
y = vfmaq_f32(p3, y, r);
y = vfmaq_f32(p4, y, r);
y = vfmaq_f32(p5, y, r);
let r2 = vmulq_f32(r, r);
let y_r2 = vmulq_f32(y, r2);
let exp_r = vaddq_f32(vaddq_f32(y_r2, r), one);
let bias = vdupq_n_s32(127);
let pow2k_bits = vshlq_n_s32::<23>(vaddq_s32(fx_i, bias));
let pow2k = vreinterpretq_f32_s32(pow2k_bits);
vmulq_f32(exp_r, pow2k)
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn sigmoid_slice_f32_neon(buf: &mut [f32]) {
let n = buf.len();
let chunks_4 = n / 4;
let mut i = 0usize;
let zero = vdupq_n_f32(0.0);
let one = vdupq_n_f32(1.0);
let ptr = buf.as_mut_ptr();
for _ in 0..chunks_4 {
let x = vld1q_f32(ptr.add(i));
let mask = vcgeq_f32(x, zero);
let neg_x = vnegq_f32(x);
let exp_in = vbslq_f32(mask, neg_x, x);
let e = expf_neon_f32x4(exp_in);
let one_plus_e = vaddq_f32(one, e);
let recip = vdivq_f32(one, one_plus_e);
let pos_branch = recip; let neg_branch = vmulq_f32(e, recip); let r = vbslq_f32(mask, pos_branch, neg_branch);
vst1q_f32(ptr.add(i), r);
i += 4;
}
while i < n {
let x = *ptr.add(i);
*ptr.add(i) = if x >= 0.0 {
let e = (-x).exp();
1.0 / (1.0 + e)
} else {
let e = x.exp();
e / (1.0 + e)
};
i += 1;
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn softmax_inplace_f32_neon(buf: &mut [f32]) {
if buf.is_empty() {
return;
}
let n = buf.len();
let chunks_4 = n / 4;
let ptr = buf.as_mut_ptr();
let mut m = f32::NEG_INFINITY;
if chunks_4 > 0 {
let mut max_v = vld1q_f32(ptr);
let mut i = 4;
for _ in 1..chunks_4 {
let v = vld1q_f32(ptr.add(i));
max_v = vmaxq_f32(max_v, v);
i += 4;
}
m = vmaxvq_f32(max_v);
}
{
let mut i = chunks_4 * 4;
while i < n {
let v = *ptr.add(i);
if v > m {
m = v;
}
i += 1;
}
}
let m_v = vdupq_n_f32(m);
let mut sum_v = vdupq_n_f32(0.0);
{
let mut i = 0;
for _ in 0..chunks_4 {
let v = vld1q_f32(ptr.add(i));
let s = vsubq_f32(v, m_v);
let e = expf_neon_f32x4(s);
sum_v = vaddq_f32(sum_v, e);
vst1q_f32(ptr.add(i), e);
i += 4;
}
}
let mut sum = vaddvq_f32(sum_v);
{
let mut i = chunks_4 * 4;
while i < n {
let e = (*ptr.add(i) - m).exp();
*ptr.add(i) = e;
sum += e;
i += 1;
}
}
if sum > 0.0 {
let inv = 1.0 / sum;
let inv_v = vdupq_n_f32(inv);
let mut i = 0;
for _ in 0..chunks_4 {
let v = vld1q_f32(ptr.add(i));
let r = vmulq_f32(v, inv_v);
vst1q_f32(ptr.add(i), r);
i += 4;
}
let mut i = chunks_4 * 4;
while i < n {
*ptr.add(i) *= inv;
i += 1;
}
}
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn fmla_f16x8(acc: uint16x8_t, a: uint16x8_t, b: uint16x8_t) -> uint16x8_t {
let result: uint16x8_t;
core::arch::asm!(
".arch_extension fp16",
"fmla {acc:v}.8h, {a:v}.8h, {b:v}.8h",
acc = inout(vreg) acc => result,
a = in(vreg) a,
b = in(vreg) b,
options(pure, nomem, nostack),
);
result
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn fmls_f16x8(acc: uint16x8_t, a: uint16x8_t, b: uint16x8_t) -> uint16x8_t {
let result: uint16x8_t;
core::arch::asm!(
".arch_extension fp16",
"fmls {acc:v}.8h, {a:v}.8h, {b:v}.8h",
acc = inout(vreg) acc => result,
a = in(vreg) a,
b = in(vreg) b,
options(pure, nomem, nostack),
);
result
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn fmul_f16x8(a: uint16x8_t, b: uint16x8_t) -> uint16x8_t {
let result: uint16x8_t;
core::arch::asm!(
".arch_extension fp16",
"fmul {r:v}.8h, {a:v}.8h, {b:v}.8h",
r = out(vreg) result,
a = in(vreg) a,
b = in(vreg) b,
options(pure, nomem, nostack),
);
result
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn fadd_f16x8(a: uint16x8_t, b: uint16x8_t) -> uint16x8_t {
let result: uint16x8_t;
core::arch::asm!(
".arch_extension fp16",
"fadd {r:v}.8h, {a:v}.8h, {b:v}.8h",
r = out(vreg) result,
a = in(vreg) a,
b = in(vreg) b,
options(pure, nomem, nostack),
);
result
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn fcvtns_s16_from_f16x8(x: uint16x8_t) -> int16x8_t {
let result: int16x8_t;
core::arch::asm!(
".arch_extension fp16",
"fcvtns {r:v}.8h, {x:v}.8h",
r = out(vreg) result,
x = in(vreg) x,
options(pure, nomem, nostack),
);
result
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn scvtf_f16_from_s16x8(x: int16x8_t) -> uint16x8_t {
let result: uint16x8_t;
core::arch::asm!(
".arch_extension fp16",
"scvtf {r:v}.8h, {x:v}.8h",
r = out(vreg) result,
x = in(vreg) x,
options(pure, nomem, nostack),
);
result
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn pack_f16x8_from_f32x4_pair(lo: float32x4_t, hi: float32x4_t) -> uint16x8_t {
let result: uint16x8_t;
core::arch::asm!(
".arch_extension fp16",
"fcvtn {r:v}.4h, {lo:v}.4s",
"fcvtn2 {r:v}.8h, {hi:v}.4s",
r = out(vreg) result,
lo = in(vreg) lo,
hi = in(vreg) hi,
options(pure, nomem, nostack),
);
result
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn unpack_f16x8_to_f32x4_pair(p: uint16x8_t) -> (float32x4_t, float32x4_t) {
let lo: float32x4_t;
let hi: float32x4_t;
core::arch::asm!(
".arch_extension fp16",
"fcvtl {lo:v}.4s, {p:v}.4h",
"fcvtl2 {hi:v}.4s, {p:v}.8h",
lo = out(vreg) lo,
hi = out(vreg) hi,
p = in(vreg) p,
options(pure, nomem, nostack),
);
(lo, hi)
}
#[inline]
#[target_feature(enable = "neon")]
pub(crate) unsafe fn expf_neon_f32x8_via_f16(
lo: float32x4_t,
hi: float32x4_t,
) -> (float32x4_t, float32x4_t) {
const LOG2_E: u16 = 0x3DC5; const LN2_HI: u16 = 0x398C; const LN2_LO: u16 = 0x8AF4; const ONE: u16 = 0x3C00;
const P0: u16 = 0x0A83; const P1: u16 = 0x15BA; const P2: u16 = 0x2044; const P3: u16 = 0x2955; const P4: u16 = 0x3155; const P5: u16 = 0x3800;
let max_v = vdupq_n_f32(10.0);
let min_v = vdupq_n_f32(-10.0);
let lo_c = vminq_f32(vmaxq_f32(lo, min_v), max_v);
let hi_c = vminq_f32(vmaxq_f32(hi, min_v), max_v);
let x = pack_f16x8_from_f32x4_pair(lo_c, hi_c);
let log2_e = vdupq_n_u16(LOG2_E);
let ln2_hi = vdupq_n_u16(LN2_HI);
let ln2_lo = vdupq_n_u16(LN2_LO);
let one = vdupq_n_u16(ONE);
let p0 = vdupq_n_u16(P0);
let p1 = vdupq_n_u16(P1);
let p2 = vdupq_n_u16(P2);
let p3 = vdupq_n_u16(P3);
let p4 = vdupq_n_u16(P4);
let p5 = vdupq_n_u16(P5);
let fx_f = fmul_f16x8(x, log2_e);
let k_int = fcvtns_s16_from_f16x8(fx_f);
let k_f = scvtf_f16_from_s16x8(k_int);
let z = fmls_f16x8(x, k_f, ln2_hi);
let r = fmls_f16x8(z, k_f, ln2_lo);
let mut y = fmla_f16x8(p1, p0, r);
y = fmla_f16x8(p2, y, r);
y = fmla_f16x8(p3, y, r);
y = fmla_f16x8(p4, y, r);
y = fmla_f16x8(p5, y, r);
let r2 = fmul_f16x8(r, r);
let exp_r = fadd_f16x8(fadd_f16x8(fmul_f16x8(y, r2), r), one);
let bias = vdupq_n_s16(15);
let pow2k_bits = vshlq_n_s16::<10>(vaddq_s16(k_int, bias));
let pow2k = vreinterpretq_u16_s16(pow2k_bits);
let result = fmul_f16x8(exp_r, pow2k);
unpack_f16x8_to_f32x4_pair(result)
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn sigmoid_slice_f32_neon_fp16(buf: &mut [f32]) {
let n = buf.len();
let chunks_8 = n / 8;
let one_v = vdupq_n_f32(1.0);
let zero = vdupq_n_f32(0.0);
let ptr = buf.as_mut_ptr();
for c in 0..chunks_8 {
let i = c * 8;
let lo = vld1q_f32(ptr.add(i));
let hi = vld1q_f32(ptr.add(i + 4));
let lo_mask = vcgeq_f32(lo, zero);
let hi_mask = vcgeq_f32(hi, zero);
let lo_in = vbslq_f32(lo_mask, vnegq_f32(lo), lo);
let hi_in = vbslq_f32(hi_mask, vnegq_f32(hi), hi);
let (e_lo, e_hi) = expf_neon_f32x8_via_f16(lo_in, hi_in);
let recip_lo = vdivq_f32(one_v, vaddq_f32(one_v, e_lo));
let recip_hi = vdivq_f32(one_v, vaddq_f32(one_v, e_hi));
let neg_lo = vmulq_f32(e_lo, recip_lo);
let neg_hi = vmulq_f32(e_hi, recip_hi);
let r_lo = vbslq_f32(lo_mask, recip_lo, neg_lo);
let r_hi = vbslq_f32(hi_mask, recip_hi, neg_hi);
vst1q_f32(ptr.add(i), r_lo);
vst1q_f32(ptr.add(i + 4), r_hi);
}
let tail_start = chunks_8 * 8;
for i in tail_start..n {
let x = *ptr.add(i);
let r = if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let e = x.exp();
e / (1.0 + e)
};
*ptr.add(i) = r;
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn softmax_inplace_f32_neon_fp16(buf: &mut [f32]) {
if buf.is_empty() {
return;
}
let n = buf.len();
let chunks_8 = n / 8;
let ptr = buf.as_mut_ptr();
let chunks_4 = n / 4;
let mut m = f32::NEG_INFINITY;
if chunks_4 > 0 {
let mut max_v = vld1q_f32(ptr);
let mut i = 4;
for _ in 1..chunks_4 {
let v = vld1q_f32(ptr.add(i));
max_v = vmaxq_f32(max_v, v);
i += 4;
}
m = vmaxvq_f32(max_v);
}
{
let mut i = chunks_4 * 4;
while i < n {
let v = *ptr.add(i);
if v > m {
m = v;
}
i += 1;
}
}
let m_v = vdupq_n_f32(m);
let mut sum_v = vdupq_n_f32(0.0);
for c in 0..chunks_8 {
let i = c * 8;
let lo = vld1q_f32(ptr.add(i));
let hi = vld1q_f32(ptr.add(i + 4));
let s_lo = vsubq_f32(lo, m_v);
let s_hi = vsubq_f32(hi, m_v);
let (e_lo, e_hi) = expf_neon_f32x8_via_f16(s_lo, s_hi);
sum_v = vaddq_f32(sum_v, e_lo);
sum_v = vaddq_f32(sum_v, e_hi);
vst1q_f32(ptr.add(i), e_lo);
vst1q_f32(ptr.add(i + 4), e_hi);
}
let mut sum = vaddvq_f32(sum_v);
let tail_start = chunks_8 * 8;
for i in tail_start..n {
let e = (*ptr.add(i) - m).exp();
*ptr.add(i) = e;
sum += e;
}
if sum > 0.0 {
let inv = 1.0 / sum;
let inv_v = vdupq_n_f32(inv);
for c in 0..chunks_8 {
let i = c * 8;
let lo = vld1q_f32(ptr.add(i));
let hi = vld1q_f32(ptr.add(i + 4));
vst1q_f32(ptr.add(i), vmulq_f32(lo, inv_v));
vst1q_f32(ptr.add(i + 4), vmulq_f32(hi, inv_v));
}
for i in tail_start..n {
*ptr.add(i) *= inv;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::per_scale::kernels::dequant::{
dequant_i16_to_f32, dequant_i8_to_f32, dequant_u16_to_f32, dequant_u8_to_f32,
};
use crate::per_scale::kernels::sigmoid::sigmoid_slice_f32;
use crate::per_scale::kernels::softmax::softmax_inplace_f32;
fn close_within_ulp(a: f32, b: f32, ulps: u32) -> bool {
let diff = (a - b).abs();
if diff < 1e-7 {
return true;
}
let scale_max = a.abs().max(b.abs());
diff < f32::EPSILON * scale_max * (ulps as f32)
}
#[test]
fn dequant_i8_to_f32_neon_matches_scalar() {
let input: Vec<i8> = (-100..100).map(|i| i as i8).collect();
let q = Quantization {
scale: 0.0326_f32,
zero_point: -42,
};
let mut neon_out = vec![0f32; input.len()];
let mut scalar_out = vec![0f32; input.len()];
unsafe {
dequant_i8_to_f32_neon(&input, q, &mut neon_out);
}
dequant_i8_to_f32(&input, q, &mut scalar_out);
for (i, (&n_v, &s_v)) in neon_out.iter().zip(scalar_out.iter()).enumerate() {
assert!(
close_within_ulp(n_v, s_v, 2),
"i8 dequant NEON/scalar mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
}
#[test]
fn dequant_u8_to_f32_neon_matches_scalar() {
let input: Vec<u8> = (0..240).map(|i| i as u8).collect();
let q = Quantization {
scale: 0.00392_f32,
zero_point: 128,
};
let mut neon_out = vec![0f32; input.len()];
let mut scalar_out = vec![0f32; input.len()];
unsafe {
dequant_u8_to_f32_neon(&input, q, &mut neon_out);
}
dequant_u8_to_f32(&input, q, &mut scalar_out);
for (i, (&n_v, &s_v)) in neon_out.iter().zip(scalar_out.iter()).enumerate() {
assert!(
close_within_ulp(n_v, s_v, 2),
"u8 dequant NEON/scalar mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
}
#[test]
fn dequant_i16_to_f32_neon_matches_scalar() {
let input: Vec<i16> = (-1000..1000).step_by(7).collect();
let q = Quantization {
scale: 0.0001_f32,
zero_point: 0,
};
let mut neon_out = vec![0f32; input.len()];
let mut scalar_out = vec![0f32; input.len()];
unsafe {
dequant_i16_to_f32_neon(&input, q, &mut neon_out);
}
dequant_i16_to_f32(&input, q, &mut scalar_out);
for (i, (&n_v, &s_v)) in neon_out.iter().zip(scalar_out.iter()).enumerate() {
assert!(
close_within_ulp(n_v, s_v, 2),
"i16 dequant NEON/scalar mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
}
#[test]
fn dequant_u16_to_f32_neon_matches_scalar() {
let input: Vec<u16> = (0..2000).step_by(11).collect();
let q = Quantization {
scale: 0.0001_f32,
zero_point: 1024,
};
let mut neon_out = vec![0f32; input.len()];
let mut scalar_out = vec![0f32; input.len()];
unsafe {
dequant_u16_to_f32_neon(&input, q, &mut neon_out);
}
dequant_u16_to_f32(&input, q, &mut scalar_out);
for (i, (&n_v, &s_v)) in neon_out.iter().zip(scalar_out.iter()).enumerate() {
assert!(
close_within_ulp(n_v, s_v, 2),
"u16 dequant NEON/scalar mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
}
#[test]
fn dequant_i8_to_f32_neon_handles_short_input_under_chunk_size() {
let input: [i8; 12] = [-128, -64, -32, -1, 0, 1, 32, 64, 127, 50, -50, 25];
let q = Quantization {
scale: 0.5_f32,
zero_point: 0,
};
let mut neon_out = [0f32; 12];
let mut scalar_out = [0f32; 12];
unsafe {
dequant_i8_to_f32_neon(&input, q, &mut neon_out);
}
dequant_i8_to_f32(&input, q, &mut scalar_out);
assert_eq!(neon_out, scalar_out, "tail-only path must be exact");
}
#[test]
fn expf_neon_f32x4_matches_libm() {
let cases: Vec<f32> = (-100..=100).map(|i| i as f32 * 0.5).collect();
for chunk in cases.chunks(4) {
if chunk.len() < 4 {
continue;
}
let arr = [chunk[0], chunk[1], chunk[2], chunk[3]];
let neon = unsafe {
let v = vld1q_f32(arr.as_ptr());
let r = expf_neon_f32x4(v);
let mut out = [0f32; 4];
vst1q_f32(out.as_mut_ptr(), r);
out
};
for (i, &x) in arr.iter().enumerate() {
let oracle = x.exp();
assert!(
close_within_ulp(neon[i], oracle, 8),
"expf NEON/libm mismatch at x={x}: neon={} libm={oracle}",
neon[i]
);
}
}
}
#[test]
fn sigmoid_slice_f32_neon_matches_scalar() {
let cases: Vec<f32> = (-50..=50).map(|i| i as f32 * 0.5).collect();
let mut neon_buf = cases.clone();
let mut scalar_buf = cases;
unsafe {
sigmoid_slice_f32_neon(&mut neon_buf);
}
sigmoid_slice_f32(&mut scalar_buf);
for (i, (&n_v, &s_v)) in neon_buf.iter().zip(scalar_buf.iter()).enumerate() {
let abs_diff = (n_v - s_v).abs();
assert!(
abs_diff < 1e-6 || close_within_ulp(n_v, s_v, 16),
"sigmoid NEON/scalar mismatch at {i}: neon={n_v} scalar={s_v} diff={abs_diff}"
);
}
}
#[test]
fn sigmoid_slice_f32_neon_short_input_under_chunk_size() {
let mut neon_buf = [-2.0_f32, -1.0, 1.5];
let mut scalar_buf = neon_buf;
unsafe {
sigmoid_slice_f32_neon(&mut neon_buf);
}
sigmoid_slice_f32(&mut scalar_buf);
for (i, (&n_v, &s_v)) in neon_buf.iter().zip(scalar_buf.iter()).enumerate() {
assert!(
close_within_ulp(n_v, s_v, 1),
"sigmoid tail-only mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
}
#[test]
fn sigmoid_slice_f32_neon_zero_is_half() {
let mut buf = [0.0_f32; 4];
unsafe {
sigmoid_slice_f32_neon(&mut buf);
}
for &v in &buf {
assert!((v - 0.5).abs() < 1e-7, "sigmoid(0) = {v}, expected 0.5");
}
}
fn softmax_close(a: f32, b: f32) -> bool {
let abs = (a - b).abs();
abs < 1e-5 || close_within_ulp(a, b, 32)
}
#[test]
fn softmax_neon_matches_scalar_reg_max_16() {
let cases: Vec<f32> = (0..16).map(|i| (i as f32) * 0.5 - 4.0).collect();
let mut neon_buf = cases.clone();
let mut scalar_buf = cases;
unsafe {
softmax_inplace_f32_neon(&mut neon_buf);
}
softmax_inplace_f32(&mut scalar_buf);
for (i, (&n_v, &s_v)) in neon_buf.iter().zip(scalar_buf.iter()).enumerate() {
assert!(
softmax_close(n_v, s_v),
"softmax NEON/scalar mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
let sum: f32 = neon_buf.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "softmax sum != 1.0: {sum}");
}
#[test]
fn softmax_neon_matches_scalar_with_tail() {
let cases: Vec<f32> = (0..18).map(|i| i as f32 * 0.3).collect();
let mut neon_buf = cases.clone();
let mut scalar_buf = cases;
unsafe {
softmax_inplace_f32_neon(&mut neon_buf);
}
softmax_inplace_f32(&mut scalar_buf);
for (i, (&n_v, &s_v)) in neon_buf.iter().zip(scalar_buf.iter()).enumerate() {
assert!(
softmax_close(n_v, s_v),
"softmax tail mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
}
#[test]
fn softmax_neon_overflow_safety() {
let mut neon_buf = [1000.0_f32, 1001.0, 1002.0, 1003.0];
let mut scalar_buf = neon_buf;
unsafe {
softmax_inplace_f32_neon(&mut neon_buf);
}
softmax_inplace_f32(&mut scalar_buf);
for (i, (&n_v, &s_v)) in neon_buf.iter().zip(scalar_buf.iter()).enumerate() {
assert!(
softmax_close(n_v, s_v),
"softmax overflow test mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
let sum: f32 = neon_buf.iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn softmax_neon_empty_is_noop() {
let mut buf: [f32; 0] = [];
unsafe {
softmax_inplace_f32_neon(&mut buf);
}
}
#[test]
fn dequant_i8_to_f32_neon_handles_exact_chunk_boundary() {
let input: Vec<i8> = (-32..32).map(|i| i as i8).collect();
let q = Quantization {
scale: 0.1_f32,
zero_point: -10,
};
let mut neon_out = vec![0f32; 64];
let mut scalar_out = vec![0f32; 64];
unsafe {
dequant_i8_to_f32_neon(&input, q, &mut neon_out);
}
dequant_i8_to_f32(&input, q, &mut scalar_out);
for (i, (&n_v, &s_v)) in neon_out.iter().zip(scalar_out.iter()).enumerate() {
assert!(
close_within_ulp(n_v, s_v, 2),
"exact-chunk boundary mismatch at {i}: neon={n_v} scalar={s_v}"
);
}
}
fn fp16_supported() -> bool {
std::arch::is_aarch64_feature_detected!("fp16")
}
#[test]
fn expf_neon_f32x8_via_f16_matches_libm_relative() {
if !fp16_supported() {
eprintln!("fp16 not supported on this CPU; skipping FP16 expf parity test");
return;
}
let cases: Vec<f32> = (-30..=30).map(|i| i as f32 * 0.5).collect();
for chunk in cases.chunks(8) {
if chunk.len() < 8 {
continue;
}
let lo = unsafe { vld1q_f32(chunk.as_ptr()) };
let hi = unsafe { vld1q_f32(chunk.as_ptr().add(4)) };
let (out_lo, out_hi) = unsafe { expf_neon_f32x8_via_f16(lo, hi) };
let mut out = [0f32; 8];
unsafe {
vst1q_f32(out.as_mut_ptr(), out_lo);
vst1q_f32(out.as_mut_ptr().add(4), out_hi);
}
for (i, &x) in chunk.iter().enumerate() {
let oracle = x.clamp(-10.0, 10.0).exp();
let abs_err = (out[i] - oracle).abs();
let rel_err = abs_err / oracle.max(1e-6);
assert!(
rel_err < 0.02 || abs_err < 1e-4,
"f16 expf mismatch at x={x}: got={} oracle={oracle} rel_err={rel_err:.4}",
out[i]
);
}
}
}
#[test]
fn sigmoid_slice_f32_neon_fp16_matches_scalar() {
if !fp16_supported() {
eprintln!("fp16 not supported; skipping");
return;
}
let cases: Vec<f32> = (-16..16).map(|i| i as f32 * 0.5).collect();
let mut neon_buf = cases.clone();
let mut scalar_buf = cases;
unsafe {
sigmoid_slice_f32_neon_fp16(&mut neon_buf);
}
sigmoid_slice_f32(&mut scalar_buf);
for (i, (&n_v, &s_v)) in neon_buf.iter().zip(scalar_buf.iter()).enumerate() {
let abs_err = (n_v - s_v).abs();
assert!(
abs_err < 5e-3,
"fp16 sigmoid mismatch at {i}: neon={n_v} scalar={s_v} err={abs_err:.5}"
);
}
}
#[test]
fn softmax_neon_fp16_matches_scalar() {
if !fp16_supported() {
eprintln!("fp16 not supported; skipping");
return;
}
let cases: Vec<f32> = (0..16).map(|i| (i as f32) * 0.5 - 4.0).collect();
let mut neon_buf = cases.clone();
let mut scalar_buf = cases;
unsafe {
softmax_inplace_f32_neon_fp16(&mut neon_buf);
}
softmax_inplace_f32(&mut scalar_buf);
for (i, (&n_v, &s_v)) in neon_buf.iter().zip(scalar_buf.iter()).enumerate() {
let abs_err = (n_v - s_v).abs();
assert!(
abs_err < 5e-3,
"fp16 softmax mismatch at {i}: neon={n_v} scalar={s_v} err={abs_err:.5}"
);
}
let sum: f32 = neon_buf.iter().sum();
assert!((sum - 1.0).abs() < 1e-3, "fp16 softmax sum drift: {sum}");
}
}