use arrow::array::{Array as ArrowArray, BooleanArray, PrimitiveArray};
use arrow::buffer::Buffer;
use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType};
use ferray_core::array::aliases::Array1;
use ferray_core::{Element, FerrayError, Ix1};
use crate::dtype_map;
pub trait ArrowElement: Element + ArrowNativeType {
type ArrowType: ArrowPrimitiveType<Native = Self>;
}
macro_rules! impl_arrow_element {
($rust_ty:ty, $arrow_ty:ty) => {
impl ArrowElement for $rust_ty {
type ArrowType = $arrow_ty;
}
};
}
impl_arrow_element!(u8, arrow::datatypes::UInt8Type);
impl_arrow_element!(u16, arrow::datatypes::UInt16Type);
impl_arrow_element!(u32, arrow::datatypes::UInt32Type);
impl_arrow_element!(u64, arrow::datatypes::UInt64Type);
impl_arrow_element!(i8, arrow::datatypes::Int8Type);
impl_arrow_element!(i16, arrow::datatypes::Int16Type);
impl_arrow_element!(i32, arrow::datatypes::Int32Type);
impl_arrow_element!(i64, arrow::datatypes::Int64Type);
impl_arrow_element!(f32, arrow::datatypes::Float32Type);
impl_arrow_element!(f64, arrow::datatypes::Float64Type);
pub trait ToArrow {
type ArrowArray;
fn to_arrow(&self) -> Result<Self::ArrowArray, FerrayError>;
}
impl<T> ToArrow for Array1<T>
where
T: ArrowElement,
T::ArrowType: ArrowPrimitiveType<Native = T>,
{
type ArrowArray = PrimitiveArray<T::ArrowType>;
fn to_arrow(&self) -> Result<Self::ArrowArray, FerrayError> {
let _ = dtype_map::dtype_to_arrow(self.dtype())?;
let data: Vec<T> = match self.as_slice() {
Some(slice) => slice.to_vec(),
None => self.to_vec_flat(),
};
let buffer = Buffer::from_vec(data);
let array = PrimitiveArray::<T::ArrowType>::new(buffer.into(), None);
Ok(array)
}
}
pub trait ToArrowBool {
fn to_arrow(&self) -> Result<BooleanArray, FerrayError>;
}
impl ToArrowBool for Array1<bool> {
fn to_arrow(&self) -> Result<BooleanArray, FerrayError> {
let values: Vec<bool> = self.to_vec_flat();
Ok(BooleanArray::from(values))
}
}
pub trait FromArrow<T: Element>: Sized {
fn into_ferray(self) -> Result<Array1<T>, FerrayError>;
}
impl<T: ArrowElement> FromArrow<T> for PrimitiveArray<T::ArrowType>
where
T::ArrowType: ArrowPrimitiveType<Native = T>,
{
fn into_ferray(self) -> Result<Array1<T>, FerrayError> {
if self.null_count() > 0 {
return Err(FerrayError::invalid_value(format!(
"Arrow array contains {} null values; ferray arrays do not support nulls",
self.null_count()
)));
}
let arrow_dt = self.data_type();
let ferray_dt = dtype_map::arrow_to_dtype(arrow_dt)?;
if ferray_dt != T::dtype() {
return Err(FerrayError::invalid_dtype(format!(
"Arrow dtype {arrow_dt:?} maps to ferray {ferray_dt}, but requested {}",
T::dtype()
)));
}
let values = self.values();
let data: Vec<T> = values.iter().copied().collect();
let len = data.len();
Array1::<T>::from_vec(Ix1::new([len]), data)
}
}
pub trait FromArrowBool: Sized {
fn into_ferray_bool(self) -> Result<Array1<bool>, FerrayError>;
}
impl FromArrowBool for BooleanArray {
fn into_ferray_bool(self) -> Result<Array1<bool>, FerrayError> {
if self.null_count() > 0 {
return Err(FerrayError::invalid_value(format!(
"Arrow BooleanArray contains {} null values; ferray arrays do not support nulls",
self.null_count()
)));
}
let data: Vec<bool> = self.iter().map(|v| v.unwrap_or(false)).collect();
let len = data.len();
Array1::<bool>::from_vec(Ix1::new([len]), data)
}
}
use ferray_core::array::aliases::{Array2, ArrayD};
use ferray_core::dimension::{Ix2, IxDyn};
pub fn array2_to_arrow_columns<T>(
a: &Array2<T>,
) -> Result<Vec<PrimitiveArray<T::ArrowType>>, FerrayError>
where
T: ArrowElement,
T::ArrowType: ArrowPrimitiveType<Native = T>,
{
let _ = dtype_map::dtype_to_arrow(a.dtype())?;
let shape = a.shape();
let (nrows, ncols) = (shape[0], shape[1]);
let mut out: Vec<PrimitiveArray<T::ArrowType>> = Vec::with_capacity(ncols);
for c in 0..ncols {
let mut col: Vec<T> = Vec::with_capacity(nrows);
for r in 0..nrows {
if let Some(slice) = a.as_slice() {
col.push(slice[r * ncols + c]);
} else {
col.push(*a.iter().nth(r * ncols + c).unwrap());
}
}
let buffer = Buffer::from_vec(col);
out.push(PrimitiveArray::<T::ArrowType>::new(buffer.into(), None));
}
Ok(out)
}
pub fn array2_from_arrow_columns<T>(
cols: &[PrimitiveArray<T::ArrowType>],
) -> Result<Array2<T>, FerrayError>
where
T: ArrowElement,
T::ArrowType: ArrowPrimitiveType<Native = T>,
{
if cols.is_empty() {
return Array2::<T>::from_vec(Ix2::new([0, 0]), Vec::new());
}
let nrows = cols[0].len();
let ncols = cols.len();
for (i, c) in cols.iter().enumerate() {
if c.len() != nrows {
return Err(FerrayError::shape_mismatch(format!(
"array2_from_arrow_columns: column {i} has length {} but column 0 has length {nrows}",
c.len()
)));
}
if c.null_count() > 0 {
return Err(FerrayError::invalid_value(format!(
"Arrow array column {i} contains {} nulls; ferray arrays do not support nulls",
c.null_count()
)));
}
let arrow_dt = c.data_type();
let ferray_dt = dtype_map::arrow_to_dtype(arrow_dt)?;
if ferray_dt != T::dtype() {
return Err(FerrayError::invalid_dtype(format!(
"Arrow dtype {arrow_dt:?} on column {i} maps to ferray {ferray_dt}, but requested {}",
T::dtype()
)));
}
}
let mut data: Vec<T> = Vec::with_capacity(nrows * ncols);
for r in 0..nrows {
for c in cols {
data.push(c.values()[r]);
}
}
Array2::<T>::from_vec(Ix2::new([nrows, ncols]), data)
}
pub fn arrayd_to_arrow_flat<T>(a: &ArrayD<T>) -> Result<PrimitiveArray<T::ArrowType>, FerrayError>
where
T: ArrowElement,
T::ArrowType: ArrowPrimitiveType<Native = T>,
{
let _ = dtype_map::dtype_to_arrow(a.dtype())?;
let data: Vec<T> = match a.as_slice() {
Some(slice) => slice.to_vec(),
None => a.to_vec_flat(),
};
let buffer = Buffer::from_vec(data);
Ok(PrimitiveArray::<T::ArrowType>::new(buffer.into(), None))
}
pub fn arrayd_from_arrow_flat<T>(
arr: &PrimitiveArray<T::ArrowType>,
shape: &[usize],
) -> Result<ArrayD<T>, FerrayError>
where
T: ArrowElement,
T::ArrowType: ArrowPrimitiveType<Native = T>,
{
if arr.null_count() > 0 {
return Err(FerrayError::invalid_value(format!(
"Arrow array contains {} null values; ferray arrays do not support nulls",
arr.null_count()
)));
}
let arrow_dt = arr.data_type();
let ferray_dt = dtype_map::arrow_to_dtype(arrow_dt)?;
if ferray_dt != T::dtype() {
return Err(FerrayError::invalid_dtype(format!(
"Arrow dtype {arrow_dt:?} maps to ferray {ferray_dt}, but requested {}",
T::dtype()
)));
}
let expected: usize = shape.iter().product();
if arr.len() != expected {
return Err(FerrayError::shape_mismatch(format!(
"arrayd_from_arrow_flat: arrow length {} does not match shape {:?} (product {})",
arr.len(),
shape,
expected
)));
}
let data: Vec<T> = arr.values().iter().copied().collect();
ArrayD::<T>::from_vec(IxDyn::new(shape), data)
}
#[cfg(test)]
#[allow(
clippy::float_cmp,
clippy::unreadable_literal,
clippy::type_repetition_in_bounds,
clippy::trait_duplication_in_bounds
)]
mod tests {
use super::*;
macro_rules! test_roundtrip {
($name:ident, $ty:ty, $values:expr) => {
#[test]
fn $name() {
let data: Vec<$ty> = $values;
let len = data.len();
let arr = Array1::<$ty>::from_vec(Ix1::new([len]), data.clone()).unwrap();
let arrow_arr = arr.to_arrow().unwrap();
assert_eq!(arrow_arr.len(), len);
let back: Array1<$ty> = arrow_arr.into_ferray().unwrap();
assert_eq!(back.shape(), &[len]);
assert_eq!(back.as_slice().unwrap(), &data[..]);
}
};
}
test_roundtrip!(roundtrip_f64, f64, vec![1.0, 2.5, -4.75, 0.0, f64::MAX]);
test_roundtrip!(roundtrip_f32, f32, vec![1.0f32, -2.5, 0.0, f32::MIN]);
test_roundtrip!(roundtrip_i32, i32, vec![0, 1, -1, i32::MAX, i32::MIN]);
test_roundtrip!(roundtrip_i64, i64, vec![0i64, 42, -99]);
test_roundtrip!(roundtrip_i8, i8, vec![0i8, 127, -128]);
test_roundtrip!(roundtrip_i16, i16, vec![0i16, 32767, -32768]);
test_roundtrip!(roundtrip_u8, u8, vec![0u8, 128, 255]);
test_roundtrip!(roundtrip_u16, u16, vec![0u16, 1000, 65535]);
test_roundtrip!(roundtrip_u32, u32, vec![0u32, 1, u32::MAX]);
test_roundtrip!(roundtrip_u64, u64, vec![0u64, 1, u64::MAX]);
#[test]
fn roundtrip_bool() {
let data = vec![true, false, true, true, false];
let len = data.len();
let arr = Array1::<bool>::from_vec(Ix1::new([len]), data.clone()).unwrap();
let arrow_arr = arr.to_arrow().unwrap();
assert_eq!(arrow_arr.len(), len);
let back = arrow_arr.into_ferray_bool().unwrap();
assert_eq!(back.as_slice().unwrap(), &data[..]);
}
#[test]
fn empty_array_roundtrip() {
let arr = Array1::<f64>::from_vec(Ix1::new([0]), vec![]).unwrap();
let arrow_arr = arr.to_arrow().unwrap();
assert_eq!(arrow_arr.len(), 0);
let back: Array1<f64> = arrow_arr.into_ferray().unwrap();
assert_eq!(back.shape(), &[0]);
}
#[test]
fn arrow_with_nulls_rejected() {
let arr =
PrimitiveArray::<arrow::datatypes::Float64Type>::from(vec![Some(1.0), None, Some(3.0)]);
let result: Result<Array1<f64>, _> = arr.into_ferray();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("null"));
}
#[test]
fn bool_arrow_with_nulls_rejected() {
let arr = BooleanArray::from(vec![Some(true), None, Some(false)]);
let result = arr.into_ferray_bool();
assert!(result.is_err());
}
#[test]
fn dtype_mismatch_arrow_to_dtype() {
let result = dtype_map::arrow_to_dtype(&arrow::datatypes::DataType::Utf8);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("no ferray equivalent"), "got: {msg}");
}
#[test]
fn bit_identical_roundtrip() {
let original: Vec<f64> = vec![
1.0,
-0.0,
f64::INFINITY,
f64::NEG_INFINITY,
f64::NAN,
1.23456789012345e-300,
9.87654321098765e+300,
];
let len = original.len();
let arr = Array1::<f64>::from_vec(Ix1::new([len]), original.clone()).unwrap();
let arrow_arr = arr.to_arrow().unwrap();
let back: Array1<f64> = arrow_arr.into_ferray().unwrap();
let back_slice = back.as_slice().unwrap();
for (i, (a, b)) in original.iter().zip(back_slice.iter()).enumerate() {
assert_eq!(
a.to_bits(),
b.to_bits(),
"bit mismatch at index {i}: {a} vs {b}"
);
}
}
#[test]
fn array2_arrow_columns_roundtrip_f64() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = Array2::<f64>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
let cols = array2_to_arrow_columns(&a).unwrap();
assert_eq!(cols.len(), 3);
for c in &cols {
assert_eq!(c.len(), 2);
}
assert_eq!(cols[0].values()[0], 1.0);
assert_eq!(cols[0].values()[1], 4.0);
assert_eq!(cols[2].values()[0], 3.0);
assert_eq!(cols[2].values()[1], 6.0);
let back = array2_from_arrow_columns::<f64>(&cols).unwrap();
assert_eq!(back.shape(), &[2, 3]);
assert_eq!(back.as_slice().unwrap(), &data[..]);
}
#[test]
fn array2_arrow_columns_roundtrip_i32() {
let data = vec![1i32, 2, 3, 4, 5, 6, 7, 8];
let a = Array2::<i32>::from_vec(Ix2::new([4, 2]), data.clone()).unwrap();
let cols = array2_to_arrow_columns(&a).unwrap();
assert_eq!(cols.len(), 2);
let back = array2_from_arrow_columns::<i32>(&cols).unwrap();
assert_eq!(back.shape(), &[4, 2]);
assert_eq!(back.as_slice().unwrap(), &data[..]);
}
#[test]
fn array2_from_arrow_columns_inconsistent_length_errors() {
let c0 = PrimitiveArray::<arrow::datatypes::Float64Type>::new(
Buffer::from_vec(vec![1.0_f64, 2.0, 3.0]).into(),
None,
);
let c1 = PrimitiveArray::<arrow::datatypes::Float64Type>::new(
Buffer::from_vec(vec![4.0_f64, 5.0]).into(),
None,
);
assert!(array2_from_arrow_columns::<f64>(&[c0, c1]).is_err());
}
#[test]
fn array2_from_arrow_columns_empty_is_empty() {
let cols: Vec<PrimitiveArray<arrow::datatypes::Float64Type>> = Vec::new();
let a = array2_from_arrow_columns::<f64>(&cols).unwrap();
assert_eq!(a.shape(), &[0, 0]);
}
#[test]
fn arrayd_flat_roundtrip() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let a = ArrayD::<f64>::from_vec(IxDyn::new(&[2, 2, 2]), data.clone()).unwrap();
let flat = arrayd_to_arrow_flat(&a).unwrap();
assert_eq!(flat.len(), 8);
let back = arrayd_from_arrow_flat::<f64>(&flat, &[2, 2, 2]).unwrap();
assert_eq!(back.shape(), &[2, 2, 2]);
assert_eq!(back.to_vec_flat(), data);
}
#[test]
fn arrayd_flat_reshape_to_different_rank() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = ArrayD::<f64>::from_vec(IxDyn::new(&[6]), data.clone()).unwrap();
let flat = arrayd_to_arrow_flat(&a).unwrap();
let reshaped = arrayd_from_arrow_flat::<f64>(&flat, &[2, 3]).unwrap();
assert_eq!(reshaped.shape(), &[2, 3]);
assert_eq!(reshaped.to_vec_flat(), data);
}
#[test]
fn arrayd_from_arrow_flat_wrong_shape_errors() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = ArrayD::<f64>::from_vec(IxDyn::new(&[6]), data).unwrap();
let flat = arrayd_to_arrow_flat(&a).unwrap();
assert!(arrayd_from_arrow_flat::<f64>(&flat, &[2, 4]).is_err());
}
}