Skip to main content

grafeo_engine/
procedures.rs

1//! Built-in procedure registry for CALL statement execution.
2//!
3//! Maps procedure names to [`GraphAlgorithm`] implementations, enabling
4//! `CALL grafeo.pagerank({damping: 0.85}) YIELD nodeId, score` from any
5//! supported query language (GQL, Cypher, SQL/PGQ).
6
7use std::sync::Arc;
8
9use grafeo_adapters::plugins::algorithms::{
10    ArticulationPointsAlgorithm, BellmanFordAlgorithm, BetweennessCentralityAlgorithm,
11    BfsAlgorithm, BridgesAlgorithm, ClosenessCentralityAlgorithm, ClusteringCoefficientAlgorithm,
12    ConnectedComponentsAlgorithm, DegreeCentralityAlgorithm, DfsAlgorithm, DijkstraAlgorithm,
13    FloydWarshallAlgorithm, GraphAlgorithm, KCoreAlgorithm, KruskalAlgorithm,
14    LabelPropagationAlgorithm, LouvainAlgorithm, MaxFlowAlgorithm, MinCostFlowAlgorithm,
15    PageRankAlgorithm, PrimAlgorithm, StronglyConnectedComponentsAlgorithm,
16    TopologicalSortAlgorithm,
17};
18use grafeo_adapters::plugins::{AlgorithmResult, ParameterDef, Parameters};
19use grafeo_common::types::Value;
20use hashbrown::HashMap;
21
22use crate::query::plan::LogicalExpression;
23
24/// Registry of built-in procedures backed by graph algorithms.
25pub struct BuiltinProcedures {
26    algorithms: HashMap<String, Arc<dyn GraphAlgorithm>>,
27}
28
29impl BuiltinProcedures {
30    /// Creates a new registry with all built-in algorithms registered.
31    pub fn new() -> Self {
32        let mut algorithms: HashMap<String, Arc<dyn GraphAlgorithm>> = HashMap::new();
33        let register = |map: &mut HashMap<String, Arc<dyn GraphAlgorithm>>,
34                        algo: Arc<dyn GraphAlgorithm>| {
35            map.insert(algo.name().to_string(), algo);
36        };
37
38        // Centrality
39        register(&mut algorithms, Arc::new(PageRankAlgorithm));
40        register(&mut algorithms, Arc::new(BetweennessCentralityAlgorithm));
41        register(&mut algorithms, Arc::new(ClosenessCentralityAlgorithm));
42        register(&mut algorithms, Arc::new(DegreeCentralityAlgorithm));
43
44        // Traversal
45        register(&mut algorithms, Arc::new(BfsAlgorithm));
46        register(&mut algorithms, Arc::new(DfsAlgorithm));
47
48        // Components
49        register(&mut algorithms, Arc::new(ConnectedComponentsAlgorithm));
50        register(
51            &mut algorithms,
52            Arc::new(StronglyConnectedComponentsAlgorithm),
53        );
54        register(&mut algorithms, Arc::new(TopologicalSortAlgorithm));
55
56        // Shortest Path
57        register(&mut algorithms, Arc::new(DijkstraAlgorithm));
58        register(&mut algorithms, Arc::new(BellmanFordAlgorithm));
59        register(&mut algorithms, Arc::new(FloydWarshallAlgorithm));
60
61        // Clustering
62        register(&mut algorithms, Arc::new(ClusteringCoefficientAlgorithm));
63
64        // Community
65        register(&mut algorithms, Arc::new(LabelPropagationAlgorithm));
66        register(&mut algorithms, Arc::new(LouvainAlgorithm));
67
68        // MST
69        register(&mut algorithms, Arc::new(KruskalAlgorithm));
70        register(&mut algorithms, Arc::new(PrimAlgorithm));
71
72        // Flow
73        register(&mut algorithms, Arc::new(MaxFlowAlgorithm));
74        register(&mut algorithms, Arc::new(MinCostFlowAlgorithm));
75
76        // Structure
77        register(&mut algorithms, Arc::new(ArticulationPointsAlgorithm));
78        register(&mut algorithms, Arc::new(BridgesAlgorithm));
79        register(&mut algorithms, Arc::new(KCoreAlgorithm));
80
81        Self { algorithms }
82    }
83
84    /// Resolves a procedure name to its algorithm.
85    ///
86    /// Strips `"grafeo."` prefix if present:
87    /// - `["grafeo", "pagerank"]` → looks up `"pagerank"`
88    /// - `["pagerank"]` → looks up `"pagerank"`
89    pub fn get(&self, name: &[String]) -> Option<Arc<dyn GraphAlgorithm>> {
90        let key = resolve_name(name);
91        self.algorithms.get(key).cloned()
92    }
93
94    /// Returns info for all registered procedures.
95    pub fn list(&self) -> Vec<ProcedureInfo> {
96        let mut result: Vec<ProcedureInfo> = self
97            .algorithms
98            .values()
99            .map(|algo| ProcedureInfo {
100                name: format!("grafeo.{}", algo.name()),
101                description: algo.description().to_string(),
102                parameters: algo.parameters().to_vec(),
103                output_columns: output_columns_for(algo.as_ref()),
104            })
105            .collect();
106        result.sort_by(|a, b| a.name.cmp(&b.name));
107        result
108    }
109}
110
111impl Default for BuiltinProcedures {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117/// Metadata about a registered procedure.
118pub struct ProcedureInfo {
119    /// Qualified name (e.g., `"grafeo.pagerank"`).
120    pub name: String,
121    /// Description of what the procedure does.
122    pub description: String,
123    /// Parameter definitions.
124    pub parameters: Vec<ParameterDef>,
125    /// Output column names.
126    pub output_columns: Vec<String>,
127}
128
129/// Resolves a dotted procedure name to its lookup key.
130///
131/// Strips the `"grafeo"` namespace prefix if present.
132fn resolve_name(parts: &[String]) -> &str {
133    match parts {
134        [_, name] if parts[0].eq_ignore_ascii_case("grafeo") => name.as_str(),
135        [name] => name.as_str(),
136        _ => parts.last().map_or("", String::as_str),
137    }
138}
139
140/// Infers output column names from an algorithm.
141///
142/// Returns the standard column names for known algorithm categories.
143pub fn output_columns_for_name(algo: &dyn GraphAlgorithm) -> Vec<String> {
144    output_columns_for(algo)
145}
146
147/// Canonical output column names for each algorithm.
148///
149/// These must match the actual column count from each algorithm's `execute()`,
150/// providing user-friendly names (e.g., `"score"` instead of `"pagerank"`).
151fn output_columns_for(algo: &dyn GraphAlgorithm) -> Vec<String> {
152    match algo.name() {
153        "pagerank" => vec!["node_id".into(), "score".into()],
154        "betweenness_centrality" => vec!["node_id".into(), "centrality".into()],
155        "closeness_centrality" => vec!["node_id".into(), "centrality".into()],
156        "degree_centrality" => {
157            vec![
158                "node_id".into(),
159                "in_degree".into(),
160                "out_degree".into(),
161                "total_degree".into(),
162            ]
163        }
164        "bfs" => vec!["node_id".into(), "depth".into()],
165        "dfs" => vec!["node_id".into(), "depth".into()],
166        "connected_components" | "strongly_connected_components" => {
167            vec!["node_id".into(), "component_id".into()]
168        }
169        "topological_sort" => vec!["node_id".into(), "order".into()],
170        "dijkstra" => vec!["node_id".into(), "distance".into()],
171        "bellman_ford" => vec![
172            "node_id".into(),
173            "distance".into(),
174            "has_negative_cycle".into(),
175        ],
176        "floyd_warshall" => vec!["source".into(), "target".into(), "distance".into()],
177        "clustering_coefficient" => {
178            vec![
179                "node_id".into(),
180                "coefficient".into(),
181                "triangle_count".into(),
182            ]
183        }
184        "label_propagation" => vec!["node_id".into(), "community_id".into()],
185        "louvain" => vec!["node_id".into(), "community_id".into(), "modularity".into()],
186        "kruskal" | "prim" => vec!["source".into(), "target".into(), "weight".into()],
187        "max_flow" => {
188            vec![
189                "source".into(),
190                "target".into(),
191                "flow".into(),
192                "max_flow".into(),
193            ]
194        }
195        "min_cost_max_flow" => {
196            vec![
197                "source".into(),
198                "target".into(),
199                "flow".into(),
200                "cost".into(),
201                "max_flow".into(),
202            ]
203        }
204        "articulation_points" => vec!["node_id".into()],
205        "bridges" => vec!["source".into(), "target".into()],
206        "k_core" => vec!["node_id".into(), "core_number".into(), "max_core".into()],
207        _ => vec!["node_id".into(), "value".into()],
208    }
209}
210
211/// Converts logical expression arguments into algorithm [`Parameters`].
212///
213/// Supports two patterns:
214/// 1. **Map literal**: `{damping: 0.85, iterations: 20}` → named parameters
215/// 2. **Positional args**: `(42, 'weight')` → mapped by index to `ParameterDef` names
216pub fn evaluate_arguments(args: &[LogicalExpression], param_defs: &[ParameterDef]) -> Parameters {
217    let mut params = Parameters::new();
218
219    if args.len() == 1
220        && let LogicalExpression::Map(entries) = &args[0]
221    {
222        // Map literal: {damping: 0.85, iterations: 20}
223        for (key, value_expr) in entries {
224            set_param_from_expression(&mut params, key, value_expr);
225        }
226        return params;
227    }
228
229    // Positional arguments: map by index to parameter definitions
230    for (i, arg) in args.iter().enumerate() {
231        if let Some(def) = param_defs.get(i) {
232            set_param_from_expression(&mut params, &def.name, arg);
233        }
234    }
235
236    params
237}
238
239/// Sets a parameter from a `LogicalExpression` constant value.
240fn set_param_from_expression(params: &mut Parameters, name: &str, expr: &LogicalExpression) {
241    match expr {
242        LogicalExpression::Literal(Value::Int64(v)) => params.set_int(name, *v),
243        LogicalExpression::Literal(Value::Float64(v)) => params.set_float(name, *v),
244        LogicalExpression::Literal(Value::String(v)) => {
245            params.set_string(name, AsRef::<str>::as_ref(v));
246        }
247        LogicalExpression::Literal(Value::Bool(v)) => params.set_bool(name, *v),
248        _ => {} // Non-constant expressions are ignored in Phase 1
249    }
250}
251
252/// Builds a `grafeo.procedures()` result listing all registered procedures.
253pub fn procedures_result(registry: &BuiltinProcedures) -> AlgorithmResult {
254    let procedures = registry.list();
255    let mut result = AlgorithmResult::new(vec![
256        "name".into(),
257        "description".into(),
258        "parameters".into(),
259        "output_columns".into(),
260    ]);
261    for proc in procedures {
262        let param_desc: String = proc
263            .parameters
264            .iter()
265            .map(|p| {
266                if p.required {
267                    format!("{} ({:?})", p.name, p.param_type)
268                } else if let Some(ref default) = p.default {
269                    format!("{} ({:?}, default={})", p.name, p.param_type, default)
270                } else {
271                    format!("{} ({:?}, optional)", p.name, p.param_type)
272                }
273            })
274            .collect::<Vec<_>>()
275            .join(", ");
276
277        let columns_desc = proc.output_columns.join(", ");
278
279        result.add_row(vec![
280            Value::from(proc.name.as_str()),
281            Value::from(proc.description.as_str()),
282            Value::from(param_desc.as_str()),
283            Value::from(columns_desc.as_str()),
284        ]);
285    }
286    result
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_registry_has_all_algorithms() {
295        let registry = BuiltinProcedures::new();
296        let list = registry.list();
297        assert!(
298            list.len() >= 22,
299            "Expected at least 22 algorithms, got {}",
300            list.len()
301        );
302    }
303
304    #[test]
305    fn test_resolve_with_namespace() {
306        let registry = BuiltinProcedures::new();
307        let name = vec!["grafeo".to_string(), "pagerank".to_string()];
308        assert!(registry.get(&name).is_some());
309    }
310
311    #[test]
312    fn test_resolve_without_namespace() {
313        let registry = BuiltinProcedures::new();
314        let name = vec!["pagerank".to_string()];
315        assert!(registry.get(&name).is_some());
316    }
317
318    #[test]
319    fn test_resolve_unknown() {
320        let registry = BuiltinProcedures::new();
321        let name = vec!["grafeo".to_string(), "nonexistent".to_string()];
322        assert!(registry.get(&name).is_none());
323    }
324
325    #[test]
326    fn test_evaluate_map_arguments() {
327        let args = vec![LogicalExpression::Map(vec![
328            (
329                "damping".to_string(),
330                LogicalExpression::Literal(Value::Float64(0.85)),
331            ),
332            (
333                "max_iterations".to_string(),
334                LogicalExpression::Literal(Value::Int64(20)),
335            ),
336        ])];
337        let params = evaluate_arguments(&args, &[]);
338        assert_eq!(params.get_float("damping"), Some(0.85));
339        assert_eq!(params.get_int("max_iterations"), Some(20));
340    }
341
342    #[test]
343    fn test_evaluate_empty_arguments() {
344        let params = evaluate_arguments(&[], &[]);
345        assert_eq!(params.get_float("damping"), None);
346    }
347
348    #[test]
349    fn test_procedures_result() {
350        let registry = BuiltinProcedures::new();
351        let result = procedures_result(&registry);
352        assert_eq!(
353            result.columns,
354            vec!["name", "description", "parameters", "output_columns"]
355        );
356        assert!(result.rows.len() >= 22);
357    }
358}