use std::arch::aarch64::*;
use super::NEON256Vector;
#[derive(Debug, Clone, Copy)]
pub struct NEONVector(uint8x16_t);
impl NEONVector {
#[inline(always)]
unsafe fn load_partial_safe(ptr: *const u8, len: usize) -> uint8x16_t {
unsafe {
debug_assert!(len < 8);
let val: u64 = match len {
0 => 0,
1 => *ptr as u64,
2 => (ptr as *const u16).read_unaligned() as u64,
3 => {
let lo = (ptr as *const u16).read_unaligned() as u64;
let hi = *ptr.add(2) as u64;
lo | (hi << 16)
}
4 => (ptr as *const u32).read_unaligned() as u64,
5 => {
let lo = (ptr as *const u32).read_unaligned() as u64;
let hi = *ptr.add(4) as u64;
lo | (hi << 32)
}
6 => {
let lo = (ptr as *const u32).read_unaligned() as u64;
let hi = (ptr.add(4) as *const u16).read_unaligned() as u64;
lo | (hi << 32)
}
7 => {
let lo = (ptr as *const u32).read_unaligned() as u64;
let hi = (ptr.add(4) as *const u32).read_unaligned() as u64;
lo | ((hi & 0xFFFFFF) << 32)
}
_ => std::hint::unreachable_unchecked(),
};
vcombine_u8(vreinterpret_u8_u64(vdup_n_u64(val)), vdup_n_u8(0))
}
}
#[inline(always)]
fn can_overread_8(ptr: *const u8) -> bool {
(ptr as usize & 0xFFF) <= (4096 - 8)
}
}
impl super::Vector for NEONVector {
#[inline]
fn is_available() -> bool {
cfg!(target_arch = "aarch64")
}
#[inline(always)]
unsafe fn zero() -> Self {
unsafe { Self(vdupq_n_u8(0)) }
}
#[inline(always)]
unsafe fn splat_u8(value: u8) -> Self {
unsafe { Self(vdupq_n_u8(value)) }
}
#[inline(always)]
unsafe fn splat_u16(value: u16) -> Self {
unsafe { Self(vreinterpretq_u8_u16(vdupq_n_u16(value))) }
}
#[inline(always)]
unsafe fn eq_u8(self, other: Self) -> Self {
unsafe { Self(vceqq_u8(self.0, other.0)) }
}
#[inline(always)]
unsafe fn gt_u8(self, other: Self) -> Self {
unsafe { Self(vcgtq_u8(self.0, other.0)) }
}
#[inline(always)]
unsafe fn lt_u8(self, other: Self) -> Self {
unsafe { Self(vcltq_u8(self.0, other.0)) }
}
#[inline(always)]
unsafe fn max_u16(self, other: Self) -> Self {
unsafe {
Self(vreinterpretq_u8_u16(vmaxq_u16(
vreinterpretq_u16_u8(self.0),
vreinterpretq_u16_u8(other.0),
)))
}
}
#[inline(always)]
unsafe fn smax_u16(self) -> u16 {
unsafe { vmaxvq_u16(vreinterpretq_u16_u8(self.0)) }
}
#[inline(always)]
unsafe fn add_u16(self, other: Self) -> Self {
unsafe {
Self(vreinterpretq_u8_u16(vaddq_u16(
vreinterpretq_u16_u8(self.0),
vreinterpretq_u16_u8(other.0),
)))
}
}
#[inline(always)]
unsafe fn subs_u16(self, other: Self) -> Self {
unsafe {
Self(vreinterpretq_u8_u16(vqsubq_u16(
vreinterpretq_u16_u8(self.0),
vreinterpretq_u16_u8(other.0),
)))
}
}
#[inline(always)]
unsafe fn and(self, other: Self) -> Self {
unsafe { Self(vandq_u8(self.0, other.0)) }
}
#[inline(always)]
unsafe fn or(self, other: Self) -> Self {
unsafe { Self(vorrq_u8(self.0, other.0)) }
}
#[inline(always)]
unsafe fn not(self) -> Self {
unsafe { Self(vmvnq_u8(self.0)) }
}
#[inline(always)]
unsafe fn shift_right_padded_u16<const L: i32>(self, other: Self) -> Self {
unsafe {
assert!(L >= 0 && L <= 8);
match L {
0 => self,
1 => Self(vextq_u8(other.0, self.0, 14)),
2 => Self(vextq_u8(other.0, self.0, 12)),
3 => Self(vextq_u8(other.0, self.0, 10)),
4 => Self(vextq_u8(other.0, self.0, 8)),
5 => Self(vextq_u8(other.0, self.0, 6)),
6 => Self(vextq_u8(other.0, self.0, 4)),
7 => Self(vextq_u8(other.0, self.0, 2)),
8 => Self(other.0),
_ => unreachable!(),
}
}
}
#[cfg(test)]
fn from_array(arr: [u8; 16]) -> Self {
Self(unsafe { vld1q_u8(arr.as_ptr()) })
}
#[cfg(test)]
fn to_array(self) -> [u8; 16] {
let mut arr = [0u8; 16];
unsafe { vst1q_u8(arr.as_mut_ptr(), self.0) };
arr
}
#[cfg(test)]
fn from_array_u16(arr: [u16; 8]) -> Self {
Self(unsafe { vld1q_u8(arr.as_ptr() as *const u8) })
}
#[cfg(test)]
fn to_array_u16(self) -> [u16; 8] {
let mut arr = [0u16; 8];
unsafe { vst1q_u16(arr.as_mut_ptr(), vreinterpretq_u16_u8(self.0)) };
arr
}
}
impl super::Vector128 for NEONVector {
#[inline(always)]
unsafe fn load_aligned_16(ptr: *const u8) -> Self {
unsafe { Self(vld1q_u8(ptr)) }
}
#[inline(always)]
unsafe fn load_partial(data: *const u8, start: usize, len: usize) -> Self {
unsafe {
Self(match len {
0 => vdupq_n_u8(0),
8 => {
let lo = vld1_u8(data);
vcombine_u8(lo, vdup_n_u8(0))
}
16 => vld1q_u8(data),
1..=7 if Self::can_overread_8(data) => {
let lo = vld1_u8(data);
let mask_vals: [u8; 8] = [
if 0 < len { 0xFF } else { 0 },
if 1 < len { 0xFF } else { 0 },
if 2 < len { 0xFF } else { 0 },
if 3 < len { 0xFF } else { 0 },
if 4 < len { 0xFF } else { 0 },
if 5 < len { 0xFF } else { 0 },
if 6 < len { 0xFF } else { 0 },
if 7 < len { 0xFF } else { 0 },
];
let mask = vld1_u8(mask_vals.as_ptr());
let masked = vand_u8(lo, mask);
vcombine_u8(masked, vdup_n_u8(0))
}
1..=7 => Self::load_partial_safe(data, len),
9..=15 => {
let lo = vld1_u8(data);
let hi_start = len - 8;
let hi = vld1_u8(data.add(hi_start));
let hi = match 16 - len {
1 => vext_u8(hi, vdup_n_u8(0), 1),
2 => vext_u8(hi, vdup_n_u8(0), 2),
3 => vext_u8(hi, vdup_n_u8(0), 3),
4 => vext_u8(hi, vdup_n_u8(0), 4),
5 => vext_u8(hi, vdup_n_u8(0), 5),
6 => vext_u8(hi, vdup_n_u8(0), 6),
7 => vext_u8(hi, vdup_n_u8(0), 7),
_ => unreachable!(),
};
vcombine_u8(lo, hi)
}
_ if start + 16 <= len => vld1q_u8(data.add(start)),
_ => {
let overlap = start + 16 - len;
let loaded = vld1q_u8(data.add(len - 16));
match overlap {
1 => vextq_u8(loaded, vdupq_n_u8(0), 1),
2 => vextq_u8(loaded, vdupq_n_u8(0), 2),
3 => vextq_u8(loaded, vdupq_n_u8(0), 3),
4 => vextq_u8(loaded, vdupq_n_u8(0), 4),
5 => vextq_u8(loaded, vdupq_n_u8(0), 5),
6 => vextq_u8(loaded, vdupq_n_u8(0), 6),
7 => vextq_u8(loaded, vdupq_n_u8(0), 7),
8 => vextq_u8(loaded, vdupq_n_u8(0), 8),
9 => vextq_u8(loaded, vdupq_n_u8(0), 9),
10 => vextq_u8(loaded, vdupq_n_u8(0), 10),
11 => vextq_u8(loaded, vdupq_n_u8(0), 11),
12 => vextq_u8(loaded, vdupq_n_u8(0), 12),
13 => vextq_u8(loaded, vdupq_n_u8(0), 13),
14 => vextq_u8(loaded, vdupq_n_u8(0), 14),
15 => vextq_u8(loaded, vdupq_n_u8(0), 15),
_ => vdupq_n_u8(0),
}
}
})
}
}
#[inline(always)]
unsafe fn shift_right_padded_u8<const L: i32>(self, other: Self) -> Self {
unsafe {
assert!(L >= 0 && L <= 15);
match L {
0 => self,
1 => Self(vextq_u8(other.0, self.0, 15)),
2 => Self(vextq_u8(other.0, self.0, 14)),
3 => Self(vextq_u8(other.0, self.0, 13)),
4 => Self(vextq_u8(other.0, self.0, 12)),
5 => Self(vextq_u8(other.0, self.0, 11)),
6 => Self(vextq_u8(other.0, self.0, 10)),
7 => Self(vextq_u8(other.0, self.0, 9)),
8 => Self(vextq_u8(other.0, self.0, 8)),
9 => Self(vextq_u8(other.0, self.0, 7)),
10 => Self(vextq_u8(other.0, self.0, 6)),
11 => Self(vextq_u8(other.0, self.0, 5)),
12 => Self(vextq_u8(other.0, self.0, 4)),
13 => Self(vextq_u8(other.0, self.0, 3)),
14 => Self(vextq_u8(other.0, self.0, 2)),
15 => Self(vextq_u8(other.0, self.0, 1)),
_ => unreachable!(),
}
}
}
}
impl super::Vector128Expansion<NEON256Vector> for NEONVector {
#[inline(always)]
unsafe fn cast_i8_to_i16(self) -> NEON256Vector {
unsafe {
NEON256Vector((
vreinterpretq_u8_s16(vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(self.0)))),
vreinterpretq_u8_s16(vmovl_high_s8(vreinterpretq_s8_u8(self.0))),
))
}
}
}