use crate::query::df_graph::GraphExecutionContext;
use crate::query::df_graph::common::{
arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan, extract_row_params,
};
use crate::query::planner::LogicalPlan;
use arrow_array::builder::{
BooleanBuilder, Float64Builder, Int32Builder, Int64Builder, StringBuilder, UInt64Builder,
};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::{DataType, SchemaRef};
use datafusion::common::Result as DFResult;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use datafusion::prelude::SessionContext;
use futures::Stream;
use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use uni_common::Value;
use uni_common::core::schema::Schema as UniSchema;
use uni_cypher::ast::{Expr, UnaryOp};
use uni_store::storage::manager::StorageManager;
pub struct GraphApplyExec {
input_exec: Arc<dyn ExecutionPlan>,
subquery_plan: LogicalPlan,
input_filter: Option<Expr>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
properties: PlanProperties,
metrics: ExecutionPlanMetricsSet,
}
impl fmt::Debug for GraphApplyExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GraphApplyExec")
.field("has_input_filter", &self.input_filter.is_some())
.finish()
}
}
impl GraphApplyExec {
#[expect(clippy::too_many_arguments)]
pub fn new(
input_exec: Arc<dyn ExecutionPlan>,
subquery_plan: LogicalPlan,
input_filter: Option<Expr>,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
) -> Self {
let properties = compute_plan_properties(output_schema.clone());
Self {
input_exec,
subquery_plan,
input_filter,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
}
}
}
impl DisplayAs for GraphApplyExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GraphApplyExec: filter={}",
if self.input_filter.is_some() {
"yes"
} else {
"none"
}
)
}
}
impl ExecutionPlan for GraphApplyExec {
fn name(&self) -> &str {
"GraphApplyExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.output_schema.clone()
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(datafusion::error::DataFusionError::Plan(
"GraphApplyExec has no children".to_string(),
));
}
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let metrics = BaselineMetrics::new(&self.metrics, partition);
let input_exec = self.input_exec.clone();
let subquery_plan = self.subquery_plan.clone();
let input_filter = self.input_filter.clone();
let graph_ctx = self.graph_ctx.clone();
let session_ctx = self.session_ctx.clone();
let storage = self.storage.clone();
let schema_info = self.schema_info.clone();
let params = self.params.clone();
let output_schema = self.output_schema.clone();
let fut = async move {
run_apply(
input_exec,
&subquery_plan,
input_filter.as_ref(),
&graph_ctx,
&session_ctx,
&storage,
&schema_info,
¶ms,
&output_schema,
)
.await
};
Ok(Box::pin(ApplyStream {
state: ApplyStreamState::Running(Box::pin(fut)),
schema: self.output_schema.clone(),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
fn batches_to_row_maps(batches: &[RecordBatch]) -> Vec<HashMap<String, Value>> {
batches
.iter()
.flat_map(|batch| {
(0..batch.num_rows()).map(move |row_idx| extract_row_params(batch, row_idx))
})
.collect()
}
fn evaluate_filter(filter: &Expr, row: &HashMap<String, Value>) -> bool {
match filter {
Expr::BinaryOp { left, op, right } => {
use uni_cypher::ast::BinaryOp;
match op {
BinaryOp::And => evaluate_filter(left, row) && evaluate_filter(right, row),
BinaryOp::Or => evaluate_filter(left, row) || evaluate_filter(right, row),
_ => {
let left_val = resolve_expr_value(left, row);
let right_val = resolve_expr_value(right, row);
evaluate_comparison(op, &left_val, &right_val)
}
}
}
Expr::UnaryOp {
op: UnaryOp::Not,
expr,
} => !evaluate_filter(expr, row),
_ => {
let val = resolve_expr_value(filter, row);
val.as_bool().unwrap_or(false)
}
}
}
fn resolve_expr_value(expr: &Expr, row: &HashMap<String, Value>) -> Value {
match expr {
Expr::Literal(lit) => lit.to_value(),
Expr::Variable(name) => row.get(name).cloned().unwrap_or(Value::Null),
Expr::Property(base_expr, key) => {
if let Expr::Variable(var) = base_expr.as_ref() {
let col_name = format!("{}.{}", var, key);
row.get(&col_name).cloned().unwrap_or(Value::Null)
} else {
Value::Null
}
}
_ => Value::Null,
}
}
fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
match (a, b) {
(Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
(Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
(Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
(Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
(Value::String(a), Value::String(b)) => Some(a.cmp(b)),
_ => None,
}
}
fn evaluate_comparison(op: &uni_cypher::ast::BinaryOp, left: &Value, right: &Value) -> bool {
use std::cmp::Ordering;
use uni_cypher::ast::BinaryOp;
match op {
BinaryOp::Eq => left == right,
BinaryOp::NotEq => left != right,
BinaryOp::Lt => compare_values(left, right) == Some(Ordering::Less),
BinaryOp::LtEq => matches!(
compare_values(left, right),
Some(Ordering::Less | Ordering::Equal)
),
BinaryOp::Gt => compare_values(left, right) == Some(Ordering::Greater),
BinaryOp::GtEq => matches!(
compare_values(left, right),
Some(Ordering::Greater | Ordering::Equal)
),
_ => false,
}
}
fn build_column<B, T>(
rows: &[HashMap<String, Value>],
col_name: &str,
mut builder: B,
extract: impl Fn(&Value) -> Option<T>,
) -> ArrayRef
where
B: arrow_array::builder::ArrayBuilder,
B: PrimitiveAppend<T>,
{
for row in rows {
match row.get(col_name).and_then(&extract) {
Some(v) => builder.append_typed_value(v),
None => builder.append_typed_null(),
}
}
Arc::new(builder.finish_to_array())
}
trait PrimitiveAppend<T> {
fn append_typed_value(&mut self, val: T);
fn append_typed_null(&mut self);
fn finish_to_array(self) -> ArrayRef;
}
macro_rules! impl_primitive_append {
($builder:ty, $native:ty, $array:ty) => {
impl PrimitiveAppend<$native> for $builder {
fn append_typed_value(&mut self, val: $native) {
self.append_value(val);
}
fn append_typed_null(&mut self) {
self.append_null();
}
fn finish_to_array(mut self) -> ArrayRef {
Arc::new(self.finish()) as ArrayRef
}
}
};
}
impl_primitive_append!(UInt64Builder, u64, arrow_array::UInt64Array);
impl_primitive_append!(Int64Builder, i64, arrow_array::Int64Array);
impl_primitive_append!(Int32Builder, i32, arrow_array::Int32Array);
impl_primitive_append!(Float64Builder, f64, arrow_array::Float64Array);
impl_primitive_append!(BooleanBuilder, bool, arrow_array::BooleanArray);
fn rows_to_batch(rows: &[HashMap<String, Value>], schema: &SchemaRef) -> DFResult<RecordBatch> {
if rows.is_empty() {
return Ok(RecordBatch::new_empty(schema.clone()));
}
let num_rows = rows.len();
let mut columns: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
for field in schema.fields() {
let col_name = field.name();
let col = match field.data_type() {
DataType::UInt64 => build_column(
rows,
col_name,
UInt64Builder::with_capacity(num_rows),
|v| v.as_u64().or_else(|| v.as_i64().map(|i| i as u64)),
),
DataType::Int64 => build_column(
rows,
col_name,
Int64Builder::with_capacity(num_rows),
Value::as_i64,
),
DataType::Int32 => {
build_column(rows, col_name, Int32Builder::with_capacity(num_rows), |v| {
v.as_i64().map(|i| i as i32)
})
}
DataType::Float64 => build_column(
rows,
col_name,
Float64Builder::with_capacity(num_rows),
Value::as_f64,
),
DataType::Boolean => build_column(
rows,
col_name,
BooleanBuilder::with_capacity(num_rows),
Value::as_bool,
),
DataType::LargeBinary => {
let mut builder = arrow_array::builder::LargeBinaryBuilder::with_capacity(
num_rows,
num_rows * 64,
);
for row in rows {
match row.get(col_name) {
Some(val) if !val.is_null() => {
let cv_bytes = uni_common::cypher_value_codec::encode(val);
builder.append_value(&cv_bytes);
}
_ => builder.append_null(),
}
}
Arc::new(builder.finish()) as ArrayRef
}
DataType::List(inner_field) if inner_field.data_type() == &DataType::Utf8 => {
let mut builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
for row in rows {
match row.get(col_name) {
Some(Value::List(items)) => {
for item in items {
match item {
Value::String(s) => builder.values().append_value(s),
Value::Null => builder.values().append_null(),
other => builder.values().append_value(format!("{other}")),
}
}
builder.append(true);
}
_ => builder.append_null(),
}
}
Arc::new(builder.finish()) as ArrayRef
}
DataType::Null => Arc::new(arrow_array::NullArray::new(num_rows)) as ArrayRef,
_ => {
let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
for row in rows {
match row.get(col_name) {
Some(Value::Null) | None => builder.append_null(),
Some(Value::String(s)) => builder.append_value(s),
Some(other) => builder.append_value(format!("{other}")),
}
}
Arc::new(builder.finish()) as ArrayRef
}
};
columns.push(col);
}
RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
}
fn slice_row(batch: &RecordBatch, row_idx: usize) -> Vec<ArrayRef> {
batch
.columns()
.iter()
.map(|col| col.slice(row_idx, 1))
.collect()
}
fn is_procedure_call(plan: &LogicalPlan) -> bool {
match plan {
LogicalPlan::ProcedureCall { .. } => true,
LogicalPlan::Project { input, .. }
| LogicalPlan::Filter { input, .. }
| LogicalPlan::Sort { input, .. }
| LogicalPlan::Limit { input, .. }
| LogicalPlan::Distinct { input } => is_procedure_call(input),
_ => false,
}
}
fn hash_row_params(params: &HashMap<String, Value>) -> u64 {
let mut hasher = DefaultHasher::new();
let mut entries: Vec<_> = params.iter().collect();
entries.sort_unstable_by_key(|(k, _)| *k);
for (key, val) in entries {
key.hash(&mut hasher);
format!("{val:?}").hash(&mut hasher);
}
hasher.finish()
}
fn is_batch_eligible(filtered_entries: &[(&RecordBatch, usize, HashMap<String, Value>)]) -> bool {
if filtered_entries.len() < 2 {
return false;
}
filtered_entries
.iter()
.any(|(_, _, row_params)| row_params.keys().any(|k| k.ends_with("._vid")))
}
#[expect(clippy::too_many_arguments)]
async fn run_apply(
input_exec: Arc<dyn ExecutionPlan>,
subquery_plan: &LogicalPlan,
input_filter: Option<&Expr>,
graph_ctx: &Arc<GraphExecutionContext>,
session_ctx: &Arc<RwLock<SessionContext>>,
storage: &Arc<StorageManager>,
schema_info: &Arc<UniSchema>,
params: &HashMap<String, Value>,
output_schema: &SchemaRef,
) -> DFResult<RecordBatch> {
let apply_start = std::time::Instant::now();
let is_proc_call = is_procedure_call(subquery_plan);
tracing::debug!("run_apply: is_procedure_call={}", is_proc_call);
let task_ctx = session_ctx.read().task_ctx();
let input_batches = collect_all_partitions(&input_exec, task_ctx).await?;
let mut filtered_entries: Vec<(&RecordBatch, usize, HashMap<String, Value>)> = Vec::new();
for batch in &input_batches {
for row_idx in 0..batch.num_rows() {
let row_params = extract_row_params(batch, row_idx);
if let Some(filter) = input_filter
&& !evaluate_filter(filter, &row_params)
{
continue;
}
filtered_entries.push((batch, row_idx, row_params));
}
}
tracing::debug!(
"run_apply: filtered_entries count = {}",
filtered_entries.len()
);
if filtered_entries.is_empty() {
let sub_batches = execute_subplan(
subquery_plan,
params,
&HashMap::new(), graph_ctx,
session_ctx,
storage,
schema_info,
)
.await?;
let sub_rows = batches_to_row_maps(&sub_batches);
return rows_to_batch(&sub_rows, output_schema);
}
let has_filter = input_filter.is_some();
if is_batch_eligible(&filtered_entries) && !is_proc_call && has_filter {
tracing::debug!("run_apply: batching eligible, attempting batch execution");
let mut vid_values: HashMap<String, Vec<Value>> = HashMap::new();
for (_, _, row_params) in &filtered_entries {
for (key, value) in row_params {
if key.ends_with("._vid") {
vid_values
.entry(key.clone())
.or_default()
.push(value.clone());
}
}
}
let mut batched_params = params.clone();
for (key, values) in &vid_values {
batched_params.insert(key.clone(), Value::List(values.clone()));
}
if let Some((_, _, first_row_params)) = filtered_entries.first() {
for (key, value) in first_row_params {
if !key.ends_with("._vid") {
batched_params
.entry(key.clone())
.or_insert_with(|| value.clone());
}
}
}
let subplan_start = std::time::Instant::now();
let sub_batches = execute_subplan(
subquery_plan,
&batched_params,
&HashMap::new(),
graph_ctx,
session_ctx,
storage,
schema_info,
)
.await?;
let subplan_elapsed = subplan_start.elapsed();
tracing::debug!(
"run_apply: batch execute_subplan took {:?}",
subplan_elapsed
);
let sub_rows = batches_to_row_maps(&sub_batches);
let mut sub_index: HashMap<i64, Vec<&HashMap<String, Value>>> = HashMap::new();
let vid_key = vid_values.keys().next().expect("at least one VID key");
for sub_row in &sub_rows {
if let Some(Value::Int(vid)) = sub_row.get(vid_key) {
sub_index.entry(*vid).or_default().push(sub_row);
}
}
let input_schema = input_batches[0].schema();
let num_input_cols = input_schema.fields().len();
let num_output_cols = output_schema.fields().len();
let mut column_arrays: Vec<Vec<ArrayRef>> = vec![Vec::new(); num_output_cols];
for (batch, row_idx, row_params) in &filtered_entries {
let input_vid = if let Some(Value::Int(vid)) = row_params.get(vid_key) {
*vid
} else {
continue; };
if let Some(matching_sub_rows) = sub_index.get(&input_vid) {
let input_row_arrays = slice_row(batch, *row_idx);
for sub_row in matching_sub_rows {
append_cross_join_row(
&mut column_arrays,
&input_row_arrays,
sub_row,
output_schema,
num_input_cols,
)?;
}
}
}
let result = concat_column_arrays(&column_arrays, output_schema);
let apply_elapsed = apply_start.elapsed();
tracing::debug!(
"run_apply: completed (batched) in {:?}, 1 subplan execution",
apply_elapsed
);
return result;
}
let input_schema = input_batches[0].schema();
let num_input_cols = input_schema.fields().len();
let num_output_cols = output_schema.fields().len();
let mut column_arrays: Vec<Vec<ArrayRef>> = vec![Vec::new(); num_output_cols];
let mut total_subplan_time = std::time::Duration::ZERO;
let mut subplan_executions = 0;
let mut subplan_cache: HashMap<u64, Vec<HashMap<String, Value>>> = HashMap::new();
let mut cache_hits = 0;
for (batch, row_idx, row_params) in &filtered_entries {
let (sub_params, sub_outer_values) = if is_procedure_call(subquery_plan) {
(params.clone(), row_params.clone())
} else {
let mut merged = params.clone();
merged.extend(row_params.clone());
(merged, HashMap::new())
};
let params_hash = hash_row_params(row_params);
let sub_rows = if let Some(cached_rows) = subplan_cache.get(¶ms_hash) {
cache_hits += 1;
tracing::debug!(
"run_apply: cache hit for params hash {}, skipping execute_subplan",
params_hash
);
cached_rows.clone()
} else {
let subplan_start = std::time::Instant::now();
let sub_batches = execute_subplan(
subquery_plan,
&sub_params,
&sub_outer_values,
graph_ctx,
session_ctx,
storage,
schema_info,
)
.await?;
let subplan_elapsed = subplan_start.elapsed();
total_subplan_time += subplan_elapsed;
subplan_executions += 1;
tracing::debug!(
"run_apply: execute_subplan #{} took {:?}",
subplan_executions,
subplan_elapsed
);
let rows = batches_to_row_maps(&sub_batches);
subplan_cache.insert(params_hash, rows.clone());
rows
};
let input_row_arrays = slice_row(batch, *row_idx);
if sub_rows.is_empty() {
continue;
}
for sub_row in &sub_rows {
append_cross_join_row(
&mut column_arrays,
&input_row_arrays,
sub_row,
output_schema,
num_input_cols,
)?;
}
}
let result = concat_column_arrays(&column_arrays, output_schema);
let apply_elapsed = apply_start.elapsed();
tracing::debug!(
"run_apply: completed in {:?}, {} subplan executions, {} cache hits, {:?} total subplan time",
apply_elapsed,
subplan_executions,
cache_hits,
total_subplan_time
);
result
}
fn single_row_array<B, T>(mut builder: B, val: Option<T>) -> ArrayRef
where
B: PrimitiveAppend<T>,
{
match val {
Some(v) => builder.append_typed_value(v),
None => builder.append_typed_null(),
}
builder.finish_to_array()
}
fn value_to_single_row_array(val: &Value, data_type: &DataType) -> DFResult<ArrayRef> {
Ok(match data_type {
DataType::UInt64 => single_row_array(
UInt64Builder::with_capacity(1),
val.as_u64().or_else(|| val.as_i64().map(|v| v as u64)),
),
DataType::Int64 => single_row_array(Int64Builder::with_capacity(1), val.as_i64()),
DataType::Int32 => single_row_array(
Int32Builder::with_capacity(1),
val.as_i64().map(|v| v as i32),
),
DataType::Float64 => single_row_array(Float64Builder::with_capacity(1), val.as_f64()),
DataType::Boolean => single_row_array(BooleanBuilder::with_capacity(1), val.as_bool()),
DataType::Null => Arc::new(arrow_array::NullArray::new(1)) as ArrayRef,
_ => {
let mut b = StringBuilder::with_capacity(1, 64);
match val {
Value::Null => b.append_null(),
Value::String(s) => b.append_value(s),
other => b.append_value(format!("{other}")),
}
Arc::new(b.finish()) as ArrayRef
}
})
}
fn append_cross_join_row(
column_arrays: &mut [Vec<ArrayRef>],
input_row_arrays: &[ArrayRef],
sub_row: &HashMap<String, Value>,
output_schema: &SchemaRef,
num_input_cols: usize,
) -> DFResult<()> {
for (col_idx, arr) in input_row_arrays.iter().enumerate() {
column_arrays[col_idx].push(arr.clone());
}
let num_output_cols = output_schema.fields().len();
for (col_arr, field) in column_arrays[num_input_cols..num_output_cols]
.iter_mut()
.zip(output_schema.fields()[num_input_cols..num_output_cols].iter())
{
let col_name = field.name();
let val = sub_row.get(col_name).cloned().unwrap_or(Value::Null);
let arr = value_to_single_row_array(&val, field.data_type())?;
col_arr.push(arr);
}
Ok(())
}
fn concat_column_arrays(
column_arrays: &[Vec<ArrayRef>],
output_schema: &SchemaRef,
) -> DFResult<RecordBatch> {
if column_arrays[0].is_empty() {
return Ok(RecordBatch::new_empty(output_schema.clone()));
}
let mut final_columns: Vec<ArrayRef> = Vec::with_capacity(column_arrays.len());
for arrays in column_arrays {
let refs: Vec<&dyn arrow_array::Array> = arrays.iter().map(|a| a.as_ref()).collect();
let concatenated = arrow::compute::concat(&refs).map_err(arrow_err)?;
final_columns.push(concatenated);
}
RecordBatch::try_new(output_schema.clone(), final_columns).map_err(arrow_err)
}
enum ApplyStreamState {
Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
Done,
}
struct ApplyStream {
state: ApplyStreamState,
schema: SchemaRef,
metrics: BaselineMetrics,
}
impl Stream for ApplyStream {
type Item = DFResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut self.state {
ApplyStreamState::Running(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(batch)) => {
self.metrics.record_output(batch.num_rows());
self.state = ApplyStreamState::Done;
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Err(e)) => {
self.state = ApplyStreamState::Done;
Poll::Ready(Some(Err(e)))
}
Poll::Pending => Poll::Pending,
},
ApplyStreamState::Done => Poll::Ready(None),
}
}
}
impl RecordBatchStream for ApplyStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}