Skip to main content

grafeo_engine/query/executor/
procedure_call.rs

1//! Physical operator for CALL procedure execution.
2//!
3//! Wraps a [`GraphAlgorithm`] and produces [`DataChunk`]s from its result,
4//! with optional YIELD column filtering and aliasing.
5
6use std::sync::Arc;
7
8use grafeo_adapters::plugins::algorithms::GraphAlgorithm;
9use grafeo_adapters::plugins::{AlgorithmResult, Parameters};
10use grafeo_common::types::{LogicalType, Value};
11use grafeo_core::execution::DataChunk;
12use grafeo_core::execution::operators::{Operator, OperatorError, OperatorResult};
13use grafeo_core::graph::lpg::LpgStore;
14
15/// Physical operator that executes a graph algorithm and yields its results.
16///
17/// On the first call to [`next()`](Operator::next), the algorithm is executed and
18/// the full result is cached. Subsequent calls yield rows in chunks of
19/// `CHUNK_SIZE` until exhausted.
20pub struct ProcedureCallOperator {
21    store: Arc<LpgStore>,
22    algorithm: Arc<dyn GraphAlgorithm>,
23    params: Parameters,
24    /// YIELD items: (original_column, alias). `None` means yield all columns.
25    yield_columns: Option<Vec<(String, Option<String>)>>,
26    /// Canonical column names from the procedure registry (e.g., `["node_id", "score"]`
27    /// for PageRank, even though the algorithm internally names it `"pagerank"`).
28    /// Used to remap algorithm result columns for YIELD matching.
29    canonical_columns: Vec<String>,
30    /// Cached algorithm result (populated on first next()).
31    result: Option<AlgorithmResult>,
32    /// Current row position in the cached result.
33    row_index: usize,
34    /// Output column names (resolved after first next()).
35    output_columns: Vec<String>,
36    /// Column indices to extract from each result row (resolved after YIELD filtering).
37    column_indices: Vec<usize>,
38}
39
40/// Number of rows per DataChunk.
41const CHUNK_SIZE: usize = 1024;
42
43impl ProcedureCallOperator {
44    /// Creates a new procedure call operator.
45    pub fn new(
46        store: Arc<LpgStore>,
47        algorithm: Arc<dyn GraphAlgorithm>,
48        params: Parameters,
49        yield_columns: Option<Vec<(String, Option<String>)>>,
50        canonical_columns: Vec<String>,
51    ) -> Self {
52        Self {
53            store,
54            algorithm,
55            params,
56            yield_columns,
57            canonical_columns,
58            result: None,
59            row_index: 0,
60            output_columns: Vec::new(),
61            column_indices: Vec::new(),
62        }
63    }
64
65    /// Executes the algorithm and resolves YIELD column mapping.
66    fn execute_algorithm(&mut self) -> Result<(), OperatorError> {
67        let result = self
68            .algorithm
69            .execute(&self.store, &self.params)
70            .map_err(|e| OperatorError::Execution(format!("Procedure execution failed: {e}")))?;
71
72        // Use canonical column names if available (same length as result columns),
73        // otherwise fall back to the algorithm's own column names.
74        let display_columns = if self.canonical_columns.len() == result.columns.len() {
75            &self.canonical_columns
76        } else {
77            &result.columns
78        };
79
80        // Resolve YIELD columns → indices (matching against canonical names)
81        if let Some(ref yield_cols) = self.yield_columns {
82            for (field_name, alias) in yield_cols {
83                let idx = display_columns
84                    .iter()
85                    .position(|c| c == field_name)
86                    .ok_or_else(|| {
87                        OperatorError::ColumnNotFound(format!(
88                            "YIELD column '{}' not found in procedure result (available: {})",
89                            field_name,
90                            display_columns.join(", ")
91                        ))
92                    })?;
93                self.column_indices.push(idx);
94                self.output_columns
95                    .push(alias.clone().unwrap_or_else(|| field_name.clone()));
96            }
97        } else {
98            // No YIELD: return all columns with canonical names
99            self.column_indices = (0..result.columns.len()).collect();
100            self.output_columns = display_columns.clone();
101        }
102
103        self.result = Some(result);
104        Ok(())
105    }
106
107    /// Returns the output column names (available after first next() call).
108    pub fn output_columns(&self) -> &[String] {
109        &self.output_columns
110    }
111}
112
113impl Operator for ProcedureCallOperator {
114    fn next(&mut self) -> OperatorResult {
115        // Lazy execution: run algorithm on first call
116        if self.result.is_none() {
117            self.execute_algorithm()?;
118        }
119
120        let result = self.result.as_ref().unwrap();
121
122        if self.row_index >= result.rows.len() {
123            return Ok(None);
124        }
125
126        let remaining = result.rows.len() - self.row_index;
127        let chunk_rows = remaining.min(CHUNK_SIZE);
128
129        // Build column types from first row
130        let col_types: Vec<LogicalType> = if !result.rows.is_empty() {
131            self.column_indices
132                .iter()
133                .map(|&idx| value_to_logical_type(&result.rows[0][idx]))
134                .collect()
135        } else {
136            vec![LogicalType::Any; self.column_indices.len()]
137        };
138
139        let mut chunk = DataChunk::with_capacity(&col_types, chunk_rows);
140
141        for row_offset in 0..chunk_rows {
142            let row = &result.rows[self.row_index + row_offset];
143            for (col_idx, &src_idx) in self.column_indices.iter().enumerate() {
144                let value = row.get(src_idx).cloned().unwrap_or(Value::Null);
145                if let Some(col) = chunk.column_mut(col_idx) {
146                    col.push_value(value);
147                }
148            }
149        }
150        chunk.set_count(chunk_rows);
151
152        self.row_index += chunk_rows;
153        Ok(Some(chunk))
154    }
155
156    fn reset(&mut self) {
157        self.row_index = 0;
158        // Keep the cached result for re-iteration
159    }
160
161    fn name(&self) -> &'static str {
162        "ProcedureCall"
163    }
164}
165
166/// Maps a `Value` to its `LogicalType`.
167fn value_to_logical_type(value: &Value) -> LogicalType {
168    match value {
169        Value::Null => LogicalType::Any,
170        Value::Bool(_) => LogicalType::Bool,
171        Value::Int64(_) => LogicalType::Int64,
172        Value::Float64(_) => LogicalType::Float64,
173        Value::String(_) => LogicalType::String,
174        _ => LogicalType::Any,
175    }
176}