use std::sync::Arc;
use vortex_buffer::BitBufferMut;
use vortex_error::vortex_panic;
use crate::Mask;
#[derive(Debug, Clone)]
pub struct MaskMut(Inner);
impl Default for MaskMut {
fn default() -> Self {
Self::empty()
}
}
#[derive(Debug, Clone)]
enum Inner {
Empty { capacity: usize },
Constant {
value: bool,
len: usize,
capacity: usize,
},
Builder(BitBufferMut),
}
impl MaskMut {
pub fn empty() -> Self {
Self::with_capacity(0)
}
pub fn with_capacity(capacity: usize) -> Self {
Self(Inner::Empty { capacity })
}
pub fn new(len: usize, value: bool) -> Self {
Self(Inner::Constant {
value,
len,
capacity: len,
})
}
pub fn new_true(len: usize) -> Self {
Self(Inner::Constant {
value: true,
len,
capacity: len,
})
}
pub fn new_false(len: usize) -> Self {
Self(Inner::Constant {
value: false,
len,
capacity: len,
})
}
pub fn from_buffer(bit_buffer: BitBufferMut) -> Self {
Self(Inner::Builder(bit_buffer))
}
pub fn value(&self, index: usize) -> bool {
match &self.0 {
Inner::Empty { .. } => {
vortex_panic!("index out of bounds: the length is 0 but the index is {index}")
}
Inner::Constant { value, len, .. } => {
assert!(
index < *len,
"index out of bounds: the length is {} but the index is {index}",
*len
);
*value
}
Inner::Builder(bit_buffer) => bit_buffer.value(index),
}
}
pub fn reserve(&mut self, additional: usize) {
match &mut self.0 {
Inner::Empty { capacity } => {
*capacity += additional;
}
Inner::Constant { capacity, .. } => {
*capacity += additional;
}
Inner::Builder(bits) => {
bits.reserve(additional);
}
}
}
pub unsafe fn set_len(&mut self, new_len: usize) {
debug_assert!(new_len < self.capacity());
match &mut self.0 {
Inner::Empty { capacity, .. } => {
self.0 = Inner::Constant {
value: false, len: new_len,
capacity: *capacity,
}
}
Inner::Constant { len, .. } => {
*len = new_len;
}
Inner::Builder(bits) => {
unsafe { bits.set_len(new_len) };
}
}
}
pub fn capacity(&self) -> usize {
match &self.0 {
Inner::Empty { capacity } => *capacity,
Inner::Constant { capacity, .. } => *capacity,
Inner::Builder(bits) => bits.capacity(),
}
}
pub fn clear(&mut self) {
match &mut self.0 {
Inner::Empty { .. } => {}
Inner::Constant { capacity, .. } => {
self.0 = Inner::Empty {
capacity: *capacity,
}
}
Inner::Builder(bit_buffer) => bit_buffer.clear(),
};
}
pub fn truncate(&mut self, len: usize) {
let truncated_len = len;
if truncated_len > self.len() {
return;
}
match &mut self.0 {
Inner::Empty { .. } => {}
Inner::Constant { len, .. } => *len = truncated_len.min(*len),
Inner::Builder(bit_buffer) => bit_buffer.truncate(truncated_len),
};
}
pub fn append_n(&mut self, new_value: bool, n: usize) {
match &mut self.0 {
Inner::Empty { capacity } => {
self.0 = Inner::Constant {
value: new_value,
len: n,
capacity: (*capacity).max(n),
}
}
Inner::Constant {
value,
len,
capacity,
} => {
if *value == new_value {
self.0 = Inner::Constant {
value: *value,
len: *len + n,
capacity: (*capacity).max(*len + n),
}
} else {
let bits = self.materialize();
bits.append_n(new_value, n);
}
}
Inner::Builder(bits) => {
bits.append_n(new_value, n);
}
}
}
pub fn append_mask(&mut self, other: &Mask) {
match other {
Mask::AllTrue(len) => self.append_n(true, *len),
Mask::AllFalse(len) => self.append_n(false, *len),
Mask::Values(values) => {
let bitbuffer = values.buffer.clone();
self.materialize().append_buffer(&bitbuffer);
}
}
}
fn materialize(&mut self) -> &mut BitBufferMut {
let needs_materialization = !matches!(self.0, Inner::Builder(_));
if needs_materialization {
let new_builder = match &self.0 {
Inner::Empty { capacity } => BitBufferMut::with_capacity(*capacity),
Inner::Constant {
value,
len,
capacity,
} => {
let required_capacity = (*capacity).max(*len);
let mut bits = BitBufferMut::with_capacity(required_capacity);
bits.append_n(*value, *len);
bits
}
Inner::Builder(_) => unreachable!(),
};
self.0 = Inner::Builder(new_builder);
}
match &mut self.0 {
Inner::Builder(bits) => bits,
_ => unreachable!(),
}
}
pub fn split_off(&mut self, at: usize) -> Self {
assert!(at <= self.capacity(), "split_off index out of bounds");
match &mut self.0 {
Inner::Empty { capacity } => {
let new_capacity = *capacity - at;
*capacity = at;
Self(Inner::Empty {
capacity: new_capacity,
})
}
Inner::Constant {
value,
len,
capacity,
} => {
let new_len = len.saturating_sub(at);
let new_capacity = *capacity - at;
*len = (*len).min(at);
*capacity = at;
Self(Inner::Constant {
value: *value,
len: new_len,
capacity: new_capacity,
})
}
Inner::Builder(bits) => {
let new_bits = bits.split_off(at);
Self(Inner::Builder(new_bits))
}
}
}
pub fn unsplit(&mut self, other: Self) {
match other.0 {
Inner::Empty { .. } => {
}
Inner::Constant { value, len, .. } => {
self.append_n(value, len);
}
Inner::Builder(bits) => {
self.materialize().unsplit(bits);
}
}
}
pub fn freeze(self) -> Mask {
match self.0 {
Inner::Empty { .. } => Mask::new_true(0),
Inner::Constant { value, len, .. } => {
if value {
Mask::new_true(len)
} else {
Mask::new_false(len)
}
}
Inner::Builder(bits) => Mask::from_buffer(bits.freeze()),
}
}
pub fn len(&self) -> usize {
match &self.0 {
Inner::Empty { .. } => 0,
Inner::Constant { len, .. } => *len,
Inner::Builder(bits) => bits.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn all_true(&self) -> bool {
match &self.0 {
Inner::Empty { .. } => true,
Inner::Constant { value, .. } => *value,
Inner::Builder(bits) => bits.true_count() == bits.len(),
}
}
pub fn all_false(&self) -> bool {
match &self.0 {
Inner::Empty { .. } => true,
Inner::Constant { value, .. } => !*value,
Inner::Builder(bits) => !bits.is_empty() && bits.true_count() == 0,
}
}
pub fn as_bit_buffer_mut(&mut self) -> Option<&mut BitBufferMut> {
match &mut self.0 {
Inner::Builder(bits) => Some(bits),
_ => None,
}
}
pub fn set(&mut self, index: usize) {
self.set_to(index, true);
}
pub fn unset(&mut self, index: usize) {
self.set_to(index, false);
}
pub fn set_to(&mut self, index: usize, value: bool) {
match &mut self.0 {
Inner::Empty { .. } => {
vortex_panic!("index out of bounds: the length is 0 but the index is {index}")
}
Inner::Constant {
value: current_value,
len,
..
} => {
assert!(
index < *len,
"index out of bounds: the length is {} but the index is {index}",
*len
);
if *current_value != value {
self.materialize().set_to(index, value);
}
}
Inner::Builder(bit_buffer) => {
bit_buffer.set_to(index, value);
}
}
}
pub unsafe fn set_unchecked(&mut self, index: usize) {
unsafe { self.set_to_unchecked(index, true) }
}
pub unsafe fn unset_unchecked(&mut self, index: usize) {
unsafe { self.set_to_unchecked(index, false) }
}
pub unsafe fn set_to_unchecked(&mut self, index: usize, value: bool) {
unsafe {
match &mut self.0 {
Inner::Empty { .. } => {
debug_assert!(false, "cannot set value in empty mask");
}
Inner::Constant {
value: current_value,
len,
..
} => {
debug_assert!(
index < *len,
"index out of bounds: the length is {} but the index is {index}",
*len
);
if *current_value != value {
self.materialize().set_to_unchecked(index, value);
}
}
Inner::Builder(bit_buffer) => {
bit_buffer.set_to_unchecked(index, value);
}
}
}
}
}
impl Mask {
pub fn try_into_mut(self) -> Result<MaskMut, Self> {
match self {
Mask::AllTrue(len) => Ok(MaskMut::new_true(len)),
Mask::AllFalse(len) => Ok(MaskMut::new_false(len)),
Mask::Values(values) => {
let owned_values = Arc::try_unwrap(values).map_err(Mask::Values)?;
let bit_buffer = owned_values.into_buffer();
let mut_buffer = bit_buffer.try_into_mut().map_err(Mask::from_buffer)?;
Ok(MaskMut(Inner::Builder(mut_buffer)))
}
}
}
pub fn into_mut(self) -> MaskMut {
match self {
Mask::AllTrue(len) => MaskMut::new_true(len),
Mask::AllFalse(len) => MaskMut::new_false(len),
Mask::Values(values) => {
let bit_buffer_mut = match Arc::try_unwrap(values) {
Ok(mask_values) => mask_values
.into_buffer()
.try_into_mut()
.unwrap_or_else(|bb| BitBufferMut::copy_from(&bb)),
Err(arc_mask_values) => {
let bit_buffer = arc_mask_values.bit_buffer();
BitBufferMut::copy_from(bit_buffer)
}
};
MaskMut(Inner::Builder(bit_buffer_mut))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_off_empty() {
let mut mask = MaskMut::with_capacity(10);
assert_eq!(mask.len(), 0);
let other = mask.split_off(0);
assert_eq!(mask.len(), 0);
assert_eq!(other.len(), 0);
}
#[test]
fn test_split_off_constant_true_at_zero() {
let mut mask = MaskMut::new_true(10);
let other = mask.split_off(0);
assert_eq!(mask.len(), 0);
assert_eq!(other.len(), 10);
let frozen = other.freeze();
assert_eq!(frozen.true_count(), 10);
}
#[test]
fn test_split_off_constant_true_at_end() {
let mut mask = MaskMut::new_true(10);
let other = mask.split_off(10);
assert_eq!(mask.len(), 10);
assert_eq!(other.len(), 0);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 10);
}
#[test]
fn test_split_off_constant_true_in_middle() {
let mut mask = MaskMut::new_true(10);
let other = mask.split_off(6);
assert_eq!(mask.len(), 6);
assert_eq!(other.len(), 4);
let frozen_first = mask.freeze();
assert_eq!(frozen_first.true_count(), 6);
let frozen_second = other.freeze();
assert_eq!(frozen_second.true_count(), 4);
}
#[test]
fn test_split_off_constant_false() {
let mut mask = MaskMut::new_false(20);
let other = mask.split_off(12);
assert_eq!(mask.len(), 12);
assert_eq!(other.len(), 8);
let frozen_first = mask.freeze();
assert_eq!(frozen_first.true_count(), 0);
let frozen_second = other.freeze();
assert_eq!(frozen_second.true_count(), 0);
}
#[test]
fn test_split_off_builder_at_byte_boundary() {
let mut mask = MaskMut::with_capacity(16);
mask.append_n(true, 8);
mask.append_n(false, 8);
let mask_ptr = match &mask.0 {
Inner::Builder(bits) => bits.as_slice().as_ptr(),
_ => unreachable!(),
};
let other = mask.split_off(8);
assert_eq!(mask.len(), 8);
assert_eq!(other.len(), 8);
mask.unsplit(other);
let new_mask_ptr = match &mask.0 {
Inner::Builder(bits) => bits.as_slice().as_ptr(),
_ => unreachable!(),
};
assert_eq!(mask_ptr, new_mask_ptr);
}
#[test]
fn test_split_off_builder_not_byte_aligned() {
let mut mask = MaskMut::with_capacity(20);
mask.append_n(true, 10);
mask.append_n(false, 10);
let other = mask.split_off(10);
assert_eq!(mask.len(), 10);
assert_eq!(other.len(), 10);
let frozen_first = mask.freeze();
assert_eq!(frozen_first.true_count(), 10);
let frozen_second = other.freeze();
assert_eq!(frozen_second.true_count(), 0);
}
#[test]
fn test_split_off_builder_mixed_pattern() {
let mut mask = MaskMut::with_capacity(15);
for i in 0..15 {
mask.append_n(i % 2 == 0, 1);
}
let other = mask.split_off(7);
assert_eq!(mask.len(), 7);
assert_eq!(other.len(), 8);
let frozen_first = mask.freeze();
assert_eq!(frozen_first.true_count(), 4);
let frozen_second = other.freeze();
assert_eq!(frozen_second.true_count(), 4); }
#[test]
fn test_unsplit_empty_with_empty() {
let mut mask = MaskMut::with_capacity(10);
let other = MaskMut::with_capacity(10);
mask.unsplit(other);
assert_eq!(mask.len(), 0);
}
#[test]
fn test_unsplit_empty_with_constant() {
let mut mask = MaskMut::with_capacity(10);
let other = MaskMut::new_true(5);
mask.unsplit(other);
assert_eq!(mask.len(), 5);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 5);
}
#[test]
fn test_unsplit_constant_with_constant_same() {
let mut mask = MaskMut::new_true(5);
let other = MaskMut::new_true(5);
mask.unsplit(other);
assert_eq!(mask.len(), 10);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 10);
}
#[test]
fn test_unsplit_constant_with_constant_different() {
let mut mask = MaskMut::new_true(5);
let other = MaskMut::new_false(5);
mask.unsplit(other);
assert_eq!(mask.len(), 10);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 5);
}
#[test]
fn test_unsplit_constant_with_builder() {
let mut mask = MaskMut::new_true(5);
let mut other = MaskMut::with_capacity(10);
other.append_n(true, 3);
other.append_n(false, 2);
mask.unsplit(other);
assert_eq!(mask.len(), 10);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 8); }
#[test]
fn test_unsplit_builder_with_constant() {
let mut mask = MaskMut::with_capacity(10);
mask.append_n(true, 3);
mask.append_n(false, 2);
let other = MaskMut::new_true(5);
mask.unsplit(other);
assert_eq!(mask.len(), 10);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 8); }
#[test]
fn test_unsplit_builder_with_builder() {
let mut mask = MaskMut::with_capacity(10);
mask.append_n(true, 3);
mask.append_n(false, 2);
let mut other = MaskMut::with_capacity(10);
other.append_n(false, 3);
other.append_n(true, 2);
mask.unsplit(other);
assert_eq!(mask.len(), 10);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 5); }
#[test]
fn test_round_trip_split_unsplit() {
let mut original = MaskMut::with_capacity(20);
original.append_n(true, 10);
original.append_n(false, 10);
let original_frozen = original.freeze();
let original_true_count = original_frozen.true_count();
let mut mask = original_frozen.try_into_mut().unwrap();
let other = mask.split_off(10);
mask.unsplit(other);
assert_eq!(mask.len(), 20);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), original_true_count);
}
#[test]
#[should_panic(expected = "split_off index out of bounds")]
fn test_split_off_out_of_bounds() {
let mut mask = MaskMut::new_true(10);
mask.split_off(11);
}
#[test]
fn test_split_off_builder_at_bit_1() {
let mut mask = MaskMut::with_capacity(16);
mask.append_n(true, 16);
let other = mask.split_off(1);
assert_eq!(mask.len(), 1);
assert_eq!(other.len(), 15);
let frozen_first = mask.freeze();
assert_eq!(frozen_first.true_count(), 1);
let frozen_second = other.freeze();
assert_eq!(frozen_second.true_count(), 15);
}
#[test]
fn test_multiple_split_unsplit() {
let mut mask = MaskMut::new_true(30);
let third = mask.split_off(20); let second = mask.split_off(10);
assert_eq!(mask.len(), 10);
assert_eq!(second.len(), 10);
assert_eq!(third.len(), 10);
mask.unsplit(second);
mask.unsplit(third);
assert_eq!(mask.len(), 30);
let frozen = mask.freeze();
assert_eq!(frozen.true_count(), 30);
}
#[test]
fn test_try_into_mut_all_variants() {
let mask_true = Mask::new_true(100);
let mut_mask_true = mask_true.try_into_mut().unwrap();
assert_eq!(mut_mask_true.len(), 100);
assert_eq!(mut_mask_true.freeze().true_count(), 100);
let mask_false = Mask::new_false(50);
let mut_mask_false = mask_false.try_into_mut().unwrap();
assert_eq!(mut_mask_false.len(), 50);
assert_eq!(mut_mask_false.freeze().true_count(), 0);
}
#[test]
fn test_try_into_mut_with_references() {
let mut mask_mut = MaskMut::with_capacity(10);
mask_mut.append_n(true, 5);
mask_mut.append_n(false, 5);
let mask = mask_mut.freeze();
let mask2 = {
let mut mask_mut2 = MaskMut::with_capacity(10);
mask_mut2.append_n(true, 5);
mask_mut2.append_n(false, 5);
mask_mut2.freeze()
};
let result = mask2.try_into_mut();
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 10);
let _cloned = mask.clone();
let result = mask.try_into_mut();
assert!(result.is_err());
if let Err(returned_mask) = result {
assert_eq!(returned_mask.len(), 10);
assert_eq!(returned_mask.true_count(), 5);
}
}
#[test]
fn test_try_into_mut_round_trip() {
let mut original = MaskMut::with_capacity(20);
original.append_n(true, 10);
original.append_n(false, 10);
let frozen = original.freeze();
assert_eq!(frozen.true_count(), 10);
let mut mut_mask = frozen.try_into_mut().unwrap();
mut_mask.append_n(true, 5);
assert_eq!(mut_mask.len(), 25);
let frozen_again = mut_mask.freeze();
assert_eq!(frozen_again.true_count(), 15);
}
}