use std::ops::BitAnd;
use std::ops::BitOr;
use std::ops::Not;
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_mask::Mask;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::Chunked;
use crate::arrays::ChunkedArray;
use crate::arrays::ListView;
use crate::arrays::ListViewArray;
use crate::arrays::chunked::ChunkedArrayExt;
use crate::arrays::listview::ListViewArrayExt;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::scalar_fn::fns::zip::ZipKernel;
use crate::validity::Validity;
impl ZipKernel for ListView {
fn zip(
if_true: ArrayView<'_, ListView>,
if_false: &ArrayRef,
mask: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let Some(if_false) = if_false.as_opt::<ListView>() else {
return Ok(None);
};
let mask = mask.try_to_mask_fill_null_false(ctx)?;
match &mask {
Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None),
Mask::Values(_) => {}
}
let len = if_true.len();
let result_elements_dtype = if_true
.elements()
.dtype()
.union_nullability(if_false.elements().dtype().nullability());
let true_elements = if_true.elements().cast(result_elements_dtype.clone())?;
let false_elements = if_false.elements().cast(result_elements_dtype.clone())?;
let false_shift = true_elements.len() as u64;
let mut chunks = Vec::with_capacity(2);
push_element_chunks(true_elements, &mut chunks);
push_element_chunks(false_elements, &mut chunks);
let elements = ChunkedArray::try_new(chunks, result_elements_dtype)?.into_array();
let true_offsets = to_u64(if_true.offsets(), ctx)?;
let true_sizes = to_u64(if_true.sizes(), ctx)?;
let false_offsets = to_u64(if_false.offsets(), ctx)?;
let false_sizes = to_u64(if_false.sizes(), ctx)?;
let mut offsets = BufferMut::<u64>::with_capacity(len);
let mut sizes = BufferMut::<u64>::with_capacity(len);
for ((idx, (out_offsets, out_sizes)), selected) in offsets
.spare_capacity_mut()
.iter_mut()
.zip(sizes.spare_capacity_mut().iter_mut())
.take(len)
.enumerate()
.zip(mask.iter())
{
if selected {
out_offsets.write(true_offsets[idx]);
out_sizes.write(true_sizes[idx]);
} else {
out_offsets.write(false_offsets[idx] + false_shift);
out_sizes.write(false_sizes[idx]);
}
}
unsafe {
offsets.set_len(len);
sizes.set_len(len);
}
let validity = zip_validity(if_true.validity()?, if_false.validity()?, &mask, ctx)?;
Ok(Some(
ListViewArray::try_new(
elements,
offsets.freeze().into_array(),
sizes.freeze().into_array(),
validity,
)?
.into_array(),
))
}
}
fn push_element_chunks(array: ArrayRef, chunks: &mut Vec<ArrayRef>) {
match array.as_opt::<Chunked>() {
Some(chunked) => chunks.extend(chunked.iter_chunks().cloned()),
None => chunks.push(array),
}
}
fn to_u64(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Buffer<u64>> {
array
.clone()
.cast(DType::Primitive(PType::U64, Nullability::NonNullable))?
.execute::<Buffer<u64>>(ctx)
}
fn zip_validity(
if_true: Validity,
if_false: Validity,
mask: &Mask,
ctx: &mut ExecutionCtx,
) -> VortexResult<Validity> {
Ok(match (&if_true, &if_false) {
(Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
(Validity::AllValid, Validity::AllValid) => Validity::AllValid,
(Validity::AllInvalid, Validity::AllInvalid) => Validity::AllInvalid,
_ => {
let true_mask = if_true.execute_mask(mask.len(), ctx)?;
let false_mask = if_false.execute_mask(mask.len(), ctx)?;
let combined = true_mask
.bitand(mask)
.bitor(&false_mask.bitand(&mask.not()));
Validity::from_mask(combined, if_true.nullability() | if_false.nullability())
}
})
}
#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_error::VortexResult;
use vortex_mask::Mask;
use crate::ArrayRef;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::arrays::BoolArray;
use crate::arrays::Chunked;
use crate::arrays::ChunkedArray;
use crate::arrays::ListView;
use crate::arrays::ListViewArray;
use crate::arrays::chunked::ChunkedArrayExt;
use crate::arrays::listview::ListViewArrayExt;
use crate::assert_arrays_eq;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::validity::Validity;
fn list_view(
elements: ArrayRef,
offsets: ArrayRef,
sizes: ArrayRef,
validity: Validity,
) -> ArrayRef {
ListViewArray::try_new(elements, offsets, sizes, validity)
.unwrap()
.into_array()
}
#[test]
fn zip_selects_lists() -> VortexResult<()> {
let if_true = list_view(
buffer![1i32, 2, 3, 4, 5, 6].into_array(),
buffer![0u32, 2, 3].into_array(),
buffer![2u32, 1, 3].into_array(),
Validity::NonNullable,
);
let if_false = list_view(
buffer![10i32, 20, 21, 30].into_array(),
buffer![0u32, 1, 3].into_array(),
buffer![1u32, 2, 1].into_array(),
Validity::NonNullable,
);
let mask = Mask::from_iter([true, false, true]);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let result = mask
.into_array()
.zip(if_true, if_false)?
.execute::<ArrayRef>(&mut ctx)?;
assert!(result.is::<ListView>());
let expected = list_view(
buffer![1i32, 2, 20, 21, 4, 5, 6].into_array(),
buffer![0u32, 2, 4].into_array(),
buffer![2u32, 2, 3].into_array(),
Validity::NonNullable,
);
assert_arrays_eq!(result, expected);
Ok(())
}
#[test]
fn zip_selects_validity() -> VortexResult<()> {
let if_true = list_view(
buffer![1i32, 2].into_array(),
buffer![0u32, 1, 1].into_array(),
buffer![1u32, 0, 1].into_array(),
Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
);
let if_false = list_view(
buffer![10i32, 20].into_array(),
buffer![0u32, 1, 2].into_array(),
buffer![1u32, 1, 0].into_array(),
Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
);
let mask = Mask::from_iter([false, true, true]);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let result = mask
.into_array()
.zip(if_true, if_false)?
.execute::<ArrayRef>(&mut ctx)?;
let expected = list_view(
buffer![10i32, 2].into_array(),
buffer![0u32, 1, 1].into_array(),
buffer![1u32, 0, 1].into_array(),
Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
);
assert_arrays_eq!(result, expected);
Ok(())
}
#[test]
fn zip_out_of_order_offsets_and_widening() -> VortexResult<()> {
let if_true = list_view(
buffer![7i32, 8, 9, 5, 6].into_array(),
buffer![3u32, 0, 1].into_array(),
buffer![2u32, 1, 2].into_array(),
Validity::NonNullable,
);
let if_false = list_view(
buffer![100i32, 200, 201].into_array(),
buffer![0u32, 1, 1].into_array(),
buffer![1u32, 0, 2].into_array(),
Validity::Array(BoolArray::from_iter([true, false, true]).into_array()),
);
let mask = Mask::from_iter([true, true, false]);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let result = mask
.into_array()
.zip(if_true, if_false)?
.execute::<ArrayRef>(&mut ctx)?;
assert!(result.is::<ListView>());
let expected = list_view(
buffer![5i32, 6, 7, 200, 201].into_array(),
buffer![0u32, 2, 3].into_array(),
buffer![2u32, 1, 2].into_array(),
Validity::AllValid,
);
assert_arrays_eq!(result, expected);
Ok(())
}
#[test]
fn zip_flattens_chunked_elements() -> VortexResult<()> {
let chunked_elements = ChunkedArray::try_new(
vec![buffer![1i32, 2].into_array(), buffer![3i32].into_array()],
DType::Primitive(PType::I32, Nullability::NonNullable),
)?
.into_array();
let if_true = list_view(
chunked_elements,
buffer![0u32, 2].into_array(),
buffer![2u32, 1].into_array(),
Validity::NonNullable,
);
let if_false = list_view(
buffer![10i32, 20].into_array(),
buffer![0u32, 1].into_array(),
buffer![1u32, 1].into_array(),
Validity::NonNullable,
);
let mask = Mask::from_iter([true, false]);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let result = mask
.into_array()
.zip(if_true, if_false)?
.execute::<ArrayRef>(&mut ctx)?;
let result_lv = result
.as_opt::<ListView>()
.expect("zip keeps the list-view encoding");
let chunked = result_lv
.elements()
.as_opt::<Chunked>()
.expect("zip concatenates elements into a chunked array");
assert!(
chunked.iter_chunks().all(|chunk| !chunk.is::<Chunked>()),
"chunked elements must be flattened, not nested",
);
let expected = list_view(
buffer![1i32, 2, 20].into_array(),
buffer![0u32, 2].into_array(),
buffer![2u32, 1].into_array(),
Validity::NonNullable,
);
assert_arrays_eq!(result, expected);
Ok(())
}
}