use std::mem::MaybeUninit;
use vortex_buffer::BufferMut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_mask::Mask;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::dtype::NativePType;
use crate::match_each_native_ptype;
use crate::scalar_fn::fns::zip::ZipKernel;
use crate::scalar_fn::fns::zip::zip_validity;
impl ZipKernel for Primitive {
fn zip(
if_true: ArrayView<'_, Primitive>,
if_false: &ArrayRef,
mask: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(if_false) = if_false.as_opt::<Primitive>() else {
return Ok(None);
};
if if_true.ptype() != if_false.ptype() {
vortex_bail!(
"zip requires if_true and if_false to share a primitive type, got {} and {}",
if_true.ptype(),
if_false.ptype()
);
}
let mask = mask.try_to_mask_fill_null_false(ctx)?;
match &mask {
Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None),
Mask::Values(_) => {}
}
let validity = zip_validity(if_true.validity()?, if_false.validity()?, &mask)?;
let array = match_each_native_ptype!(if_true.ptype(), |T| {
let values =
select_values::<T>(if_true.as_slice::<T>(), if_false.as_slice::<T>(), &mask);
PrimitiveArray::new(values.freeze(), validity).into_array()
});
Ok(Some(array))
}
}
fn select_values<T: NativePType>(
true_values: &[T],
false_values: &[T],
mask: &Mask,
) -> BufferMut<T> {
let len = true_values.len();
let mut out = BufferMut::<T>::with_capacity(len);
{
let out_slice = out.spare_capacity_mut();
let mask_bits = mask
.values()
.vortex_expect("mask is Mask::Values")
.bit_buffer();
let chunks = mask_bits.chunks();
let mut base = 0;
for word in chunks.iter() {
let end = base + 64;
select_block(
word,
&true_values[base..end],
&false_values[base..end],
&mut out_slice[base..end],
);
base = end;
}
let remainder = chunks.remainder_len();
if remainder > 0 {
let end = base + remainder;
select_block(
chunks.remainder_bits(),
&true_values[base..end],
&false_values[base..end],
&mut out_slice[base..end],
);
}
}
unsafe { out.set_len(len) };
out
}
#[inline]
fn select_block<T: NativePType>(
word: u64,
true_values: &[T],
false_values: &[T],
out: &mut [MaybeUninit<T>],
) {
let n = out.len();
let true_values = &true_values[..n];
let false_values = &false_values[..n];
for j in 0..n {
let pick = (word >> j) & 1 == 1;
out[j].write(if pick {
true_values[j]
} else {
false_values[j]
});
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::cast_possible_truncation,
reason = "test fixtures use small indices that fit the target widths"
)]
use vortex_error::VortexResult;
use vortex_mask::Mask;
use crate::ArrayRef;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::assert_arrays_eq;
use crate::builtins::ArrayBuiltins;
#[test]
fn zip_nonnull_spans_mask_chunks() -> VortexResult<()> {
let len = 150usize;
let if_true = PrimitiveArray::from_iter(0..len as i64).into_array();
let if_false = PrimitiveArray::from_iter((0..len as i64).map(|i| 1_000 + i)).into_array();
let bits: Vec<bool> = (0..len).map(|i| i.is_multiple_of(3) || i == 64).collect();
let mask = Mask::from_iter(bits.iter().copied());
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let result = mask
.into_array()
.zip(if_true, if_false)?
.execute::<ArrayRef>(&mut ctx)?;
assert!(result.is::<Primitive>());
let expected = PrimitiveArray::from_iter(
(0..len).map(|i| if bits[i] { i as i64 } else { 1_000 + i as i64 }),
)
.into_array();
assert_arrays_eq!(result, expected);
Ok(())
}
#[test]
fn zip_nullable_selects_values_and_validity() -> VortexResult<()> {
let len = 130usize;
let if_true =
PrimitiveArray::from_option_iter((0..len as i64).map(|i| (i % 4 != 0).then_some(i)))
.into_array();
let if_false = PrimitiveArray::from_option_iter(
(0..len as i64).map(|i| (i % 5 != 0).then_some(1_000 + i)),
)
.into_array();
let bits: Vec<bool> = (0..len).map(|i| i.is_multiple_of(2)).collect();
let mask = Mask::from_iter(bits.iter().copied());
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let result = mask
.into_array()
.zip(if_true, if_false)?
.execute::<ArrayRef>(&mut ctx)?;
assert!(result.is::<Primitive>());
let expected = PrimitiveArray::from_option_iter((0..len).map(|i| {
let v = i as i64;
if bits[i] {
(!i.is_multiple_of(4)).then_some(v)
} else {
(!i.is_multiple_of(5)).then_some(1_000 + v)
}
}))
.into_array();
assert_arrays_eq!(result, expected);
Ok(())
}
}