use crate::query::df_graph::GraphExecutionContext;
use crate::query::df_graph::common::{arrow_err, compute_plan_properties, execute_subplan};
use crate::query::df_graph::unwind::arrow_to_json_value;
use crate::query::planner::LogicalPlan;
use arrow_array::RecordBatch;
use arrow_array::builder::{Int64Builder, LargeListBuilder};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::common::Result as DFResult;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use datafusion::prelude::SessionContext;
use futures::Stream;
use parking_lot::RwLock;
use std::any::Any;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use uni_common::Value;
use uni_common::core::schema::Schema as UniSchema;
use uni_store::storage::manager::StorageManager;
const MAX_ITERATIONS: usize = 1000;
pub struct RecursiveCTEExec {
cte_name: String,
initial_plan: LogicalPlan,
recursive_plan: LogicalPlan,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
output_schema: SchemaRef,
properties: PlanProperties,
metrics: ExecutionPlanMetricsSet,
}
impl fmt::Debug for RecursiveCTEExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecursiveCTEExec")
.field("cte_name", &self.cte_name)
.finish()
}
}
impl RecursiveCTEExec {
#[expect(clippy::too_many_arguments)]
pub fn new(
cte_name: String,
initial_plan: LogicalPlan,
recursive_plan: LogicalPlan,
graph_ctx: Arc<GraphExecutionContext>,
session_ctx: Arc<RwLock<SessionContext>>,
storage: Arc<StorageManager>,
schema_info: Arc<UniSchema>,
params: HashMap<String, Value>,
) -> Self {
let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
let field = Field::new(&cte_name, DataType::LargeList(inner_field), false);
let output_schema = Arc::new(Schema::new(vec![field]));
let properties = compute_plan_properties(output_schema.clone());
Self {
cte_name,
initial_plan,
recursive_plan,
graph_ctx,
session_ctx,
storage,
schema_info,
params,
output_schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
}
}
}
impl DisplayAs for RecursiveCTEExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RecursiveCTEExec: {}", self.cte_name)
}
}
impl ExecutionPlan for RecursiveCTEExec {
fn name(&self) -> &str {
"RecursiveCTEExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.output_schema.clone()
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if !children.is_empty() {
return Err(datafusion::error::DataFusionError::Plan(
"RecursiveCTEExec has no children".to_string(),
));
}
Ok(self)
}
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let metrics = BaselineMetrics::new(&self.metrics, partition);
let cte_name = self.cte_name.clone();
let initial_plan = self.initial_plan.clone();
let recursive_plan = self.recursive_plan.clone();
let graph_ctx = self.graph_ctx.clone();
let session_ctx = self.session_ctx.clone();
let storage = self.storage.clone();
let schema_info = self.schema_info.clone();
let params = self.params.clone();
let output_schema = self.output_schema.clone();
let fut = async move {
run_cte_loop(
&cte_name,
&initial_plan,
&recursive_plan,
&graph_ctx,
&session_ctx,
&storage,
&schema_info,
¶ms,
&output_schema,
)
.await
};
Ok(Box::pin(RecursiveCTEStream {
state: RecursiveCTEStreamState::Running(Box::pin(fut)),
schema: self.output_schema.clone(),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
fn batches_to_values(batches: &[RecordBatch]) -> Vec<Value> {
let mut values = Vec::new();
for batch in batches {
let num_cols = batch.num_columns();
let schema = batch.schema();
for row_idx in 0..batch.num_rows() {
if num_cols == 1 {
values.push(arrow_to_json_value(batch.column(0).as_ref(), row_idx));
} else {
let mut map = Vec::new();
for col_idx in 0..num_cols {
let col_name = schema.field(col_idx).name().clone();
let val = arrow_to_json_value(batch.column(col_idx).as_ref(), row_idx);
map.push((col_name, val));
}
values.push(Value::Map(map.into_iter().collect()));
}
}
}
values
}
fn value_key(val: &Value) -> String {
format!("{val:?}")
}
fn extract_vid(val: &Value) -> Option<u64> {
match val {
Value::Int(v) => Some(*v as u64),
Value::Map(map) => {
for (k, v) in map {
if k == "_vid" || k.ends_with("._vid") {
return v.as_u64();
}
}
None
}
_ => val.as_u64(),
}
}
#[expect(clippy::too_many_arguments)]
async fn run_cte_loop(
cte_name: &str,
initial_plan: &LogicalPlan,
recursive_plan: &LogicalPlan,
graph_ctx: &Arc<GraphExecutionContext>,
session_ctx: &Arc<RwLock<SessionContext>>,
storage: &Arc<StorageManager>,
schema_info: &Arc<UniSchema>,
params: &HashMap<String, Value>,
output_schema: &SchemaRef,
) -> DFResult<RecordBatch> {
let anchor_batches = execute_subplan(
initial_plan,
params,
&HashMap::new(), graph_ctx,
session_ctx,
storage,
schema_info,
)
.await?;
let mut working_values = batches_to_values(&anchor_batches);
let mut result_values = working_values.clone();
let mut seen: HashSet<String> = working_values.iter().map(value_key).collect();
for _iteration in 0..MAX_ITERATIONS {
if working_values.is_empty() {
break;
}
let vid_list = Value::List(
working_values
.iter()
.filter_map(|v| extract_vid(v).map(|vid| Value::Int(vid as i64)))
.collect(),
);
let mut next_params = params.clone();
next_params.insert(cte_name.to_string(), vid_list);
let recursive_batches = execute_subplan(
recursive_plan,
&next_params,
&HashMap::new(), graph_ctx,
session_ctx,
storage,
schema_info,
)
.await?;
let next_values = batches_to_values(&recursive_batches);
if next_values.is_empty() {
break;
}
let new_values: Vec<Value> = next_values
.into_iter()
.filter(|val| {
let key = value_key(val);
seen.insert(key) })
.collect();
if new_values.is_empty() {
break;
}
result_values.extend(new_values.clone());
working_values = new_values;
}
let mut list_builder = LargeListBuilder::new(Int64Builder::new());
for val in &result_values {
if let Some(vid) = extract_vid(val) {
list_builder.values().append_value(vid as i64);
}
}
list_builder.append(true);
let array = Arc::new(list_builder.finish());
RecordBatch::try_new(output_schema.clone(), vec![array]).map_err(arrow_err)
}
enum RecursiveCTEStreamState {
Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
Done,
}
struct RecursiveCTEStream {
state: RecursiveCTEStreamState,
schema: SchemaRef,
metrics: BaselineMetrics,
}
impl Stream for RecursiveCTEStream {
type Item = DFResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut self.state {
RecursiveCTEStreamState::Running(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(batch)) => {
self.metrics.record_output(batch.num_rows());
self.state = RecursiveCTEStreamState::Done;
Poll::Ready(Some(Ok(batch)))
}
Poll::Ready(Err(e)) => {
self.state = RecursiveCTEStreamState::Done;
Poll::Ready(Some(Err(e)))
}
Poll::Pending => Poll::Pending,
},
RecursiveCTEStreamState::Done => Poll::Ready(None),
}
}
}
impl RecordBatchStream for RecursiveCTEStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}