use vortex_dtype::DType;
use vortex_error::{VortexResult, vortex_panic};
use crate::arrays::{
BoolArray, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray,
ListViewRebuildMode, NullArray, PrimitiveArray, StructArray, VarBinViewArray,
};
use crate::builders::builder_with_capacity;
use crate::{Array, ArrayRef, IntoArray};
#[derive(Debug, Clone)]
pub enum Canonical {
Null(NullArray),
Bool(BoolArray),
Primitive(PrimitiveArray),
Decimal(DecimalArray),
VarBinView(VarBinViewArray),
List(ListViewArray),
FixedSizeList(FixedSizeListArray),
Struct(StructArray),
Extension(ExtensionArray),
}
impl Canonical {
pub fn empty(dtype: &DType) -> Canonical {
builder_with_capacity(dtype, 0).finish_into_canonical()
}
}
impl Canonical {
pub fn compact(&self) -> VortexResult<Canonical> {
match self {
Canonical::VarBinView(array) => Ok(Canonical::VarBinView(array.compact_buffers()?)),
Canonical::List(array) => Ok(Canonical::List(
array.rebuild(ListViewRebuildMode::MakeZeroCopyToList),
)),
_ => Ok(self.clone()),
}
}
}
impl Canonical {
pub fn as_null(&self) -> &NullArray {
if let Canonical::Null(a) = self {
a
} else {
vortex_panic!("Cannot get NullArray from {:?}", &self)
}
}
pub fn into_null(self) -> NullArray {
if let Canonical::Null(a) = self {
a
} else {
vortex_panic!("Cannot unwrap NullArray from {:?}", &self)
}
}
pub fn as_bool(&self) -> &BoolArray {
if let Canonical::Bool(a) = self {
a
} else {
vortex_panic!("Cannot get BoolArray from {:?}", &self)
}
}
pub fn into_bool(self) -> BoolArray {
if let Canonical::Bool(a) = self {
a
} else {
vortex_panic!("Cannot unwrap BoolArray from {:?}", &self)
}
}
pub fn as_primitive(&self) -> &PrimitiveArray {
if let Canonical::Primitive(a) = self {
a
} else {
vortex_panic!("Cannot get PrimitiveArray from {:?}", &self)
}
}
pub fn into_primitive(self) -> PrimitiveArray {
if let Canonical::Primitive(a) = self {
a
} else {
vortex_panic!("Cannot unwrap PrimitiveArray from {:?}", &self)
}
}
pub fn as_decimal(&self) -> &DecimalArray {
if let Canonical::Decimal(a) = self {
a
} else {
vortex_panic!("Cannot get DecimalArray from {:?}", &self)
}
}
pub fn into_decimal(self) -> DecimalArray {
if let Canonical::Decimal(a) = self {
a
} else {
vortex_panic!("Cannot unwrap DecimalArray from {:?}", &self)
}
}
pub fn as_varbinview(&self) -> &VarBinViewArray {
if let Canonical::VarBinView(a) = self {
a
} else {
vortex_panic!("Cannot get VarBinViewArray from {:?}", &self)
}
}
pub fn into_varbinview(self) -> VarBinViewArray {
if let Canonical::VarBinView(a) = self {
a
} else {
vortex_panic!("Cannot unwrap VarBinViewArray from {:?}", &self)
}
}
pub fn as_listview(&self) -> &ListViewArray {
if let Canonical::List(a) = self {
a
} else {
vortex_panic!("Cannot get ListArray from {:?}", &self)
}
}
pub fn into_listview(self) -> ListViewArray {
if let Canonical::List(a) = self {
a
} else {
vortex_panic!("Cannot unwrap ListArray from {:?}", &self)
}
}
pub fn as_fixed_size_list(&self) -> &FixedSizeListArray {
if let Canonical::FixedSizeList(a) = self {
a
} else {
vortex_panic!("Cannot get FixedSizeListArray from {:?}", &self)
}
}
pub fn into_fixed_size_list(self) -> FixedSizeListArray {
if let Canonical::FixedSizeList(a) = self {
a
} else {
vortex_panic!("Cannot unwrap FixedSizeListArray from {:?}", &self)
}
}
pub fn as_struct(&self) -> &StructArray {
if let Canonical::Struct(a) = self {
a
} else {
vortex_panic!("Cannot get StructArray from {:?}", &self)
}
}
pub fn into_struct(self) -> StructArray {
if let Canonical::Struct(a) = self {
a
} else {
vortex_panic!("Cannot unwrap StructArray from {:?}", &self)
}
}
pub fn as_extension(&self) -> &ExtensionArray {
if let Canonical::Extension(a) = self {
a
} else {
vortex_panic!("Cannot get ExtensionArray from {:?}", &self)
}
}
pub fn into_extension(self) -> ExtensionArray {
if let Canonical::Extension(a) = self {
a
} else {
vortex_panic!("Cannot unwrap ExtensionArray from {:?}", &self)
}
}
}
impl AsRef<dyn Array> for Canonical {
fn as_ref(&self) -> &(dyn Array + 'static) {
match &self {
Canonical::Null(a) => a.as_ref(),
Canonical::Bool(a) => a.as_ref(),
Canonical::Primitive(a) => a.as_ref(),
Canonical::Decimal(a) => a.as_ref(),
Canonical::Struct(a) => a.as_ref(),
Canonical::List(a) => a.as_ref(),
Canonical::FixedSizeList(a) => a.as_ref(),
Canonical::VarBinView(a) => a.as_ref(),
Canonical::Extension(a) => a.as_ref(),
}
}
}
impl IntoArray for Canonical {
fn into_array(self) -> ArrayRef {
match self {
Canonical::Null(a) => a.into_array(),
Canonical::Bool(a) => a.into_array(),
Canonical::Primitive(a) => a.into_array(),
Canonical::Decimal(a) => a.into_array(),
Canonical::Struct(a) => a.into_array(),
Canonical::List(a) => a.into_array(),
Canonical::FixedSizeList(a) => a.into_array(),
Canonical::VarBinView(a) => a.into_array(),
Canonical::Extension(a) => a.into_array(),
}
}
}
pub trait ToCanonical {
fn to_null(&self) -> NullArray;
fn to_bool(&self) -> BoolArray;
fn to_primitive(&self) -> PrimitiveArray;
fn to_decimal(&self) -> DecimalArray;
fn to_struct(&self) -> StructArray;
fn to_listview(&self) -> ListViewArray;
fn to_fixed_size_list(&self) -> FixedSizeListArray;
fn to_varbinview(&self) -> VarBinViewArray;
fn to_extension(&self) -> ExtensionArray;
}
impl<A: Array + ?Sized> ToCanonical for A {
fn to_null(&self) -> NullArray {
self.to_canonical().into_null()
}
fn to_bool(&self) -> BoolArray {
self.to_canonical().into_bool()
}
fn to_primitive(&self) -> PrimitiveArray {
self.to_canonical().into_primitive()
}
fn to_decimal(&self) -> DecimalArray {
self.to_canonical().into_decimal()
}
fn to_struct(&self) -> StructArray {
self.to_canonical().into_struct()
}
fn to_listview(&self) -> ListViewArray {
self.to_canonical().into_listview()
}
fn to_fixed_size_list(&self) -> FixedSizeListArray {
self.to_canonical().into_fixed_size_list()
}
fn to_varbinview(&self) -> VarBinViewArray {
self.to_canonical().into_varbinview()
}
fn to_extension(&self) -> ExtensionArray {
self.to_canonical().into_extension()
}
}
impl From<Canonical> for ArrayRef {
fn from(value: Canonical) -> Self {
match value {
Canonical::Null(a) => a.into_array(),
Canonical::Bool(a) => a.into_array(),
Canonical::Primitive(a) => a.into_array(),
Canonical::Decimal(a) => a.into_array(),
Canonical::Struct(a) => a.into_array(),
Canonical::List(a) => a.into_array(),
Canonical::FixedSizeList(a) => a.into_array(),
Canonical::VarBinView(a) => a.into_array(),
Canonical::Extension(a) => a.into_array(),
}
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use arrow_array::cast::AsArray;
use arrow_array::types::{Int32Type, Int64Type, UInt64Type};
use arrow_array::{
Array as ArrowArray, ArrayRef as ArrowArrayRef, ListArray as ArrowListArray,
PrimitiveArray as ArrowPrimitiveArray, StringArray, StringViewArray,
StructArray as ArrowStructArray,
};
use arrow_buffer::{NullBufferBuilder, OffsetBuffer};
use arrow_schema::{DataType, Field};
use vortex_buffer::buffer;
use crate::arrays::{ConstantArray, StructArray};
use crate::arrow::{FromArrowArray, IntoArrowArray};
use crate::{ArrayRef, IntoArray};
#[test]
fn test_canonicalize_nested_struct() {
let nested_struct_array = StructArray::from_fields(&[
("a", buffer![1u64].into_array()),
(
"b",
StructArray::from_fields(&[(
"inner_a",
ConstantArray::new(100i64, 1).into_array(),
)])
.unwrap()
.into_array(),
),
])
.unwrap();
let arrow_struct = nested_struct_array
.into_array()
.into_arrow_preferred()
.unwrap()
.as_any()
.downcast_ref::<ArrowStructArray>()
.cloned()
.unwrap();
assert!(
arrow_struct
.column(0)
.as_any()
.downcast_ref::<ArrowPrimitiveArray<UInt64Type>>()
.is_some()
);
let inner_struct = arrow_struct
.column(1)
.clone()
.as_any()
.downcast_ref::<ArrowStructArray>()
.cloned()
.unwrap();
let inner_a = inner_struct
.column(0)
.as_any()
.downcast_ref::<ArrowPrimitiveArray<Int64Type>>();
assert!(inner_a.is_some());
assert_eq!(
inner_a.cloned().unwrap(),
ArrowPrimitiveArray::from_iter([100i64])
);
}
#[test]
fn roundtrip_struct() {
let mut nulls = NullBufferBuilder::new(6);
nulls.append_n_non_nulls(4);
nulls.append_null();
nulls.append_non_null();
let names = Arc::new(StringViewArray::from_iter(vec![
Some("Joseph"),
None,
Some("Angela"),
Some("Mikhail"),
None,
None,
]));
let ages = Arc::new(ArrowPrimitiveArray::<Int32Type>::from(vec![
Some(25),
Some(31),
None,
Some(57),
None,
None,
]));
let arrow_struct = ArrowStructArray::new(
vec![
Arc::new(Field::new("name", DataType::Utf8View, true)),
Arc::new(Field::new("age", DataType::Int32, true)),
]
.into(),
vec![names, ages],
nulls.finish(),
);
let vortex_struct = ArrayRef::from_arrow(&arrow_struct, true);
assert_eq!(
&arrow_struct,
vortex_struct.into_arrow_preferred().unwrap().as_struct()
);
}
#[test]
fn roundtrip_list() {
let names = Arc::new(StringArray::from_iter(vec![
Some("Joseph"),
Some("Angela"),
Some("Mikhail"),
]));
let arrow_list = ArrowListArray::new(
Arc::new(Field::new_list_field(DataType::Utf8, true)),
OffsetBuffer::from_lengths(vec![0, 2, 1]),
names,
None,
);
let list_data_type = arrow_list.data_type();
let vortex_list = ArrayRef::from_arrow(&arrow_list, true);
let rt_arrow_list = vortex_list.into_arrow(list_data_type).unwrap();
assert_eq!(
(Arc::new(arrow_list.clone()) as ArrowArrayRef).as_ref(),
rt_arrow_list.as_ref()
);
}
}