use core::arch::aarch64::{self, uint32x4_t};
use core::mem::transmute;
#[inline(always)]
fn array_to_uint32x4(input: [u32; 4]) -> uint32x4_t {
unsafe { transmute::<[u32; 4], uint32x4_t>(input) }
}
#[inline(always)]
fn uint32x4_to_array(input: uint32x4_t) -> [u32; 4] {
unsafe { transmute::<uint32x4_t, [u32; 4]>(input) }
}
#[inline]
#[must_use]
pub fn uint32x4_mod_add(a: uint32x4_t, b: uint32x4_t, p: uint32x4_t) -> uint32x4_t {
unsafe {
let t = aarch64::vaddq_u32(a, b);
let u = aarch64::vsubq_u32(t, p);
aarch64::vminq_u32(t, u)
}
}
#[inline]
#[must_use]
pub fn uint32x4_mod_sub(a: uint32x4_t, b: uint32x4_t, p: uint32x4_t) -> uint32x4_t {
unsafe {
let t = aarch64::vsubq_u32(a, b);
let u = aarch64::vaddq_u32(t, p);
aarch64::vminq_u32(t, u)
}
}
#[inline(always)]
pub fn packed_mod_add<const WIDTH: usize>(
a: &[u32; WIDTH],
b: &[u32; WIDTH],
res: &mut [u32; WIDTH],
p: u32,
scalar_add: fn(u32, u32) -> u32,
) {
match WIDTH {
1 => res[0] = scalar_add(a[0], b[0]),
4 => {
let out: [u32; 4] = unsafe {
let a = array_to_uint32x4([a[0], a[1], a[2], a[3]]);
let b = array_to_uint32x4([b[0], b[1], b[2], b[3]]);
let p: uint32x4_t = aarch64::vdupq_n_u32(p);
uint32x4_to_array(uint32x4_mod_add(a, b, p))
};
res.copy_from_slice(&out);
}
5 => {
let out: [u32; 4] = unsafe {
let a = array_to_uint32x4([a[0], a[1], a[2], a[3]]);
let b = array_to_uint32x4([b[0], b[1], b[2], b[3]]);
let p: uint32x4_t = aarch64::vdupq_n_u32(p);
uint32x4_to_array(uint32x4_mod_add(a, b, p))
};
res[4] = scalar_add(a[4], b[4]);
res[..4].copy_from_slice(&out);
}
8 => {
let (out_lo, out_hi): ([u32; 4], [u32; 4]) = unsafe {
let p: uint32x4_t = aarch64::vdupq_n_u32(p);
let a_lo = array_to_uint32x4([a[0], a[1], a[2], a[3]]);
let b_lo = array_to_uint32x4([b[0], b[1], b[2], b[3]]);
let out_lo = uint32x4_to_array(uint32x4_mod_add(a_lo, b_lo, p));
let a_hi = array_to_uint32x4([a[4], a[5], a[6], a[7]]);
let b_hi = array_to_uint32x4([b[4], b[5], b[6], b[7]]);
let out_hi = uint32x4_to_array(uint32x4_mod_add(a_hi, b_hi, p));
(out_lo, out_hi)
};
res[..4].copy_from_slice(&out_lo);
res[4..].copy_from_slice(&out_hi);
}
_ => panic!("Currently unsupported width for packed addition"),
}
}
#[inline(always)]
pub fn packed_mod_sub<const WIDTH: usize>(
a: &[u32; WIDTH],
b: &[u32; WIDTH],
res: &mut [u32; WIDTH],
p: u32,
scalar_sub: fn(u32, u32) -> u32,
) {
match WIDTH {
1 => res[0] = scalar_sub(a[0], b[0]),
4 => {
let out: [u32; 4] = unsafe {
let a = array_to_uint32x4([a[0], a[1], a[2], a[3]]);
let b = array_to_uint32x4([b[0], b[1], b[2], b[3]]);
let p: uint32x4_t = aarch64::vdupq_n_u32(p);
uint32x4_to_array(uint32x4_mod_sub(a, b, p))
};
res.copy_from_slice(&out);
}
5 => {
let out: [u32; 4] = unsafe {
let a = array_to_uint32x4([a[0], a[1], a[2], a[3]]);
let b = array_to_uint32x4([b[0], b[1], b[2], b[3]]);
let p: uint32x4_t = aarch64::vdupq_n_u32(p);
uint32x4_to_array(uint32x4_mod_sub(a, b, p))
};
res[4] = scalar_sub(a[4], b[4]);
res[..4].copy_from_slice(&out);
}
8 => {
let (out_lo, out_hi): ([u32; 4], [u32; 4]) = unsafe {
let p: uint32x4_t = aarch64::vdupq_n_u32(p);
let a_lo = array_to_uint32x4([a[0], a[1], a[2], a[3]]);
let b_lo = array_to_uint32x4([b[0], b[1], b[2], b[3]]);
let out_lo = uint32x4_to_array(uint32x4_mod_sub(a_lo, b_lo, p));
let a_hi = array_to_uint32x4([a[4], a[5], a[6], a[7]]);
let b_hi = array_to_uint32x4([b[4], b[5], b[6], b[7]]);
let out_hi = uint32x4_to_array(uint32x4_mod_sub(a_hi, b_hi, p));
(out_lo, out_hi)
};
res[..4].copy_from_slice(&out_lo);
res[4..].copy_from_slice(&out_hi);
}
_ => panic!("Currently unsupported width for packed subtraction"),
}
}