use crate::query::df_graph::GraphExecutionContext;
use crate::query::df_graph::common::{
arrow_err, column_as_vid_array, compute_plan_properties, edge_struct_fields,
new_node_list_builder,
};
use arrow::compute::take;
use arrow_array::builder::{ListBuilder, StructBuilder, UInt64Builder};
use arrow_array::{Array, ArrayRef, RecordBatch, UInt32Array, UInt64Array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::common::Result as DFResult;
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::{Stream, StreamExt};
use fxhash::FxHashMap;
use std::any::Any;
use std::collections::{HashSet, VecDeque};
use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use uni_common::core::id::Vid;
use uni_store::runtime::l0_visibility;
use uni_store::storage::direction::Direction;
pub struct GraphShortestPathExec {
input: Arc<dyn ExecutionPlan>,
source_column: String,
target_column: String,
edge_type_ids: Vec<u32>,
direction: Direction,
path_variable: String,
all_shortest: bool,
graph_ctx: Arc<GraphExecutionContext>,
schema: SchemaRef,
properties: PlanProperties,
metrics: ExecutionPlanMetricsSet,
}
impl fmt::Debug for GraphShortestPathExec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GraphShortestPathExec")
.field("source_column", &self.source_column)
.field("target_column", &self.target_column)
.field("edge_type_ids", &self.edge_type_ids)
.field("direction", &self.direction)
.field("path_variable", &self.path_variable)
.field("all_shortest", &self.all_shortest)
.finish()
}
}
impl GraphShortestPathExec {
#[expect(
clippy::too_many_arguments,
reason = "Shortest path requires many parameters"
)]
pub fn new(
input: Arc<dyn ExecutionPlan>,
source_column: impl Into<String>,
target_column: impl Into<String>,
edge_type_ids: Vec<u32>,
direction: Direction,
path_variable: impl Into<String>,
graph_ctx: Arc<GraphExecutionContext>,
all_shortest: bool,
) -> Self {
let source_column = source_column.into();
let target_column = target_column.into();
let path_variable = path_variable.into();
let schema = Self::build_schema(input.schema(), &path_variable);
let properties = compute_plan_properties(schema.clone());
Self {
input,
source_column,
target_column,
edge_type_ids,
direction,
path_variable,
all_shortest,
graph_ctx,
schema,
properties,
metrics: ExecutionPlanMetricsSet::new(),
}
}
fn build_schema(input_schema: SchemaRef, path_variable: &str) -> SchemaRef {
let mut fields: Vec<Field> = input_schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
fields.push(crate::query::df_graph::common::build_path_struct_field(
path_variable,
));
let path_col_name = format!("{}._path", path_variable);
fields.push(Field::new(
&path_col_name,
DataType::List(Arc::new(Field::new("item", DataType::UInt64, true))),
true, ));
let len_col_name = format!("{}._length", path_variable);
fields.push(Field::new(&len_col_name, DataType::UInt64, true));
Arc::new(Schema::new(fields))
}
}
impl DisplayAs for GraphShortestPathExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mode = if self.all_shortest { "all" } else { "any" };
write!(
f,
"GraphShortestPathExec: {} -> {} via {:?} ({})",
self.source_column, self.target_column, self.edge_type_ids, mode
)
}
}
impl ExecutionPlan for GraphShortestPathExec {
fn name(&self) -> &str {
"GraphShortestPathExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(datafusion::error::DataFusionError::Plan(
"GraphShortestPathExec requires exactly one child".to_string(),
));
}
Ok(Arc::new(Self::new(
Arc::clone(&children[0]),
self.source_column.clone(),
self.target_column.clone(),
self.edge_type_ids.clone(),
self.direction,
self.path_variable.clone(),
Arc::clone(&self.graph_ctx),
self.all_shortest,
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, context)?;
let metrics = BaselineMetrics::new(&self.metrics, partition);
let warm_fut = self
.graph_ctx
.warming_future(self.edge_type_ids.clone(), self.direction);
Ok(Box::pin(GraphShortestPathStream {
input: input_stream,
source_column: self.source_column.clone(),
target_column: self.target_column.clone(),
edge_type_ids: self.edge_type_ids.clone(),
direction: self.direction,
all_shortest: self.all_shortest,
graph_ctx: Arc::clone(&self.graph_ctx),
schema: Arc::clone(&self.schema),
state: ShortestPathStreamState::Warming(warm_fut),
metrics,
}))
}
fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}
enum ShortestPathStreamState {
Warming(Pin<Box<dyn std::future::Future<Output = DFResult<()>> + Send>>),
Reading,
Done,
}
struct GraphShortestPathStream {
input: SendableRecordBatchStream,
source_column: String,
target_column: String,
edge_type_ids: Vec<u32>,
direction: Direction,
all_shortest: bool,
graph_ctx: Arc<GraphExecutionContext>,
schema: SchemaRef,
state: ShortestPathStreamState,
metrics: BaselineMetrics,
}
impl GraphShortestPathStream {
fn compute_shortest_path(&self, source: Vid, target: Vid) -> Option<Vec<Vid>> {
if source == target {
return Some(vec![source]);
}
let mut visited: HashSet<Vid> = HashSet::new();
let mut queue: VecDeque<(Vid, Vec<Vid>)> = VecDeque::new();
visited.insert(source);
queue.push_back((source, vec![source]));
while let Some((current, path)) = queue.pop_front() {
for &edge_type in &self.edge_type_ids {
let neighbors = self
.graph_ctx
.get_neighbors(current, edge_type, self.direction);
for (neighbor, _eid) in neighbors {
if neighbor == target {
let mut result = path.clone();
result.push(target);
return Some(result);
}
if !visited.contains(&neighbor) {
visited.insert(neighbor);
let mut new_path = path.clone();
new_path.push(neighbor);
queue.push_back((neighbor, new_path));
}
}
}
}
None }
fn compute_all_shortest_paths(&self, source: Vid, target: Vid) -> Vec<Vec<Vid>> {
if source == target {
return vec![vec![source]];
}
let mut depth: FxHashMap<Vid, u32> = FxHashMap::default();
let mut predecessors: FxHashMap<Vid, Vec<Vid>> = FxHashMap::default();
depth.insert(source, 0);
let mut current_layer: Vec<Vid> = vec![source];
let mut current_depth = 0u32;
let mut target_found = false;
while !current_layer.is_empty() && !target_found {
current_depth += 1;
let mut next_layer_set: HashSet<Vid> = HashSet::new();
for ¤t in ¤t_layer {
for &edge_type in &self.edge_type_ids {
let neighbors =
self.graph_ctx
.get_neighbors(current, edge_type, self.direction);
for (neighbor, _eid) in neighbors {
if let Some(&d) = depth.get(&neighbor) {
if d == current_depth {
predecessors.entry(neighbor).or_default().push(current);
}
continue;
}
depth.insert(neighbor, current_depth);
predecessors.entry(neighbor).or_default().push(current);
if neighbor == target {
target_found = true;
} else {
next_layer_set.insert(neighbor);
}
}
}
}
current_layer = next_layer_set.into_iter().collect();
}
if !target_found {
return vec![];
}
let mut result: Vec<Vec<Vid>> = Vec::new();
let mut stack: Vec<(Vid, Vec<Vid>)> = vec![(target, vec![target])];
while let Some((node, path)) = stack.pop() {
if node == source {
let mut full_path = path;
full_path.reverse();
result.push(full_path);
continue;
}
if let Some(preds) = predecessors.get(&node) {
for &pred in preds {
let mut new_path = path.clone();
new_path.push(pred);
stack.push((pred, new_path));
}
}
}
result
}
fn process_batch(&self, batch: RecordBatch) -> DFResult<RecordBatch> {
let source_col = batch.column_by_name(&self.source_column).ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"Source column '{}' not found",
self.source_column
))
})?;
let target_col = batch.column_by_name(&self.target_column).ok_or_else(|| {
datafusion::error::DataFusionError::Execution(format!(
"Target column '{}' not found",
self.target_column
))
})?;
let source_vid_cow = column_as_vid_array(source_col.as_ref())?;
let source_vids: &UInt64Array = &source_vid_cow;
let target_vid_cow = column_as_vid_array(target_col.as_ref())?;
let target_vids: &UInt64Array = &target_vid_cow;
if self.all_shortest {
let mut row_indices: Vec<u32> = Vec::new();
let mut all_paths: Vec<Option<Vec<Vid>>> = Vec::new();
for i in 0..batch.num_rows() {
if source_vids.is_null(i) || target_vids.is_null(i) {
row_indices.push(i as u32);
all_paths.push(None);
} else {
let source = Vid::from(source_vids.value(i));
let target = Vid::from(target_vids.value(i));
let paths = self.compute_all_shortest_paths(source, target);
if paths.is_empty() {
row_indices.push(i as u32);
all_paths.push(None);
} else {
for path in paths {
row_indices.push(i as u32);
all_paths.push(Some(path));
}
}
}
}
let indices = UInt32Array::from(row_indices);
let expanded_columns: Vec<ArrayRef> = batch
.columns()
.iter()
.map(|col| {
take(col.as_ref(), &indices, None).map_err(|e| {
datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
})
})
.collect::<DFResult<Vec<_>>>()?;
let expanded_batch =
RecordBatch::try_new(batch.schema(), expanded_columns).map_err(arrow_err)?;
self.build_output_batch(&expanded_batch, &all_paths)
} else {
let mut paths: Vec<Option<Vec<Vid>>> = Vec::with_capacity(batch.num_rows());
for i in 0..batch.num_rows() {
let path = if source_vids.is_null(i) || target_vids.is_null(i) {
None
} else {
let source = Vid::from(source_vids.value(i));
let target = Vid::from(target_vids.value(i));
self.compute_shortest_path(source, target)
};
paths.push(path);
}
self.build_output_batch(&batch, &paths)
}
}
fn build_output_batch(
&self,
input: &RecordBatch,
paths: &[Option<Vec<Vid>>],
) -> DFResult<RecordBatch> {
let num_rows = paths.len();
let query_ctx = self.graph_ctx.query_context();
let mut columns: Vec<ArrayRef> = input.columns().to_vec();
let mut nodes_builder = new_node_list_builder();
let mut rels_builder =
ListBuilder::new(StructBuilder::from_fields(edge_struct_fields(), num_rows));
let mut path_validity = Vec::with_capacity(num_rows);
for path in paths {
match path {
Some(vids) => {
for &vid in vids {
super::common::append_node_to_struct(
nodes_builder.values(),
vid,
&query_ctx,
);
}
nodes_builder.append(true);
for window in vids.windows(2) {
let src = window[0];
let dst = window[1];
let (eid, type_name) = self.find_edge(src, dst);
super::common::append_edge_to_struct(
rels_builder.values(),
eid,
&type_name,
src.as_u64(),
dst.as_u64(),
&query_ctx,
);
}
rels_builder.append(true);
path_validity.push(true);
}
None => {
nodes_builder.append(false);
rels_builder.append(false);
path_validity.push(false);
}
}
}
let nodes_array = Arc::new(nodes_builder.finish()) as ArrayRef;
let rels_array = Arc::new(rels_builder.finish()) as ArrayRef;
let path_struct =
super::common::build_path_struct_array(nodes_array, rels_array, path_validity)?;
columns.push(Arc::new(path_struct));
let mut list_builder = ListBuilder::new(UInt64Builder::new());
for path in paths {
match path {
Some(p) => {
let values: Vec<u64> = p.iter().map(|v| v.as_u64()).collect();
list_builder.values().append_slice(&values);
list_builder.append(true);
}
None => {
list_builder.append(false); }
}
}
columns.push(Arc::new(list_builder.finish()));
let lengths: Vec<Option<u64>> = paths
.iter()
.map(|p| p.as_ref().map(|path| (path.len() - 1) as u64))
.collect();
columns.push(Arc::new(UInt64Array::from(lengths)));
self.metrics.record_output(num_rows);
RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(arrow_err)
}
fn find_edge(&self, src: Vid, dst: Vid) -> (uni_common::core::id::Eid, String) {
let query_ctx = self.graph_ctx.query_context();
for &edge_type in &self.edge_type_ids {
let neighbors = self.graph_ctx.get_neighbors(src, edge_type, self.direction);
for (neighbor, eid) in neighbors {
if neighbor == dst {
let type_name =
l0_visibility::get_edge_type(eid, &query_ctx).unwrap_or_default();
return (eid, type_name);
}
}
}
(uni_common::core::id::Eid::from(0u64), String::new())
}
}
impl Stream for GraphShortestPathStream {
type Item = DFResult<RecordBatch>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let state = std::mem::replace(&mut self.state, ShortestPathStreamState::Done);
match state {
ShortestPathStreamState::Warming(mut fut) => match fut.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {
self.state = ShortestPathStreamState::Reading;
}
Poll::Ready(Err(e)) => {
self.state = ShortestPathStreamState::Done;
return Poll::Ready(Some(Err(e)));
}
Poll::Pending => {
self.state = ShortestPathStreamState::Warming(fut);
return Poll::Pending;
}
},
ShortestPathStreamState::Reading => {
if let Err(e) = self.graph_ctx.check_timeout() {
return Poll::Ready(Some(Err(
datafusion::error::DataFusionError::Execution(e.to_string()),
)));
}
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
let result = self.process_batch(batch);
self.state = ShortestPathStreamState::Reading;
return Poll::Ready(Some(result));
}
Poll::Ready(Some(Err(e))) => {
self.state = ShortestPathStreamState::Done;
return Poll::Ready(Some(Err(e)));
}
Poll::Ready(None) => {
self.state = ShortestPathStreamState::Done;
return Poll::Ready(None);
}
Poll::Pending => {
self.state = ShortestPathStreamState::Reading;
return Poll::Pending;
}
}
}
ShortestPathStreamState::Done => {
return Poll::Ready(None);
}
}
}
}
}
impl RecordBatchStream for GraphShortestPathStream {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shortest_path_schema() {
let input_schema = Arc::new(Schema::new(vec![
Field::new("_source_vid", DataType::UInt64, false),
Field::new("_target_vid", DataType::UInt64, false),
]));
let output_schema = GraphShortestPathExec::build_schema(input_schema, "p");
assert_eq!(output_schema.fields().len(), 5);
assert_eq!(output_schema.field(0).name(), "_source_vid");
assert_eq!(output_schema.field(1).name(), "_target_vid");
assert_eq!(output_schema.field(2).name(), "p");
assert_eq!(output_schema.field(3).name(), "p._path");
assert_eq!(output_schema.field(4).name(), "p._length");
}
#[test]
fn test_shortest_path_schema_with_extra_input_fields() {
let input_schema = Arc::new(Schema::new(vec![
Field::new("_source_vid", DataType::UInt64, false),
Field::new("_target_vid", DataType::UInt64, false),
Field::new("extra_col", DataType::Utf8, true),
]));
let output_schema = GraphShortestPathExec::build_schema(input_schema, "route");
assert!(
output_schema.field_with_name("extra_col").is_ok(),
"Extra input columns should pass through"
);
assert!(
output_schema.field_with_name("route").is_ok(),
"Path variable should be in output"
);
assert!(
output_schema.field_with_name("route._length").is_ok(),
"Path length should be in output"
);
}
#[test]
fn test_shortest_path_schema_empty_path_var() {
let input_schema = Arc::new(Schema::new(vec![
Field::new("_source_vid", DataType::UInt64, false),
Field::new("_target_vid", DataType::UInt64, false),
]));
let output_schema = GraphShortestPathExec::build_schema(input_schema, "");
assert!(output_schema.fields().len() >= 4);
}
}