use core::arch::aarch64::*;
pub(crate) union UnionCast128 {
pub f32x4: [f32; 4],
pub u32x4: [u32; 4],
pub float32x4: float32x4_t,
pub uint32x4: uint32x4_t,
}
#[inline]
pub const fn float32x4_from_f32x4(f32x4: [f32; 4]) -> float32x4_t {
unsafe { UnionCast128 { f32x4 }.float32x4 }
}
#[inline]
pub const fn f32x4_from_float32x4(float32x4: float32x4_t) -> [f32; 4] {
unsafe { UnionCast128 { float32x4 }.f32x4 }
}
#[inline]
pub const fn uint32x4_from_u32x4(u32x4: [u32; 4]) -> uint32x4_t {
unsafe { UnionCast128 { u32x4 }.uint32x4 }
}
#[inline]
pub(crate) unsafe fn set_fourth_with_third(v: float32x4_t) -> float32x4_t {
vcopyq_laneq_f32(v, 3, v, 2)
}
#[inline]
pub(crate) unsafe fn dot3(lhs: float32x4_t, rhs: float32x4_t) -> f32 {
let v_temp = vmulq_f32(lhs, rhs);
let v_temp = vsetq_lane_f32(0_f32, v_temp, 3);
vaddvq_f32(v_temp)
}
#[inline]
pub(crate) unsafe fn dot4(lhs: float32x4_t, rhs: float32x4_t) -> f32 {
let v_temp = vmulq_f32(lhs, rhs);
vaddvq_f32(v_temp)
}
#[inline]
pub(crate) unsafe fn dot3_into_float32x4(lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {
let v_temp = vmulq_f32(lhs, rhs);
let v1 = vget_low_f32(v_temp);
let v2 = vget_high_f32(v_temp);
let v1 = vpadd_f32(v1, v1);
let v2 = vdup_lane_f32(v2, 0);
let v1 = vadd_f32(v1, v2);
vcombine_f32(v1, v1)
}
#[inline]
pub(crate) unsafe fn dot4_into_float32x4(lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {
let v_temp = vmulq_f32(lhs, rhs);
let v1 = vget_low_f32(v_temp);
let v2 = vget_high_f32(v_temp);
let v1 = vadd_f32(v1, v2);
let v1 = vpadd_f32(v1, v1);
vcombine_f32(v1, v1)
}
#[inline]
pub(crate) unsafe fn float32x4_sin(val: float32x4_t) -> float32x4_t {
const TE_SIN_COEFF2: f32 = 0.166_666_67_f32; const TE_SIN_COEFF3: f32 = 0.05_f32; const TE_SIN_COEFF4: f32 = 0.023_809_524_f32; const TE_SIN_COEFF5: f32 = 0.013_888_889_f32;
const PI_V: float32x4_t = float32x4_from_f32x4([core::f32::consts::PI; 4]);
const PIO2_V: float32x4_t = float32x4_from_f32x4([core::f32::consts::FRAC_PI_2; 4]);
const IPI_V: float32x4_t = float32x4_from_f32x4([core::f32::consts::FRAC_1_PI; 4]);
let c_v = vabsq_s32(vcvtq_s32_f32(vmulq_f32(val, IPI_V)));
let sign_v = vcleq_f32(val, vdupq_n_f32(0f32));
let odd_v = vandq_u32(vreinterpretq_u32_s32(c_v), vdupq_n_u32(1));
let neg_v = veorq_u32(odd_v, sign_v);
let ma = vsubq_f32(vabsq_f32(val), vmulq_f32(PI_V, vcvtq_f32_s32(c_v)));
let reb_v = vcgeq_f32(ma, PIO2_V);
let ma = vbslq_f32(reb_v, vsubq_f32(PI_V, ma), ma);
let ma2 = vmulq_f32(ma, ma);
let elem = vmulq_f32(vmulq_f32(ma, ma2), vdupq_n_f32(TE_SIN_COEFF2));
let res = vsubq_f32(ma, elem);
let elem = vmulq_f32(vmulq_f32(elem, ma2), vdupq_n_f32(TE_SIN_COEFF3));
let res = vaddq_f32(res, elem);
let elem = vmulq_f32(vmulq_f32(elem, ma2), vdupq_n_f32(TE_SIN_COEFF4));
let res = vsubq_f32(res, elem);
let elem = vmulq_f32(vmulq_f32(elem, ma2), vdupq_n_f32(TE_SIN_COEFF5));
let res = vaddq_f32(res, elem);
let neg_v = vshlq_n_u32(neg_v, 31);
vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(res), neg_v))
}
macro_rules! swizzle_f32x4_t {
(0, 1, 2, $v: expr) => {
$v
};
(0, 1, 0, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_low_f32($v);
vcombine_f32(v, v)
}
};
(2, 3, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_high_f32($v);
vcombine_f32(v, v)
}
};
(1, 2, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vextq_f32($v, $v, 1)
}
};
(2, 3, 0, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vextq_f32($v, $v, 2)
}
};
(3, 0, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vextq_f32($v, $v, 3)
}
};
(0, 0, 0, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_low_f32($v);
vdupq_lane_f32(v, 0)
}
};
(1, 1, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_low_f32($v);
vdupq_lane_f32(v, 1)
}
};
(2, 2, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_high_f32($v);
vdupq_lane_f32(v, 0)
}
};
(3, 3, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_high_f32($v);
vdupq_lane_f32(v, 1)
}
};
(0, 0, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vzipq_f32($v, $v).0
}
};
(2, 2, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vzipq_f32($v, $v).1
}
};
(0, 2, 0, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vuzpq_f32($v, $v).0
}
};
(1, 3, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vuzpq_f32($v, $v).1
}
};
(1, 0, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vrev64q_f32($v)
}
};
(0, 0, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vtrnq_f32($v, $v).0
}
};
(1, 1, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vtrnq_f32($v, $v).1
}
};
($e0: expr, $e1: expr, $e2: expr, $v: expr) => {
crate::neon::swizzle_f32x4_t!($e0, $e1, $e2, $e2, $v)
};
(0, 1, 2, 3, $v: expr) => {
$v
};
(0, 1, 0, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_low_f32($v);
vcombine_f32(v, v)
}
};
(2, 3, 2, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_high_f32($v);
vcombine_f32(v, v)
}
};
(1, 2, 3, 0, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vextq_f32($v, $v, 1)
}
};
(2, 3, 0, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vextq_f32($v, $v, 2)
}
};
(3, 0, 1, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vextq_f32($v, $v, 3)
}
};
(0, 0, 0, 0, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_low_f32($v);
vdupq_lane_f32(v, 0)
}
};
(1, 1, 1, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_low_f32($v);
vdupq_lane_f32(v, 1)
}
};
(2, 2, 2, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_high_f32($v);
vdupq_lane_f32(v, 0)
}
};
(3, 3, 3, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
let v = vget_high_f32($v);
vdupq_lane_f32(v, 1)
}
};
(0, 0, 1, 1, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vzipq_f32($v, $v).0
}
};
(2, 2, 3, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vzipq_f32($v, $v).1
}
};
(0, 2, 0, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vuzpq_f32($v, $v).0
}
};
(1, 3, 1, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vuzpq_f32($v, $v).1
}
};
(1, 0, 3, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vrev64q_f32($v)
}
};
(0, 0, 2, 2, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vtrnq_f32($v, $v).0
}
};
(1, 1, 3, 3, $v: expr) => {
unsafe {
use core::arch::aarch64::*;
vtrnq_f32($v, $v).1
}
};
($e0: expr, $e1: expr, $e2: expr, $e3: expr, $v: expr) => {{
let arr = crate::neon::f32x4_from_float32x4($v);
crate::neon::float32x4_from_f32x4([arr[$e0], arr[$e1], arr[$e2], arr[$e3]])
}};
}
pub(crate) use swizzle_f32x4_t;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_fourth_with_third() {
let v0 = float32x4_from_f32x4([1., 2., 3., 4.]);
let v1 = unsafe { set_fourth_with_third(v0) };
let a1 = f32x4_from_float32x4(v1);
for (a, b) in a1.iter().zip([1., 2., 3., 3.].iter()) {
assert!(a == b);
}
}
#[test]
fn test_dot3() {
let v0 = float32x4_from_f32x4([1., 2., 3., 4.]);
let v1 = float32x4_from_f32x4([5., 6., 7., 8.]);
assert!(unsafe { dot3(v0, v1) == 38. });
assert!(unsafe { dot3(v1, v0) == 38. });
}
#[test]
fn test_dot4() {
let v0 = float32x4_from_f32x4([1., 2., 3., 4.]);
let v1 = float32x4_from_f32x4([5., 6., 7., 8.]);
assert!(unsafe { dot4(v0, v1) == 70. });
assert!(unsafe { dot4(v1, v0) == 70. });
}
#[test]
fn test_neon_swizzle() {
let arr = [1., 2., 3., 4.];
let v0 = float32x4_from_f32x4(arr);
let v1 = swizzle_f32x4_t!(3, 2, 1, 0, v0);
let v1_arr = f32x4_from_float32x4(v1);
for i in 0..4 {
assert!(v1_arr[i] == arr[3 - i]);
}
let v1 = swizzle_f32x4_t!(1, 2, 3, 1, v0);
let v1_arr = f32x4_from_float32x4(v1);
let cmp_arr = [2., 3., 4., 2.];
for i in 0..4 {
assert!(v1_arr[i] == cmp_arr[i]);
}
let v1 = swizzle_f32x4_t!(1, 0, 3, 0, v0);
let v1_arr = f32x4_from_float32x4(v1);
let cmp_arr = [2., 1., 4., 1.];
for i in 0..4 {
assert!(v1_arr[i] == cmp_arr[i]);
}
}
#[test]
fn test_sin() {
let arr = [1.2, 2.3, 3.4, 4.5];
let v0 = float32x4_from_f32x4(arr);
let v_sin = unsafe { float32x4_sin(v0) };
let v_sin_arr = f32x4_from_float32x4(v_sin);
let cmp_arr = arr.map(f32::sin);
for i in 0..4 {
assert!(
(v_sin_arr[i] - cmp_arr[i]).abs() < 1e-6,
"left = {}, right = {}",
v_sin_arr[i],
cmp_arr[i]
);
}
}
}