use std::sync::Arc;
use arrow::array::as_struct_array;
use arrow_array::{
Array, ArrayRef, ArrowNumericType, FixedSizeBinaryArray, FixedSizeListArray, GenericListArray,
OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt8Array,
};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
mod kernels;
pub mod linalg;
use crate::error::{Error, Result};
pub use kernels::*;
pub mod schema;
pub use schema::*;
pub mod bfloat16;
pub mod json;
pub trait DataTypeExt {
fn is_binary_like(&self) -> bool;
fn is_struct(&self) -> bool;
fn is_fixed_stride(&self) -> bool;
fn is_dictionary(&self) -> bool;
fn byte_width(&self) -> usize;
}
impl DataTypeExt for DataType {
fn is_binary_like(&self) -> bool {
use DataType::*;
matches!(self, Utf8 | Binary | LargeUtf8 | LargeBinary)
}
fn is_struct(&self) -> bool {
matches!(self, Self::Struct(_))
}
fn is_fixed_stride(&self) -> bool {
use DataType::*;
matches!(
self,
Boolean
| UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
| Decimal128(_, _)
| Decimal256(_, _)
| FixedSizeList(_, _)
| FixedSizeBinary(_)
| Duration(_)
| Timestamp(_, _)
| Date32
| Date64
| Time32(_)
| Time64(_)
)
}
fn is_dictionary(&self) -> bool {
matches!(self, Self::Dictionary(_, _))
}
fn byte_width(&self) -> usize {
match self {
Self::Int8 => 1,
Self::Int16 => 2,
Self::Int32 => 4,
Self::Int64 => 8,
Self::UInt8 => 1,
Self::UInt16 => 2,
Self::UInt32 => 4,
Self::UInt64 => 8,
Self::Float16 => 2,
Self::Float32 => 4,
Self::Float64 => 8,
Self::Date32 => 4,
Self::Date64 => 8,
Self::Time32(_) => 4,
Self::Time64(_) => 8,
Self::Timestamp(_, _) => 8,
Self::Duration(_) => 8,
Self::Decimal128(_, _) => 16,
Self::Decimal256(_, _) => 32,
Self::FixedSizeBinary(s) => *s as usize,
Self::FixedSizeList(dt, s) => *s as usize * dt.data_type().byte_width(),
_ => panic!("Does not support get byte width on type {self}"),
}
}
}
pub fn try_new_generic_list_array<T: Array, Offset: ArrowNumericType>(
values: T,
offsets: &PrimitiveArray<Offset>,
) -> Result<GenericListArray<Offset::Native>>
where
Offset::Native: OffsetSizeTrait,
{
let data_type = if Offset::Native::IS_LARGE {
DataType::LargeList(Arc::new(Field::new(
"item",
values.data_type().clone(),
true,
)))
} else {
DataType::List(Arc::new(Field::new(
"item",
values.data_type().clone(),
true,
)))
};
let data = ArrayDataBuilder::new(data_type)
.len(offsets.len() - 1)
.add_buffer(offsets.into_data().buffers()[0].clone())
.add_child_data(values.into_data())
.build()?;
Ok(GenericListArray::from(data))
}
pub fn fixed_size_list_type(list_width: i32, inner_type: DataType) -> DataType {
DataType::FixedSizeList(Arc::new(Field::new("item", inner_type, true)), list_width)
}
pub trait FixedSizeListArrayExt {
fn try_new_from_values<T: Array + 'static>(
values: T,
list_size: i32,
) -> Result<FixedSizeListArray>;
}
impl FixedSizeListArrayExt for FixedSizeListArray {
fn try_new_from_values<T: Array + 'static>(values: T, list_size: i32) -> Result<Self> {
let field = Arc::new(Field::new("item", values.data_type().clone(), true));
let values = Arc::new(values);
Ok(Self::try_new(field, list_size, values, None)?)
}
}
pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
}
pub trait FixedSizeBinaryArrayExt {
fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<FixedSizeBinaryArray>;
}
impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray {
fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<Self> {
let data_type = DataType::FixedSizeBinary(stride);
let data = ArrayDataBuilder::new(data_type)
.len(values.len() / stride as usize)
.add_buffer(values.into_data().buffers()[0].clone())
.build()?;
Ok(Self::from(data))
}
}
pub fn as_fixed_size_binary_array(arr: &dyn Array) -> &FixedSizeBinaryArray {
arr.as_any().downcast_ref::<FixedSizeBinaryArray>().unwrap()
}
pub trait RecordBatchExt {
fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
fn try_new_from_struct_array(&self, arr: StructArray) -> Result<RecordBatch>;
fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;
fn drop_column(&self, name: &str) -> Result<RecordBatch>;
fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef>;
fn project_by_schema(&self, schema: &Schema) -> Result<RecordBatch>;
}
impl RecordBatchExt for RecordBatch {
fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<Self> {
let mut new_fields: Vec<FieldRef> = self.schema().fields.iter().cloned().collect();
new_fields.push(FieldRef::new(field));
let new_schema = Arc::new(Schema::new_with_metadata(
Fields::from(new_fields.as_slice()),
self.schema().metadata.clone(),
));
let mut new_columns = self.columns().to_vec();
new_columns.push(arr);
Ok(Self::try_new(new_schema, new_columns)?)
}
fn try_new_from_struct_array(&self, arr: StructArray) -> Result<Self> {
let schema = Arc::new(Schema::new_with_metadata(
arr.fields().to_vec(),
self.schema().metadata.clone(),
));
Ok(Self::try_new(schema, arr.columns().to_vec())?)
}
fn merge(&self, other: &Self) -> Result<Self> {
if self.num_rows() != other.num_rows() {
return Err(Error::Arrow {
message: format!(
"Attempt to merge two RecordBatch with different sizes: {} != {}",
self.num_rows(),
other.num_rows()
),
});
}
let left_struct_array: StructArray = self.clone().into();
let right_struct_array: StructArray = other.clone().into();
self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array)?)
}
fn drop_column(&self, name: &str) -> Result<Self> {
let mut fields = vec![];
let mut columns = vec![];
for i in 0..self.schema().fields.len() {
if self.schema().field(i).name() != name {
fields.push(self.schema().field(i).clone());
columns.push(self.column(i).clone());
}
}
Ok(Self::try_new(
Arc::new(Schema::new_with_metadata(
fields,
self.schema().metadata().clone(),
)),
columns,
)?)
}
fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef> {
let split = name.split('.').collect::<Vec<_>>();
if split.is_empty() {
return None;
}
self.column_by_name(split[0])
.and_then(|arr| get_sub_array(arr, &split[1..]))
}
fn project_by_schema(&self, schema: &Schema) -> Result<Self> {
let struct_array: StructArray = self.clone().into();
self.try_new_from_struct_array(project(&struct_array, schema.fields())?)
}
}
fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
let mut columns: Vec<ArrayRef> = vec![];
for field in fields.iter() {
if let Some(col) = struct_array.column_by_name(field.name()) {
match field.data_type() {
DataType::Struct(subfields) => {
let projected = project(as_struct_array(col), subfields)?;
columns.push(Arc::new(projected));
}
_ => {
columns.push(col.clone());
}
}
} else {
return Err(Error::Arrow {
message: format!("field {} does not exist in the RecordBatch", field.name()),
});
}
}
Ok(StructArray::from(
fields.iter().cloned().zip(columns).collect::<Vec<_>>(),
))
}
fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> Result<StructArray> {
let mut fields: Vec<Field> = vec![];
let mut columns: Vec<ArrayRef> = vec![];
let right_fields = right_struct_array.fields();
let right_columns = right_struct_array.columns();
for (left_field, left_column) in left_struct_array
.fields()
.iter()
.zip(left_struct_array.columns().iter())
{
match right_fields
.iter()
.position(|f| f.name() == left_field.name())
{
Some(right_index) => {
let right_field = right_fields.get(right_index).unwrap();
let right_column = right_columns.get(right_index).unwrap();
match (left_field.data_type(), right_field.data_type()) {
(DataType::Struct(_), DataType::Struct(_)) => {
let left_sub_array = as_struct_array(left_column);
let right_sub_array = as_struct_array(right_column);
let merged_sub_array = merge(left_sub_array, right_sub_array)?;
fields.push(Field::new(
left_field.name(),
merged_sub_array.data_type().clone(),
left_field.is_nullable(),
));
columns.push(Arc::new(merged_sub_array) as ArrayRef);
}
_ => {
fields.push(left_field.as_ref().clone());
columns.push(left_column.clone());
}
}
}
None => {
fields.push(left_field.as_ref().clone());
columns.push(left_column.clone());
}
}
}
right_fields
.iter()
.zip(right_columns.iter())
.for_each(|(field, column)| {
if !left_struct_array
.fields()
.iter()
.any(|f| f.name() == field.name())
{
fields.push(field.as_ref().clone());
columns.push(column.clone() as ArrayRef);
}
});
let zipped: Vec<(FieldRef, ArrayRef)> = fields
.iter()
.cloned()
.map(Arc::new)
.zip(columns.iter().cloned())
.collect::<Vec<_>>();
StructArray::try_from(zipped).map_err(|e| Error::Arrow {
message: format!("Failed to merge RecordBatch: {}", e),
})
}
fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> {
if components.is_empty() {
return Some(array);
}
if !matches!(array.data_type(), DataType::Struct(_)) {
return None;
}
let struct_arr = as_struct_array(array.as_ref());
struct_arr
.column_by_name(components[0])
.and_then(|arr| get_sub_array(arr, &components[1..]))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray};
use arrow_schema::{DataType, Field};
#[test]
fn test_merge_recursive() {
let a_array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
let e_array = Int32Array::from(vec![Some(4), Some(5), Some(6)]);
let c_array = Int32Array::from(vec![Some(7), Some(8), Some(9)]);
let d_array = StringArray::from(vec![Some("a"), Some("b"), Some("c")]);
let left_schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new(
"b",
DataType::Struct(vec![Field::new("c", DataType::Int32, true)].into()),
true,
),
]);
let left_batch = RecordBatch::try_new(
Arc::new(left_schema),
vec![
Arc::new(a_array.clone()),
Arc::new(StructArray::from(vec![(
Arc::new(Field::new("c", DataType::Int32, true)),
Arc::new(c_array.clone()) as ArrayRef,
)])),
],
)
.unwrap();
let right_schema = Schema::new(vec![
Field::new("e", DataType::Int32, true),
Field::new(
"b",
DataType::Struct(vec![Field::new("d", DataType::Utf8, true)].into()),
true,
),
]);
let right_batch = RecordBatch::try_new(
Arc::new(right_schema),
vec![
Arc::new(e_array.clone()),
Arc::new(StructArray::from(vec![(
Arc::new(Field::new("d", DataType::Utf8, true)),
Arc::new(d_array.clone()) as ArrayRef,
)])) as ArrayRef,
],
)
.unwrap();
let merged_schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new(
"b",
DataType::Struct(
vec![
Field::new("c", DataType::Int32, true),
Field::new("d", DataType::Utf8, true),
]
.into(),
),
true,
),
Field::new("e", DataType::Int32, true),
]);
let merged_batch = RecordBatch::try_new(
Arc::new(merged_schema),
vec![
Arc::new(a_array) as ArrayRef,
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("c", DataType::Int32, true)),
Arc::new(c_array) as ArrayRef,
),
(
Arc::new(Field::new("d", DataType::Utf8, true)),
Arc::new(d_array) as ArrayRef,
),
])) as ArrayRef,
Arc::new(e_array) as ArrayRef,
],
)
.unwrap();
let result = left_batch.merge(&right_batch).unwrap();
assert_eq!(result, merged_batch);
}
}