use core::simd::Simd;
use crate::{Bitmask, BitmaskVT};
use crate::kernels::arithmetic::simd::{W8, W16, W32, W64};
use crate::enums::operators::{LogicalOperator, UnaryOperator};
use crate::kernels::bitmask::{
bitmask_window_bytes, bitmask_window_bytes_mut, clear_trailing_bits, mask_bits_as_words,
mask_bits_as_words_mut,
};
#[inline(always)]
pub fn bitmask_binop_simd<const LANES: usize>(
lhs: BitmaskVT<'_>,
rhs: BitmaskVT<'_>,
op: LogicalOperator,
) -> Bitmask
where
{
let (lhs_mask, lhs_off, len) = lhs;
let (rhs_mask, rhs_off, _) = rhs;
if len == 0 {
return Bitmask::new_set_all(0, false);
}
let mut out = Bitmask::new_set_all(len, false);
let nw = (len + 63) / 64;
unsafe {
let lp = mask_bits_as_words(bitmask_window_bytes(lhs_mask, lhs_off, len));
let rp = mask_bits_as_words(bitmask_window_bytes(rhs_mask, rhs_off, len));
let dp = mask_bits_as_words_mut(bitmask_window_bytes_mut(&mut out, 0, len));
let mut i = 0;
while i + LANES <= nw {
let a = Simd::<u64, LANES>::from_slice(std::slice::from_raw_parts(lp.add(i), LANES));
let b = Simd::<u64, LANES>::from_slice(std::slice::from_raw_parts(rp.add(i), LANES));
let r = match op {
LogicalOperator::And => a & b,
LogicalOperator::Or => a | b,
LogicalOperator::Xor => a ^ b,
};
std::ptr::copy_nonoverlapping(r.as_array().as_ptr(), dp.add(i), LANES);
i += LANES;
}
for k in i..nw {
let a = *lp.add(k);
let b = *rp.add(k);
*dp.add(k) = match op {
LogicalOperator::And => a & b,
LogicalOperator::Or => a | b,
LogicalOperator::Xor => a ^ b,
};
}
}
out.len = len;
clear_trailing_bits(&mut out);
out
}
#[inline(always)]
pub fn bitmask_unop_simd<const LANES: usize>(src: BitmaskVT<'_>, op: UnaryOperator) -> Bitmask
where
{
let (mask, offset, len) = src;
if len == 0 {
return Bitmask::new_set_all(0, false);
}
let mut out = Bitmask::new_set_all(len, false);
let nw = (len + 63) / 64;
unsafe {
let sp = mask_bits_as_words(bitmask_window_bytes(mask, offset, len));
let dp = mask_bits_as_words_mut(bitmask_window_bytes_mut(&mut out, 0, len));
let mut i = 0;
while i + LANES <= nw {
let a = Simd::<u64, LANES>::from_slice(std::slice::from_raw_parts(sp.add(i), LANES));
let r = match op {
UnaryOperator::Not => !a,
_ => unreachable!(),
};
std::ptr::copy_nonoverlapping(r.as_array().as_ptr(), dp.add(i), LANES);
i += LANES;
}
for k in i..nw {
let a = *sp.add(k);
*dp.add(k) = match op {
UnaryOperator::Not => !a,
_ => unreachable!(),
};
}
}
out.len = len;
clear_trailing_bits(&mut out);
out
}
#[inline(always)]
pub fn and_masks_simd<const LANES: usize>(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask
where
{
bitmask_binop_simd::<LANES>(lhs, rhs, LogicalOperator::And)
}
#[inline(always)]
pub fn or_masks_simd<const LANES: usize>(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask
where
{
bitmask_binop_simd::<LANES>(lhs, rhs, LogicalOperator::Or)
}
#[inline(always)]
pub fn xor_masks_simd<const LANES: usize>(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask
where
{
bitmask_binop_simd::<LANES>(lhs, rhs, LogicalOperator::Xor)
}
#[inline(always)]
pub fn not_mask_simd<const LANES: usize>(src: BitmaskVT<'_>) -> Bitmask
where
{
bitmask_unop_simd::<LANES>(src, UnaryOperator::Not)
}
#[inline]
pub fn in_mask_simd<const LANES: usize>(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask
where
{
let (lhs_mask, lhs_off, len) = lhs;
let (rhs_mask, rhs_off, rlen) = rhs;
debug_assert_eq!(len, rlen, "in_mask: window length mismatch");
if len == 0 {
return Bitmask::new_set_all(0, false);
}
let n_words = (len + 63) / 64;
let trailing = len & 63;
let mut any_set = 0u64;
let mut any_unset = 0u64;
unsafe {
let rp = rhs_mask.bits.as_ptr().cast::<u64>().add(rhs_off / 64);
for k in 0..n_words {
let mut w = *rp.add(k);
if k == n_words - 1 && trailing != 0 {
let valid_mask = (1u64 << trailing) - 1;
w &= valid_mask;
any_set |= w;
any_unset |= (!w) & valid_mask;
} else {
any_set |= w;
any_unset |= !w;
}
if any_set != 0 && any_unset != 0 {
break;
}
}
}
let has_true = any_set != 0;
let has_false = any_unset != 0;
match (has_true, has_false) {
(true, true) => Bitmask::new_set_all(len, true),
(true, false) => lhs_mask.slice_clone(lhs_off, len),
(false, true) => not_mask_simd::<LANES>((lhs_mask, lhs_off, len)),
(false, false) => Bitmask::new_set_all(len, false),
}
}
#[inline]
pub fn not_in_mask_simd<const LANES: usize>(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask
where
{
let mask = in_mask_simd::<LANES>(lhs, rhs);
not_mask_simd::<LANES>((&mask, 0, mask.len))
}
#[inline]
pub fn eq_mask_simd<const LANES: usize>(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> Bitmask
where
{
let (am, ao, len) = a;
let (bm, bo, blen) = b;
debug_assert_eq!(len, blen, "BitWindow length mismatch in eq_bits_mask");
if len == 0 {
return Bitmask::new_set_all(0, true);
}
if ao % 64 != 0 || bo % 64 != 0 {
panic!(
"eq_bits_mask: offsets must be 64-bit aligned (got a: {}, b: {})",
ao, bo
);
}
let n_words = (len + 63) / 64;
let mut out = Bitmask::new_set_all(len, false);
unsafe {
let ap = am.bits.as_ptr().cast::<u64>().add(ao / 64);
let bp = bm.bits.as_ptr().cast::<u64>().add(bo / 64);
let dp = out.bits.as_mut_ptr().cast::<u64>();
let aw = std::slice::from_raw_parts(ap, n_words);
let bw = std::slice::from_raw_parts(bp, n_words);
#[cfg(feature = "simd")]
{
let mut i = 0;
while i + LANES <= n_words {
let sa = Simd::<u64, LANES>::from_slice(&aw[i..i + LANES]);
let sb = Simd::<u64, LANES>::from_slice(&bw[i..i + LANES]);
let eq = !(sa ^ sb);
std::ptr::copy_nonoverlapping(eq.as_array().as_ptr(), dp.add(i), LANES);
i += LANES;
}
for k in i..n_words {
*dp.add(k) = !(aw[k] ^ bw[k]);
}
}
#[cfg(not(feature = "simd"))]
{
for k in 0..n_words {
*dp.add(k) = !(aw[k] ^ bw[k]);
}
}
}
out.mask_trailing_bits();
out
}
#[inline]
pub fn ne_mask_simd<const LANES: usize>(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> Bitmask
where
{
!eq_mask_simd::<LANES>(a, b)
}
#[inline]
pub fn all_ne_mask_simd<const LANES: usize>(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> bool
where
{
!all_eq_mask_simd::<LANES>(a, b)
}
#[inline]
pub fn all_eq_mask_simd<const LANES: usize>(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> bool
where
{
let (am, ao, len) = a;
let (bm, bo, blen) = b;
debug_assert_eq!(len, blen, "BitWindow length mismatch in all_eq_mask");
if len == 0 {
return true;
}
if len < 64 {
let wa = unsafe { am.word_unchecked(ao / 64) };
let wb = unsafe { bm.word_unchecked(bo / 64) };
let valid_mask = (1u64 << len) - 1;
return (wa & valid_mask) == (wb & valid_mask);
}
if ao % 64 != 0 || bo % 64 != 0 {
panic!(
"all_eq_mask_simd: offsets must be 64-bit aligned (got a: {}, b: {})",
ao, bo
);
}
let n_words = (len + 63) / 64;
let trailing = len & 63;
unsafe {
let aw = std::slice::from_raw_parts(am.bits.as_ptr().cast::<u64>().add(ao / 64), n_words);
let bw = std::slice::from_raw_parts(bm.bits.as_ptr().cast::<u64>().add(bo / 64), n_words);
#[cfg(feature = "simd")]
{
use std::simd::prelude::SimdPartialEq;
let mut i = 0;
while i + LANES <= n_words {
let sa = Simd::<u64, LANES>::from_slice(&aw[i..i + LANES]);
let sb = Simd::<u64, LANES>::from_slice(&bw[i..i + LANES]);
if !sa.simd_eq(sb).all() {
return false;
}
i += LANES;
}
for k in i..n_words {
if k == n_words - 1 && trailing != 0 {
let mask = (1u64 << trailing) - 1;
if (aw[k] & mask) != (bw[k] & mask) {
return false;
}
} else if aw[k] != bw[k] {
return false;
}
}
}
#[cfg(not(feature = "simd"))]
{
for k in 0..n_words {
if k == n_words - 1 && trailing != 0 {
let mask = (1u64 << trailing) - 1;
if (aw[k] & mask) != (bw[k] & mask) {
return false;
}
} else if aw[k] != bw[k] {
return false;
}
}
}
}
true
}
#[inline]
pub fn popcount_mask_simd<const LANES: usize>(m: BitmaskVT<'_>) -> usize
where
{
let (mask, offset, len) = m;
if len == 0 {
return 0;
}
let n_words = (len + 63) / 64;
let word_start = offset / 64;
let mut acc = 0usize;
unsafe {
let words = std::slice::from_raw_parts(
mask.bits.as_ptr().cast::<u64>().add(word_start),
n_words,
);
#[cfg(feature = "simd")]
{
use std::simd::prelude::SimdUint;
let mut i = 0;
while i + LANES <= n_words {
let v = Simd::<u64, LANES>::from_slice(&words[i..i + LANES]);
acc += v.count_ones().reduce_sum() as usize;
i += LANES;
}
for k in i..n_words {
if k == n_words - 1 && len % 64 != 0 {
let slack_mask = (1u64 << (len % 64)) - 1;
acc += (words[k] & slack_mask).count_ones() as usize;
} else {
acc += words[k].count_ones() as usize;
}
}
}
#[cfg(not(feature = "simd"))]
{
for k in 0..n_words {
if k == n_words - 1 && len % 64 != 0 {
let slack_mask = (1u64 << (len % 64)) - 1;
acc += (words[k] & slack_mask).count_ones() as usize;
} else {
acc += words[k].count_ones() as usize;
}
}
}
}
acc
}
#[inline]
pub fn all_true_mask_simd<const LANES: usize>(mask: &Bitmask) -> bool
where
{
if mask.len == 0 {
return true;
}
if mask.len < 64 {
let w = unsafe { mask.word_unchecked(0) };
let valid_mask = (1u64 << mask.len) - 1;
return (w & valid_mask) == valid_mask;
}
let n_bits = mask.len;
let n_words = (n_bits + 63) / 64;
let words: &[u64] =
unsafe { std::slice::from_raw_parts(mask.bits.as_ptr() as *const u64, n_words) };
let simd_chunks = n_words / LANES;
let all_ones = Simd::<u64, LANES>::splat(!0u64);
for chunk in 0..simd_chunks {
let base = chunk * LANES;
let arr = Simd::<u64, LANES>::from_slice(&words[base..base + LANES]);
if arr != all_ones {
return false;
}
}
let tail_words = n_words % LANES;
let base = simd_chunks * LANES;
for k in 0..tail_words {
if base + k == n_words - 1 && n_bits % 64 != 0 {
let valid_bits = n_bits % 64;
let slack_mask = (1u64 << valid_bits) - 1;
if words[base + k] != slack_mask {
return false;
}
} else {
if words[base + k] != !0u64 {
return false;
}
}
}
true
}
pub fn all_false_mask_simd<const LANES: usize>(mask: &Bitmask) -> bool
where
{
if mask.len == 0 {
return true;
}
if mask.len < 64 {
let w = unsafe { mask.word_unchecked(0) };
let valid_mask = (1u64 << mask.len) - 1;
return (w & valid_mask) == 0;
}
let n_bits = mask.len;
let n_words = (n_bits + 63) / 64;
let words: &[u64] =
unsafe { std::slice::from_raw_parts(mask.bits.as_ptr() as *const u64, n_words) };
let simd_chunks = n_words / LANES;
for chunk in 0..simd_chunks {
let base = chunk * LANES;
let arr = Simd::<u64, LANES>::from_slice(&words[base..base + LANES]);
if arr != Simd::<u64, LANES>::splat(0u64) {
return false;
}
}
let tail_words = n_words % LANES;
let base = simd_chunks * LANES;
for k in 0..tail_words {
if base + k == n_words - 1 && n_bits % 64 != 0 {
let valid_bits = n_bits % 64;
let slack_mask = (1u64 << valid_bits) - 1;
if words[base + k] & slack_mask != 0 {
return false;
}
} else {
if words[base + k] != 0u64 {
return false;
}
}
}
true
}
macro_rules! impl_simd_eq_mask {
($fn_name:ident, $t:ty, $lanes:expr) => {
pub fn $fn_name(data: &[$t], field_mask: $t, target: $t) -> Bitmask {
use vec64::Vec64;
use std::simd::cmp::SimdPartialEq;
let n = data.len();
let n_bytes = (n + 7) / 8;
let mut bytes = Vec64::<u8>::with_capacity(n_bytes);
bytes.resize(n_bytes, 0);
let mask_vec = Simd::<$t, $lanes>::splat(field_mask);
let target_vec = Simd::<$t, $lanes>::splat(target);
let chunks = n / $lanes;
for i in 0..chunks {
let d = Simd::<$t, $lanes>::from_slice(&data[i * $lanes..]);
let masked = d & mask_vec;
let cmp = masked.simd_eq(target_vec);
let bits = cmp.to_bitmask() as u64;
let bit_pos = i * $lanes;
let byte_idx = bit_pos / 8;
let bit_shift = bit_pos % 8;
let shifted = bits << bit_shift;
for b in 0..(($lanes + 7) / 8) {
bytes[byte_idx + b] |= (shifted >> (b * 8)) as u8;
}
}
let start = chunks * $lanes;
for j in start..n {
if (data[j] & field_mask) == target {
bytes[j / 8] |= 1 << (j % 8);
}
}
Bitmask::new(bytes, n)
}
};
}
impl_simd_eq_mask!(simd_eq_mask_u8, u8, W8);
impl_simd_eq_mask!(simd_eq_mask_u16, u16, W16);
impl_simd_eq_mask!(simd_eq_mask_u32, u32, W32);
impl_simd_eq_mask!(simd_eq_mask_u64, u64, W64);
#[cfg(test)]
mod tests {
use crate::{Bitmask, BitmaskVT};
use super::*;
macro_rules! simd_bitmask_suite {
($mod_name:ident, $lanes:expr) => {
mod $mod_name {
use super::*;
const LANES: usize = $lanes;
fn bm(bits: &[bool]) -> Bitmask {
let mut m = Bitmask::new_set_all(bits.len(), false);
for (i, b) in bits.iter().enumerate() {
if *b {
m.set(i, true);
}
}
m
}
fn slice(mask: &Bitmask) -> BitmaskVT<'_> {
(mask, 0, mask.len)
}
#[test]
fn test_and_masks_simd() {
let a = bm(&[true, false, true, false, true, true, false, false]);
let b = bm(&[true, true, false, false, true, false, true, false]);
let c = and_masks_simd::<LANES>(slice(&a), slice(&b));
for i in 0..a.len {
assert_eq!(c.get(i), a.get(i) & b.get(i), "bit {i}");
}
}
#[test]
fn test_or_masks_simd() {
let a = bm(&[true, false, true, false, true, true, false, false]);
let b = bm(&[true, true, false, false, true, false, true, false]);
let c = or_masks_simd::<LANES>(slice(&a), slice(&b));
for i in 0..a.len {
assert_eq!(c.get(i), a.get(i) | b.get(i), "bit {i}");
}
}
#[test]
fn test_xor_masks_simd() {
let a = bm(&[true, false, true, false, true, true, false, false]);
let b = bm(&[true, true, false, false, true, false, true, false]);
let c = xor_masks_simd::<LANES>(slice(&a), slice(&b));
for i in 0..a.len {
assert_eq!(c.get(i), a.get(i) ^ b.get(i), "bit {i}");
}
}
#[test]
fn test_not_mask_simd() {
let a = bm(&[true, false, true, false]);
let c = not_mask_simd::<LANES>(slice(&a));
for i in 0..a.len {
assert_eq!(c.get(i), !a.get(i));
}
}
#[test]
fn test_in_mask_simd_variants() {
let lhs = bm(&[true, false, true, false]);
let rhs_true = bm(&[true; 4]);
let out = in_mask_simd::<LANES>(slice(&lhs), slice(&rhs_true));
for i in 0..lhs.len {
assert_eq!(out.get(i), lhs.get(i), "in_mask, only true, bit {i}");
}
let rhs_false = bm(&[false; 4]);
let out = in_mask_simd::<LANES>(slice(&lhs), slice(&rhs_false));
for i in 0..lhs.len {
assert_eq!(out.get(i), !lhs.get(i), "in_mask, only false, bit {i}");
}
let rhs_both = bm(&[true, false, true, false]);
let out = in_mask_simd::<LANES>(slice(&lhs), slice(&rhs_both));
for i in 0..lhs.len {
assert!(out.get(i), "in_mask, both true/false, bit {i}");
}
let rhs_empty = bm(&[false; 0]);
let out = in_mask_simd::<LANES>((&lhs, 0, 0), (&rhs_empty, 0, 0));
assert_eq!(out.len, 0);
}
#[test]
fn test_not_in_mask_simd() {
let lhs = bm(&[true, false, true, false]);
let rhs = bm(&[true, false, true, false]);
let in_mask = in_mask_simd::<LANES>(slice(&lhs), slice(&rhs));
let not_in = not_in_mask_simd::<LANES>(slice(&lhs), slice(&rhs));
for i in 0..lhs.len {
assert_eq!(not_in.get(i), !in_mask.get(i));
}
}
#[test]
fn test_eq_mask_simd_and_ne_mask_simd() {
let a = bm(&[true, false, true, false]);
let b = bm(&[true, false, false, true]);
let eq = eq_mask_simd::<LANES>(slice(&a), slice(&b));
let ne = ne_mask_simd::<LANES>(slice(&a), slice(&b));
for i in 0..a.len {
assert_eq!(eq.get(i), a.get(i) == b.get(i), "eq_mask bit {i}");
assert_eq!(ne.get(i), a.get(i) != b.get(i), "ne_mask bit {i}");
}
}
#[test]
fn test_all_eq_mask_simd() {
let a = bm(&[true, false, true, false, true, true, false, false]);
let b = bm(&[true, false, true, false, true, true, false, false]);
assert!(all_eq_mask_simd::<LANES>(slice(&a), slice(&b)));
let mut b2 = b.clone();
b2.set(0, false);
assert!(!all_eq_mask_simd::<LANES>(slice(&a), slice(&b2)));
}
#[test]
fn test_all_ne_mask_simd() {
let a = bm(&[true, false, true]);
let b = bm(&[false, true, false]);
assert!(all_ne_mask_simd::<LANES>(slice(&a), slice(&b)));
assert!(!all_ne_mask_simd::<LANES>(slice(&a), slice(&a)));
}
#[test]
fn test_popcount_mask_simd() {
let a = bm(&[true, false, true, false, true, false, false, true]);
let pop = popcount_mask_simd::<LANES>(slice(&a));
assert_eq!(pop, 4);
}
#[test]
fn test_all_true_mask_simd_and_false() {
let all_true = Bitmask::new_set_all(64 * LANES, true);
assert!(all_true_mask_simd::<LANES>(&all_true));
let mut not_true = all_true.clone();
not_true.set(3, false);
assert!(!all_true_mask_simd::<LANES>(¬_true));
}
#[test]
fn test_all_false_mask_simd() {
let all_true = Bitmask::new_set_all(64 * LANES, true);
assert!(!all_false_mask_simd::<LANES>(&all_true));
let all_false = Bitmask::new_set_all(64 * LANES, false);
assert!(all_false_mask_simd::<LANES>(&all_false));
}
}
};
}
include!(concat!(env!("OUT_DIR"), "/simd_lanes.rs"));
simd_bitmask_suite!(simd_bitmask_w8, W8);
simd_bitmask_suite!(simd_bitmask_w16, W16);
simd_bitmask_suite!(simd_bitmask_w32, W32);
simd_bitmask_suite!(simd_bitmask_w64, W64);
}