use crate::{Bitmask, BitmaskVT};
use crate::{
enums::operators::{LogicalOperator, UnaryOperator},
kernels::bitmask::{
bitmask_window_bytes, bitmask_window_bytes_mut, clear_trailing_bits,
},
};
#[inline(always)]
pub fn bitmask_binop_std(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>, op: LogicalOperator) -> Bitmask {
let (lhs_mask, lhs_off, len) = lhs;
let (rhs_mask, rhs_off, _) = rhs;
if len == 0 {
return Bitmask::new_set_all(0, false);
}
let mut out = Bitmask::new_set_all(len, false);
{
let lhs_bytes = bitmask_window_bytes(lhs_mask, lhs_off, len);
let rhs_bytes = bitmask_window_bytes(rhs_mask, rhs_off, len);
let dp_bytes = bitmask_window_bytes_mut(&mut out, 0, len);
let total_bytes = lhs_bytes.len();
let full_words = total_bytes / 8;
let tail_bytes = total_bytes % 8;
unsafe {
let lp = lhs_bytes.as_ptr().cast::<u64>();
let rp = rhs_bytes.as_ptr().cast::<u64>();
let dp = dp_bytes.as_mut_ptr().cast::<u64>();
for k in 0..full_words {
*dp.add(k) = match op {
LogicalOperator::And => *lp.add(k) & *rp.add(k),
LogicalOperator::Or => *lp.add(k) | *rp.add(k),
LogicalOperator::Xor => *lp.add(k) ^ *rp.add(k),
};
}
}
let base = full_words * 8;
for k in 0..tail_bytes {
let a = lhs_bytes[base + k];
let b = rhs_bytes[base + k];
dp_bytes[base + k] = match op {
LogicalOperator::And => a & b,
LogicalOperator::Or => a | b,
LogicalOperator::Xor => a ^ b,
};
}
}
out.len = len;
clear_trailing_bits(&mut out);
out
}
#[inline(always)]
pub fn bitmask_unop_std(src: BitmaskVT<'_>, op: UnaryOperator) -> Bitmask {
let (mask, offset, len) = src;
if len == 0 {
return Bitmask::new_set_all(0, false);
}
let mut out = Bitmask::new_set_all(len, false);
{
let src_bytes = bitmask_window_bytes(mask, offset, len);
let dp_bytes = bitmask_window_bytes_mut(&mut out, 0, len);
let total_bytes = src_bytes.len();
let full_words = total_bytes / 8;
let tail_bytes = total_bytes % 8;
unsafe {
let sp = src_bytes.as_ptr().cast::<u64>();
let dp = dp_bytes.as_mut_ptr().cast::<u64>();
for k in 0..full_words {
*dp.add(k) = match op {
UnaryOperator::Not => !*sp.add(k),
_ => unreachable!(), };
}
}
let base = full_words * 8;
for k in 0..tail_bytes {
let a = src_bytes[base + k];
dp_bytes[base + k] = match op {
UnaryOperator::Not => !a,
_ => unreachable!(),
};
}
}
out.len = len;
clear_trailing_bits(&mut out);
out
}
#[inline]
pub fn and_masks(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
bitmask_binop_std(lhs, rhs, LogicalOperator::And)
}
#[inline]
pub fn or_masks(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
bitmask_binop_std(lhs, rhs, LogicalOperator::Or)
}
#[inline]
pub fn xor_masks(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
bitmask_binop_std(lhs, rhs, LogicalOperator::Xor)
}
#[inline]
pub fn not_mask(src: BitmaskVT<'_>) -> Bitmask {
bitmask_unop_std(src, UnaryOperator::Not)
}
#[inline]
pub fn in_mask(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
let (lhs_mask, lhs_off, len) = lhs;
let (rhs_mask, rhs_off, _) = rhs;
let mut has_true = false;
let mut has_false = false;
for i in 0..len {
let v = unsafe { rhs_mask.get_unchecked(rhs_off + i) };
if v {
has_true = true;
} else {
has_false = true;
}
if has_true && has_false {
break;
}
}
match (has_true, has_false) {
(true, true) => Bitmask::new_set_all(len, true),
(true, false) => lhs_mask.slice_clone(lhs_off, len),
(false, true) => not_mask((lhs_mask, lhs_off, len)),
(false, false) => Bitmask::new_set_all(len, false),
}
}
#[inline]
pub fn not_in_mask(lhs: BitmaskVT<'_>, rhs: BitmaskVT<'_>) -> Bitmask {
let mask = in_mask(lhs, rhs);
not_mask((&mask, 0, mask.len))
}
#[inline]
pub fn eq_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> Bitmask {
let (am, ao, len) = a;
let (bm, bo, blen) = b;
debug_assert_eq!(len, blen, "BitWindow length mismatch in eq_mask");
if len == 0 {
return Bitmask::new_set_all(0, true);
}
if ao % 64 != 0 || bo % 64 != 0 {
panic!(
"eq_mask: offsets must be 64-bit aligned (got a: {}, b: {})",
ao, bo
);
}
let mut out = Bitmask::new_set_all(len, false);
{
let a_bytes = bitmask_window_bytes(am, ao, len);
let b_bytes = bitmask_window_bytes(bm, bo, len);
let dp_bytes = bitmask_window_bytes_mut(&mut out, 0, len);
let total_bytes = a_bytes.len();
let full_words = total_bytes / 8;
let tail_bytes = total_bytes % 8;
unsafe {
let ap = a_bytes.as_ptr().cast::<u64>();
let bp = b_bytes.as_ptr().cast::<u64>();
let dp = dp_bytes.as_mut_ptr().cast::<u64>();
for k in 0..full_words {
*dp.add(k) = !(*ap.add(k) ^ *bp.add(k));
}
}
let base = full_words * 8;
for k in 0..tail_bytes {
dp_bytes[base + k] = !(a_bytes[base + k] ^ b_bytes[base + k]);
}
}
out.mask_trailing_bits();
out
}
#[inline]
pub fn ne_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> Bitmask {
let eq = eq_mask(a, b);
not_mask((&eq, 0, eq.len))
}
#[inline]
pub fn all_eq_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> bool {
let (am, ao, len) = a;
let (bm, bo, blen) = b;
debug_assert_eq!(len, blen);
if len == 0 {
return true;
}
if ao % 64 != 0 || bo % 64 != 0 {
panic!(
"all_eq_mask: offsets must be 64-bit aligned (got a: {}, b: {})",
ao, bo
);
}
let a_bytes = bitmask_window_bytes(am, ao, len);
let b_bytes = bitmask_window_bytes(bm, bo, len);
let total_bytes = a_bytes.len();
let full_words = total_bytes / 8;
let tail_bytes = total_bytes % 8;
let full_logical_bytes = len / 8;
let last_bits = len & 7;
unsafe {
let ap = a_bytes.as_ptr().cast::<u64>();
let bp = b_bytes.as_ptr().cast::<u64>();
for k in 0..full_words {
if *ap.add(k) != *bp.add(k) {
return false;
}
}
}
let base = full_words * 8;
for k in 0..tail_bytes {
let byte_index = base + k;
let av = a_bytes[byte_index];
let bv = b_bytes[byte_index];
if byte_index < full_logical_bytes {
if av != bv {
return false;
}
} else if last_bits != 0 {
let m = (1u8 << last_bits) - 1;
if (av & m) != (bv & m) {
return false;
}
}
}
true
}
#[inline]
pub fn all_ne_mask(a: BitmaskVT<'_>, b: BitmaskVT<'_>) -> bool {
!all_eq_mask(a, b)
}
#[inline]
pub fn popcount_mask(m: BitmaskVT<'_>) -> usize {
let (mask, offset, len) = m;
if len == 0 {
return 0;
}
let bytes = bitmask_window_bytes(mask, offset, len);
let total_bytes = bytes.len();
let full_words = total_bytes / 8;
let tail_bytes = total_bytes % 8;
let full_logical_bytes = len / 8;
let last_bits = len & 7;
let mut acc = 0usize;
unsafe {
let words_ptr = bytes.as_ptr().cast::<u64>();
for k in 0..full_words {
acc += (*words_ptr.add(k)).count_ones() as usize;
}
}
let base = full_words * 8;
for k in 0..tail_bytes {
let byte_index = base + k;
let b = bytes[byte_index];
if byte_index < full_logical_bytes {
acc += b.count_ones() as usize;
} else if last_bits != 0 {
let m = (1u8 << last_bits) - 1;
acc += (b & m).count_ones() as usize;
}
}
acc
}
#[inline]
pub fn all_true_mask(mask: &Bitmask) -> bool {
let n_bits = mask.len;
if n_bits == 0 {
return true;
}
let bytes = mask.bits.as_ref();
let total_bytes = (n_bits + 7) / 8;
let full_words = total_bytes / 8;
let tail_bytes = total_bytes % 8;
let full_logical_bytes = n_bits / 8;
let last_bits = n_bits & 7;
unsafe {
let words_ptr = bytes.as_ptr().cast::<u64>();
for k in 0..full_words {
if *words_ptr.add(k) != !0u64 {
return false;
}
}
}
let base = full_words * 8;
for k in 0..tail_bytes {
let byte_index = base + k;
let b = bytes[byte_index];
if byte_index < full_logical_bytes {
if b != 0xFF {
return false;
}
} else if last_bits != 0 {
let m = (1u8 << last_bits) - 1;
if (b & m) != m {
return false;
}
}
}
true
}
#[inline]
pub fn all_false_mask(mask: &Bitmask) -> bool {
let n_bits = mask.len;
if n_bits == 0 {
return true;
}
let bytes = mask.bits.as_ref();
let total_bytes = (n_bits + 7) / 8;
let full_words = total_bytes / 8;
let tail_bytes = total_bytes % 8;
let full_logical_bytes = n_bits / 8;
let last_bits = n_bits & 7;
unsafe {
let words_ptr = bytes.as_ptr().cast::<u64>();
for k in 0..full_words {
if *words_ptr.add(k) != 0u64 {
return false;
}
}
}
let base = full_words * 8;
for k in 0..tail_bytes {
let byte_index = base + k;
let b = bytes[byte_index];
if byte_index < full_logical_bytes {
if b != 0 {
return false;
}
} else if last_bits != 0 {
let m = (1u8 << last_bits) - 1;
if (b & m) != 0 {
return false;
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Bitmask;
fn bm(bits: &[bool]) -> Bitmask {
let mut bm = Bitmask::new_set_all(bits.len(), false);
for (i, &b) in bits.iter().enumerate() {
unsafe { bm.set_unchecked(i, b) };
}
bm
}
#[test]
fn test_and_masks() {
let a = bm(&[true, false, true, true, false, false, true, true]);
let b = bm(&[false, false, true, false, true, false, true, false]);
let out = and_masks((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[false, false, true, false, false, false, true, false]);
for i in 0..8 {
assert_eq!(out.get(i), expected.get(i), "Mismatch at bit {}", i);
}
}
#[test]
fn test_or_masks() {
let a = bm(&[true, false, true, true]);
let b = bm(&[false, false, true, false]);
let out = or_masks((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[true, false, true, true]);
for i in 0..4 {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_xor_masks() {
let a = bm(&[true, false, true, false]);
let b = bm(&[false, true, true, false]);
let out = xor_masks((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[true, true, false, false]);
for i in 0..4 {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_not_mask() {
let a = bm(&[true, false, true, false]);
let out = not_mask((&a, 0, a.len));
let expected = bm(&[false, true, false, true]);
for i in 0..4 {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_in_mask_all() {
let a = bm(&[true, false, true]);
let b = bm(&[true, false, true]); let out = in_mask((&a, 0, a.len), (&b, 0, b.len()));
for i in 0..a.len {
assert!(out.get(i), "in_mask (all true/false in rhs) bit {}", i);
}
}
#[test]
fn test_in_mask_true_only() {
let a = bm(&[true, false, true]);
let b = bm(&[true, true, true]);
let out = in_mask((&a, 0, a.len), (&b, 0, b.len()));
assert!(out.get(0));
assert!(!out.get(1));
assert!(out.get(2));
}
#[test]
fn test_in_mask_false_only() {
let a = bm(&[true, false, true]);
let b = bm(&[false, false, false]);
let out = in_mask((&a, 0, a.len), (&b, 0, b.len()));
assert!(!out.get(0));
assert!(out.get(1));
assert!(!out.get(2));
}
#[test]
fn test_not_in_mask() {
let a = bm(&[true, false]);
let b = bm(&[true, false]);
let out = not_in_mask((&a, 0, a.len), (&b, 0, b.len()));
for i in 0..a.len {
assert!(!out.get(i));
}
}
#[test]
fn test_eq_mask() {
let a = bm(&[true, false, true]);
let b = bm(&[true, false, false]);
let out = eq_mask((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[true, true, false]);
for i in 0..a.len {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_ne_mask() {
let a = bm(&[true, false, true]);
let b = bm(&[true, true, false]);
let out = ne_mask((&a, 0, a.len), (&b, 0, b.len()));
let expected = bm(&[false, true, true]);
for i in 0..a.len {
assert_eq!(out.get(i), expected.get(i));
}
}
#[test]
fn test_all_eq_mask_true() {
let a = bm(&[true, false, true, false]);
let b = bm(&[true, false, true, false]);
assert!(all_eq_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_all_eq_mask_false() {
let a = bm(&[true, false, true, false]);
let b = bm(&[false, true, false, true]);
assert!(!all_eq_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_all_ne_mask_true() {
let a = bm(&[true, false]);
let b = bm(&[false, true]);
assert!(all_ne_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_all_ne_mask_false() {
let a = bm(&[true, false]);
let b = bm(&[true, false]);
assert!(!all_ne_mask((&a, 0, a.len), (&b, 0, b.len())));
}
#[test]
fn test_popcount_mask() {
let a = bm(&[true, false, true, false, true, true]);
assert_eq!(popcount_mask((&a, 0, a.len)), 4);
}
#[test]
fn test_all_true_mask() {
let a = bm(&[true, true, true, true]);
assert!(all_true_mask(&a));
let b = bm(&[true, true, false, true]);
assert!(!all_true_mask(&b));
}
#[test]
fn test_all_false_mask() {
let a = bm(&[false, false, false, false]);
assert!(all_false_mask(&a));
let b = bm(&[false, true, false, false]);
assert!(!all_false_mask(&b));
}
#[test]
fn test_clear_trailing_bits_and_window() {
let mut a = Bitmask::new_set_all(9, true);
a.bits[1] = 0xFF; clear_trailing_bits(&mut a);
assert!(a.get(8));
if a.bits[1] >> 1 != 0 {
panic!("Trailing slack bits not cleared");
}
let a = bm(&[true, false, true, true, false, false, true, false]);
let window = bitmask_window_bytes(&a, 2, 4);
assert_eq!(window.len(), 1); let mut b = a.clone();
let window_mut = bitmask_window_bytes_mut(&mut b, 2, 4);
assert_eq!(window_mut.len(), 1);
}
}