use polars_buffer::Buffer;
use polars_error::{PolarsResult, polars_bail, polars_err};
use super::{Array, Splitable, new_empty_array, new_null_array};
use crate::bitmap::Bitmap;
use crate::datatypes::{ArrowDataType, Field, UnionMode};
use crate::scalar::{Scalar, new_scalar};
mod ffi;
pub(super) mod fmt;
mod iterator;
type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);
#[derive(Clone)]
pub struct UnionArray {
types: Buffer<i8>,
map: Option<[usize; 127]>,
fields: Vec<Box<dyn Array>>,
offsets: Option<Buffer<i32>>,
dtype: ArrowDataType,
offset: usize,
}
impl UnionArray {
pub fn try_new(
dtype: ArrowDataType,
types: Buffer<i8>,
fields: Vec<Box<dyn Array>>,
offsets: Option<Buffer<i32>>,
) -> PolarsResult<Self> {
let (f, ids, mode) = Self::try_get_all(&dtype)?;
if f.len() != fields.len() {
polars_bail!(ComputeError: "the number of `fields` must equal the number of children fields in DataType::Union")
};
let number_of_fields: i8 = fields.len().try_into().map_err(
|_| polars_err!(ComputeError: "the number of `fields` cannot be larger than i8::MAX"),
)?;
f
.iter().map(|a| a.dtype())
.zip(fields.iter().map(|a| a.dtype()))
.enumerate()
.try_for_each(|(index, (dtype, child))| {
if dtype != child {
polars_bail!(ComputeError:
"the children DataTypes of a UnionArray must equal the children data types.
However, the field {index} has data type {dtype:?} but the value has data type {child:?}"
)
} else {
Ok(())
}
})?;
if let Some(offsets) = &offsets {
if offsets.len() != types.len() {
polars_bail!(ComputeError:
"in a UnionArray, the offsets' length must be equal to the number of types"
)
}
}
if offsets.is_none() != mode.is_sparse() {
polars_bail!(ComputeError:
"in a sparse UnionArray, the offsets must be set (and vice-versa)",
)
}
let map = if let Some(&ids) = ids.as_ref() {
if ids.len() != fields.len() {
polars_bail!(ComputeError:
"in a union, when the ids are set, their length must be equal to the number of fields",
)
}
let mut hash = [0; 127];
for (pos, &id) in ids.iter().enumerate() {
if !(0..=127).contains(&id) {
polars_bail!(ComputeError:
"in a union, when the ids are set, every id must belong to [0, 128[",
)
}
hash[id as usize] = pos;
}
types.iter().try_for_each(|&type_| {
if type_ < 0 {
polars_bail!(ComputeError:
"in a union, when the ids are set, every type must be >= 0"
)
}
let id = hash[type_ as usize];
if id >= fields.len() {
polars_bail!(ComputeError:
"in a union, when the ids are set, each id must be smaller than the number of fields."
)
} else {
Ok(())
}
})?;
Some(hash)
} else {
let mut is_valid = true;
for &type_ in types.iter() {
if type_ < 0 || type_ >= number_of_fields {
is_valid = false
}
}
if !is_valid {
polars_bail!(ComputeError:
"every type in `types` must be larger than 0 and smaller than the number of fields.",
)
}
None
};
Ok(Self {
dtype,
map,
fields,
offsets,
types,
offset: 0,
})
}
pub fn new(
dtype: ArrowDataType,
types: Buffer<i8>,
fields: Vec<Box<dyn Array>>,
offsets: Option<Buffer<i32>>,
) -> Self {
Self::try_new(dtype, types, fields, offsets).unwrap()
}
pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
if let ArrowDataType::Union(u) = &dtype {
let fields = u
.fields
.iter()
.map(|x| new_null_array(x.dtype().clone(), length))
.collect();
let offsets = if u.mode.is_sparse() {
None
} else {
Some((0..length as i32).collect::<Vec<_>>().into())
};
let types = vec![0i8; length].into();
Self::new(dtype, types, fields, offsets)
} else {
panic!("Union struct must be created with the corresponding Union DataType")
}
}
pub fn new_empty(dtype: ArrowDataType) -> Self {
if let ArrowDataType::Union(u) = dtype.to_storage() {
let fields = u
.fields
.iter()
.map(|x| new_empty_array(x.dtype().clone()))
.collect();
let offsets = if u.mode.is_sparse() {
None
} else {
Some(Buffer::default())
};
Self {
dtype,
map: None,
fields,
offsets,
types: Buffer::new(),
offset: 0,
}
} else {
panic!("Union struct must be created with the corresponding Union DataType")
}
}
}
impl UnionArray {
#[inline]
pub fn slice(&mut self, offset: usize, length: usize) {
assert!(
offset + length <= self.len(),
"the offset of the new array cannot exceed the existing length"
);
unsafe { self.slice_unchecked(offset, length) }
}
#[inline]
pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
debug_assert!(offset + length <= self.len());
self.types.slice_in_place_unchecked(offset..offset + length);
if let Some(offsets) = self.offsets.as_mut() {
offsets.slice_in_place_unchecked(offset..offset + length)
}
self.offset += offset;
}
impl_sliced!();
impl_into_array!();
}
impl UnionArray {
#[inline]
pub fn len(&self) -> usize {
self.types.len()
}
pub fn offsets(&self) -> Option<&Buffer<i32>> {
self.offsets.as_ref()
}
pub fn fields(&self) -> &Vec<Box<dyn Array>> {
&self.fields
}
pub fn types(&self) -> &Buffer<i8> {
&self.types
}
#[inline]
unsafe fn field_slot_unchecked(&self, index: usize) -> usize {
self.offsets()
.as_ref()
.map(|x| *x.get_unchecked(index) as usize)
.unwrap_or(index + self.offset)
}
#[inline]
pub fn index(&self, index: usize) -> (usize, usize) {
assert!(index < self.len());
unsafe { self.index_unchecked(index) }
}
#[inline]
pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) {
debug_assert!(index < self.len());
let type_ = unsafe { *self.types.get_unchecked(index) };
let type_ = self
.map
.as_ref()
.map(|map| unsafe { *map.get_unchecked(type_ as usize) })
.unwrap_or(type_ as usize);
let index = self.field_slot_unchecked(index);
(type_, index)
}
pub fn value(&self, index: usize) -> Box<dyn Scalar> {
assert!(index < self.len());
unsafe { self.value_unchecked(index) }
}
pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {
debug_assert!(index < self.len());
let (type_, index) = self.index_unchecked(index);
debug_assert!(type_ < self.fields.len());
let field = self.fields.get_unchecked(type_).as_ref();
new_scalar(field, index)
}
}
impl Array for UnionArray {
impl_common_array!();
fn validity(&self) -> Option<&Bitmap> {
None
}
fn with_validity(&self, _: Option<Bitmap>) -> Box<dyn Array> {
panic!("cannot set validity of a union array")
}
}
impl UnionArray {
fn try_get_all(dtype: &ArrowDataType) -> PolarsResult<UnionComponents<'_>> {
match dtype.to_storage() {
ArrowDataType::Union(u) => Ok((&u.fields, u.ids.as_ref().map(|x| x.as_ref()), u.mode)),
_ => polars_bail!(ComputeError:
"The UnionArray requires a logical type of DataType::Union",
),
}
}
fn get_all(dtype: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) {
Self::try_get_all(dtype).unwrap()
}
pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {
Self::get_all(dtype).0
}
pub fn is_sparse(dtype: &ArrowDataType) -> bool {
Self::get_all(dtype).2.is_sparse()
}
}
impl Splitable for UnionArray {
fn check_bound(&self, offset: usize) -> bool {
offset <= self.len()
}
unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
let (lhs_types, rhs_types) = unsafe { self.types.split_at_unchecked(offset) };
let (lhs_offsets, rhs_offsets) = self.offsets.as_ref().map_or((None, None), |v| {
let (lhs, rhs) = unsafe { v.split_at_unchecked(offset) };
(Some(lhs), Some(rhs))
});
(
Self {
types: lhs_types,
map: self.map,
fields: self.fields.clone(),
offsets: lhs_offsets,
dtype: self.dtype.clone(),
offset: self.offset,
},
Self {
types: rhs_types,
map: self.map,
fields: self.fields.clone(),
offsets: rhs_offsets,
dtype: self.dtype.clone(),
offset: self.offset + offset,
},
)
}
}