grafeo_engine/query/executor/
procedure_call.rs1use std::sync::Arc;
7
8use grafeo_adapters::plugins::algorithms::GraphAlgorithm;
9use grafeo_adapters::plugins::{AlgorithmResult, Parameters};
10use grafeo_common::types::{LogicalType, Value};
11use grafeo_core::execution::DataChunk;
12use grafeo_core::execution::operators::{Operator, OperatorError, OperatorResult};
13use grafeo_core::graph::lpg::LpgStore;
14
15pub struct ProcedureCallOperator {
21 store: Arc<LpgStore>,
22 algorithm: Arc<dyn GraphAlgorithm>,
23 params: Parameters,
24 yield_columns: Option<Vec<(String, Option<String>)>>,
26 canonical_columns: Vec<String>,
30 result: Option<AlgorithmResult>,
32 row_index: usize,
34 output_columns: Vec<String>,
36 column_indices: Vec<usize>,
38}
39
40const CHUNK_SIZE: usize = 1024;
42
43impl ProcedureCallOperator {
44 pub fn new(
46 store: Arc<LpgStore>,
47 algorithm: Arc<dyn GraphAlgorithm>,
48 params: Parameters,
49 yield_columns: Option<Vec<(String, Option<String>)>>,
50 canonical_columns: Vec<String>,
51 ) -> Self {
52 Self {
53 store,
54 algorithm,
55 params,
56 yield_columns,
57 canonical_columns,
58 result: None,
59 row_index: 0,
60 output_columns: Vec::new(),
61 column_indices: Vec::new(),
62 }
63 }
64
65 fn execute_algorithm(&mut self) -> Result<(), OperatorError> {
67 let result = self
68 .algorithm
69 .execute(&self.store, &self.params)
70 .map_err(|e| OperatorError::Execution(format!("Procedure execution failed: {e}")))?;
71
72 let display_columns = if self.canonical_columns.len() == result.columns.len() {
75 &self.canonical_columns
76 } else {
77 &result.columns
78 };
79
80 if let Some(ref yield_cols) = self.yield_columns {
82 for (field_name, alias) in yield_cols {
83 let idx = display_columns
84 .iter()
85 .position(|c| c == field_name)
86 .ok_or_else(|| {
87 OperatorError::ColumnNotFound(format!(
88 "YIELD column '{}' not found in procedure result (available: {})",
89 field_name,
90 display_columns.join(", ")
91 ))
92 })?;
93 self.column_indices.push(idx);
94 self.output_columns
95 .push(alias.clone().unwrap_or_else(|| field_name.clone()));
96 }
97 } else {
98 self.column_indices = (0..result.columns.len()).collect();
100 self.output_columns = display_columns.clone();
101 }
102
103 self.result = Some(result);
104 Ok(())
105 }
106
107 pub fn output_columns(&self) -> &[String] {
109 &self.output_columns
110 }
111}
112
113impl Operator for ProcedureCallOperator {
114 fn next(&mut self) -> OperatorResult {
115 if self.result.is_none() {
117 self.execute_algorithm()?;
118 }
119
120 let result = self.result.as_ref().unwrap();
121
122 if self.row_index >= result.rows.len() {
123 return Ok(None);
124 }
125
126 let remaining = result.rows.len() - self.row_index;
127 let chunk_rows = remaining.min(CHUNK_SIZE);
128
129 let col_types: Vec<LogicalType> = if !result.rows.is_empty() {
131 self.column_indices
132 .iter()
133 .map(|&idx| value_to_logical_type(&result.rows[0][idx]))
134 .collect()
135 } else {
136 vec![LogicalType::Any; self.column_indices.len()]
137 };
138
139 let mut chunk = DataChunk::with_capacity(&col_types, chunk_rows);
140
141 for row_offset in 0..chunk_rows {
142 let row = &result.rows[self.row_index + row_offset];
143 for (col_idx, &src_idx) in self.column_indices.iter().enumerate() {
144 let value = row.get(src_idx).cloned().unwrap_or(Value::Null);
145 if let Some(col) = chunk.column_mut(col_idx) {
146 col.push_value(value);
147 }
148 }
149 }
150 chunk.set_count(chunk_rows);
151
152 self.row_index += chunk_rows;
153 Ok(Some(chunk))
154 }
155
156 fn reset(&mut self) {
157 self.row_index = 0;
158 }
160
161 fn name(&self) -> &'static str {
162 "ProcedureCall"
163 }
164}
165
166fn value_to_logical_type(value: &Value) -> LogicalType {
168 match value {
169 Value::Null => LogicalType::Any,
170 Value::Bool(_) => LogicalType::Bool,
171 Value::Int64(_) => LogicalType::Int64,
172 Value::Float64(_) => LogicalType::Float64,
173 Value::String(_) => LogicalType::String,
174 _ => LogicalType::Any,
175 }
176}