use arrow::array::{Array as _, ArrayRef, ListArray, UInt64Array};
use arrow::buffer::NullBuffer;
use crate::{Error, Transform};
#[derive(Clone, Debug)]
pub(crate) struct GetIndexList {
index: u64,
}
impl GetIndexList {
pub fn new(index: u64) -> Self {
Self { index }
}
}
impl Transform for GetIndexList {
type Source = ListArray;
type Target = ArrayRef;
fn transform(&self, source: &ListArray) -> Result<ArrayRef, Error> {
let offsets = source.offsets();
let values = source.values();
if values.is_empty() {
return Ok(arrow::array::new_null_array(
values.data_type(),
source.len(),
));
}
let mut indices = Vec::with_capacity(source.len());
let mut validity = Vec::with_capacity(source.len());
for row_idx in 0..source.len() {
if source.is_null(row_idx) {
indices.push(0u64); validity.push(false);
} else {
let start = offsets[row_idx];
let end = offsets[row_idx + 1];
let length = end - start;
if self.index < length as u64 {
indices.push(start as u64 + self.index);
validity.push(true);
} else {
indices.push(0u64); validity.push(false);
}
}
}
let indices_array = UInt64Array::from(indices);
let options = arrow::compute::TakeOptions { check_bounds: true };
#[expect(clippy::disallowed_methods)]
let mut result = arrow::compute::take(values.as_ref(), &indices_array, Some(options))?;
let validity_buffer = NullBuffer::from(validity);
let result_data = result.to_data();
let combined_nulls = match result_data.nulls() {
Some(existing_nulls) => {
let combined_buffer = existing_nulls.inner() & validity_buffer.inner();
Some(NullBuffer::new(combined_buffer))
}
None => Some(validity_buffer),
};
let new_data = result_data.into_builder().nulls(combined_nulls).build()?;
result = arrow::array::make_array(new_data);
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{AsArray as _, ListArray};
use arrow::datatypes::Int32Type;
#[test]
fn test_get_index_basic() {
let input = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5)]),
]);
let result = GetIndexList::new(0).transform(&input).unwrap();
let result_i32 = result.as_primitive::<Int32Type>();
assert_eq!(result_i32.len(), 2);
assert_eq!(result_i32.value(0), 1);
assert_eq!(result_i32.value(1), 4);
}
#[test]
fn test_get_index_out_of_bounds() {
let input = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
Some(vec![Some(3)]),
Some(vec![]),
]);
let result = GetIndexList::new(5).transform(&input).unwrap();
let result_i32 = result.as_primitive::<Int32Type>();
assert_eq!(result_i32.len(), 3);
assert!(result_i32.is_null(0)); assert!(result_i32.is_null(1)); assert!(result_i32.is_null(2)); }
#[test]
fn test_get_index_with_nulls() {
let input = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(1), Some(2)]),
None,
Some(vec![Some(3), None, Some(5)]),
Some(vec![]),
]);
let result = GetIndexList::new(1).transform(&input).unwrap();
let result_i32 = result.as_primitive::<Int32Type>();
assert_eq!(result_i32.len(), 4);
assert_eq!(result_i32.value(0), 2);
assert!(result_i32.is_null(1)); assert!(result_i32.is_null(2)); assert!(result_i32.is_null(3)); }
}