use std::sync::Arc;
use num_traits::AsPrimitive;
use vortex_dtype::{DType, match_each_integer_ptype, match_each_native_ptype};
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_ensure};
use crate::arrays::{ListVTable, PrimitiveVTable};
use crate::compute::{min_max, sub_scalar};
use crate::stats::ArrayStats;
use crate::validity::Validity;
use crate::{Array, ArrayRef, IntoArray};
#[derive(Clone, Debug)]
pub struct ListArray {
pub(super) dtype: DType,
pub(super) elements: ArrayRef,
pub(super) offsets: ArrayRef,
pub(super) validity: Validity,
pub(super) stats_set: ArrayStats,
}
impl ListArray {
pub fn new(elements: ArrayRef, offsets: ArrayRef, validity: Validity) -> Self {
Self::try_new(elements, offsets, validity).vortex_expect("ListArray new")
}
pub fn try_new(
elements: ArrayRef,
offsets: ArrayRef,
validity: Validity,
) -> VortexResult<Self> {
Self::validate(&elements, &offsets, &validity)?;
Ok(unsafe { Self::new_unchecked(elements, offsets, validity) })
}
pub unsafe fn new_unchecked(elements: ArrayRef, offsets: ArrayRef, validity: Validity) -> Self {
#[cfg(debug_assertions)]
Self::validate(&elements, &offsets, &validity)
.vortex_expect("[Debug Assertion]: Invalid `ListViewArray` parameters");
Self {
dtype: DType::List(Arc::new(elements.dtype().clone()), validity.nullability()),
elements,
offsets,
validity,
stats_set: Default::default(),
}
}
pub fn validate(
elements: &dyn Array,
offsets: &dyn Array,
validity: &Validity,
) -> VortexResult<()> {
vortex_ensure!(
!offsets.is_empty(),
"Offsets must have at least one element, [0] for an empty list"
);
vortex_ensure!(
offsets.dtype().is_int() && !offsets.dtype().is_nullable(),
"offsets have invalid type {}",
offsets.dtype()
);
let offsets_ptype = offsets.dtype().as_ptype();
if let Some(is_sorted) = offsets.statistics().compute_is_sorted() {
vortex_ensure!(is_sorted, "offsets must be sorted");
} else {
vortex_bail!("offsets must report is_sorted statistic");
}
if let Some(min_max) = min_max(offsets)? {
match_each_integer_ptype!(offsets_ptype, |P| {
let max_offset = P::try_from(offsets.scalar_at(offsets.len() - 1))
.vortex_expect("Offsets type must fit offsets values");
#[allow(clippy::absurd_extreme_comparisons, unused_comparisons)]
{
if let Some(min) = min_max.min.as_primitive().as_::<P>() {
vortex_ensure!(
min >= 0 && min <= max_offset,
"offsets minimum {min} outside valid range [0, {max_offset}]"
);
}
if let Some(max) = min_max.max.as_primitive().as_::<P>() {
vortex_ensure!(
max >= 0 && max <= max_offset,
"offsets maximum {max} outside valid range [0, {max_offset}]"
)
}
}
vortex_ensure!(
max_offset
<= P::try_from(elements.len())
.vortex_expect("Offsets type must be able to fit elements length"),
"Max offset {max_offset} is beyond the length of the elements array {}",
elements.len()
);
})
} else {
vortex_bail!(
"offsets array with encoding {} must support min_max compute function",
offsets.encoding_id()
);
};
if let Some(validity_len) = validity.maybe_len() {
vortex_ensure!(
validity_len == offsets.len() - 1,
"validity with size {validity_len} does not match array size {}",
offsets.len() - 1
);
}
Ok(())
}
pub fn offset_at(&self, index: usize) -> usize {
assert!(
index <= self.len(),
"Index {index} out of bounds 0..={}",
self.len()
);
self.offsets()
.as_opt::<PrimitiveVTable>()
.map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
.unwrap_or_else(|| {
self.offsets()
.scalar_at(index)
.as_primitive()
.as_::<usize>()
.vortex_expect("index must fit in usize")
})
}
pub fn list_elements_at(&self, index: usize) -> ArrayRef {
let start = self.offset_at(index);
let end = self.offset_at(index + 1);
self.elements().slice(start..end)
}
pub fn sliced_elements(&self) -> ArrayRef {
let start = self.offset_at(0);
let end = self.offset_at(self.len());
self.elements().slice(start..end)
}
pub fn offsets(&self) -> &ArrayRef {
&self.offsets
}
pub fn elements(&self) -> &ArrayRef {
&self.elements
}
pub fn reset_offsets(&self, recurse: bool) -> VortexResult<Self> {
let mut elements = self.sliced_elements();
if recurse && elements.is_canonical() {
elements = elements.to_canonical().compact()?.into_array();
} else if recurse && let Some(child_list_array) = elements.as_opt::<ListVTable>() {
elements = child_list_array.reset_offsets(recurse)?.into_array();
}
let offsets = self.offsets();
let first_offset = offsets.scalar_at(0);
let adjusted_offsets = sub_scalar(offsets, first_offset)?;
Self::try_new(elements, adjusted_offsets, self.validity.clone())
}
}