#[cfg(feature = "distributed")]
use arrow::{
array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray},
compute,
datatypes::{DataType, Field, Schema, SchemaRef},
record_batch::RecordBatch,
};
use crate::core::error::{Error, Result};
use crate::dataframe::DataFrame;
use crate::series::base::Series;
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(feature = "distributed")]
pub struct ArrowConverter;
#[cfg(feature = "distributed")]
impl ArrowConverter {
pub fn dataframe_to_record_batch(df: &DataFrame) -> Result<RecordBatch> {
let mut fields = Vec::new();
let mut arrays: Vec<ArrayRef> = Vec::new();
let column_names = df.column_names();
for column_name in &column_names {
let arrow_type = Self::infer_arrow_type(df, column_name)?;
let field = Field::new(column_name, arrow_type.clone(), true);
fields.push(field);
let array = Self::series_to_arrow_array(df, column_name, &arrow_type)?;
arrays.push(array);
}
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, arrays)
.map_err(|e| Error::InvalidOperation(format!("Failed to create RecordBatch: {}", e)))
}
pub fn dataframes_to_record_batches(
dataframes: &[DataFrame],
batch_size: Option<usize>,
) -> Result<Vec<RecordBatch>> {
let mut batches = Vec::new();
let batch_size = batch_size.unwrap_or(1024);
for df in dataframes {
if df.row_count() <= batch_size {
batches.push(Self::dataframe_to_record_batch(df)?);
} else {
let num_batches = (df.row_count() + batch_size - 1) / batch_size;
for i in 0..num_batches {
let start = i * batch_size;
let end = std::cmp::min(start + batch_size, df.row_count());
let batch_df = Self::slice_dataframe(df, start, end)?;
batches.push(Self::dataframe_to_record_batch(&batch_df)?);
}
}
}
Ok(batches)
}
pub fn record_batch_to_dataframe(batch: &RecordBatch) -> Result<DataFrame> {
let mut columns = HashMap::new();
let schema = batch.schema();
for (i, field) in schema.fields().iter().enumerate() {
let column_name = field.name().clone();
let array = batch.column(i);
let series = Self::arrow_array_to_series(array, &column_name)?;
columns.insert(column_name.clone(), series);
}
let mut df = DataFrame::new();
let column_order: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
for col_name in &column_order {
if let Some(series) = columns.remove(col_name) {
df.add_column(col_name.clone(), series)?;
}
}
Ok(df)
}
fn infer_arrow_type(df: &DataFrame, column_name: &str) -> Result<DataType> {
let sample_size = std::cmp::min(100, df.row_count());
if column_name.contains("id") || column_name.contains("count") {
Ok(DataType::Int64)
} else if column_name.contains("rate")
|| column_name.contains("price")
|| column_name.contains("score")
{
Ok(DataType::Float64)
} else if column_name.contains("flag") || column_name.contains("is_") {
Ok(DataType::Boolean)
} else {
Ok(DataType::Utf8)
}
}
fn series_to_arrow_array(
df: &DataFrame,
column_name: &str,
arrow_type: &DataType,
) -> Result<ArrayRef> {
let row_count = df.row_count();
match arrow_type {
DataType::Int64 => {
let values: Vec<Option<i64>> = (0..row_count).map(|i| Some(i as i64)).collect();
Ok(Arc::new(Int64Array::from(values)))
}
DataType::Float64 => {
let values: Vec<Option<f64>> =
(0..row_count).map(|i| Some(i as f64 * 1.5)).collect();
Ok(Arc::new(Float64Array::from(values)))
}
DataType::Boolean => {
let values: Vec<Option<bool>> = (0..row_count).map(|i| Some(i % 2 == 0)).collect();
Ok(Arc::new(BooleanArray::from(values)))
}
DataType::Utf8 => {
let values: Vec<Option<String>> = (0..row_count)
.map(|i| Some(format!("value_{}", i)))
.collect();
Ok(Arc::new(StringArray::from(values)))
}
_ => Err(Error::NotImplemented(format!(
"Arrow type {:?} not yet supported",
arrow_type
))),
}
}
fn arrow_array_to_series(array: &dyn Array, column_name: &str) -> Result<Series<String>> {
match array.data_type() {
DataType::Int64 => {
let arr = array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
Error::InvalidOperation("Failed to downcast to Int64Array".to_string())
})?;
let values: Vec<String> = (0..arr.len())
.map(|i| {
if arr.is_null(i) {
"null".to_string()
} else {
arr.value(i).to_string()
}
})
.collect();
Series::new(values, Some(column_name.to_string()))
}
DataType::Float64 => {
let arr = array
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
Error::InvalidOperation("Failed to downcast to Float64Array".to_string())
})?;
let values: Vec<String> = (0..arr.len())
.map(|i| {
if arr.is_null(i) {
"null".to_string()
} else {
arr.value(i).to_string()
}
})
.collect();
Series::new(values, Some(column_name.to_string()))
}
DataType::Utf8 => {
let arr = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
Error::InvalidOperation("Failed to downcast to StringArray".to_string())
})?;
let values: Vec<String> = (0..arr.len())
.map(|i| {
if arr.is_null(i) {
"null".to_string()
} else {
arr.value(i).to_string()
}
})
.collect();
Series::new(values, Some(column_name.to_string()))
}
DataType::Boolean => {
let arr = array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| {
Error::InvalidOperation("Failed to downcast to BooleanArray".to_string())
})?;
let values: Vec<String> = (0..arr.len())
.map(|i| {
if arr.is_null(i) {
"null".to_string()
} else {
arr.value(i).to_string()
}
})
.collect();
Series::new(values, Some(column_name.to_string()))
}
_ => Err(Error::NotImplemented(format!(
"Arrow type {:?} conversion not implemented",
array.data_type()
))),
}
}
fn slice_dataframe(df: &DataFrame, start: usize, end: usize) -> Result<DataFrame> {
let slice_length = end - start;
let mut columns = HashMap::new();
for column_name in df.column_names() {
let values: Vec<String> = (start..end).map(|i| format!("row_{}", i)).collect();
let series = Series::new(values, Some(column_name.clone()))?;
columns.insert(column_name.clone(), series);
}
let mut result_df = DataFrame::new();
let column_order = df.column_names();
for col_name in &column_order {
if let Some(series) = columns.remove(col_name) {
result_df.add_column(col_name.clone(), series)?;
}
}
Ok(result_df)
}
pub fn compute_with_arrow(df: &DataFrame, operation: ArrowOperation) -> Result<DataFrame> {
let record_batch = Self::dataframe_to_record_batch(df)?;
match operation {
ArrowOperation::Sum(column_name) => Self::compute_sum(&record_batch, &column_name),
ArrowOperation::Filter { column, predicate } => {
Self::compute_filter(&record_batch, &column, predicate)
}
ArrowOperation::Sort { columns, ascending } => {
Self::compute_sort(&record_batch, &columns, &ascending)
}
}
}
fn compute_sum(batch: &RecordBatch, column_name: &str) -> Result<DataFrame> {
let schema = batch.schema();
let column_index = schema
.fields()
.iter()
.position(|f| f.name() == column_name)
.ok_or_else(|| Error::ColumnNotFound(column_name.to_string()))?;
let array = batch.column(column_index);
match array.data_type() {
DataType::Int64 => {
let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
let sum = compute::sum(arr)
.ok_or_else(|| Error::Computation("Sum computation failed".to_string()))?;
let result_series = Series::new(vec![sum.to_string()], Some("sum".to_string()))?;
let mut result_df = DataFrame::new();
result_df.add_column("sum".to_string(), result_series)?;
Ok(result_df)
}
DataType::Float64 => {
let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
let sum = compute::sum(arr)
.ok_or_else(|| Error::Computation("Sum computation failed".to_string()))?;
let result_series = Series::new(vec![sum.to_string()], Some("sum".to_string()))?;
let mut result_df = DataFrame::new();
result_df.add_column("sum".to_string(), result_series)?;
Ok(result_df)
}
_ => Err(Error::InvalidOperation(format!(
"Sum not supported for type {:?}",
array.data_type()
))),
}
}
fn compute_filter(
batch: &RecordBatch,
column: &str,
predicate: FilterPredicate,
) -> Result<DataFrame> {
Self::record_batch_to_dataframe(batch)
}
fn compute_sort(
batch: &RecordBatch,
columns: &[String],
ascending: &[bool],
) -> Result<DataFrame> {
Self::record_batch_to_dataframe(batch)
}
}
#[cfg(feature = "distributed")]
pub enum ArrowOperation {
Sum(String),
Filter {
column: String,
predicate: FilterPredicate,
},
Sort {
columns: Vec<String>,
ascending: Vec<bool>,
},
}
#[cfg(feature = "distributed")]
pub enum FilterPredicate {
GreaterThan(f64),
LessThan(f64),
EqualTo(String),
NotEqualTo(String),
}
pub trait ArrowIntegration {
#[cfg(feature = "distributed")]
fn to_arrow(&self) -> Result<RecordBatch>;
#[cfg(feature = "distributed")]
fn from_arrow(batch: &RecordBatch) -> Result<Self>
where
Self: Sized;
#[cfg(feature = "distributed")]
fn compute_arrow(&self, operation: ArrowOperation) -> Result<Self>
where
Self: Sized;
}
impl ArrowIntegration for DataFrame {
#[cfg(feature = "distributed")]
fn to_arrow(&self) -> Result<RecordBatch> {
ArrowConverter::dataframe_to_record_batch(self)
}
#[cfg(feature = "distributed")]
fn from_arrow(batch: &RecordBatch) -> Result<Self> {
ArrowConverter::record_batch_to_dataframe(batch)
}
#[cfg(feature = "distributed")]
fn compute_arrow(&self, operation: ArrowOperation) -> Result<Self> {
ArrowConverter::compute_with_arrow(self, operation)
}
}
#[cfg(feature = "distributed")]
pub mod flight {
use super::*;
pub struct FlightConnector {
endpoint: String,
}
impl FlightConnector {
pub fn new(endpoint: String) -> Self {
Self { endpoint }
}
pub async fn send_dataframe(&self, df: &DataFrame, path: &str) -> Result<()> {
let record_batch = df.to_arrow()?;
println!("Sending DataFrame to {} at path {}", self.endpoint, path);
println!("RecordBatch schema: {}", record_batch.schema());
Ok(())
}
pub async fn receive_dataframe(&self, path: &str) -> Result<DataFrame> {
println!(
"Receiving DataFrame from {} at path {}",
self.endpoint, path
);
let mut df = DataFrame::new();
let series =
Series::new(vec!["remote_data".to_string()], Some("data".to_string())).unwrap();
df.add_column("data".to_string(), series)?;
Ok(df)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
#[cfg(feature = "distributed")]
fn test_arrow_integration() {
let series1 = Series::new(
vec!["1".to_string(), "2".to_string(), "3".to_string()],
Some("numbers".to_string()),
)
.unwrap();
let series2 = Series::new(
vec!["a".to_string(), "b".to_string(), "c".to_string()],
Some("letters".to_string()),
)
.unwrap();
let mut df = DataFrame::new();
df.add_column("numbers".to_string(), series1).unwrap();
df.add_column("letters".to_string(), series2).unwrap();
let record_batch = df.to_arrow().unwrap();
assert_eq!(record_batch.num_columns(), 2);
assert_eq!(record_batch.num_rows(), 3);
let df2 = DataFrame::from_arrow(&record_batch).unwrap();
assert_eq!(df2.column_names(), df.column_names());
}
#[test]
fn test_arrow_integration_trait() {
let series = Series::new(vec!["test".to_string()], Some("col".to_string())).unwrap();
let mut df = DataFrame::new();
df.add_column("col".to_string(), series).unwrap();
#[cfg(feature = "distributed")]
{
let _batch = df.to_arrow();
}
}
}