#![cfg(feature = "simd")]
use super::scalar;
use core::simd::cmp::{SimdPartialEq, SimdPartialOrd};
use core::simd::{mask16x8, u16x8, u8x16, Mask, Select as _, Simd, SimdElement, ToBytes};
pub fn or(lhs: &[u16], rhs: &[u16], visitor: &mut impl BinaryOperationVisitor) {
#[inline]
fn dedup(slice: &mut [u16]) -> usize {
let mut pos: usize = 1;
for i in 1..slice.len() {
if slice[i] != slice[i - 1] {
slice[pos] = slice[i];
pos += 1;
}
}
pos
}
#[inline]
fn handle_vector(old: u16x8, new: u16x8, f: impl FnOnce(u16x8, u8)) {
let tmp: u16x8 = Shr1::concat_swizzle(new, old);
let mask = 255 - tmp.simd_eq(new).to_bitmask() as u8;
f(new, mask);
}
if (lhs.len() < 8) || (rhs.len() < 8) {
scalar::or(lhs, rhs, visitor);
return;
}
let len1: usize = lhs.len() / 8;
let len2: usize = rhs.len() / 8;
let v_a: u16x8 = load(lhs);
let v_b: u16x8 = load(rhs);
let [mut v_min, mut v_max] = simd_merge_u16(v_a, v_b);
let mut i = 1;
let mut j = 1;
handle_vector(Simd::splat(u16::MAX), v_min, |v, m| visitor.visit_vector(v, m));
let mut v_prev: u16x8 = v_min;
if (i < len1) && (j < len2) {
let mut v: u16x8;
let mut cur_a: u16 = lhs[8 * i];
let mut cur_b: u16 = rhs[8 * j];
loop {
if cur_a <= cur_b {
v = load(&lhs[8 * i..]);
i += 1;
if i < len1 {
cur_a = lhs[8 * i];
} else {
break;
}
} else {
v = load(&rhs[8 * j..]);
j += 1;
if j < len2 {
cur_b = rhs[8 * j];
} else {
break;
}
}
[v_min, v_max] = simd_merge_u16(v, v_max);
handle_vector(v_prev, v_min, |v, m| visitor.visit_vector(v, m));
v_prev = v_min;
}
[v_min, v_max] = simd_merge_u16(v, v_max);
handle_vector(v_prev, v_min, |v, m| visitor.visit_vector(v, m));
v_prev = v_min;
}
debug_assert!(i == len1 || j == len2);
let mut buffer: [u16; 16] = [0; 16];
let mut rem = 0;
handle_vector(v_prev, v_max, |v, m| {
store(swizzle_to_front(v, m), buffer.as_mut_slice());
rem = m.count_ones() as usize;
});
let (tail_a, tail_b, tail_len) = if i == len1 {
(&lhs[8 * i..], &rhs[8 * j..], lhs.len() - 8 * len1)
} else {
(&rhs[8 * j..], &lhs[8 * i..], rhs.len() - 8 * len2)
};
buffer[rem..rem + tail_len].copy_from_slice(tail_a);
rem += tail_len;
if rem == 0 {
visitor.visit_slice(tail_b)
} else {
buffer[..rem].sort_unstable();
rem = dedup(&mut buffer[..rem]);
scalar::or(&buffer[..rem], tail_b, visitor);
}
}
pub fn and(lhs: &[u16], rhs: &[u16], visitor: &mut impl BinaryOperationVisitor) {
let st_a = (lhs.len() / u16x8::LEN) * u16x8::LEN;
let st_b = (rhs.len() / u16x8::LEN) * u16x8::LEN;
let mut i: usize = 0;
let mut j: usize = 0;
if (i < st_a) && (j < st_b) {
let mut v_a: u16x8 = load(&lhs[i..]);
let mut v_b: u16x8 = load(&rhs[j..]);
loop {
let mask = matrix_cmp_u16(v_a, v_b).to_bitmask() as u8;
visitor.visit_vector(v_a, mask);
let a_max: u16 = lhs[i + u16x8::LEN - 1];
let b_max: u16 = rhs[j + u16x8::LEN - 1];
if a_max <= b_max {
i += u16x8::LEN;
if i == st_a {
break;
}
v_a = load(&lhs[i..]);
}
if b_max <= a_max {
j += u16x8::LEN;
if j == st_b {
break;
}
v_b = load(&rhs[j..]);
}
}
}
scalar::and(&lhs[i..], &rhs[j..], visitor);
}
pub fn xor(lhs: &[u16], rhs: &[u16], visitor: &mut impl BinaryOperationVisitor) {
#[inline]
fn xor_slice(slice: &mut [u16]) -> usize {
let mut pos: usize = 1;
for i in 1..slice.len() {
if slice[i] != slice[i - 1] {
slice[pos] = slice[i];
pos += 1;
} else {
pos -= 1; }
}
pos
}
#[inline]
fn handle_vector(old: u16x8, new: u16x8, f: impl FnOnce(u16x8, u8)) {
let tmp1: u16x8 = Shr2::concat_swizzle(new, old);
let tmp2: u16x8 = Shr1::concat_swizzle(new, old);
let eq_l: mask16x8 = tmp2.simd_eq(tmp1);
let eq_r: mask16x8 = tmp2.simd_eq(new);
let eq_l_or_r: mask16x8 = eq_l | eq_r;
let mask: u8 = eq_l_or_r.to_bitmask() as u8;
f(tmp2, 255 - mask);
}
if (lhs.len() < 8) || (rhs.len() < 8) {
scalar::xor(lhs, rhs, visitor);
return;
}
let len1: usize = lhs.len() / 8;
let len2: usize = rhs.len() / 8;
let v_a: u16x8 = load(lhs);
let v_b: u16x8 = load(rhs);
let [mut v_min, mut v_max] = simd_merge_u16(v_a, v_b);
let mut i = 1;
let mut j = 1;
handle_vector(Simd::splat(u16::MAX), v_min, |v, m| visitor.visit_vector(v, m));
let mut v_prev: u16x8 = v_min;
if (i < len1) && (j < len2) {
let mut v: u16x8;
let mut cur_a: u16 = lhs[8 * i];
let mut cur_b: u16 = rhs[8 * j];
loop {
if cur_a <= cur_b {
v = load(&lhs[8 * i..]);
i += 1;
if i < len1 {
cur_a = lhs[8 * i];
} else {
break;
}
} else {
v = load(&rhs[8 * j..]);
j += 1;
if j < len2 {
cur_b = rhs[8 * j];
} else {
break;
}
}
[v_min, v_max] = simd_merge_u16(v, v_max);
handle_vector(v_prev, v_min, |v, m| visitor.visit_vector(v, m));
v_prev = v_min;
}
[v_min, v_max] = simd_merge_u16(v, v_max);
handle_vector(v_prev, v_min, |v, m| visitor.visit_vector(v, m));
v_prev = v_min;
}
debug_assert!(i == len1 || j == len2);
let mut buffer: [u16; 17] = [0; 17];
let mut rem = 0;
handle_vector(v_prev, v_max, |v, m| {
store(swizzle_to_front(v, m), buffer.as_mut_slice());
rem = m.count_ones() as usize;
});
let arr_max = v_max.as_array();
let vec7 = arr_max[7];
let vec6 = arr_max[6];
if vec6 != vec7 {
buffer[rem] = vec7;
rem += 1;
}
let (tail_a, tail_b, tail_len) = if i == len1 {
(&lhs[8 * i..], &rhs[8 * j..], lhs.len() - 8 * len1)
} else {
(&rhs[8 * j..], &lhs[8 * i..], rhs.len() - 8 * len2)
};
buffer[rem..rem + tail_len].copy_from_slice(tail_a);
rem += tail_len;
if rem == 0 {
visitor.visit_slice(tail_b)
} else {
buffer[..rem].sort_unstable();
rem = xor_slice(&mut buffer[..rem]);
scalar::xor(&buffer[..rem], tail_b, visitor);
}
}
pub fn sub(lhs: &[u16], rhs: &[u16], visitor: &mut impl BinaryOperationVisitor) {
if lhs.is_empty() {
return;
} else if rhs.is_empty() {
visitor.visit_slice(lhs);
return;
}
let st_a = (lhs.len() / u16x8::LEN) * u16x8::LEN;
let st_b = (rhs.len() / u16x8::LEN) * u16x8::LEN;
let mut i = 0;
let mut j = 0;
if (i < st_a) && (j < st_b) {
let mut v_a: u16x8 = load(&lhs[i..]);
let mut v_b: u16x8 = load(&rhs[j..]);
let mut runningmask_a_found_in_b: u8 = 0;
loop {
let a_found_in_b: u8 = matrix_cmp_u16(v_a, v_b).to_bitmask() as u8;
runningmask_a_found_in_b |= a_found_in_b;
let a_max: u16 = lhs[i + u16x8::LEN - 1];
let b_max: u16 = rhs[j + u16x8::LEN - 1];
if a_max <= b_max {
let bitmask_belongs_to_difference = runningmask_a_found_in_b ^ 0xFF;
visitor.visit_vector(v_a, bitmask_belongs_to_difference);
i += u16x8::LEN;
if i == st_a {
break;
}
runningmask_a_found_in_b = 0;
v_a = load(&lhs[i..]);
}
if b_max <= a_max {
j += u16x8::LEN;
if j == st_b {
break;
}
v_b = load(&rhs[j..]);
}
}
debug_assert!(i == st_a || j == st_b);
if i < st_a {
let remaining_rhs = &rhs[j..];
if !remaining_rhs.is_empty() {
let mut buffer: [u16; 8] = [0; 8]; buffer[..remaining_rhs.len()].copy_from_slice(remaining_rhs);
buffer[remaining_rhs.len()..].fill(remaining_rhs[0]);
v_b = Simd::from_array(buffer);
let a_found_in_b: u8 = matrix_cmp_u16(v_a, v_b).to_bitmask() as u8;
runningmask_a_found_in_b |= a_found_in_b;
let [.., max_va] = *v_a.as_array();
let used_rhs = remaining_rhs.partition_point(|&b| b <= max_va);
j += used_rhs;
}
let bitmask_belongs_to_difference: u8 = runningmask_a_found_in_b ^ 0xFF;
visitor.visit_vector(v_a, bitmask_belongs_to_difference);
i += u16x8::LEN;
}
}
scalar::sub(&lhs[i..], &rhs[j..], visitor);
}
#[inline]
fn lanes_min_u16<const LANES: usize>(
lhs: Simd<u16, LANES>,
rhs: Simd<u16, LANES>,
) -> Simd<u16, LANES> {
lhs.simd_le(rhs).select(lhs, rhs)
}
#[inline]
fn lanes_max_u16<const LANES: usize>(
lhs: Simd<u16, LANES>,
rhs: Simd<u16, LANES>,
) -> Simd<u16, LANES> {
lhs.simd_gt(rhs).select(lhs, rhs)
}
#[inline]
pub fn load<U, const LANES: usize>(src: &[U]) -> Simd<U, LANES>
where
U: SimdElement + PartialOrd,
{
debug_assert!(src.len() >= LANES);
unsafe { load_unchecked(src) }
}
#[inline]
pub unsafe fn load_unchecked<U, const LANES: usize>(src: &[U]) -> Simd<U, LANES>
where
U: SimdElement + PartialOrd,
{
unsafe { core::ptr::read_unaligned(src as *const _ as *const Simd<U, LANES>) }
}
#[inline]
pub fn store<U, const LANES: usize>(v: Simd<U, LANES>, out: &mut [U])
where
U: SimdElement + PartialOrd,
{
debug_assert!(out.len() >= LANES);
unsafe {
store_unchecked(v, out);
}
}
#[inline]
unsafe fn store_unchecked<U, const LANES: usize>(v: Simd<U, LANES>, out: &mut [U])
where
U: SimdElement + PartialOrd,
{
unsafe { core::ptr::write_unaligned(out as *mut _ as *mut Simd<U, LANES>, v) }
}
#[inline]
fn matrix_cmp_u16(a: Simd<u16, 8>, b: Simd<u16, 8>) -> Mask<i16, 8> {
a.simd_eq(b)
| a.simd_eq(b.rotate_elements_left::<1>())
| a.simd_eq(b.rotate_elements_left::<2>())
| a.simd_eq(b.rotate_elements_left::<3>())
| a.simd_eq(b.rotate_elements_left::<4>())
| a.simd_eq(b.rotate_elements_left::<5>())
| a.simd_eq(b.rotate_elements_left::<6>())
| a.simd_eq(b.rotate_elements_left::<7>())
}
use crate::bitmap::store::array_store::visitor::BinaryOperationVisitor;
use core::simd::Swizzle;
pub struct Shr1;
impl Swizzle<8> for Shr1 {
const INDEX: [usize; 8] = [15, 0, 1, 2, 3, 4, 5, 6];
}
pub struct Shr2;
impl Swizzle<8> for Shr2 {
const INDEX: [usize; 8] = [14, 15, 0, 1, 2, 3, 4, 5];
}
#[inline]
fn simd_merge_u16(a: Simd<u16, 8>, b: Simd<u16, 8>) -> [Simd<u16, 8>; 2] {
let mut tmp: Simd<u16, 8> = lanes_min_u16(a, b);
let mut max: Simd<u16, 8> = lanes_max_u16(a, b);
tmp = tmp.rotate_elements_left::<1>();
let mut min: Simd<u16, 8> = lanes_min_u16(tmp, max);
for _ in 0..6 {
max = lanes_max_u16(tmp, max);
tmp = min.rotate_elements_left::<1>();
min = lanes_min_u16(tmp, max);
}
max = lanes_max_u16(tmp, max);
min = min.rotate_elements_left::<1>();
[min, max]
}
pub fn swizzle_to_front(val: u16x8, bitmask: u8) -> u16x8 {
static SWIZZLE_TABLE: [[u8; 16]; 256] = {
let mut table = [[0; 16]; 256];
let mut n = 0usize;
while n < table.len() {
let mut x = n;
let mut i = 0;
while x > 0 {
let lsb = x.trailing_zeros() as u8;
x ^= 1 << lsb;
table[n][i] = lsb * 2; table[n][i + 1] = lsb * 2 + 1; i += 2;
}
n += 1;
}
table
};
let val_convert: u8x16 = val.to_ne_bytes();
let swizzle_idxs = u8x16::from_array(SWIZZLE_TABLE[bitmask as usize]);
#[cfg(all(target_arch = "x86_64", any(target_feature = "ssse3", feature = "std")))]
{
let has_ssse3 = {
#[cfg(target_feature = "ssse3")]
{
true
}
#[cfg(not(target_feature = "ssse3"))]
{
std::arch::is_x86_feature_detected!("ssse3")
}
};
if has_ssse3 {
use core::arch::x86_64::{__m128i, _mm_shuffle_epi8};
let val_m128 = __m128i::from(val_convert);
let swizzle_m128 = __m128i::from(swizzle_idxs);
let swizzled_m128 = unsafe { _mm_shuffle_epi8(val_m128, swizzle_m128) };
return u16x8::from(swizzled_m128);
}
}
let swizzled: u8x16 = val_convert.swizzle_dyn(swizzle_idxs);
u16x8::from_ne_bytes(swizzled)
}