hirn_exec/operators/
graph_traverse.rs1use std::any::Any;
4use std::fmt;
5use std::sync::Arc;
6
7use arrow_array::{Float32Array, RecordBatch, StringArray, UInt32Array};
8use arrow_schema::SchemaRef;
9use datafusion_common::{DataFusionError, Result};
10use datafusion_execution::{SendableRecordBatchStream, TaskContext};
11use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
12use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
13use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
14use hirn_core::id::MemoryId;
15use hirn_core::types::{EdgeRelation, Namespace};
16
17use crate::extensions::{GraphTraverseRow, HirnSessionExt};
18
19#[derive(Debug, Clone)]
20pub struct GraphTraverseExec {
21 schema: SchemaRef,
22 properties: PlanProperties,
23 start_id: String,
24 relation_filter: Vec<EdgeRelation>,
25 depth: u32,
26 namespace: Option<String>,
27}
28
29impl GraphTraverseExec {
30 pub fn new(
31 schema: SchemaRef,
32 start_id: String,
33 relation_filter: Vec<EdgeRelation>,
34 depth: u32,
35 namespace: Option<String>,
36 ) -> Self {
37 let properties = PlanProperties::new(
38 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
39 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
40 EmissionType::Final,
41 Boundedness::Bounded,
42 );
43
44 Self {
45 schema,
46 properties,
47 start_id,
48 relation_filter,
49 depth,
50 namespace,
51 }
52 }
53}
54
55impl DisplayAs for GraphTraverseExec {
56 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(
58 f,
59 "GraphTraverseExec: depth={}, namespace={}",
60 self.depth,
61 self.namespace.as_deref().unwrap_or("*")
62 )
63 }
64}
65
66impl ExecutionPlan for GraphTraverseExec {
67 fn name(&self) -> &str {
68 "GraphTraverseExec"
69 }
70
71 fn as_any(&self) -> &dyn Any {
72 self
73 }
74
75 fn schema(&self) -> SchemaRef {
76 self.schema.clone()
77 }
78
79 fn properties(&self) -> &PlanProperties {
80 &self.properties
81 }
82
83 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
84 vec![]
85 }
86
87 fn with_new_children(
88 self: Arc<Self>,
89 children: Vec<Arc<dyn ExecutionPlan>>,
90 ) -> Result<Arc<dyn ExecutionPlan>> {
91 if !children.is_empty() {
92 return Err(DataFusionError::Plan(
93 "GraphTraverseExec is a leaf node and does not accept children".to_string(),
94 ));
95 }
96 Ok(self)
97 }
98
99 fn execute(
100 &self,
101 _partition: usize,
102 context: Arc<TaskContext>,
103 ) -> Result<SendableRecordBatchStream> {
104 let schema = self.schema.clone();
105 let stream_schema = schema.clone();
106 let start_id = self.start_id.clone();
107 let relation_filter = self.relation_filter.clone();
108 let depth = self.depth;
109 let namespace = self.namespace.clone();
110 let ext = context
111 .session_config()
112 .options()
113 .extensions
114 .get::<HirnSessionExt>()
115 .cloned();
116
117 let fut = async move {
118 let Some(ext) = ext else {
119 return Err(DataFusionError::Execution(
120 "GraphTraverseExec requires HirnSessionExt".to_string(),
121 ));
122 };
123 let Some(runtime) = ext.graph_read_runtime() else {
124 return Err(DataFusionError::Execution(
125 "GraphTraverseExec requires a graph read runtime in HirnSessionExt".to_string(),
126 ));
127 };
128
129 let start_id = MemoryId::parse(&start_id)
130 .map_err(|error| DataFusionError::Execution(error.to_string()))?;
131 let requested_namespace = parse_namespace(namespace.as_deref())?;
132 let allowed_namespaces = parse_allowed_namespaces(ext.allowed_namespaces())?;
133 let visible_namespaces =
134 resolve_visible_namespaces(requested_namespace, allowed_namespaces)?;
135 let relation_filter =
136 (!relation_filter.is_empty()).then_some(relation_filter.as_slice());
137
138 let rows = runtime
139 .traverse_graph(
140 &[start_id],
141 depth,
142 ext.config.graph_depth_delegation_threshold,
143 relation_filter,
144 visible_namespaces.as_deref(),
145 )
146 .await
147 .map_err(|error| DataFusionError::Execution(error.to_string()))?;
148
149 build_output_batch(stream_schema, &rows)
150 };
151
152 let stream = futures::stream::once(fut);
153 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
154 }
155}
156
157fn parse_namespace(namespace: Option<&str>) -> Result<Option<Namespace>> {
158 namespace
159 .map(|value| {
160 Namespace::new(value).map_err(|error| {
161 DataFusionError::Execution(format!(
162 "invalid namespace '{value}' in graph traverse: {error}"
163 ))
164 })
165 })
166 .transpose()
167}
168
169fn parse_allowed_namespaces(
170 allowed_namespaces: Option<&[String]>,
171) -> Result<Option<Vec<Namespace>>> {
172 allowed_namespaces
173 .map(|namespaces| {
174 namespaces
175 .iter()
176 .map(|namespace| {
177 Namespace::new(namespace).map_err(|error| {
178 DataFusionError::Execution(format!(
179 "invalid visible namespace '{namespace}' in graph traverse: {error}"
180 ))
181 })
182 })
183 .collect::<Result<Vec<_>>>()
184 })
185 .transpose()
186}
187
188fn resolve_visible_namespaces(
189 requested_namespace: Option<Namespace>,
190 allowed_namespaces: Option<Vec<Namespace>>,
191) -> Result<Option<Vec<Namespace>>> {
192 match (requested_namespace, allowed_namespaces) {
193 (Some(requested_namespace), Some(allowed_namespaces)) => {
194 if allowed_namespaces.contains(&requested_namespace) {
195 Ok(Some(vec![requested_namespace]))
196 } else {
197 Err(DataFusionError::Execution(format!(
198 "graph traverse cannot access namespace '{}'",
199 requested_namespace.as_str()
200 )))
201 }
202 }
203 (Some(requested_namespace), None) => Ok(Some(vec![requested_namespace])),
204 (None, allowed_namespaces) => Ok(allowed_namespaces),
205 }
206}
207
208fn build_output_batch(schema: SchemaRef, rows: &[GraphTraverseRow]) -> Result<RecordBatch> {
209 let node_ids = StringArray::from(
210 rows.iter()
211 .map(|row| row.node_id.as_str())
212 .collect::<Vec<_>>(),
213 );
214 let depths = UInt32Array::from(rows.iter().map(|row| row.depth).collect::<Vec<_>>());
215 let edge_relations = StringArray::from(vec![None::<&str>; rows.len()]);
216 let edge_weights = Float32Array::from(vec![None::<f32>; rows.len()]);
217
218 Ok(RecordBatch::try_new(
219 schema,
220 vec![
221 Arc::new(node_ids),
222 Arc::new(depths),
223 Arc::new(edge_relations),
224 Arc::new(edge_weights),
225 ],
226 )?)
227}