use num_traits::AsPrimitive;
use vortex_buffer::BitBuffer;
use vortex_buffer::BitBufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_mask::Mask;
use super::super::Interleave;
use super::super::InterleaveArrayExt;
use crate::array::Array;
use crate::arrays::Bool;
use crate::arrays::BoolArray;
use crate::arrays::Primitive;
use crate::arrays::bool::BoolArrayExt;
use crate::executor::ExecutionCtx;
use crate::executor::ExecutionResult;
use crate::match_each_unsigned_integer_ptype;
use crate::require_child;
use crate::validity::Validity;
pub(super) fn execute(
array: Array<Interleave>,
ctx: &mut ExecutionCtx,
) -> VortexResult<ExecutionResult> {
let num_values = array.num_values();
let mut array = array;
for i in 0..num_values {
array = require_child!(array, array.value(i), i => Bool);
}
array = require_child!(array, array.array_indices(), num_values => Primitive);
array = require_child!(array, array.row_indices(), num_values + 1 => Primitive);
let dtype = array.as_ref().dtype().clone();
let len = array.as_ref().len();
let nullable = dtype.is_nullable();
let mut value_bits = Vec::with_capacity(num_values);
let mut value_validity = Vec::with_capacity(num_values);
for i in 0..num_values {
let value = array.value(i).as_::<Bool>();
let bits = value.to_bit_buffer();
let validity = nullable
.then(|| value.validity()?.execute_mask(bits.len(), ctx))
.transpose()?;
value_bits.push(bits);
value_validity.push(validity);
}
let array_indices = array.array_indices().as_::<Primitive>();
let row_indices = array.row_indices().as_::<Primitive>();
let (values, validity) = match_each_unsigned_integer_ptype!(array_indices.ptype(), |A| {
match_each_unsigned_integer_ptype!(row_indices.ptype(), |R| {
gather(
len,
num_values,
&value_bits,
&value_validity,
array_indices.as_slice::<A>(),
row_indices.as_slice::<R>(),
nullable,
)?
})
});
let validity = match validity {
Some(bits) => Validity::from(bits.freeze()),
None => Validity::NonNullable,
};
Ok(ExecutionResult::done(BoolArray::try_new(
values.freeze(),
validity,
)?))
}
#[allow(clippy::too_many_arguments)]
fn gather<A: AsPrimitive<usize>, R: AsPrimitive<usize>>(
len: usize,
num_values: usize,
value_bits: &[BitBuffer],
value_validity: &[Option<Mask>],
branches: &[A],
rows: &[R],
nullable: bool,
) -> VortexResult<(BitBufferMut, Option<BitBufferMut>)> {
for i in 0..len {
let branch = branches[i].as_();
vortex_ensure!(branch < num_values, "interleave array index out of bounds");
vortex_ensure!(
rows[i].as_() < value_bits[branch].len(),
"interleave row index out of bounds"
);
}
let values =
BitBufferMut::collect_bool(len, |i| value_bits[branches[i].as_()].value(rows[i].as_()));
let validity = nullable.then(|| {
BitBufferMut::collect_bool(len, |i| {
value_validity[branches[i].as_()]
.as_ref()
.is_none_or(|mask| mask.value(rows[i].as_()))
})
});
Ok((values, validity))
}