use crate::{Bitmask, BitmaskVT};
use crate::{
enums::operators::{LogicalOperator, UnaryOperator},
kernels::bitmask::{
bitmask_window_bytes, bitmask_window_bytes_mut, clear_trailing_bits, mask_bits_as_words,
mask_bits_as_words_mut,
},
};
#[inline(always)]
pub fn bitmask_binop_std(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>, op: LogicalOperator) -> Bitmask {
let (lhs_mask, lhs_off, len) = lhs;
let (rhs_mask, rhs_off, _) = rhs;
let mut out = Bitmask::new_set_all(len, false);
let nw = (len + 63) / 64;
unsafe {
let lp = mask_bits_as_words(bitmask_window_bytes(lhs_mask, lhs_off, len));
let rp = mask_bits_as_words(bitmask_window_bytes(rhs_mask, rhs_off, len));
let dp = mask_bits_as_words_mut(bitmask_window_bytes_mut(&mut out, 0, len));
for k in 0..nw {
*dp.add(k) = match op {
LogicalOperator::And => *lp.add(k) & *rp.add(k),
LogicalOperator::Or => *lp.add(k) | *rp.add(k),
LogicalOperator::Xor => *lp.add(k) ^ *rp.add(k),
};
}
}
out.len = len;
clear_trailing_bits(&mut out);
out
}
#[inline(always)]
pub fn bitmask_unop_std(src: BitmaskVT<'_>, op: UnaryOperator) -> Bitmask {
let (mask, offset, len) = src;
let mut out = Bitmask::new_set_all(len, false);
let nw = (len + 63) / 64;
unsafe {
let sp = mask_bits_as_words(bitmask_window_bytes(mask, offset, len));
let dp = mask_bits_as_words_mut(bitmask_window_bytes_mut(&mut out, 0, len));
for k in 0..nw {
*dp.add(k) = match op {
UnaryOperator::Not => !*sp.add(k),
_ => unreachable!(), };
}
}
out.len = len;
clear_trailing_bits(&mut out);
out
}
#[inline]
pub fn and_masks(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
bitmask_binop_std(lhs, rhs, LogicalOperator::And)
}
#[inline]
pub fn or_masks(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
bitmask_binop_std(lhs, rhs, LogicalOperator::Or)
}
#[inline]
pub fn xor_masks(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
bitmask_binop_std(lhs, rhs, LogicalOperator::Xor)
}
#[inline]
pub fn not_mask(src: BitmaskVT<'_>) -> Bitmask {
bitmask_unop_std(src, UnaryOperator::Not)
}
#[inline]
pub fn in_mask(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
let (lhs_mask, lhs_off, len) = lhs;
let (rhs_mask, rhs_off, _) = rhs;
let mut has_true = false;
let mut has_false = false;
for i in 0..len {
let v = unsafe { rhs_mask.get_unchecked(rhs_off + i) };
if v {
has_true = true;
} else {
has_false = true;
}
if has_true && has_false {
break;
}
}
match (has_true, has_false) {
(true, true) => Bitmask::new_set_all(len, true),
(true, false) => lhs_mask.slice_clone(lhs_off, len),
(false, true) => not_mask((lhs_mask, lhs_off, len)),
(false, false) => Bitmask::new_set_all(len, false),
}
}
#[inline]
pub fn not_in_mask(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
let mask = in_mask(lhs, rhs);
not_mask((&mask, 0, mask.len))
}
#[inline]
pub fn eq_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> Bitmask {
let (am, ao, len) = a;
let (bm, bo, blen) = b;
debug_assert_eq!(len, blen, "BitWindow length mismatch in eq_mask");
if ao % 64 != 0 || bo % 64 != 0 {
panic!(
"eq_mask: offsets must be 64-bit aligned (got a: {}, b: {})",
ao, bo
);
}
let a_words = ao / 64;
let b_words = bo / 64;
let n_words = (len + 63) / 64;
let mut out = Bitmask::new_set_all(len, false);
for k in 0..n_words {
let wa = unsafe { am.word_unchecked(a_words + k) };
let wb = unsafe { bm.word_unchecked(b_words + k) };
let eq = !(wa ^ wb);
unsafe {
out.set_word_unchecked(k, eq);
}
}
out.mask_trailing_bits();
out
}
#[inline]
pub fn ne_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> Bitmask {
let eq = eq_mask(a, b);
not_mask((&eq, 0, eq.len))
}
#[inline]
pub fn all_eq_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> bool {
let (am, ao, len) = a;
if len < 64 {
for i in 0..len {
if unsafe { a.0.get_unchecked(a.1 + i) } != unsafe { b.0.get_unchecked(b.1 + i) } {
return false;
}
}
return true;
}
let (bm, bo, _) = b;
debug_assert_eq!(len, b.2);
let aw = ao >> 6;
let bw = bo >> 6;
let n_words = (len + 63) >> 6;
let trailing = len & 63;
for k in 0..n_words {
let wa = unsafe { am.word_unchecked(aw + k) };
let wb = unsafe { bm.word_unchecked(bw + k) };
if k == n_words - 1 && trailing != 0 {
let m = (1u64 << trailing) - 1;
if ((wa ^ wb) & m) != 0 {
return false;
}
} else if wa != wb {
return false;
}
}
true
}
#[inline]
pub fn all_ne_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> bool {
!all_eq_mask(a, b)
}
#[inline]
pub fn popcount_mask(m: BitmaskVT<'_>) -> usize {
let (mask, offset, len) = m;
let n_words = (len + 63) / 64;
let word_start = offset / 64;
let mut acc = 0usize;
for k in 0..n_words {
let word = unsafe { mask.word_unchecked(word_start + k) };
if k == n_words - 1 && len % 64 != 0 {
let valid = len % 64;
let slack_mask = (1u64 << valid) - 1;
acc += (word & slack_mask).count_ones() as usize;
} else {
acc += word.count_ones() as usize;
}
}
acc
}
#[inline]
pub fn all_true_mask(mask: &Bitmask) -> bool {
let n_bits = mask.len;
if n_bits == 0 {
return true;
}
if n_bits < 64 {
for i in 0..n_bits {
if !unsafe { mask.get_unchecked(i) } {
return false;
}
}
return true;
}
let n_words = (n_bits + 63) >> 6;
let words: &[u64] =
unsafe { core::slice::from_raw_parts(mask.bits.as_ptr() as *const u64, n_words) };
let trailing = n_bits & 63;
for i in 0..n_words {
let w = words[i];
if i == n_words - 1 && trailing != 0 {
let m = (1u64 << trailing) - 1;
if (w & m) != m {
return false;
}
} else if w != !0u64 {
return false;
}
}
true
}
#[inline]
pub fn all_false_mask(mask: &Bitmask) -> bool {
if mask.len < 64 {
for i in 0..mask.len {
if unsafe { mask.get_unchecked(i) } {
return false;
}
}
return true;
}
let n_bits = mask.len;
if n_bits == 0 {
return true;
}
let n_words = (n_bits + 63) >> 6;
let words: &[u64] =
unsafe { core::slice::from_raw_parts(mask.bits.as_ptr() as *const u64, n_words) };
let trailing = n_bits & 63;
for i in 0..n_words {
let w = words[i];
if i == n_words - 1 && trailing != 0 {
let m = (1u64 << trailing) - 1;
if (w & m) != 0 {
return false;
}
} else if w != 0 {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Bitmask;
fn bm(bits: &[bool]) -> Bitmask {
let mut bm = Bitmask::new_set_all(bits.len(), false);
for (i, &b) in bits.iter().enumerate() {
unsafe { bm.set_unchecked(i, b) };
}
bm
}
#[test]
fn test_and_masks() {
let a = bm(&[true, false, true, true, false, false, true, true]);
let b = bm(&[false, false, true, false, true, false, true, false]);
let out = and_masks((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[false, false, true, false, false, false, true, false]);
for i in 0..8 {
assert_eq!(out.get(i), expected.get(i), "Mismatch at bit {}", i);
}
}
#[test]
fn test_or_masks() {
let a = bm(&[true, false, true, true]);
let b = bm(&[false, false, true, false]);
let out = or_masks((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[true, false, true, true]);
for i in 0..4 {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_xor_masks() {
let a = bm(&[true, false, true, false]);
let b = bm(&[false, true, true, false]);
let out = xor_masks((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[true, true, false, false]);
for i in 0..4 {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_not_mask() {
let a = bm(&[true, false, true, false]);
let out = not_mask((&a, 0, a.len));
let expected = bm(&[false, true, false, true]);
for i in 0..4 {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_in_mask_all() {
let a = bm(&[true, false, true]);
let b = bm(&[true, false, true]); let out = in_mask((&a, 0, a.len), (&b, 0, b.len()));
for i in 0..a.len {
assert!(out.get(i), "in_mask (all true/false in rhs) bit {}", i);
}
}
#[test]
fn test_in_mask_true_only() {
let a = bm(&[true, false, true]);
let b = bm(&[true, true, true]);
let out = in_mask((&a, 0, a.len), (&b, 0, b.len()));
assert!(out.get(0));
assert!(!out.get(1));
assert!(out.get(2));
}
#[test]
fn test_in_mask_false_only() {
let a = bm(&[true, false, true]);
let b = bm(&[false, false, false]);
let out = in_mask((&a, 0, a.len), (&b, 0, b.len()));
assert!(!out.get(0));
assert!(out.get(1));
assert!(!out.get(2));
}
#[test]
fn test_not_in_mask() {
let a = bm(&[true, false]);
let b = bm(&[true, false]);
let out = not_in_mask((&a, 0, a.len), (&b, 0, b.len()));
for i in 0..a.len {
assert!(!out.get(i));
}
}
#[test]
fn test_eq_mask() {
let a = bm(&[true, false, true]);
let b = bm(&[true, false, false]);
let out = eq_mask((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[true, true, false]);
for i in 0..a.len {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_ne_mask() {
let a = bm(&[true, false, true]);
let b = bm(&[true, true, false]);
let out = ne_mask((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[false, true, true]);
for i in 0..a.len {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_all_eq_mask_true() {
let a = bm(&[true, false, true, false]);
let b = bm(&[true, false, true, false]);
assert!(all_eq_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_all_eq_mask_false() {
let a = bm(&[true, false, true, false]);
let b = bm(&[false, true, false, true]);
assert!(!all_eq_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_all_ne_mask_true() {
let a = bm(&[true, false]);
let b = bm(&[false, true]);
assert!(all_ne_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_all_ne_mask_false() {
let a = bm(&[true, false]);
let b = bm(&[true, false]);
assert!(!all_ne_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_popcount_mask() {
let a = bm(&[true, false, true, false, true, true]);
assert_eq!(popcount_mask((&a, 0, a.len)), 4);
}
#[test]
fn test_all_true_mask() {
let a = bm(&[true, true, true, true]);
assert!(all_true_mask(&a));
let b = bm(&[true, true, false, true]);
assert!(!all_true_mask(&b));
}
#[test]
fn test_all_false_mask() {
let a = bm(&[false, false, false, false]);
assert!(all_false_mask(&a));
let b = bm(&[false, true, false, false]);
assert!(!all_false_mask(&b));
}
#[test]
fn test_clear_trailing_bits_and_window() {
let mut a = Bitmask::new_set_all(9, true);
a.bits[1] = 0xFF; clear_trailing_bits(&mut a);
assert!(a.get(8));
if a.bits[1] >> 1 != 0 {
panic!("Trailing slack bits not cleared");
}
let a = bm(&[true, false, true, true, false, false, true, false]);
let window = bitmask_window_bytes(&a, 2, 4);
assert_eq!(window.len(), 1); let mut b = a.clone();
let window_mut = bitmask_window_bytes_mut(&mut b, 2, 4);
assert_eq!(window_mut.len(), 1);
}
}