Skip to main content

hirn_exec/operators/
graph_traverse.rs

1//! `GraphTraverseExec` — DataFusion operator for graph traversal reads.
2
3use std::any::Any;
4use std::fmt;
5use std::sync::Arc;
6
7use arrow_array::{Float32Array, RecordBatch, StringArray, UInt32Array};
8use arrow_schema::SchemaRef;
9use datafusion_common::{DataFusionError, Result};
10use datafusion_execution::{SendableRecordBatchStream, TaskContext};
11use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
12use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
13use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
14use hirn_core::id::MemoryId;
15use hirn_core::types::{EdgeRelation, Namespace};
16
17use crate::extensions::{GraphTraverseRow, HirnSessionExt};
18
19#[derive(Debug, Clone)]
20pub struct GraphTraverseExec {
21    schema: SchemaRef,
22    properties: PlanProperties,
23    start_id: String,
24    relation_filter: Vec<EdgeRelation>,
25    depth: u32,
26    namespace: Option<String>,
27}
28
29impl GraphTraverseExec {
30    pub fn new(
31        schema: SchemaRef,
32        start_id: String,
33        relation_filter: Vec<EdgeRelation>,
34        depth: u32,
35        namespace: Option<String>,
36    ) -> Self {
37        let properties = PlanProperties::new(
38            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
39            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
40            EmissionType::Final,
41            Boundedness::Bounded,
42        );
43
44        Self {
45            schema,
46            properties,
47            start_id,
48            relation_filter,
49            depth,
50            namespace,
51        }
52    }
53}
54
55impl DisplayAs for GraphTraverseExec {
56    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(
58            f,
59            "GraphTraverseExec: depth={}, namespace={}",
60            self.depth,
61            self.namespace.as_deref().unwrap_or("*")
62        )
63    }
64}
65
66impl ExecutionPlan for GraphTraverseExec {
67    fn name(&self) -> &str {
68        "GraphTraverseExec"
69    }
70
71    fn as_any(&self) -> &dyn Any {
72        self
73    }
74
75    fn schema(&self) -> SchemaRef {
76        self.schema.clone()
77    }
78
79    fn properties(&self) -> &PlanProperties {
80        &self.properties
81    }
82
83    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
84        vec![]
85    }
86
87    fn with_new_children(
88        self: Arc<Self>,
89        children: Vec<Arc<dyn ExecutionPlan>>,
90    ) -> Result<Arc<dyn ExecutionPlan>> {
91        if !children.is_empty() {
92            return Err(DataFusionError::Plan(
93                "GraphTraverseExec is a leaf node and does not accept children".to_string(),
94            ));
95        }
96        Ok(self)
97    }
98
99    fn execute(
100        &self,
101        _partition: usize,
102        context: Arc<TaskContext>,
103    ) -> Result<SendableRecordBatchStream> {
104        let schema = self.schema.clone();
105        let stream_schema = schema.clone();
106        let start_id = self.start_id.clone();
107        let relation_filter = self.relation_filter.clone();
108        let depth = self.depth;
109        let namespace = self.namespace.clone();
110        let ext = context
111            .session_config()
112            .options()
113            .extensions
114            .get::<HirnSessionExt>()
115            .cloned();
116
117        let fut = async move {
118            let Some(ext) = ext else {
119                return Err(DataFusionError::Execution(
120                    "GraphTraverseExec requires HirnSessionExt".to_string(),
121                ));
122            };
123            let Some(runtime) = ext.graph_read_runtime() else {
124                return Err(DataFusionError::Execution(
125                    "GraphTraverseExec requires a graph read runtime in HirnSessionExt".to_string(),
126                ));
127            };
128
129            let start_id = MemoryId::parse(&start_id)
130                .map_err(|error| DataFusionError::Execution(error.to_string()))?;
131            let requested_namespace = parse_namespace(namespace.as_deref())?;
132            let allowed_namespaces = parse_allowed_namespaces(ext.allowed_namespaces())?;
133            let visible_namespaces =
134                resolve_visible_namespaces(requested_namespace, allowed_namespaces)?;
135            let relation_filter =
136                (!relation_filter.is_empty()).then_some(relation_filter.as_slice());
137
138            let rows = runtime
139                .traverse_graph(
140                    &[start_id],
141                    depth,
142                    ext.config.graph_depth_delegation_threshold,
143                    relation_filter,
144                    visible_namespaces.as_deref(),
145                )
146                .await
147                .map_err(|error| DataFusionError::Execution(error.to_string()))?;
148
149            build_output_batch(stream_schema, &rows)
150        };
151
152        let stream = futures::stream::once(fut);
153        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
154    }
155}
156
157fn parse_namespace(namespace: Option<&str>) -> Result<Option<Namespace>> {
158    namespace
159        .map(|value| {
160            Namespace::new(value).map_err(|error| {
161                DataFusionError::Execution(format!(
162                    "invalid namespace '{value}' in graph traverse: {error}"
163                ))
164            })
165        })
166        .transpose()
167}
168
169fn parse_allowed_namespaces(
170    allowed_namespaces: Option<&[String]>,
171) -> Result<Option<Vec<Namespace>>> {
172    allowed_namespaces
173        .map(|namespaces| {
174            namespaces
175                .iter()
176                .map(|namespace| {
177                    Namespace::new(namespace).map_err(|error| {
178                        DataFusionError::Execution(format!(
179                            "invalid visible namespace '{namespace}' in graph traverse: {error}"
180                        ))
181                    })
182                })
183                .collect::<Result<Vec<_>>>()
184        })
185        .transpose()
186}
187
188fn resolve_visible_namespaces(
189    requested_namespace: Option<Namespace>,
190    allowed_namespaces: Option<Vec<Namespace>>,
191) -> Result<Option<Vec<Namespace>>> {
192    match (requested_namespace, allowed_namespaces) {
193        (Some(requested_namespace), Some(allowed_namespaces)) => {
194            if allowed_namespaces.contains(&requested_namespace) {
195                Ok(Some(vec![requested_namespace]))
196            } else {
197                Err(DataFusionError::Execution(format!(
198                    "graph traverse cannot access namespace '{}'",
199                    requested_namespace.as_str()
200                )))
201            }
202        }
203        (Some(requested_namespace), None) => Ok(Some(vec![requested_namespace])),
204        (None, allowed_namespaces) => Ok(allowed_namespaces),
205    }
206}
207
208fn build_output_batch(schema: SchemaRef, rows: &[GraphTraverseRow]) -> Result<RecordBatch> {
209    let node_ids = StringArray::from(
210        rows.iter()
211            .map(|row| row.node_id.as_str())
212            .collect::<Vec<_>>(),
213    );
214    let depths = UInt32Array::from(rows.iter().map(|row| row.depth).collect::<Vec<_>>());
215    let edge_relations = StringArray::from(vec![None::<&str>; rows.len()]);
216    let edge_weights = Float32Array::from(vec![None::<f32>; rows.len()]);
217
218    Ok(RecordBatch::try_new(
219        schema,
220        vec![
221            Arc::new(node_ids),
222            Arc::new(depths),
223            Arc::new(edge_relations),
224            Arc::new(edge_weights),
225        ],
226    )?)
227}