use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::List;
use crate::arrays::ListArray;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::arrays::dict::TakeExecute;
use crate::arrays::list::ListArrayExt;
use crate::arrays::primitive::PrimitiveArrayExt;
use crate::builders::ArrayBuilder;
use crate::builders::PrimitiveBuilder;
use crate::dtype::IntegerPType;
use crate::dtype::Nullability;
use crate::executor::ExecutionCtx;
use crate::match_each_integer_ptype;
use crate::match_smallest_offset_type;
impl TakeExecute for List {
#[expect(clippy::cognitive_complexity)]
fn take(
array: ArrayView<'_, List>,
indices: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
let total_approx = array.elements().len().saturating_mul(indices.len());
match_each_integer_ptype!(array.offsets().dtype().as_ptype(), |O| {
match_each_integer_ptype!(indices.ptype(), |I| {
match_smallest_offset_type!(total_approx, |OutputOffsetType| {
{
let indices = indices.as_view();
_take::<I, O, OutputOffsetType>(array, indices, ctx).map(Some)
}
})
})
})
}
}
fn _take<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
array: ArrayView<'_, List>,
indices_array: ArrayView<'_, Primitive>,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let data_validity = array.list_validity_mask();
let indices_validity = indices_array.validity_mask();
if !indices_validity.all_true() || !data_validity.all_true() {
return _take_nullable::<I, O, OutputOffsetType>(array, indices_array, ctx);
}
let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
let offsets: &[O] = offsets_array.as_slice();
let indices: &[I] = indices_array.as_slice();
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
Nullability::NonNullable,
indices.len(),
);
let mut elements_to_take =
PrimitiveBuilder::with_capacity(Nullability::NonNullable, 2 * indices.len());
let mut current_offset = OutputOffsetType::zero();
new_offsets.append_zero();
for &data_idx in indices {
let data_idx: usize = data_idx.as_();
let start = offsets[data_idx];
let stop = offsets[data_idx + 1];
let additional: usize = (stop - start).as_();
elements_to_take.reserve_exact(additional);
for i in 0..additional {
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
}
current_offset +=
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
new_offsets.append_value(current_offset);
}
let elements_to_take = elements_to_take.finish();
let new_offsets = new_offsets.finish();
let new_elements = array.elements().take(elements_to_take)?;
Ok(ListArray::try_new(
new_elements,
new_offsets,
array.validity()?.take(indices_array.array())?,
)?
.into_array())
}
fn _take_nullable<I: IntegerPType, O: IntegerPType, OutputOffsetType: IntegerPType>(
array: ArrayView<'_, List>,
indices_array: ArrayView<'_, Primitive>,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let offsets_array = array.offsets().clone().execute::<PrimitiveArray>(ctx)?;
let offsets: &[O] = offsets_array.as_slice();
let indices: &[I] = indices_array.as_slice();
let data_validity = array.list_validity_mask();
let indices_validity = indices_array.validity_mask();
let mut new_offsets = PrimitiveBuilder::<OutputOffsetType>::with_capacity(
Nullability::NonNullable,
indices.len(),
);
let mut elements_to_take =
PrimitiveBuilder::<O>::with_capacity(Nullability::NonNullable, 2 * indices.len());
let mut current_offset = OutputOffsetType::zero();
new_offsets.append_zero();
for (idx, data_idx) in indices.iter().enumerate() {
if !indices_validity.value(idx) {
new_offsets.append_value(current_offset);
continue;
}
let data_idx: usize = data_idx.as_();
if !data_validity.value(data_idx) {
new_offsets.append_value(current_offset);
continue;
}
let start = offsets[data_idx];
let stop = offsets[data_idx + 1];
let additional: usize = (stop - start).as_();
elements_to_take.reserve_exact(additional);
for i in 0..additional {
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
}
current_offset +=
OutputOffsetType::from_usize((stop - start).as_()).vortex_expect("offset conversion");
new_offsets.append_value(current_offset);
}
let elements_to_take = elements_to_take.finish();
let new_offsets = new_offsets.finish();
let new_elements = array.elements().take(elements_to_take)?;
Ok(ListArray::try_new(
new_elements,
new_offsets,
array.validity()?.take(indices_array.array())?,
)?
.into_array())
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use rstest::rstest;
use vortex_buffer::buffer;
use crate::IntoArray as _;
use crate::ToCanonical;
use crate::arrays::BoolArray;
use crate::arrays::ListArray;
use crate::arrays::PrimitiveArray;
use crate::compute::conformance::take::test_take_conformance;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType::I32;
use crate::scalar::Scalar;
use crate::validity::Validity;
#[test]
fn nullable_take() {
let list = ListArray::try_new(
buffer![0i32, 5, 3, 4].into_array(),
buffer![0, 2, 3, 4, 4].into_array(),
Validity::Array(BoolArray::from_iter(vec![true, true, false, true]).into_array()),
)
.unwrap()
.into_array();
let idx =
PrimitiveArray::from_option_iter(vec![Some(0), None, Some(1), Some(3)]).into_array();
let result = list.take(idx).unwrap();
assert_eq!(
result.dtype(),
&DType::List(
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
Nullability::Nullable
)
);
let result = result.to_listview();
assert_eq!(result.len(), 4);
let element_dtype: Arc<DType> = Arc::new(I32.into());
assert!(result.is_valid(0).unwrap());
assert_eq!(
result.scalar_at(0).unwrap(),
Scalar::list(
Arc::clone(&element_dtype),
vec![0i32.into(), 5.into()],
Nullability::Nullable
)
);
assert!(result.is_invalid(1).unwrap());
assert!(result.is_valid(2).unwrap());
assert_eq!(
result.scalar_at(2).unwrap(),
Scalar::list(
Arc::clone(&element_dtype),
vec![3i32.into()],
Nullability::Nullable
)
);
assert!(result.is_valid(3).unwrap());
assert_eq!(
result.scalar_at(3).unwrap(),
Scalar::list(element_dtype, vec![], Nullability::Nullable)
);
}
#[test]
fn change_validity() {
let list = ListArray::try_new(
buffer![0i32, 5, 3, 4].into_array(),
buffer![0, 2, 3].into_array(),
Validity::NonNullable,
)
.unwrap()
.into_array();
let idx = PrimitiveArray::from_option_iter(vec![Some(0), Some(1), None]).into_array();
let result = list.take(idx).unwrap();
assert_eq!(
result.dtype(),
&DType::List(
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
Nullability::Nullable
)
);
}
#[test]
fn non_nullable_take() {
let list = ListArray::try_new(
buffer![0i32, 5, 3, 4].into_array(),
buffer![0, 2, 3, 3, 4].into_array(),
Validity::NonNullable,
)
.unwrap()
.into_array();
let idx = buffer![1, 0, 2].into_array();
let result = list.take(idx).unwrap();
assert_eq!(
result.dtype(),
&DType::List(
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
Nullability::NonNullable
)
);
let result = result.to_listview();
assert_eq!(result.len(), 3);
let element_dtype: Arc<DType> = Arc::new(I32.into());
assert!(result.is_valid(0).unwrap());
assert_eq!(
result.scalar_at(0).unwrap(),
Scalar::list(
Arc::clone(&element_dtype),
vec![3i32.into()],
Nullability::NonNullable
)
);
assert!(result.is_valid(1).unwrap());
assert_eq!(
result.scalar_at(1).unwrap(),
Scalar::list(
Arc::clone(&element_dtype),
vec![0i32.into(), 5.into()],
Nullability::NonNullable
)
);
assert!(result.is_valid(2).unwrap());
assert_eq!(
result.scalar_at(2).unwrap(),
Scalar::list(element_dtype, vec![], Nullability::NonNullable)
);
}
#[test]
fn test_take_empty_array() {
let list = ListArray::try_new(
buffer![0i32, 5, 3, 4].into_array(),
buffer![0].into_array(),
Validity::NonNullable,
)
.unwrap()
.into_array();
let idx = PrimitiveArray::empty::<i32>(Nullability::Nullable).into_array();
let result = list.take(idx).unwrap();
assert_eq!(
result.dtype(),
&DType::List(
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
Nullability::Nullable
)
);
assert_eq!(result.len(), 0,);
}
#[rstest]
#[case(ListArray::try_new(
buffer![0i32, 1, 2, 3, 4, 5].into_array(),
buffer![0, 2, 3, 5, 5, 6].into_array(),
Validity::NonNullable,
).unwrap())]
#[case(ListArray::try_new(
buffer![10i32, 20, 30, 40, 50].into_array(),
buffer![0, 2, 3, 4, 5].into_array(),
Validity::Array(BoolArray::from_iter(vec![true, false, true, true]).into_array()),
).unwrap())]
#[case(ListArray::try_new(
buffer![1i32, 2, 3].into_array(),
buffer![0, 0, 2, 2, 3].into_array(), // First and third are empty
Validity::NonNullable,
).unwrap())]
#[case(ListArray::try_new(
buffer![42i32, 43].into_array(),
buffer![0, 2].into_array(),
Validity::NonNullable,
).unwrap())]
#[case({
let elements = buffer![0i32..200].into_array();
let mut offsets = vec![0u64];
for i in 1..=50 {
offsets.push(offsets[i - 1] + (i as u64 % 5)); // Variable length lists
}
ListArray::try_new(
elements,
PrimitiveArray::from_iter(offsets).into_array(),
Validity::NonNullable,
).unwrap()
})]
#[case(ListArray::try_new(
PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]).into_array(),
buffer![0, 2, 3, 5].into_array(),
Validity::NonNullable,
).unwrap())]
fn test_take_list_conformance(#[case] list: ListArray) {
test_take_conformance(&list.into_array());
}
#[test]
fn test_u64_offset_accumulation_non_nullable() {
let elements = buffer![0i32; 200].into_array();
let offsets = buffer![0u8, 200].into_array();
let list = ListArray::try_new(elements, offsets, Validity::NonNullable)
.unwrap()
.into_array();
let idx = buffer![0u8, 0].into_array();
let result = list.take(idx).unwrap();
assert_eq!(result.len(), 2);
let result_view = result.to_listview();
assert_eq!(result_view.len(), 2);
assert!(result_view.is_valid(0).unwrap());
assert!(result_view.is_valid(1).unwrap());
}
#[test]
fn test_u64_offset_accumulation_nullable() {
let elements = buffer![0i32; 150].into_array();
let offsets = buffer![0u8, 150, 150].into_array();
let validity = BoolArray::from_iter(vec![true, false]).into_array();
let list = ListArray::try_new(elements, offsets, Validity::Array(validity))
.unwrap()
.into_array();
let idx = PrimitiveArray::from_option_iter(vec![Some(0u8), None, Some(0u8)]).into_array();
let result = list.take(idx).unwrap();
assert_eq!(result.len(), 3);
let result_view = result.to_listview();
assert_eq!(result_view.len(), 3);
assert!(result_view.is_valid(0).unwrap());
assert!(result_view.is_invalid(1).unwrap());
assert!(result_view.is_valid(2).unwrap());
}
#[test]
fn test_take_validity_length_mismatch_regression() {
let list = ListArray::try_new(
buffer![1i32, 2, 3, 4].into_array(),
buffer![0, 2, 4].into_array(),
Validity::Array(BoolArray::from_iter(vec![true, true]).into_array()),
)
.unwrap()
.into_array();
let idx = buffer![0u32, 1, 0, 1].into_array();
let result = list.take(idx).unwrap();
assert_eq!(result.len(), 4);
}
}