use itertools::Itertools;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::bool::BoolArrayExt;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::match_each_unsigned_integer_ptype;
use vortex_array::scalar::Scalar;
use vortex_array::validity::Validity;
use vortex_buffer::BitBuffer;
use vortex_buffer::BitBufferMut;
use vortex_error::VortexResult;
use vortex_mask::Mask;
use crate::iter::trimmed_ends_iter;
const PREFILL_RUN_THRESHOLD: usize = 32;
pub fn runend_decode_bools(
ends: PrimitiveArray,
values: BoolArray,
offset: usize,
length: usize,
) -> VortexResult<ArrayRef> {
let validity = values.as_ref().validity()?.to_mask(
values.as_ref().len(),
&mut LEGACY_SESSION.create_execution_ctx(),
)?;
let values_buf = values.to_bit_buffer();
let nullability = values.dtype().nullability();
let num_runs = values_buf.len();
if offset == 0 && num_runs < PREFILL_RUN_THRESHOLD {
return Ok(match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
decode_few_runs_no_offset(
ends.as_slice::<E>(),
&values_buf,
validity,
nullability,
length,
)
}));
}
Ok(match_each_unsigned_integer_ptype!(ends.ptype(), |E| {
runend_decode_typed_bool(
trimmed_ends_iter(ends.as_slice::<E>(), offset, length),
&values_buf,
validity,
nullability,
length,
)
}))
}
pub fn runend_decode_typed_bool(
run_ends: impl Iterator<Item = usize>,
values: &BitBuffer,
values_validity: Mask,
values_nullability: Nullability,
length: usize,
) -> ArrayRef {
match values_validity {
Mask::AllTrue(_) => {
decode_bool_non_nullable(run_ends, values, values_nullability, length).into_array()
}
Mask::AllFalse(_) => {
ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), length)
.into_array()
}
Mask::Values(mask) => {
decode_bool_nullable(run_ends, values, mask.bit_buffer(), length).into_array()
}
}
}
#[inline(always)]
fn decode_few_runs_no_offset<E: vortex_array::dtype::IntegerPType>(
ends: &[E],
values: &BitBuffer,
validity: Mask,
nullability: Nullability,
length: usize,
) -> ArrayRef {
match validity {
Mask::AllTrue(_) => {
let mut decoded = BitBufferMut::with_capacity(length);
let mut prev_end = 0usize;
for (i, &end) in ends.iter().enumerate() {
let end = end.as_().min(length);
decoded.append_n(values.value(i), end - prev_end);
prev_end = end;
}
BoolArray::new(decoded.freeze(), nullability.into()).into_array()
}
Mask::AllFalse(_) => {
ConstantArray::new(Scalar::null(DType::Bool(Nullability::Nullable)), length)
.into_array()
}
Mask::Values(mask) => {
let validity_buf = mask.bit_buffer();
let mut decoded = BitBufferMut::with_capacity(length);
let mut decoded_validity = BitBufferMut::with_capacity(length);
let mut prev_end = 0usize;
for (i, &end) in ends.iter().enumerate() {
let end = end.as_().min(length);
let run_len = end - prev_end;
let is_valid = validity_buf.value(i);
if is_valid {
decoded_validity.append_n(true, run_len);
decoded.append_n(values.value(i), run_len);
} else {
decoded_validity.append_n(false, run_len);
decoded.append_n(false, run_len);
}
prev_end = end;
}
BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze())).into_array()
}
}
}
fn decode_bool_non_nullable(
run_ends: impl Iterator<Item = usize>,
values: &BitBuffer,
nullability: Nullability,
length: usize,
) -> BoolArray {
let num_runs = values.len();
if num_runs < PREFILL_RUN_THRESHOLD {
let mut decoded = BitBufferMut::with_capacity(length);
for (end, value) in run_ends.zip(values.iter()) {
decoded.append_n(value, end - decoded.len());
}
return BoolArray::new(decoded.freeze(), nullability.into());
}
let prefill = values.true_count() > num_runs - values.true_count();
let mut decoded = BitBufferMut::full(prefill, length);
let mut current_pos = 0usize;
for (end, value) in run_ends.zip_eq(values.iter()) {
if end > current_pos && value != prefill {
unsafe { decoded.fill_range_unchecked(current_pos, end, value) };
}
current_pos = end;
}
BoolArray::new(decoded.freeze(), nullability.into())
}
fn decode_bool_nullable(
run_ends: impl Iterator<Item = usize>,
values: &BitBuffer,
validity_mask: &BitBuffer,
length: usize,
) -> BoolArray {
let num_runs = values.len();
if num_runs < PREFILL_RUN_THRESHOLD {
return decode_nullable_sequential(run_ends, values, validity_mask, length);
}
let prefill_decoded = values.true_count() > num_runs - values.true_count();
let prefill_valid = validity_mask.true_count() > num_runs - validity_mask.true_count();
let mut decoded = BitBufferMut::full(prefill_decoded, length);
let mut decoded_validity = BitBufferMut::full(prefill_valid, length);
let mut current_pos = 0usize;
for (end, (value, is_valid)) in run_ends.zip_eq(values.iter().zip(validity_mask.iter())) {
if end > current_pos {
if is_valid != prefill_valid {
unsafe { decoded_validity.fill_range_unchecked(current_pos, end, is_valid) };
}
let want_decoded = is_valid && value;
if want_decoded != prefill_decoded {
unsafe { decoded.fill_range_unchecked(current_pos, end, want_decoded) };
}
current_pos = end;
}
}
BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze()))
}
#[inline(always)]
fn decode_nullable_sequential(
run_ends: impl Iterator<Item = usize>,
values: &BitBuffer,
validity_mask: &BitBuffer,
length: usize,
) -> BoolArray {
let mut decoded = BitBufferMut::with_capacity(length);
let mut decoded_validity = BitBufferMut::with_capacity(length);
for (end, (value, is_valid)) in run_ends.zip(values.iter().zip(validity_mask.iter())) {
let run_len = end - decoded.len();
if is_valid {
decoded_validity.append_n(true, run_len);
decoded.append_n(value, run_len);
} else {
decoded_validity.append_n(false, run_len);
decoded.append_n(false, run_len);
}
}
BoolArray::new(decoded.freeze(), Validity::from(decoded_validity.freeze()))
}
#[cfg(test)]
mod tests {
use vortex_array::LEGACY_SESSION;
use vortex_array::ToCanonical;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::BoolArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::bool::BoolArrayExt;
use vortex_array::assert_arrays_eq;
use vortex_array::validity::Validity;
use vortex_buffer::BitBuffer;
use vortex_error::VortexResult;
use super::runend_decode_bools;
#[test]
fn decode_bools_alternating() -> VortexResult<()> {
let ends = PrimitiveArray::from_iter([2u32, 5, 10]);
let values = BoolArray::from(BitBuffer::from(vec![true, false, true]));
let decoded = runend_decode_bools(ends, values, 0, 10)?;
let expected = BoolArray::from(BitBuffer::from(vec![
true, true, false, false, false, true, true, true, true, true,
]));
assert_arrays_eq!(decoded, expected);
Ok(())
}
#[test]
fn decode_bools_mostly_true() -> VortexResult<()> {
let ends = PrimitiveArray::from_iter([5u32, 6, 10]);
let values = BoolArray::from(BitBuffer::from(vec![true, false, true]));
let decoded = runend_decode_bools(ends, values, 0, 10)?;
let expected = BoolArray::from(BitBuffer::from(vec![
true, true, true, true, true, false, true, true, true, true,
]));
assert_arrays_eq!(decoded, expected);
Ok(())
}
#[test]
fn decode_bools_mostly_false() -> VortexResult<()> {
let ends = PrimitiveArray::from_iter([5u32, 6, 10]);
let values = BoolArray::from(BitBuffer::from(vec![false, true, false]));
let decoded = runend_decode_bools(ends, values, 0, 10)?;
let expected = BoolArray::from(BitBuffer::from(vec![
false, false, false, false, false, true, false, false, false, false,
]));
assert_arrays_eq!(decoded, expected);
Ok(())
}
#[test]
fn decode_bools_all_true_single_run() -> VortexResult<()> {
let ends = PrimitiveArray::from_iter([10u32]);
let values = BoolArray::from(BitBuffer::from(vec![true]));
let decoded = runend_decode_bools(ends, values, 0, 10)?;
let expected = BoolArray::from(BitBuffer::from(vec![
true, true, true, true, true, true, true, true, true, true,
]));
assert_arrays_eq!(decoded, expected);
Ok(())
}
#[test]
fn decode_bools_all_false_single_run() -> VortexResult<()> {
let ends = PrimitiveArray::from_iter([10u32]);
let values = BoolArray::from(BitBuffer::from(vec![false]));
let decoded = runend_decode_bools(ends, values, 0, 10)?;
let expected = BoolArray::from(BitBuffer::from(vec![
false, false, false, false, false, false, false, false, false, false,
]));
assert_arrays_eq!(decoded, expected);
Ok(())
}
#[test]
fn decode_bools_with_offset() -> VortexResult<()> {
let ends = PrimitiveArray::from_iter([2u32, 5, 10]);
let values = BoolArray::from(BitBuffer::from(vec![true, false, true]));
let decoded = runend_decode_bools(ends, values, 2, 6)?;
let expected =
BoolArray::from(BitBuffer::from(vec![false, false, false, true, true, true]));
assert_arrays_eq!(decoded, expected);
Ok(())
}
#[test]
fn decode_bools_nullable() -> VortexResult<()> {
use vortex_array::validity::Validity;
let ends = PrimitiveArray::from_iter([2u32, 5, 10]);
let values = BoolArray::new(
BitBuffer::from(vec![true, false, true]),
Validity::from(BitBuffer::from(vec![true, false, true])),
);
let decoded = runend_decode_bools(ends, values, 0, 10)?;
let expected = BoolArray::new(
BitBuffer::from(vec![
true, true, false, false, false, true, true, true, true, true,
]),
Validity::from(BitBuffer::from(vec![
true, true, false, false, false, true, true, true, true, true,
])),
);
assert_arrays_eq!(decoded, expected);
Ok(())
}
#[test]
fn decode_bools_nullable_few_runs() -> VortexResult<()> {
let ends = PrimitiveArray::from_iter([2000u32, 4000, 6000, 8000, 10000]);
let values = BoolArray::new(
BitBuffer::from(vec![true, false, true, false, true]),
Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
);
let decoded = runend_decode_bools(ends, values, 0, 10000)?.to_bool();
assert_eq!(decoded.len(), 10000);
assert!(
decoded
.as_ref()
.validity()?
.to_mask(
decoded.as_ref().len(),
&mut LEGACY_SESSION.create_execution_ctx()
)
.unwrap()
.value(0)
);
assert!(decoded.to_bit_buffer().value(0));
assert!(
!decoded
.as_ref()
.validity()?
.to_mask(
decoded.as_ref().len(),
&mut LEGACY_SESSION.create_execution_ctx()
)
.unwrap()
.value(2000)
);
assert!(
decoded
.as_ref()
.validity()?
.to_mask(
decoded.as_ref().len(),
&mut LEGACY_SESSION.create_execution_ctx()
)
.unwrap()
.value(4000)
);
assert!(decoded.to_bit_buffer().value(4000));
Ok(())
}
}