use std::any::Any;
use std::fmt::{self, Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use datafusion::arrow::array::{Array, BooleanArray, RecordBatch, UInt32Array};
use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
use datafusion::arrow::compute::{cast, filter, filter_record_batch, take};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::Result;
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_plan::PhysicalExpr;
#[derive(Debug)]
pub struct ListComprehensionExecExpr {
input_list: Arc<dyn PhysicalExpr>,
map_expr: Arc<dyn PhysicalExpr>,
predicate: Option<Arc<dyn PhysicalExpr>>,
variable_name: String,
input_schema: Arc<Schema>,
output_item_type: DataType,
needs_vid_extraction: bool,
}
impl Clone for ListComprehensionExecExpr {
fn clone(&self) -> Self {
Self {
input_list: self.input_list.clone(),
map_expr: self.map_expr.clone(),
predicate: self.predicate.clone(),
variable_name: self.variable_name.clone(),
input_schema: self.input_schema.clone(),
output_item_type: self.output_item_type.clone(),
needs_vid_extraction: self.needs_vid_extraction,
}
}
}
impl ListComprehensionExecExpr {
pub fn new(
input_list: Arc<dyn PhysicalExpr>,
map_expr: Arc<dyn PhysicalExpr>,
predicate: Option<Arc<dyn PhysicalExpr>>,
variable_name: String,
input_schema: Arc<Schema>,
output_item_type: DataType,
needs_vid_extraction: bool,
) -> Self {
Self {
input_list,
map_expr,
predicate,
variable_name,
input_schema,
output_item_type,
needs_vid_extraction,
}
}
}
impl Display for ListComprehensionExecExpr {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"ListComprehension(var={}, list={})",
self.variable_name, self.input_list
)
}
}
impl PartialEq for ListComprehensionExecExpr {
fn eq(&self, other: &Self) -> bool {
self.variable_name == other.variable_name
&& self.output_item_type == other.output_item_type
&& Arc::ptr_eq(&self.input_list, &other.input_list)
&& Arc::ptr_eq(&self.map_expr, &other.map_expr)
&& match (&self.predicate, &other.predicate) {
(Some(a), Some(b)) => Arc::ptr_eq(a, b),
(None, None) => true,
_ => false,
}
}
}
impl Eq for ListComprehensionExecExpr {}
impl Hash for ListComprehensionExecExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.variable_name.hash(state);
self.output_item_type.hash(state);
}
}
impl PartialEq<dyn Any> for ListComprehensionExecExpr {
fn eq(&self, other: &dyn Any) -> bool {
other
.downcast_ref::<Self>()
.map(|x| self == x)
.unwrap_or(false)
}
}
impl PhysicalExpr for ListComprehensionExecExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::LargeBinary)
}
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(true)
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let list_val = self.input_list.evaluate(batch)?;
let list_array = list_val.into_array(batch.num_rows())?;
let list_array = if let DataType::LargeBinary = list_array.data_type() {
crate::query::df_graph::common::cv_array_to_large_list(
list_array.as_ref(),
&DataType::LargeBinary,
)?
} else {
list_array
};
let list_array = if let DataType::List(field) = list_array.data_type() {
let target_type = DataType::LargeList(field.clone());
cast(&list_array, &target_type).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("Cast failed: {}", e))
})?
} else {
list_array
};
let large_list = list_array
.as_any()
.downcast_ref::<datafusion::arrow::array::LargeListArray>()
.ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"Expected LargeListArray, got {:?}",
list_array.data_type()
))
})?;
let values = large_list.values();
let offsets = large_list.offsets();
let nulls = large_list.nulls();
let num_rows = batch.num_rows();
let num_values = values.len();
let mut indices_builder =
datafusion::arrow::array::UInt32Builder::with_capacity(num_values);
for row_idx in 0..num_rows {
let start = offsets[row_idx] as usize;
let end = offsets[row_idx + 1] as usize;
let len = end - start;
for _ in 0..len {
indices_builder.append_value(row_idx as u32);
}
}
let indices = indices_builder.finish();
let mut inner_columns = Vec::with_capacity(batch.num_columns() + 1);
for col in batch.columns() {
let taken = take(col, &indices, None).map_err(|e| {
datafusion::error::DataFusionError::Execution(format!("Take failed: {}", e))
})?;
inner_columns.push(taken);
}
let mut inner_fields = batch.schema().fields().to_vec();
let loop_field = Arc::new(Field::new(
&self.variable_name,
values.data_type().clone(),
true,
));
if let Some(pos) = inner_fields
.iter()
.position(|f| f.name() == &self.variable_name)
{
inner_columns[pos] = values.clone();
inner_fields[pos] = loop_field;
} else {
inner_columns.push(values.clone());
inner_fields.push(loop_field);
}
if self.needs_vid_extraction {
let vid_field_name = format!("{}._vid", self.variable_name);
if !inner_fields.iter().any(|f| f.name() == &vid_field_name) {
let vid_field = Arc::new(Field::new(&vid_field_name, DataType::UInt64, true));
let loop_var_idx = inner_fields
.iter()
.position(|f| f.name() == &self.variable_name);
if let Some(idx) = loop_var_idx {
let vid_array = super::common::extract_vids_from_cypher_value_column(
inner_columns[idx].as_ref(),
)?;
inner_fields.push(vid_field);
inner_columns.push(vid_array);
}
}
}
let inner_schema = Arc::new(Schema::new(inner_fields));
let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
let (filtered_batch, filtered_indices) = if let Some(pred) = &self.predicate {
let mask = pred
.evaluate(&inner_batch)?
.into_array(inner_batch.num_rows())?;
let mask = cast(&mask, &DataType::Boolean)?;
let boolean_mask = mask.as_any().downcast_ref::<BooleanArray>().unwrap();
let filtered_batch = filter_record_batch(&inner_batch, boolean_mask)?;
let indices_array: Arc<dyn Array> = Arc::new(indices.clone());
let filtered_indices = filter(&indices_array, boolean_mask)?;
let filtered_indices = filtered_indices
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap()
.clone();
(filtered_batch, filtered_indices)
} else {
(inner_batch, indices.clone())
};
let mapped_val = self.map_expr.evaluate(&filtered_batch)?;
let mapped_array = mapped_val.into_array(filtered_batch.num_rows())?;
let new_offsets = if self.predicate.is_some() {
let num_rows = batch.num_rows();
let mut new_offsets = Vec::with_capacity(num_rows + 1);
new_offsets.push(0);
let indices_slice = filtered_indices.values();
let mut pos = 0;
let mut current_len = 0;
for row_idx in 0..num_rows {
let mut count = 0;
while pos < indices_slice.len() && indices_slice[pos] as usize == row_idx {
count += 1;
pos += 1;
}
current_len += count;
new_offsets.push(current_len);
}
OffsetBuffer::new(ScalarBuffer::from(new_offsets))
} else {
offsets.clone()
};
let new_field = Arc::new(Field::new("item", mapped_array.data_type().clone(), true));
let new_list = datafusion::arrow::array::LargeListArray::new(
new_field,
new_offsets,
mapped_array,
nulls.cloned(),
);
let cypher_value_array =
crate::query::df_graph::common::typed_large_list_to_cv_array(&new_list)?;
Ok(ColumnarValue::Array(cypher_value_array))
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.input_list]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
if children.len() != 1 {
return Err(datafusion::error::DataFusionError::Internal(
"ListComprehension requires exactly 1 child (input_list)".to_string(),
));
}
Ok(Arc::new(Self {
input_list: children[0].clone(),
map_expr: self.map_expr.clone(),
predicate: self.predicate.clone(),
variable_name: self.variable_name.clone(),
input_schema: self.input_schema.clone(),
output_item_type: self.output_item_type.clone(),
needs_vid_extraction: self.needs_vid_extraction,
}))
}
fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(pred) = &self.predicate {
write!(
f,
"[{} IN {} WHERE {} | {}]",
self.variable_name, self.input_list, pred, self.map_expr
)
} else {
write!(
f,
"[{} IN {} | {}]",
self.variable_name, self.input_list, self.map_expr
)
}
}
}