use super::{Array, Splitable, new_empty_array, new_null_array};
use crate::bitmap::Bitmap;
use crate::datatypes::{ArrowDataType, Field};
mod builder;
pub use builder::*;
mod ffi;
pub(super) mod fmt;
mod iterator;
use polars_error::{PolarsResult, polars_bail, polars_ensure};
#[cfg(feature = "proptest")]
pub mod proptest;
#[derive(Clone)]
pub struct StructArray {
dtype: ArrowDataType,
values: Vec<Box<dyn Array>>,
length: usize,
validity: Option<Bitmap>,
}
impl StructArray {
pub fn try_new(
dtype: ArrowDataType,
length: usize,
values: Vec<Box<dyn Array>>,
validity: Option<Bitmap>,
) -> PolarsResult<Self> {
let fields = Self::try_get_fields(&dtype)?;
polars_ensure!(
fields.len() == values.len(),
ComputeError:
"a StructArray must have a number of fields in its DataType equal to the number of child values"
);
fields
.iter().map(|a| &a.dtype)
.zip(values.iter().map(|a| a.dtype()))
.enumerate()
.try_for_each(|(index, (dtype, child))| {
if dtype != child {
polars_bail!(ComputeError:
"The children DataTypes of a StructArray must equal the children data types.
However, the field {index} has data type {dtype:?} but the value has data type {child:?}"
)
} else {
Ok(())
}
})?;
values
.iter()
.map(|f| f.len())
.enumerate()
.try_for_each(|(index, f_length)| {
if f_length != length {
polars_bail!(ComputeError: "The children must have the given number of values.
However, the values at index {index} have a length of {f_length}, which is different from given length {length}.")
} else {
Ok(())
}
})?;
if validity
.as_ref()
.is_some_and(|validity| validity.len() != length)
{
polars_bail!(ComputeError:"The validity length of a StructArray must match its number of elements")
}
Ok(Self {
dtype,
length,
values,
validity,
})
}
pub fn new(
dtype: ArrowDataType,
length: usize,
values: Vec<Box<dyn Array>>,
validity: Option<Bitmap>,
) -> Self {
Self::try_new(dtype, length, values, validity).unwrap()
}
pub fn new_empty(dtype: ArrowDataType) -> Self {
if let ArrowDataType::Struct(fields) = &dtype.to_storage() {
let values = fields
.iter()
.map(|field| new_empty_array(field.dtype().clone()))
.collect();
Self::new(dtype, 0, values, None)
} else {
panic!("StructArray must be initialized with DataType::Struct");
}
}
pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
if let ArrowDataType::Struct(fields) = &dtype {
let values = fields
.iter()
.map(|field| new_null_array(field.dtype().clone(), length))
.collect();
Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length)))
} else {
panic!("StructArray must be initialized with DataType::Struct");
}
}
}
impl StructArray {
#[must_use]
pub fn into_data(self) -> (Vec<Field>, usize, Vec<Box<dyn Array>>, Option<Bitmap>) {
let Self {
dtype,
length,
values,
validity,
} = self;
let fields = if let ArrowDataType::Struct(fields) = dtype {
fields
} else {
unreachable!()
};
(fields, length, values, validity)
}
pub fn slice(&mut self, offset: usize, length: usize) {
assert!(
offset + length <= self.len(),
"offset + length may not exceed length of array"
);
unsafe { self.slice_unchecked(offset, length) }
}
pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
self.validity = self
.validity
.take()
.map(|bitmap| bitmap.sliced_unchecked(offset, length))
.filter(|bitmap| bitmap.unset_bits() > 0);
self.values
.iter_mut()
.for_each(|x| x.slice_unchecked(offset, length));
self.length = length;
}
impl_sliced!();
impl_mut_validity!();
impl_into_array!();
}
impl StructArray {
#[inline]
pub fn len(&self) -> usize {
if cfg!(debug_assertions) {
for arr in self.values.iter() {
assert_eq!(
arr.len(),
self.length,
"StructArray invariant: each array has same length"
);
}
}
self.length
}
#[inline]
pub fn validity(&self) -> Option<&Bitmap> {
self.validity.as_ref()
}
pub fn values(&self) -> &[Box<dyn Array>] {
&self.values
}
pub fn fields(&self) -> &[Field] {
let fields = Self::get_fields(&self.dtype);
debug_assert_eq!(self.values().len(), fields.len());
fields
}
}
impl StructArray {
pub(crate) fn try_get_fields(dtype: &ArrowDataType) -> PolarsResult<&[Field]> {
match dtype.to_storage() {
ArrowDataType::Struct(fields) => Ok(fields),
_ => {
polars_bail!(ComputeError: "Struct array must be created with a DataType whose physical type is Struct")
},
}
}
pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {
Self::try_get_fields(dtype).unwrap()
}
}
impl Array for StructArray {
impl_common_array!();
fn validity(&self) -> Option<&Bitmap> {
self.validity.as_ref()
}
#[inline]
fn with_validity(&self, validity: Option<Bitmap>) -> Box<dyn Array> {
Box::new(self.clone().with_validity(validity))
}
}
impl Splitable for StructArray {
fn check_bound(&self, offset: usize) -> bool {
offset <= self.len()
}
unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) };
let mut lhs_values = Vec::with_capacity(self.values.len());
let mut rhs_values = Vec::with_capacity(self.values.len());
for v in self.values.iter() {
let (lhs, rhs) = unsafe { v.split_at_boxed_unchecked(offset) };
lhs_values.push(lhs);
rhs_values.push(rhs);
}
(
Self {
dtype: self.dtype.clone(),
length: offset,
values: lhs_values,
validity: lhs_validity,
},
Self {
dtype: self.dtype.clone(),
length: self.length - offset,
values: rhs_values,
validity: rhs_validity,
},
)
}
}