use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
use std::ops::{BitAnd, BitOr, Deref, DerefMut, Index, Not};
use crate::enums::shape_dim::ShapeDim;
use crate::traits::concatenate::Concatenate;
use crate::traits::shape::Shape;
use crate::{BitmaskV, Buffer, Length, Offset};
use vec64::Vec64;
#[repr(C, align(64))]
#[derive(Clone, PartialEq, Default)]
pub struct Bitmask {
pub bits: Buffer<u8>,
pub len: usize,
}
impl Bitmask {
#[inline]
pub fn new(data: impl Into<Buffer<u8>>, len: usize) -> Self {
let data: Buffer<u8> = data.into();
Self { bits: data, len }
}
#[inline]
pub fn mask_trailing_bits(&mut self) {
if self.len == 0 || (self.len & 7) == 0 {
return;
}
let last = self.bits.len() - 1;
let mask = (1u8 << (self.len & 7)) - 1;
self.bits[last] &= mask;
}
#[inline]
pub fn new_set_all(len: usize, set: bool) -> Self {
let n_bytes = (len + 7) / 8;
let mut data = Vec64::with_capacity(n_bytes);
let fill = if set { 0xFF } else { 0 };
data.resize(n_bytes, fill);
let mut mask = Self {
bits: data.into(),
len,
};
mask.mask_trailing_bits();
mask
}
#[inline]
pub fn with_capacity(bits: usize) -> Self {
let n_bytes = (bits + 7) / 8;
let mut data = Vec64::with_capacity(n_bytes);
data.resize(n_bytes, 0);
let mut mask = Self {
bits: data.into(),
len: bits,
};
mask.mask_trailing_bits();
mask
}
pub unsafe fn from_raw_slice(ptr: *const u8, len: usize) -> Self {
if ptr.is_null() || len == 0 {
return Bitmask::default();
}
let n_bytes = (len + 7) / 8;
let slice = unsafe { std::slice::from_raw_parts(ptr, n_bytes) };
let mut buf = Vec64::with_capacity(n_bytes);
buf.extend_from_slice(slice);
let mut out = Bitmask {
bits: buf.into(),
len,
};
out.mask_trailing_bits();
out
}
#[inline(always)]
pub fn as_bytes(&self) -> &[u8] {
self.as_ref()
}
pub fn from_bytes(bytes: impl AsRef<[u8]>, len: usize) -> Self {
let mut mask = Bitmask::with_capacity(len);
let bytes = bytes.as_ref();
for i in 0..len {
let valid = (bytes[i >> 3] >> (i & 7)) & 1 != 0;
mask.set(i, valid);
}
mask
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn capacity(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn all_set(&self) -> bool {
self.count_ones() == self.len
}
#[inline]
pub fn all_unset(&self) -> bool {
self.count_ones() == 0
}
#[inline]
pub fn has_cleared(&self) -> bool {
!self.all_set()
}
#[inline]
pub fn to_owned_copy(&self) -> Self {
let owned_bits = self.bits.to_owned_copy();
Bitmask {
bits: owned_bits,
len: self.len,
}
}
#[inline]
pub fn get(&self, idx: usize) -> bool {
let cap_bits = self.bits.len() * 8;
assert!(
idx < cap_bits,
"Bitmask::get out of physical bounds (idx={idx}, cap={cap_bits})"
);
if idx >= self.len {
return false;
}
let byte = unsafe { self.bits.get_unchecked(idx >> 3) };
(byte >> (idx & 7)) & 1 != 0
}
#[inline]
pub fn set(&mut self, i: usize, value: bool) {
self.ensure_capacity(i + 1);
let byte = &mut self.bits[i >> 3];
let bit = 1u8 << (i & 7);
if value {
*byte |= bit;
} else {
*byte &= !bit;
}
self.mask_trailing_bits();
}
#[inline(always)]
pub unsafe fn set_unchecked(&mut self, i: usize, value: bool) {
let byte = unsafe { self.bits.get_unchecked_mut(i >> 3) };
let bit = 1u8 << (i & 7);
if value {
*byte |= bit;
} else {
*byte &= !bit;
}
}
#[inline(always)]
pub unsafe fn word_unchecked(&self, w: usize) -> u64 {
unsafe { *self.bits.as_ptr().cast::<u64>().add(w) }
}
#[inline(always)]
pub unsafe fn set_word_unchecked(&mut self, w: usize, word: u64) {
unsafe { *self.bits.as_mut_ptr().cast::<u64>().add(w) = word };
}
#[inline]
pub fn ensure_capacity(&mut self, bits: usize) {
let needed = (bits + 7) / 8;
if self.bits.len() < needed {
self.bits.resize(needed, 0);
}
if bits > self.len {
self.len = bits;
self.mask_trailing_bits();
}
}
#[inline]
pub fn set_bits_chunk(&mut self, start: usize, value: u64, n_bits: usize) {
assert!(n_bits <= 64, "set_bits_chunk: n_bits > 64");
for i in 0..n_bits {
let bit_val = ((value >> i) & 1) != 0;
self.set(start + i, bit_val);
}
self.mask_trailing_bits();
}
#[inline]
pub fn push_bits(&mut self, value: bool, n: usize) {
self.resize(self.len + n, value);
}
#[inline]
pub fn all_true(&self) -> bool {
if self.len == 0 {
return true;
}
let full_bytes = self.len / 8;
let last_bits = self.len & 7;
if !self.bits[..full_bytes].iter().all(|&b| b == 0xFF) {
return false;
}
if last_bits != 0 {
let mask = (1u8 << last_bits) - 1;
self.bits[full_bytes] & mask == mask
} else {
true
}
}
#[inline]
pub fn all_false(&self) -> bool {
if self.len == 0 {
return true;
}
let full_bytes = self.len / 8;
let last_bits = self.len & 7;
if !self.bits[..full_bytes].iter().all(|&b| b == 0) {
return false;
}
if last_bits != 0 {
let mask = (1u8 << last_bits) - 1;
self.bits[full_bytes] & mask == 0
} else {
true
}
}
#[inline]
pub fn from_bools(bits: &[bool]) -> Self {
let len = bits.len();
let n_bytes = (len + 7) / 8;
let mut data = Vec64::with_capacity(n_bytes);
data.resize(n_bytes, 0);
for (i, &b) in bits.iter().enumerate() {
if b {
data[i >> 3] |= 1u8 << (i & 7);
}
}
let mut mask = Self {
bits: data.into(),
len,
};
mask.mask_trailing_bits();
mask
}
#[inline]
pub fn has_nulls(&self) -> bool {
!self.all_true()
}
#[inline]
pub fn as_ptr(&self) -> *const u8 {
self.bits.as_ptr()
}
#[inline]
pub fn set_true(&mut self, idx: usize) {
self.set(idx, true)
}
#[inline]
pub fn set_false(&mut self, idx: usize) {
self.set(idx, false)
}
#[inline]
pub fn count_ones(&self) -> usize {
let full_bytes = self.len / 8;
let mut count = self.bits[..full_bytes]
.iter()
.map(|&b| b.count_ones() as usize)
.sum::<usize>();
let rem = self.len & 7;
if rem != 0 {
let mask = (1u8 << rem) - 1;
count += (self.bits[full_bytes] & mask).count_ones() as usize;
}
count
}
#[inline]
pub fn count_zeros(&self) -> usize {
self.len - self.count_ones()
}
#[inline]
pub fn null_count(&self) -> usize {
self.count_zeros()
}
pub fn resize(&mut self, new_len: usize, set: bool) {
let new_bytes = (new_len + 7) / 8;
let fill = if set { 0xFF } else { 0 };
self.bits.resize(new_bytes, fill);
self.len = new_len;
self.mask_trailing_bits();
}
pub fn split_off(&mut self, at: usize) -> Self {
assert!(at <= self.len, "split_off index out of bounds");
if at == self.len {
return Bitmask::new_set_all(0, false);
}
let start_byte = at / 8;
let bit_offset = at % 8;
let new_len = self.len - at;
if bit_offset == 0 {
let after_bits = self.bits.split_off(start_byte);
self.len = at;
self.mask_trailing_bits();
let mut after = Bitmask {
bits: after_bits,
len: new_len,
};
after.mask_trailing_bits();
return after;
}
let after_bytes_needed = (new_len + 7) / 8;
let mut after_buf = Vec64::with_capacity(after_bytes_needed);
after_buf.resize(after_bytes_needed, 0);
let original_bytes = self.bits.as_slice();
for i in 0..new_len {
let src_bit = at + i;
let src_byte = src_bit / 8;
let src_offset = src_bit % 8;
if src_byte < original_bytes.len() {
let bit_value = (original_bytes[src_byte] >> src_offset) & 1;
let dst_byte = i / 8;
let dst_offset = i % 8;
after_buf[dst_byte] |= bit_value << dst_offset;
}
}
let self_bytes_needed = (at + 7) / 8;
self.bits.resize(self_bytes_needed, 0);
self.len = at;
self.mask_trailing_bits();
let mut after = Bitmask {
bits: after_buf.into(),
len: new_len,
};
after.mask_trailing_bits();
after
}
#[inline]
pub fn extend<I: IntoIterator<Item = bool>>(&mut self, iter: I) {
for bit in iter {
self.set(self.len, bit);
self.len += 1;
}
self.mask_trailing_bits();
}
pub fn extend_from_bitmask(&mut self, other: &Bitmask) {
let old_len = self.len();
self.resize(old_len + other.len(), true);
for i in 0..other.len() {
unsafe { self.set_unchecked(old_len + i, other.get_unchecked(i)) };
}
}
pub fn extend_from_bitmask_range(&mut self, other: &Bitmask, offset: usize, len: usize) {
if len == 0 { return; }
let src_bytes = other.bits.as_slice();
if offset & 7 == 0 {
self.extend_from_slice(&src_bytes[offset >> 3..], len);
} else {
let src_byte_start = offset >> 3;
let bit_shift = (offset & 7) as u32;
let n_src_bytes = ((len + 7) >> 3) + 1; let end = (src_byte_start + n_src_bytes).min(src_bytes.len());
let mut shifted = Vec::with_capacity(n_src_bytes);
for i in src_byte_start..end {
let lo = src_bytes[i] >> bit_shift;
let hi = if i + 1 < src_bytes.len() {
src_bytes[i + 1] << (8 - bit_shift)
} else {
0
};
shifted.push(lo | hi);
}
self.extend_from_slice(&shifted, len);
}
}
pub fn extend_from_slice(&mut self, src: &[u8], len: usize) {
let start = self.len;
let total = start + len;
self.resize(total, false);
let dst = self.bits.as_mut_slice();
if (start & 7) == 0 {
let dst_byte = start >> 3;
let n_full_bytes = len >> 3;
for i in 0..n_full_bytes {
dst[dst_byte + i] = src[i];
}
let tail = len & 7;
if tail != 0 {
let mask = (1u8 << tail) - 1;
dst[dst_byte + n_full_bytes] &= !mask;
dst[dst_byte + n_full_bytes] |= src[n_full_bytes] & mask;
}
self.mask_trailing_bits();
return;
}
for i in 0..len {
let bit = (src[i >> 3] >> (i & 7)) & 1;
if bit != 0 {
let j = start + i;
dst[j >> 3] |= 1 << (j & 7);
} else {
let j = start + i;
dst[j >> 3] &= !(1 << (j & 7));
}
}
self.mask_trailing_bits();
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
self.bits.as_slice()
}
#[inline]
pub fn slice_clone(&self, offset: usize, len: usize) -> Self {
assert!(
offset + len <= self.len,
"Bitmask::slice_clone out of bounds"
);
let mut out = Bitmask::new_set_all(len, false);
let src = self.bits.as_slice();
let dst = out.bits.as_mut_slice();
for i in 0..len {
let src_idx = offset + i;
let src_byte = src_idx / 8;
let src_bit = src_idx % 8;
if (src[src_byte] & (1 << src_bit)) != 0 {
let dst_byte = i / 8;
let dst_bit = i % 8;
dst[dst_byte] |= 1 << dst_bit;
}
}
out.mask_trailing_bits();
out
}
#[inline]
pub fn slice(&self, offset: usize, len: usize) -> (&[u8], usize, usize) {
assert!(offset + len <= self.len, "Bitmask::slice out of bounds");
let start_byte = offset / 8;
let end_bit = offset + len;
let end_byte = (end_bit + 7) / 8;
(&self.bits[start_byte..end_byte], offset % 8, len)
}
#[inline(always)]
pub fn view(&self, offset: Offset, len: Length) -> BitmaskV {
BitmaskV::new(self.clone(), offset, len)
}
pub fn union_opt(a: Option<&Bitmask>, b: Option<&Bitmask>) -> Option<Bitmask> {
match (a, b) {
(None, None) => None,
(Some(m), None) | (None, Some(m)) => Some(m.clone()),
(Some(a), Some(b)) => Some(a.union(b)),
}
}
#[inline]
pub fn union(&self, other: &Self) -> Self {
assert_eq!(self.len, other.len, "Bitmask::union length mismatch");
let mut out = self.clone();
for (a, b) in out.bits.iter_mut().zip(other.bits.iter()) {
*a |= *b;
}
out.mask_trailing_bits();
out
}
#[inline]
pub fn intersect(&self, other: &Self) -> Self {
assert_eq!(self.len, other.len, "Bitmask::intersect length mismatch");
let mut out = self.clone();
for (a, b) in out.bits.iter_mut().zip(other.bits.iter()) {
*a &= *b;
}
out.mask_trailing_bits();
out
}
#[inline]
pub fn invert(&self) -> Self {
let mut out = self.clone();
for b in out.bits.iter_mut() {
*b = !*b;
}
out.mask_trailing_bits();
out
}
pub fn iter_set(&self) -> impl Iterator<Item = usize> + '_ {
let n = self.len;
self.bits.iter().enumerate().flat_map(move |(byte_i, &b)| {
let base = byte_i * 8;
(0..8).filter_map(move |bit| {
let idx = base + bit;
if idx < n && ((b >> bit) & 1) != 0 {
Some(idx)
} else {
None
}
})
})
}
pub fn iter_cleared(&self) -> impl Iterator<Item = usize> + '_ {
let n = self.len;
self.bits.iter().enumerate().flat_map(move |(byte_i, &b)| {
let base = byte_i * 8;
(0..8).filter_map(move |bit| {
let idx = base + bit;
if idx < n && ((b >> bit) & 1) == 0 {
Some(idx)
} else {
None
}
})
})
}
#[inline]
pub fn fill(&mut self, value: bool) {
let fill = if value { 0xFF } else { 0 };
for b in &mut self.bits {
*b = fill;
}
self.mask_trailing_bits();
}
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.bits
}
#[inline(always)]
pub unsafe fn get_unchecked(&self, idx: usize) -> bool {
let byte = unsafe { self.bits.get_unchecked(idx >> 3) };
((*byte) >> (idx & 7)) & 1 != 0
}
#[inline(always)]
pub unsafe fn get_unchecked_byte(&self, byte_idx: usize) -> u8 {
*unsafe { self.bits.get_unchecked(byte_idx) }
}
}
#[cfg(feature = "parallel_proc")]
mod parallel {
use rayon::prelude::*;
use super::Bitmask;
impl Bitmask {
#[inline]
pub fn par_iter(&self) -> impl ParallelIterator<Item = bool> + '_ {
(0..self.len)
.into_par_iter()
.map(move |i| unsafe { self.get_unchecked(i) })
}
#[inline]
pub fn par_iter_range(
&self,
start: usize,
end: usize,
) -> impl ParallelIterator<Item = bool> + '_ {
debug_assert!(start <= end && end <= self.len);
(start..end)
.into_par_iter()
.map(move |i| unsafe { self.get_unchecked(i) })
}
}
}
impl Index<usize> for Bitmask {
type Output = bool;
#[inline(always)]
fn index(&self, index: usize) -> &Self::Output {
if unsafe { self.get_unchecked(index) } {
&true
} else {
&false
}
}
}
impl Debug for Bitmask {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.debug_struct("Bitmask")
.field("len", &self.len)
.field("ones", &self.count_ones())
.field("zeros", &self.count_zeros())
.field("buffer", &self.bits)
.finish()
}
}
impl BitAnd for &Bitmask {
type Output = Bitmask;
#[inline]
fn bitand(self, rhs: Self) -> Bitmask {
self.intersect(rhs)
}
}
impl BitOr for &Bitmask {
type Output = Bitmask;
#[inline]
fn bitor(self, rhs: Self) -> Bitmask {
self.union(rhs)
}
}
impl Not for &Bitmask {
type Output = Bitmask;
#[inline]
fn not(self) -> Bitmask {
self.invert()
}
}
impl Not for Bitmask {
type Output = Bitmask;
#[inline]
fn not(self) -> Bitmask {
self.invert()
}
}
impl AsRef<[u8]> for Bitmask {
#[inline]
fn as_ref(&self) -> &[u8] {
self.bits.as_ref()
}
}
impl AsMut<[u8]> for Bitmask {
#[inline]
fn as_mut(&mut self) -> &mut [u8] {
self.bits.as_mut()
}
}
impl Deref for Bitmask {
type Target = [u8];
#[inline]
fn deref(&self) -> &Self::Target {
self.bits.as_ref()
}
}
impl DerefMut for Bitmask {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
self.bits.as_mut()
}
}
impl Display for Bitmask {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let ones = self.count_ones();
let zeros = self.count_zeros();
writeln!(
f,
"Bitmask [{} bits] (ones: {}, zeros: {})",
self.len, ones, zeros
)?;
const MAX_PREVIEW: usize = 64;
write!(f, "[")?;
for i in 0..usize::min(self.len, MAX_PREVIEW) {
if i > 0 {
write!(f, " ")?;
}
write!(
f,
"{}",
if unsafe { self.get_unchecked(i) } {
'1'
} else {
'0'
}
)?;
}
if self.len > MAX_PREVIEW {
write!(f, " … ({} total)", self.len)?;
}
write!(f, "]")
}
}
impl Shape for Bitmask {
fn shape(&self) -> ShapeDim {
ShapeDim::Rank1(self.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitmask_new_set_get() {
let mut m = Bitmask::new_set_all(10, false);
for i in 0..10 {
assert!(!m.get(i));
}
m.set(3, true);
assert!(m.get(3));
m.set(3, false);
assert!(!m.get(3));
}
#[test]
fn test_ensure_capacity_and_resize() {
let mut m = Bitmask::new_set_all(1, false);
m.ensure_capacity(20);
assert!(m.len >= 20);
m.set(15, true);
assert!(m.get(15));
m.resize(100, false);
assert!(m.len == 100);
}
#[test]
fn test_count_and_all() {
let mut m = Bitmask::new_set_all(16, true);
assert_eq!(m.count_ones(), 16);
assert!(m.all_set());
m.set(0, false);
assert_eq!(m.count_zeros(), 1);
assert!(!m.all_set());
assert!(!m.all_unset());
}
#[test]
fn test_invert_union_and_intersect() {
let mut a = Bitmask::new_set_all(8, false);
let mut b = Bitmask::new_set_all(8, false);
a.set(1, true);
a.set(3, true);
b.set(3, true);
b.set(4, true);
let u = &a | &b;
assert!(u.get(1) && u.get(3) && u.get(4));
let i = &a & &b;
assert!(!i.get(1) && i.get(3));
let inv = !&a;
assert!(!inv.get(3) && inv.get(2));
}
#[test]
fn test_set_bits_chunk_and_push_bits() {
let mut m = Bitmask::new_set_all(16, false);
m.set_bits_chunk(0, 0b10101, 5);
assert!(m.get(0));
assert!(!m.get(1));
assert!(m.get(2));
assert!(!m.get(3));
assert!(m.get(4));
m.push_bits(true, 3);
for i in 16..19 {
assert!(m.get(i));
}
}
#[test]
fn test_slice_clone_and_view() {
let mut m = Bitmask::new_set_all(10, false);
m.set(2, true);
m.set(5, true);
let sub = m.slice_clone(2, 4);
assert_eq!(sub.capacity(), 4);
assert!(sub.get(0) && sub.get(3));
let (buf, offset, len) = m.slice(2, 4);
let bit = (buf[0] >> offset) & 1 != 0;
assert_eq!(bit, true);
assert_eq!(len, 4);
}
#[test]
fn test_iter_set_and_iter_cleared() {
let mut m = Bitmask::new_set_all(12, false);
m.set(2, true);
m.set(5, true);
m.set(10, true);
let set: Vec<_> = m.iter_set().collect();
assert_eq!(set, vec![2, 5, 10]);
let cleared: Vec<_> = m.iter_cleared().collect();
assert!(cleared.contains(&0) && cleared.contains(&11) && !cleared.contains(&2));
}
#[test]
fn test_extend_from_slice() {
let mut mask = Bitmask::new_set_all(5, false);
mask.set(0, true);
mask.set(2, true);
mask.set(4, true);
let src_bytes = [0b01100110u8];
mask.extend_from_slice(&src_bytes, 7);
let expected = [
true, false, true, false, true, false, true, true, false, false, true, true, ];
for (i, &exp) in expected.iter().enumerate() {
assert_eq!(mask.get(i), exp, "Mismatch at bit {}", i);
}
let mut m2 = Bitmask::new_set_all(8, true);
let add_bytes = [0b10101100u8]; m2.extend_from_slice(&add_bytes, 8);
let expected2 = [
true, true, true, true, true, true, true, true, false, false, true, true, false, true, false, true, ];
for (i, &exp) in expected2.iter().enumerate() {
assert_eq!(m2.get(i), exp, "Mismatch at bit {}", i);
}
let mut m3 = Bitmask::new_set_all(3, false);
let empty_bytes = [0u8];
m3.extend_from_slice(&empty_bytes, 0);
assert_eq!(m3.len(), 3);
}
#[test]
fn test_union_opt_none_none() {
assert!(Bitmask::union_opt(None, None).is_none());
}
#[test]
fn test_union_opt_some_none() {
let m = Bitmask::from_bools(&[true, false, true]);
let result = Bitmask::union_opt(Some(&m), None).unwrap();
assert_eq!(result, m);
}
#[test]
fn test_union_opt_none_some() {
let m = Bitmask::from_bools(&[false, true, false]);
let result = Bitmask::union_opt(None, Some(&m)).unwrap();
assert_eq!(result, m);
}
#[test]
fn test_union_opt_some_some() {
let a = Bitmask::from_bools(&[true, false, false, true]);
let b = Bitmask::from_bools(&[false, true, false, true]);
let result = Bitmask::union_opt(Some(&a), Some(&b)).unwrap();
assert!(result.get(0)); assert!(result.get(1)); assert!(!result.get(2)); assert!(result.get(3)); }
#[test]
fn test_concatenate() {
let mut m1 = Bitmask::new_set_all(5, false);
m1.set(0, true);
m1.set(2, true);
m1.set(4, true);
let mut m2 = Bitmask::new_set_all(4, false);
m2.set(1, true);
m2.set(3, true);
let result = m1.concat(m2).unwrap();
assert_eq!(result.len(), 9);
assert!(result.get(0));
assert!(!result.get(1));
assert!(result.get(2));
assert!(!result.get(3));
assert!(result.get(4));
assert!(!result.get(5));
assert!(result.get(6));
assert!(!result.get(7));
assert!(result.get(8));
}
}
impl Concatenate for Bitmask {
fn concat(
mut self,
other: Self,
) -> core::result::Result<Self, crate::enums::error::MinarrowError> {
self.extend_from_bitmask(&other);
Ok(self)
}
}