use std::sync::Arc;
use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
use arrow::compute::cast;
use arrow::datatypes::Field;
use crate::{Error, Transform};
#[derive(Clone, Default)]
pub struct PrimitiveCast<S, T> {
_phantom: std::marker::PhantomData<(S, T)>,
}
impl<S, T> PrimitiveCast<PrimitiveArray<S>, PrimitiveArray<T>>
where
S: ArrowPrimitiveType,
T: ArrowPrimitiveType,
{
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
impl<S, T> Transform for PrimitiveCast<PrimitiveArray<S>, PrimitiveArray<T>>
where
S: ArrowPrimitiveType,
T: ArrowPrimitiveType,
{
type Source = PrimitiveArray<S>;
type Target = PrimitiveArray<T>;
fn transform(&self, source: &PrimitiveArray<S>) -> Result<PrimitiveArray<T>, Error> {
let source_ref: &dyn Array = source;
let target_type = T::DATA_TYPE;
let casted = cast(source_ref, &target_type)?;
DowncastRef::<T>::new().transform(&casted)
}
}
#[derive(Clone, Default)]
pub struct DowncastRef<T> {
_phantom: std::marker::PhantomData<T>,
}
impl<T> DowncastRef<T> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
impl<T> Transform for DowncastRef<T>
where
T: ArrowPrimitiveType,
{
type Source = ArrayRef;
type Target = PrimitiveArray<T>;
fn transform(&self, source: &ArrayRef) -> Result<PrimitiveArray<T>, Error> {
source
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.ok_or_else(|| Error::TypeMismatch {
expected: std::any::type_name::<PrimitiveArray<T>>().to_owned(),
actual: source.data_type().clone(),
context: "downcast_ref".to_owned(),
})
.cloned()
}
}
#[derive(Clone)]
pub struct ListToFixedSizeList {
value_length: i32,
}
impl ListToFixedSizeList {
pub fn new(value_length: i32) -> Self {
Self { value_length }
}
}
impl Transform for ListToFixedSizeList {
type Source = arrow::array::ListArray;
type Target = arrow::array::FixedSizeListArray;
fn transform(&self, source: &Self::Source) -> Result<Self::Target, Error> {
let offsets = source.value_offsets();
let expected_length = self.value_length as usize;
for list_index in 0..source.len() {
if source.is_valid(list_index) {
let start = offsets[list_index] as usize;
let end = offsets[list_index + 1] as usize;
let list_length = end - start;
if list_length != expected_length {
return Err(Error::UnexpectedListValueLength {
expected: expected_length,
actual: list_length,
});
}
}
}
let field = Arc::new(Field::new_list_field(
source.value_type().clone(),
source.is_nullable(),
));
Ok(arrow::array::FixedSizeListArray::try_new(
field,
self.value_length,
source.values().clone(),
source.nulls().cloned(),
)?)
}
}