use crate::error::{IoError, Result};
use arrow::array::{
ArrayRef, AsArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, PrimitiveArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::datatypes::{ArrowPrimitiveType, DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use scirs2_core::ndarray::{Array1, ArrayBase, Data as NdData, Dimension};
use std::sync::Arc;
pub fn ndarray_to_arrow<S, D, T>(array: &ArrayBase<S, D>, column_name: &str) -> Result<RecordBatch>
where
S: NdData<Elem = T>,
D: Dimension,
T: ToArrowArray + Clone,
{
let flat_data: Vec<T> = array.iter().cloned().collect();
let arrow_array = T::to_arrow_array(&flat_data)?;
let data_type = T::arrow_data_type();
let field = Field::new(column_name, data_type, false);
let schema = Arc::new(Schema::new(vec![field]));
RecordBatch::try_new(schema, vec![arrow_array])
.map_err(|e| IoError::ParquetError(format!("Failed to create RecordBatch: {}", e)))
}
pub fn arrow_to_ndarray<T>(batch: &RecordBatch, column_index: usize) -> Result<Array1<T>>
where
T: FromArrowArray,
{
if column_index >= batch.num_columns() {
return Err(IoError::ParquetError(format!(
"Column index {} out of bounds (num_columns={})",
column_index,
batch.num_columns()
)));
}
let column = batch.column(column_index);
T::from_arrow_array(column)
}
pub trait ToArrowArray: Sized {
fn to_arrow_array(data: &[Self]) -> Result<ArrayRef>;
fn arrow_data_type() -> DataType;
}
pub trait FromArrowArray: Sized {
fn from_arrow_array(array: &ArrayRef) -> Result<Array1<Self>>;
}
macro_rules! impl_arrow_conversion {
($rust_type:ty, $arrow_type:ty, $data_type:expr, $array_type:ty) => {
impl ToArrowArray for $rust_type {
fn to_arrow_array(data: &[Self]) -> Result<ArrayRef> {
Ok(Arc::new(<$array_type>::from(Vec::from(data))))
}
fn arrow_data_type() -> DataType {
$data_type
}
}
impl FromArrowArray for $rust_type {
fn from_arrow_array(array: &ArrayRef) -> Result<Array1<Self>> {
let typed_array = array.as_primitive_opt::<$arrow_type>().ok_or_else(|| {
IoError::ParquetError(format!(
"Expected {} array, got {:?}",
stringify!($rust_type),
array.data_type()
))
})?;
let values: Vec<$rust_type> = typed_array.values().iter().copied().collect();
Ok(Array1::from_vec(values))
}
}
};
}
impl_arrow_conversion!(
f64,
arrow::datatypes::Float64Type,
DataType::Float64,
Float64Array
);
impl_arrow_conversion!(
f32,
arrow::datatypes::Float32Type,
DataType::Float32,
Float32Array
);
impl_arrow_conversion!(
i64,
arrow::datatypes::Int64Type,
DataType::Int64,
Int64Array
);
impl_arrow_conversion!(
i32,
arrow::datatypes::Int32Type,
DataType::Int32,
Int32Array
);
impl_arrow_conversion!(
i16,
arrow::datatypes::Int16Type,
DataType::Int16,
Int16Array
);
impl_arrow_conversion!(i8, arrow::datatypes::Int8Type, DataType::Int8, Int8Array);
impl_arrow_conversion!(
u64,
arrow::datatypes::UInt64Type,
DataType::UInt64,
UInt64Array
);
impl_arrow_conversion!(
u32,
arrow::datatypes::UInt32Type,
DataType::UInt32,
UInt32Array
);
impl_arrow_conversion!(
u16,
arrow::datatypes::UInt16Type,
DataType::UInt16,
UInt16Array
);
impl_arrow_conversion!(u8, arrow::datatypes::UInt8Type, DataType::UInt8, UInt8Array);
impl ToArrowArray for bool {
fn to_arrow_array(data: &[Self]) -> Result<ArrayRef> {
Ok(Arc::new(BooleanArray::from(Vec::from(data))))
}
fn arrow_data_type() -> DataType {
DataType::Boolean
}
}
impl FromArrowArray for bool {
fn from_arrow_array(array: &ArrayRef) -> Result<Array1<Self>> {
let bool_array = array.as_boolean_opt().ok_or_else(|| {
IoError::ParquetError(format!(
"Expected Boolean array, got {:?}",
array.data_type()
))
})?;
let values: Vec<bool> = (0..bool_array.len()).map(|i| bool_array.value(i)).collect();
Ok(Array1::from_vec(values))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ndarray_to_arrow_f64() {
let arr = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let batch = ndarray_to_arrow(&arr, "test").expect("Operation failed");
assert_eq!(batch.num_rows(), 4);
assert_eq!(batch.num_columns(), 1);
assert_eq!(batch.schema().field(0).name(), "test");
}
#[test]
fn test_arrow_to_ndarray_f64() {
let arr = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let batch = ndarray_to_arrow(&arr, "test").expect("Operation failed");
let recovered: Array1<f64> = arrow_to_ndarray(&batch, 0).expect("Operation failed");
assert_eq!(arr, recovered);
}
#[test]
fn test_roundtrip_i32() {
let arr = Array1::from_vec(vec![10i32, 20, 30, 40]);
let batch = ndarray_to_arrow(&arr, "integers").expect("Operation failed");
let recovered: Array1<i32> = arrow_to_ndarray(&batch, 0).expect("Operation failed");
assert_eq!(arr, recovered);
}
#[test]
fn test_roundtrip_bool() {
let arr = Array1::from_vec(vec![true, false, true, false]);
let batch = ndarray_to_arrow(&arr, "booleans").expect("Operation failed");
let recovered: Array1<bool> = arrow_to_ndarray(&batch, 0).expect("Operation failed");
assert_eq!(arr, recovered);
}
}