Skip to main content

grafeo_engine/
procedures.rs

1//! Built-in procedure registry for CALL statement execution.
2//!
3//! Maps procedure names to [`GraphAlgorithm`] implementations, enabling
4//! `CALL grafeo.pagerank({damping: 0.85}) YIELD nodeId, score` from any
5//! supported query language (GQL, Cypher, SQL/PGQ).
6
7use std::sync::Arc;
8
9use grafeo_adapters::plugins::algorithms::{
10    ArticulationPointsAlgorithm, BellmanFordAlgorithm, BetweennessCentralityAlgorithm,
11    BfsAlgorithm, BridgesAlgorithm, ClosenessCentralityAlgorithm, ClusteringCoefficientAlgorithm,
12    ConnectedComponentsAlgorithm, DegreeCentralityAlgorithm, DfsAlgorithm, DijkstraAlgorithm,
13    FloydWarshallAlgorithm, GraphAlgorithm, KCoreAlgorithm, KruskalAlgorithm,
14    LabelPropagationAlgorithm, LouvainAlgorithm, MaxFlowAlgorithm, MinCostFlowAlgorithm,
15    PageRankAlgorithm, PrimAlgorithm, SsspAlgorithm, StronglyConnectedComponentsAlgorithm,
16    TopologicalSortAlgorithm,
17};
18use grafeo_adapters::plugins::{AlgorithmResult, ParameterDef, Parameters};
19use grafeo_common::types::Value;
20use hashbrown::HashMap;
21
22use crate::query::plan::LogicalExpression;
23
24/// Registry of built-in procedures backed by graph algorithms.
25pub struct BuiltinProcedures {
26    algorithms: HashMap<String, Arc<dyn GraphAlgorithm>>,
27}
28
29impl BuiltinProcedures {
30    /// Creates a new registry with all built-in algorithms registered.
31    pub fn new() -> Self {
32        let mut algorithms: HashMap<String, Arc<dyn GraphAlgorithm>> = HashMap::new();
33        let register = |map: &mut HashMap<String, Arc<dyn GraphAlgorithm>>,
34                        algo: Arc<dyn GraphAlgorithm>| {
35            map.insert(algo.name().to_string(), algo);
36        };
37
38        // Centrality
39        register(&mut algorithms, Arc::new(PageRankAlgorithm));
40        register(&mut algorithms, Arc::new(BetweennessCentralityAlgorithm));
41        register(&mut algorithms, Arc::new(ClosenessCentralityAlgorithm));
42        register(&mut algorithms, Arc::new(DegreeCentralityAlgorithm));
43
44        // Traversal
45        register(&mut algorithms, Arc::new(BfsAlgorithm));
46        register(&mut algorithms, Arc::new(DfsAlgorithm));
47
48        // Components
49        register(&mut algorithms, Arc::new(ConnectedComponentsAlgorithm));
50        register(
51            &mut algorithms,
52            Arc::new(StronglyConnectedComponentsAlgorithm),
53        );
54        register(&mut algorithms, Arc::new(TopologicalSortAlgorithm));
55
56        // Shortest Path
57        register(&mut algorithms, Arc::new(DijkstraAlgorithm));
58        register(&mut algorithms, Arc::new(SsspAlgorithm));
59        register(&mut algorithms, Arc::new(BellmanFordAlgorithm));
60        register(&mut algorithms, Arc::new(FloydWarshallAlgorithm));
61
62        // Clustering
63        register(&mut algorithms, Arc::new(ClusteringCoefficientAlgorithm));
64
65        // Community
66        register(&mut algorithms, Arc::new(LabelPropagationAlgorithm));
67        register(&mut algorithms, Arc::new(LouvainAlgorithm));
68
69        // MST
70        register(&mut algorithms, Arc::new(KruskalAlgorithm));
71        register(&mut algorithms, Arc::new(PrimAlgorithm));
72
73        // Flow
74        register(&mut algorithms, Arc::new(MaxFlowAlgorithm));
75        register(&mut algorithms, Arc::new(MinCostFlowAlgorithm));
76
77        // Structure
78        register(&mut algorithms, Arc::new(ArticulationPointsAlgorithm));
79        register(&mut algorithms, Arc::new(BridgesAlgorithm));
80        register(&mut algorithms, Arc::new(KCoreAlgorithm));
81
82        Self { algorithms }
83    }
84
85    /// Resolves a procedure name to its algorithm.
86    ///
87    /// Strips `"grafeo."` prefix if present:
88    /// - `["grafeo", "pagerank"]` → looks up `"pagerank"`
89    /// - `["pagerank"]` → looks up `"pagerank"`
90    pub fn get(&self, name: &[String]) -> Option<Arc<dyn GraphAlgorithm>> {
91        let key = resolve_name(name);
92        self.algorithms.get(key).cloned()
93    }
94
95    /// Returns info for all registered procedures.
96    pub fn list(&self) -> Vec<ProcedureInfo> {
97        let mut result: Vec<ProcedureInfo> = self
98            .algorithms
99            .values()
100            .map(|algo| ProcedureInfo {
101                name: format!("grafeo.{}", algo.name()),
102                description: algo.description().to_string(),
103                parameters: algo.parameters().to_vec(),
104                output_columns: output_columns_for(algo.as_ref()),
105            })
106            .collect();
107        result.sort_by(|a, b| a.name.cmp(&b.name));
108        result
109    }
110}
111
112impl Default for BuiltinProcedures {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118/// Metadata about a registered procedure.
119pub struct ProcedureInfo {
120    /// Qualified name (e.g., `"grafeo.pagerank"`).
121    pub name: String,
122    /// Description of what the procedure does.
123    pub description: String,
124    /// Parameter definitions.
125    pub parameters: Vec<ParameterDef>,
126    /// Output column names.
127    pub output_columns: Vec<String>,
128}
129
130/// Resolves a dotted procedure name to its lookup key.
131///
132/// Strips the `"grafeo"` namespace prefix if present.
133fn resolve_name(parts: &[String]) -> &str {
134    match parts {
135        [_, name] if parts[0].eq_ignore_ascii_case("grafeo") => name.as_str(),
136        [name] => name.as_str(),
137        _ => parts.last().map_or("", String::as_str),
138    }
139}
140
141/// Infers output column names from an algorithm.
142///
143/// Returns the standard column names for known algorithm categories.
144pub fn output_columns_for_name(algo: &dyn GraphAlgorithm) -> Vec<String> {
145    output_columns_for(algo)
146}
147
148/// Canonical output column names for each algorithm.
149///
150/// These must match the actual column count from each algorithm's `execute()`,
151/// providing user-friendly names (e.g., `"score"` instead of `"pagerank"`).
152fn output_columns_for(algo: &dyn GraphAlgorithm) -> Vec<String> {
153    match algo.name() {
154        "pagerank" => vec!["node_id".into(), "score".into()],
155        "betweenness_centrality" => vec!["node_id".into(), "centrality".into()],
156        "closeness_centrality" => vec!["node_id".into(), "centrality".into()],
157        "degree_centrality" => {
158            vec![
159                "node_id".into(),
160                "in_degree".into(),
161                "out_degree".into(),
162                "total_degree".into(),
163            ]
164        }
165        "bfs" => vec!["node_id".into(), "depth".into()],
166        "dfs" => vec!["node_id".into(), "depth".into()],
167        "connected_components" | "strongly_connected_components" => {
168            vec!["node_id".into(), "component_id".into()]
169        }
170        "topological_sort" => vec!["node_id".into(), "order".into()],
171        "dijkstra" | "sssp" => vec!["node_id".into(), "distance".into()],
172        "bellman_ford" => vec![
173            "node_id".into(),
174            "distance".into(),
175            "has_negative_cycle".into(),
176        ],
177        "floyd_warshall" => vec!["source".into(), "target".into(), "distance".into()],
178        "clustering_coefficient" => {
179            vec![
180                "node_id".into(),
181                "coefficient".into(),
182                "triangle_count".into(),
183            ]
184        }
185        "label_propagation" => vec!["node_id".into(), "community_id".into()],
186        "louvain" => vec!["node_id".into(), "community_id".into(), "modularity".into()],
187        "kruskal" | "prim" => vec!["source".into(), "target".into(), "weight".into()],
188        "max_flow" => {
189            vec![
190                "source".into(),
191                "target".into(),
192                "flow".into(),
193                "max_flow".into(),
194            ]
195        }
196        "min_cost_max_flow" => {
197            vec![
198                "source".into(),
199                "target".into(),
200                "flow".into(),
201                "cost".into(),
202                "max_flow".into(),
203            ]
204        }
205        "articulation_points" => vec!["node_id".into()],
206        "bridges" => vec!["source".into(), "target".into()],
207        "k_core" => vec!["node_id".into(), "core_number".into(), "max_core".into()],
208        _ => vec!["node_id".into(), "value".into()],
209    }
210}
211
212/// Converts logical expression arguments into algorithm [`Parameters`].
213///
214/// Supports two patterns:
215/// 1. **Map literal**: `{damping: 0.85, iterations: 20}` → named parameters
216/// 2. **Positional args**: `(42, 'weight')` → mapped by index to `ParameterDef` names
217pub fn evaluate_arguments(args: &[LogicalExpression], param_defs: &[ParameterDef]) -> Parameters {
218    let mut params = Parameters::new();
219
220    if args.len() == 1
221        && let LogicalExpression::Map(entries) = &args[0]
222    {
223        // Map literal: {damping: 0.85, iterations: 20}
224        for (key, value_expr) in entries {
225            set_param_from_expression(&mut params, key, value_expr);
226        }
227        return params;
228    }
229
230    // Positional arguments: map by index to parameter definitions
231    for (i, arg) in args.iter().enumerate() {
232        if let Some(def) = param_defs.get(i) {
233            set_param_from_expression(&mut params, &def.name, arg);
234        }
235    }
236
237    params
238}
239
240/// Sets a parameter from a `LogicalExpression` constant value.
241fn set_param_from_expression(params: &mut Parameters, name: &str, expr: &LogicalExpression) {
242    match expr {
243        LogicalExpression::Literal(Value::Int64(v)) => params.set_int(name, *v),
244        LogicalExpression::Literal(Value::Float64(v)) => params.set_float(name, *v),
245        LogicalExpression::Literal(Value::String(v)) => {
246            params.set_string(name, AsRef::<str>::as_ref(v));
247        }
248        LogicalExpression::Literal(Value::Bool(v)) => params.set_bool(name, *v),
249        _ => {} // Non-constant expressions are ignored in Phase 1
250    }
251}
252
253/// Builds a `grafeo.procedures()` result listing all registered procedures.
254pub fn procedures_result(registry: &BuiltinProcedures) -> AlgorithmResult {
255    let procedures = registry.list();
256    let mut result = AlgorithmResult::new(vec![
257        "name".into(),
258        "description".into(),
259        "parameters".into(),
260        "output_columns".into(),
261    ]);
262    for proc in procedures {
263        let param_desc: String = proc
264            .parameters
265            .iter()
266            .map(|p| {
267                if p.required {
268                    format!("{} ({:?})", p.name, p.param_type)
269                } else if let Some(ref default) = p.default {
270                    format!("{} ({:?}, default={})", p.name, p.param_type, default)
271                } else {
272                    format!("{} ({:?}, optional)", p.name, p.param_type)
273                }
274            })
275            .collect::<Vec<_>>()
276            .join(", ");
277
278        let columns_desc = proc.output_columns.join(", ");
279
280        result.add_row(vec![
281            Value::from(proc.name.as_str()),
282            Value::from(proc.description.as_str()),
283            Value::from(param_desc.as_str()),
284            Value::from(columns_desc.as_str()),
285        ]);
286    }
287    result
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_registry_has_all_algorithms() {
296        let registry = BuiltinProcedures::new();
297        let list = registry.list();
298        assert!(
299            list.len() >= 22,
300            "Expected at least 22 algorithms, got {}",
301            list.len()
302        );
303    }
304
305    #[test]
306    fn test_resolve_with_namespace() {
307        let registry = BuiltinProcedures::new();
308        let name = vec!["grafeo".to_string(), "pagerank".to_string()];
309        assert!(registry.get(&name).is_some());
310    }
311
312    #[test]
313    fn test_resolve_without_namespace() {
314        let registry = BuiltinProcedures::new();
315        let name = vec!["pagerank".to_string()];
316        assert!(registry.get(&name).is_some());
317    }
318
319    #[test]
320    fn test_resolve_unknown() {
321        let registry = BuiltinProcedures::new();
322        let name = vec!["grafeo".to_string(), "nonexistent".to_string()];
323        assert!(registry.get(&name).is_none());
324    }
325
326    #[test]
327    fn test_evaluate_map_arguments() {
328        let args = vec![LogicalExpression::Map(vec![
329            (
330                "damping".to_string(),
331                LogicalExpression::Literal(Value::Float64(0.85)),
332            ),
333            (
334                "max_iterations".to_string(),
335                LogicalExpression::Literal(Value::Int64(20)),
336            ),
337        ])];
338        let params = evaluate_arguments(&args, &[]);
339        assert_eq!(params.get_float("damping"), Some(0.85));
340        assert_eq!(params.get_int("max_iterations"), Some(20));
341    }
342
343    #[test]
344    fn test_evaluate_empty_arguments() {
345        let params = evaluate_arguments(&[], &[]);
346        assert_eq!(params.get_float("damping"), None);
347    }
348
349    #[test]
350    fn test_procedures_result() {
351        let registry = BuiltinProcedures::new();
352        let result = procedures_result(&registry);
353        assert_eq!(
354            result.columns,
355            vec!["name", "description", "parameters", "output_columns"]
356        );
357        assert!(result.rows.len() >= 22);
358    }
359}