use std::any::Any;
use std::fmt::{self, Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use datafusion::arrow::array::{Array, RecordBatch};
use datafusion::arrow::compute::cast;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::Result;
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_plan::PhysicalExpr;
#[derive(Debug, Clone)]
pub struct ReduceExecExpr {
accumulator_name: String,
initial_expr: Arc<dyn PhysicalExpr>,
variable_name: String,
list_expr: Arc<dyn PhysicalExpr>,
reduce_expr: Arc<dyn PhysicalExpr>,
input_schema: Arc<Schema>,
output_type: DataType,
}
impl ReduceExecExpr {
pub fn new(
accumulator_name: String,
initial_expr: Arc<dyn PhysicalExpr>,
variable_name: String,
list_expr: Arc<dyn PhysicalExpr>,
reduce_expr: Arc<dyn PhysicalExpr>,
input_schema: Arc<Schema>,
output_type: DataType,
) -> Self {
Self {
accumulator_name,
initial_expr,
variable_name,
list_expr,
reduce_expr,
input_schema,
output_type,
}
}
}
impl Display for ReduceExecExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"reduce({} = {}, {} IN {} | {})",
self.accumulator_name,
self.initial_expr,
self.variable_name,
self.list_expr,
self.reduce_expr
)
}
}
impl PartialEq for ReduceExecExpr {
fn eq(&self, other: &Self) -> bool {
self.accumulator_name == other.accumulator_name
&& self.variable_name == other.variable_name
&& self.output_type == other.output_type
&& Arc::ptr_eq(&self.initial_expr, &other.initial_expr)
&& Arc::ptr_eq(&self.list_expr, &other.list_expr)
&& Arc::ptr_eq(&self.reduce_expr, &other.reduce_expr)
}
}
impl Eq for ReduceExecExpr {}
impl Hash for ReduceExecExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.accumulator_name.hash(state);
self.variable_name.hash(state);
self.output_type.hash(state);
}
}
impl PartialEq<dyn Any> for ReduceExecExpr {
fn eq(&self, other: &dyn Any) -> bool {
other
.downcast_ref::<Self>()
.map(|x| self == x)
.unwrap_or(false)
}
}
impl PhysicalExpr for ReduceExecExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.output_type.clone())
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(true)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let list_val = self.list_expr.evaluate(batch)?;
let list_array = list_val.into_array(batch.num_rows())?;
let list_array = if let DataType::LargeBinary = list_array.data_type() {
let element_type = self.output_type.clone();
crate::query::df_graph::common::cv_array_to_large_list(
list_array.as_ref(),
&element_type,
)?
} else {
list_array
};
let list_array = if let DataType::List(field) = list_array.data_type() {
let target_type = DataType::LargeList(field.clone());
cast(&list_array, &target_type).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("Cast failed: {}", e))
})?
} else {
list_array
};
let large_list = list_array
.as_any()
.downcast_ref::<datafusion::arrow::array::LargeListArray>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Execution("Expected LargeListArray".to_string())
})?;
let offsets = large_list.offsets();
let values = large_list.values();
let init_val = self.initial_expr.evaluate(batch)?;
let mut current_acc = init_val.into_array(batch.num_rows())?;
let mut max_len = 0;
for window in offsets.windows(2) {
let len = (window[1] - window[0]) as usize;
if len > max_len {
max_len = len;
}
}
for i in 0..max_len {
let mut active_indices_builder =
datafusion::arrow::array::UInt32Builder::with_capacity(batch.num_rows());
let mut variable_indices_builder =
datafusion::arrow::array::UInt32Builder::with_capacity(batch.num_rows());
for (row_idx, window) in offsets.windows(2).enumerate() {
let start = window[0] as usize;
let end = window[1] as usize;
let len = end - start;
if i < len {
active_indices_builder.append_value(row_idx as u32);
variable_indices_builder.append_value((start + i) as u32);
}
}
let active_indices = active_indices_builder.finish();
let variable_indices = variable_indices_builder.finish();
if active_indices.is_empty() {
break;
}
let mut inner_columns = Vec::with_capacity(batch.num_columns() + 2);
for col in batch.columns() {
let taken = datafusion::arrow::compute::take(col, &active_indices, None)?;
inner_columns.push(taken);
}
let mut inner_fields = batch.schema().fields().to_vec();
let acc_field = Arc::new(Field::new(
&self.accumulator_name,
current_acc.data_type().clone(),
true,
));
let var_field = Arc::new(Field::new(
&self.variable_name,
values.data_type().clone(),
true,
));
let acc_taken = datafusion::arrow::compute::take(¤t_acc, &active_indices, None)?;
if let Some(pos) = inner_fields
.iter()
.position(|f| f.name() == &self.accumulator_name)
{
inner_columns[pos] = acc_taken;
inner_fields[pos] = acc_field;
} else {
inner_columns.push(acc_taken);
inner_fields.push(acc_field);
}
let var_taken = datafusion::arrow::compute::take(values, &variable_indices, None)?;
if let Some(pos) = inner_fields
.iter()
.position(|f| f.name() == &self.variable_name)
{
inner_columns[pos] = var_taken;
inner_fields[pos] = var_field;
} else {
inner_columns.push(var_taken);
inner_fields.push(var_field);
}
let inner_schema = Arc::new(Schema::new(inner_fields));
let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
let new_acc_val = self.reduce_expr.evaluate(&inner_batch)?;
let new_acc_array = new_acc_val.into_array(inner_batch.num_rows())?;
if active_indices.len() == batch.num_rows() {
current_acc = new_acc_array;
} else {
let mut interleave_indices = Vec::with_capacity(batch.num_rows());
let mut active_map = vec![None; batch.num_rows()];
for (k, &row_idx) in active_indices.values().iter().enumerate() {
active_map[row_idx as usize] = Some(k);
}
for (row_idx, slot) in active_map.iter().enumerate() {
if let Some(k) = slot {
interleave_indices.push((1, *k)); } else {
interleave_indices.push((0, row_idx)); }
}
current_acc = datafusion::arrow::compute::interleave(
&[¤t_acc, &new_acc_array],
&interleave_indices,
)?;
}
}
Ok(ColumnarValue::Array(current_acc))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.initial_expr, &self.list_expr]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
if children.len() != 2 {
return Err(datafusion::error::DataFusionError::Internal(
"Reduce requires 2 children (initial_expr, list_expr)".to_string(),
));
}
Ok(Arc::new(Self {
initial_expr: children[0].clone(),
list_expr: children[1].clone(),
reduce_expr: self.reduce_expr.clone(),
accumulator_name: self.accumulator_name.clone(),
variable_name: self.variable_name.clone(),
input_schema: self.input_schema.clone(),
output_type: self.output_type.clone(),
}))
}
fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"reduce({} = {}, {} IN {} | {})",
self.accumulator_name,
self.initial_expr,
self.variable_name,
self.list_expr,
self.reduce_expr
)
}
}