use crate::error_taxonomy::helpers as error_helpers;
use std::path::Path;
use std::sync::Arc;
use tenflowers_core::{Result, Tensor, TensorError};
#[cfg(feature = "parquet")]
use arrow::array::*;
#[cfg(feature = "parquet")]
use arrow::compute;
#[cfg(feature = "parquet")]
use arrow::compute::kernels::cmp::{eq, gt, lt};
#[cfg(feature = "parquet")]
use arrow::datatypes::{DataType as ArrowDataType, Field, Schema};
#[cfg(feature = "parquet")]
use arrow::record_batch::RecordBatch;
#[cfg(feature = "parquet")]
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
#[cfg(feature = "parquet")]
use parquet::file::metadata::RowGroupMetaData;
#[cfg(feature = "parquet")]
use parquet::file::statistics::Statistics;
#[derive(Debug, Clone)]
pub enum ArrowPredicate {
Equals(String, ArrowValue),
NotEquals(String, ArrowValue),
GreaterThan(String, ArrowValue),
LessThan(String, ArrowValue),
In(String, Vec<ArrowValue>),
IsNull(String),
IsNotNull(String),
And(Vec<ArrowPredicate>),
Or(Vec<ArrowPredicate>),
Not(Box<ArrowPredicate>),
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum ArrowValue {
Int32(i32),
Int64(i64),
Float32(f32),
Float64(f64),
String(String),
Bool(bool),
}
impl ArrowPredicate {
pub fn eq(column: impl Into<String>, value: ArrowValue) -> Self {
Self::Equals(column.into(), value)
}
pub fn ne(column: impl Into<String>, value: ArrowValue) -> Self {
Self::NotEquals(column.into(), value)
}
pub fn gt(column: impl Into<String>, value: ArrowValue) -> Self {
Self::GreaterThan(column.into(), value)
}
pub fn lt(column: impl Into<String>, value: ArrowValue) -> Self {
Self::LessThan(column.into(), value)
}
pub fn is_in(column: impl Into<String>, values: Vec<ArrowValue>) -> Self {
Self::In(column.into(), values)
}
pub fn is_null(column: impl Into<String>) -> Self {
Self::IsNull(column.into())
}
pub fn is_not_null(column: impl Into<String>) -> Self {
Self::IsNotNull(column.into())
}
pub fn and(predicates: Vec<ArrowPredicate>) -> Self {
Self::And(predicates)
}
pub fn or(predicates: Vec<ArrowPredicate>) -> Self {
Self::Or(predicates)
}
pub fn not(predicate: ArrowPredicate) -> Self {
Self::Not(Box::new(predicate))
}
}
#[cfg(feature = "parquet")]
#[derive(Debug, Clone)]
pub struct ArrowStatistics {
pub column_name: String,
pub min: Option<ArrowValue>,
pub max: Option<ArrowValue>,
pub null_count: usize,
pub distinct_count: Option<usize>,
pub row_count: usize,
}
#[cfg(feature = "parquet")]
impl ArrowStatistics {
pub fn can_skip_with_predicate(&self, predicate: &ArrowPredicate) -> bool {
match predicate {
ArrowPredicate::Equals(col, val) if col == &self.column_name => {
if let (Some(min), Some(max)) = (&self.min, &self.max) {
val < min || val > max
} else {
false
}
}
ArrowPredicate::GreaterThan(col, val) if col == &self.column_name => {
if let Some(max) = &self.max {
max <= val
} else {
false
}
}
ArrowPredicate::LessThan(col, val) if col == &self.column_name => {
if let Some(min) = &self.min {
min >= val
} else {
false
}
}
ArrowPredicate::IsNull(col) if col == &self.column_name => {
self.null_count == 0
}
ArrowPredicate::IsNotNull(col) if col == &self.column_name => {
self.null_count == self.row_count
}
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub struct StreamingArrowConfig {
pub batch_size: usize,
pub predicates: Vec<ArrowPredicate>,
pub projection: Option<Vec<String>>,
pub memory_limit: Option<usize>,
pub enable_statistics_pruning: bool,
}
impl Default for StreamingArrowConfig {
fn default() -> Self {
Self {
batch_size: 8192,
predicates: Vec::new(),
projection: None,
memory_limit: None,
enable_statistics_pruning: true,
}
}
}
#[cfg(feature = "parquet")]
pub struct StreamingArrowReader {
path: std::path::PathBuf,
config: StreamingArrowConfig,
current_batch: usize,
total_batches: usize,
schema: Arc<Schema>,
}
#[cfg(feature = "parquet")]
impl StreamingArrowReader {
pub fn new(path: impl AsRef<Path>, config: StreamingArrowConfig) -> Result<Self> {
let path = path.as_ref().to_path_buf();
if !path.exists() {
return Err(error_helpers::file_not_found(
"StreamingArrowReader::new",
&path,
));
}
let file = std::fs::File::open(&path)
.map_err(|_| error_helpers::file_not_found("StreamingArrowReader::new", &path))?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)
.map_err(|e| TensorError::io_error_simple(format!("Failed to open Parquet: {}", e)))?;
let metadata = builder.metadata();
let schema = builder.schema();
let total_batches = metadata.num_row_groups();
Ok(Self {
path,
config,
current_batch: 0,
total_batches,
schema: schema.clone(),
})
}
pub fn schema(&self) -> &Arc<Schema> {
&self.schema
}
pub fn total_batches(&self) -> usize {
self.total_batches
}
pub fn read_next_batch(&mut self) -> Result<Option<RecordBatch>> {
if self.current_batch >= self.total_batches {
return Ok(None);
}
let file = std::fs::File::open(&self.path).map_err(|_| {
error_helpers::file_not_found("StreamingArrowReader::read_next_batch", &self.path)
})?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file)
.map_err(|e| TensorError::io_error_simple(format!("Failed to open Parquet: {}", e)))?;
let mut builder = builder.with_batch_size(self.config.batch_size);
let mut reader = builder
.build()
.map_err(|e| TensorError::io_error_simple(format!("Failed to build reader: {}", e)))?;
for _ in 0..self.current_batch {
if reader.next().is_none() {
return Ok(None);
}
}
let batch = match reader.next() {
Some(Ok(batch)) => batch,
Some(Err(e)) => {
return Err(TensorError::io_error_simple(format!(
"Failed to read batch: {}",
e
)));
}
None => return Ok(None),
};
self.current_batch += 1;
let batch = if !self.config.predicates.is_empty() {
self.apply_predicates(batch)?
} else {
batch
};
Ok(Some(batch))
}
pub fn reset(&mut self) {
self.current_batch = 0;
}
fn apply_predicates(&self, batch: RecordBatch) -> Result<RecordBatch> {
let mut mask: Option<BooleanArray> = None;
for predicate in &self.config.predicates {
let pred_mask = self.evaluate_predicate(&batch, predicate)?;
mask = match mask {
None => Some(pred_mask),
Some(existing) => {
Some(compute::and(&existing, &pred_mask).map_err(|e| {
TensorError::unsupported_operation_simple(format!(
"Failed to combine predicates: {}",
e
))
})?)
}
};
}
if let Some(mask) = mask {
compute::filter_record_batch(&batch, &mask).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to filter batch: {}", e))
})
} else {
Ok(batch)
}
}
fn evaluate_predicate(
&self,
batch: &RecordBatch,
predicate: &ArrowPredicate,
) -> Result<BooleanArray> {
match predicate {
ArrowPredicate::Equals(column, value) => self.evaluate_equals(batch, column, value),
ArrowPredicate::NotEquals(column, value) => {
let equals = self.evaluate_equals(batch, column, value)?;
Ok(compute::not(&equals).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to negate: {}", e))
})?)
}
ArrowPredicate::GreaterThan(column, value) => {
self.evaluate_greater_than(batch, column, value)
}
ArrowPredicate::LessThan(column, value) => {
self.evaluate_less_than(batch, column, value)
}
ArrowPredicate::IsNull(column) => {
let col = batch.column_by_name(column).ok_or_else(|| {
error_helpers::schema_mismatch("evaluate_predicate", column, "column not found")
})?;
Ok(compute::is_null(col.as_ref()).map_err(|e| {
TensorError::unsupported_operation_simple(format!(
"Failed to check null: {}",
e
))
})?)
}
ArrowPredicate::IsNotNull(column) => {
let col = batch.column_by_name(column).ok_or_else(|| {
error_helpers::schema_mismatch("evaluate_predicate", column, "column not found")
})?;
Ok(compute::is_not_null(col.as_ref()).map_err(|e| {
TensorError::unsupported_operation_simple(format!(
"Failed to check not null: {}",
e
))
})?)
}
ArrowPredicate::And(predicates) => {
let mut result = None;
for pred in predicates {
let mask = self.evaluate_predicate(batch, pred)?;
result = match result {
None => Some(mask),
Some(existing) => Some(compute::and(&existing, &mask).map_err(|e| {
TensorError::unsupported_operation_simple(format!(
"Failed to AND predicates: {}",
e
))
})?),
};
}
result.ok_or_else(|| {
TensorError::unsupported_operation_simple("Empty AND predicate".to_string())
})
}
ArrowPredicate::Or(predicates) => {
let mut result = None;
for pred in predicates {
let mask = self.evaluate_predicate(batch, pred)?;
result = match result {
None => Some(mask),
Some(existing) => Some(compute::or(&existing, &mask).map_err(|e| {
TensorError::unsupported_operation_simple(format!(
"Failed to OR predicates: {}",
e
))
})?),
};
}
result.ok_or_else(|| {
TensorError::unsupported_operation_simple("Empty OR predicate".to_string())
})
}
ArrowPredicate::Not(pred) => {
let mask = self.evaluate_predicate(batch, pred)?;
Ok(compute::not(&mask).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to NOT: {}", e))
})?)
}
_ => Err(TensorError::unsupported_operation_simple(
"Predicate type not yet implemented".to_string(),
)),
}
}
fn evaluate_equals(
&self,
batch: &RecordBatch,
column: &str,
value: &ArrowValue,
) -> Result<BooleanArray> {
let col = batch.column_by_name(column).ok_or_else(|| {
error_helpers::schema_mismatch("evaluate_equals", column, "column not found")
})?;
match value {
ArrowValue::Int32(v) => {
let arr = col.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple("Type mismatch".to_string())
})?;
let result = eq(arr, &Int32Array::from(vec![*v])).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to compare: {}", e))
})?;
Ok(result)
}
ArrowValue::Int64(v) => {
let arr = col.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple("Type mismatch".to_string())
})?;
let result = eq(arr, &Int64Array::from(vec![*v])).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to compare: {}", e))
})?;
Ok(result)
}
ArrowValue::Float64(v) => {
let arr = col.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple("Type mismatch".to_string())
})?;
let result = eq(arr, &Float64Array::from(vec![*v])).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to compare: {}", e))
})?;
Ok(result)
}
_ => Err(TensorError::unsupported_operation_simple(
"Value type not yet implemented".to_string(),
)),
}
}
fn evaluate_greater_than(
&self,
batch: &RecordBatch,
column: &str,
value: &ArrowValue,
) -> Result<BooleanArray> {
let col = batch.column_by_name(column).ok_or_else(|| {
error_helpers::schema_mismatch("evaluate_greater_than", column, "column not found")
})?;
match value {
ArrowValue::Int32(v) => {
let arr = col.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple("Type mismatch".to_string())
})?;
let result = gt(arr, &Int32Array::from(vec![*v])).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to compare: {}", e))
})?;
Ok(result)
}
ArrowValue::Float64(v) => {
let arr = col.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple("Type mismatch".to_string())
})?;
let result = gt(arr, &Float64Array::from(vec![*v])).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to compare: {}", e))
})?;
Ok(result)
}
_ => Err(TensorError::unsupported_operation_simple(
"Value type not yet implemented".to_string(),
)),
}
}
fn evaluate_less_than(
&self,
batch: &RecordBatch,
column: &str,
value: &ArrowValue,
) -> Result<BooleanArray> {
let col = batch.column_by_name(column).ok_or_else(|| {
error_helpers::schema_mismatch("evaluate_less_than", column, "column not found")
})?;
match value {
ArrowValue::Int32(v) => {
let arr = col.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple("Type mismatch".to_string())
})?;
let result = lt(arr, &Int32Array::from(vec![*v])).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to compare: {}", e))
})?;
Ok(result)
}
ArrowValue::Float64(v) => {
let arr = col.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
TensorError::unsupported_operation_simple("Type mismatch".to_string())
})?;
let result = lt(arr, &Float64Array::from(vec![*v])).map_err(|e| {
TensorError::unsupported_operation_simple(format!("Failed to compare: {}", e))
})?;
Ok(result)
}
_ => Err(TensorError::unsupported_operation_simple(
"Value type not yet implemented".to_string(),
)),
}
}
}
#[cfg(feature = "parquet")]
pub struct ArrowBuffer<T> {
data: Arc<dyn Array>,
offset: usize,
len: usize,
_phantom: std::marker::PhantomData<T>,
}
#[cfg(feature = "parquet")]
impl<T> ArrowBuffer<T> {
pub fn from_array(array: Arc<dyn Array>) -> Self {
let len = array.len();
Self {
data: array,
offset: 0,
len,
_phantom: std::marker::PhantomData,
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn slice(&self, offset: usize, length: usize) -> Result<Self> {
if offset + length > self.len {
return Err(TensorError::invalid_argument(
"Slice out of bounds".to_string(),
));
}
Ok(Self {
data: self.data.clone(),
offset: self.offset + offset,
len: length,
_phantom: std::marker::PhantomData,
})
}
}
#[cfg(test)]
#[cfg(feature = "parquet")]
mod tests {
use super::*;
#[test]
fn test_arrow_predicate_creation() {
let pred = ArrowPredicate::eq("age", ArrowValue::Int32(25));
match pred {
ArrowPredicate::Equals(col, val) => {
assert_eq!(col, "age");
assert_eq!(val, ArrowValue::Int32(25));
}
_ => panic!("Wrong predicate type"),
}
}
#[test]
fn test_arrow_predicate_and() {
let pred1 = ArrowPredicate::gt("age", ArrowValue::Int32(18));
let pred2 = ArrowPredicate::lt("age", ArrowValue::Int32(65));
let combined = ArrowPredicate::and(vec![pred1, pred2]);
match combined {
ArrowPredicate::And(preds) => {
assert_eq!(preds.len(), 2);
}
_ => panic!("Wrong predicate type"),
}
}
#[test]
fn test_streaming_config_default() {
let config = StreamingArrowConfig::default();
assert_eq!(config.batch_size, 8192);
assert!(config.predicates.is_empty());
assert!(config.projection.is_none());
assert!(config.enable_statistics_pruning);
}
#[test]
fn test_arrow_value_equality() {
assert_eq!(ArrowValue::Int32(42), ArrowValue::Int32(42));
assert_ne!(ArrowValue::Int32(42), ArrowValue::Int32(43));
assert_eq!(ArrowValue::Float64(2.5), ArrowValue::Float64(2.5));
}
}