use vortex_buffer::BufferMut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::Canonical;
use crate::DynArray;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::ListViewArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::StructArray;
use crate::arrays::chunked::vtable::ChunkedArray;
use crate::arrays::listview::ListViewRebuildMode;
use crate::builders::builder_with_capacity;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::validity::Validity;
pub(super) fn _canonicalize(
array: &ChunkedArray,
ctx: &mut ExecutionCtx,
) -> VortexResult<Canonical> {
if array.nchunks() == 0 {
return Ok(Canonical::empty(array.dtype()));
}
if array.nchunks() == 1 {
return array.chunks()[0].clone().execute::<Canonical>(ctx);
}
Ok(match array.dtype() {
DType::Struct(struct_dtype, _) => {
let struct_array = pack_struct_chunks(
array.chunks(),
Validity::copy_from_array(&array.clone().into_array())?,
struct_dtype,
ctx,
)?;
Canonical::Struct(struct_array)
}
DType::List(elem_dtype, _) => Canonical::List(swizzle_list_chunks(
array.chunks(),
Validity::copy_from_array(&array.clone().into_array())?,
elem_dtype,
ctx,
)?),
_ => {
let mut builder = builder_with_capacity(array.dtype(), array.len());
array.append_to_builder(builder.as_mut(), ctx)?;
builder.finish_into_canonical()
}
})
}
fn pack_struct_chunks(
chunks: &[ArrayRef],
validity: Validity,
struct_dtype: &StructFields,
ctx: &mut ExecutionCtx,
) -> VortexResult<StructArray> {
let len = chunks.iter().map(|chunk| chunk.len()).sum();
let mut field_arrays = Vec::new();
let executed_chunks: Vec<StructArray> = chunks
.iter()
.map(|c| c.clone().execute::<StructArray>(ctx))
.collect::<VortexResult<_>>()?;
for (field_idx, field_dtype) in struct_dtype.fields().enumerate() {
let mut field_chunks = Vec::with_capacity(chunks.len());
for struct_array in &executed_chunks {
let field = struct_array
.unmasked_fields()
.get(field_idx)
.vortex_expect("Invalid field index")
.to_array();
field_chunks.push(field);
}
let field_array = unsafe { ChunkedArray::new_unchecked(field_chunks, field_dtype.clone()) };
field_arrays.push(field_array.into_array());
}
Ok(unsafe { StructArray::new_unchecked(field_arrays, struct_dtype.clone(), len, validity) })
}
fn swizzle_list_chunks(
chunks: &[ArrayRef],
validity: Validity,
elem_dtype: &DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<ListViewArray> {
let len: usize = chunks.iter().map(|c| c.len()).sum();
assert_eq!(
chunks[0]
.dtype()
.as_list_element_opt()
.vortex_expect("DType was somehow not a list")
.as_ref(),
elem_dtype
);
let mut list_elements_chunks = Vec::with_capacity(chunks.len());
let mut num_elements = 0;
let mut offsets = BufferMut::<u64>::with_capacity(len);
let mut sizes = BufferMut::<u64>::with_capacity(len);
for chunk in chunks {
let chunk_array = chunk.clone().execute::<ListViewArray>(ctx)?;
let chunk_array = chunk_array.rebuild(ListViewRebuildMode::MakeExact)?;
list_elements_chunks.push(chunk_array.elements().clone());
let offsets_arr = chunk_array
.offsets()
.to_array()
.cast(DType::Primitive(PType::U64, Nullability::NonNullable))
.vortex_expect("Must be able to fit array offsets in u64")
.execute::<PrimitiveArray>(ctx)?;
let sizes_arr = chunk_array
.sizes()
.to_array()
.cast(DType::Primitive(PType::U64, Nullability::NonNullable))
.vortex_expect("Must be able to fit array offsets in u64")
.execute::<PrimitiveArray>(ctx)?;
let offsets_slice = offsets_arr.as_slice::<u64>();
let sizes_slice = sizes_arr.as_slice::<u64>();
offsets.extend(offsets_slice.iter().map(|o| o + num_elements));
sizes.extend(sizes_slice);
num_elements += chunk_array.elements().len() as u64;
}
let chunked_elements =
unsafe { ChunkedArray::new_unchecked(list_elements_chunks, elem_dtype.clone()) }
.into_array();
let offsets = PrimitiveArray::new(offsets.freeze(), Validity::NonNullable).into_array();
let sizes = PrimitiveArray::new(sizes.freeze(), Validity::NonNullable).into_array();
Ok(unsafe {
ListViewArray::new_unchecked(chunked_elements, offsets, sizes, validity)
.with_zero_copy_to_list(true)
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use vortex_buffer::buffer;
use crate::IntoArray;
use crate::ToCanonical;
use crate::accessor::ArrayAccessor;
use crate::arrays::ChunkedArray;
use crate::arrays::ListArray;
use crate::arrays::StructArray;
use crate::arrays::VarBinViewArray;
use crate::dtype::DType::List;
use crate::dtype::DType::Primitive;
use crate::dtype::Nullability::NonNullable;
use crate::dtype::PType::I32;
use crate::validity::Validity;
#[test]
pub fn pack_nested_structs() {
let struct_array = StructArray::try_new(
["a"].into(),
vec![VarBinViewArray::from_iter_str(["foo", "bar", "baz", "quak"]).into_array()],
4,
Validity::NonNullable,
)
.unwrap();
let dtype = struct_array.dtype().clone();
let chunked = ChunkedArray::try_new(
vec![
ChunkedArray::try_new(vec![struct_array.clone().into_array()], dtype.clone())
.unwrap()
.into_array(),
],
dtype,
)
.unwrap()
.into_array();
let canonical_struct = chunked.to_struct();
let canonical_varbin = canonical_struct.unmasked_fields()[0].to_varbinview();
let original_varbin = struct_array.unmasked_fields()[0].to_varbinview();
let orig_values = original_varbin
.with_iterator(|it| it.map(|a| a.map(|v| v.to_vec())).collect::<Vec<_>>());
let canon_values = canonical_varbin
.with_iterator(|it| it.map(|a| a.map(|v| v.to_vec())).collect::<Vec<_>>());
assert_eq!(orig_values, canon_values);
}
#[test]
pub fn pack_nested_lists() {
let l1 = ListArray::try_new(
buffer![1, 2, 3, 4].into_array(),
buffer![0, 3].into_array(),
Validity::NonNullable,
)
.unwrap();
let l2 = ListArray::try_new(
buffer![5, 6].into_array(),
buffer![0, 2].into_array(),
Validity::NonNullable,
)
.unwrap();
let chunked_list = ChunkedArray::try_new(
vec![l1.clone().into_array(), l2.clone().into_array()],
List(Arc::new(Primitive(I32, NonNullable)), NonNullable),
);
let canon_values = chunked_list.unwrap().to_listview();
assert_eq!(l1.scalar_at(0).unwrap(), canon_values.scalar_at(0).unwrap());
assert_eq!(l2.scalar_at(0).unwrap(), canon_values.scalar_at(1).unwrap());
}
}