1use std::sync::Arc;
8
9use grafeo_adapters::plugins::algorithms::{
10 ArticulationPointsAlgorithm, BellmanFordAlgorithm, BetweennessCentralityAlgorithm,
11 BfsAlgorithm, BridgesAlgorithm, ClosenessCentralityAlgorithm, ClusteringCoefficientAlgorithm,
12 ConnectedComponentsAlgorithm, DegreeCentralityAlgorithm, DfsAlgorithm, DijkstraAlgorithm,
13 FloydWarshallAlgorithm, GraphAlgorithm, KCoreAlgorithm, KruskalAlgorithm,
14 LabelPropagationAlgorithm, LouvainAlgorithm, MaxFlowAlgorithm, MinCostFlowAlgorithm,
15 PageRankAlgorithm, PrimAlgorithm, StronglyConnectedComponentsAlgorithm,
16 TopologicalSortAlgorithm,
17};
18use grafeo_adapters::plugins::{AlgorithmResult, ParameterDef, Parameters};
19use grafeo_common::types::Value;
20use hashbrown::HashMap;
21
22use crate::query::plan::LogicalExpression;
23
24pub struct BuiltinProcedures {
26 algorithms: HashMap<String, Arc<dyn GraphAlgorithm>>,
27}
28
29impl BuiltinProcedures {
30 pub fn new() -> Self {
32 let mut algorithms: HashMap<String, Arc<dyn GraphAlgorithm>> = HashMap::new();
33 let register = |map: &mut HashMap<String, Arc<dyn GraphAlgorithm>>,
34 algo: Arc<dyn GraphAlgorithm>| {
35 map.insert(algo.name().to_string(), algo);
36 };
37
38 register(&mut algorithms, Arc::new(PageRankAlgorithm));
40 register(&mut algorithms, Arc::new(BetweennessCentralityAlgorithm));
41 register(&mut algorithms, Arc::new(ClosenessCentralityAlgorithm));
42 register(&mut algorithms, Arc::new(DegreeCentralityAlgorithm));
43
44 register(&mut algorithms, Arc::new(BfsAlgorithm));
46 register(&mut algorithms, Arc::new(DfsAlgorithm));
47
48 register(&mut algorithms, Arc::new(ConnectedComponentsAlgorithm));
50 register(
51 &mut algorithms,
52 Arc::new(StronglyConnectedComponentsAlgorithm),
53 );
54 register(&mut algorithms, Arc::new(TopologicalSortAlgorithm));
55
56 register(&mut algorithms, Arc::new(DijkstraAlgorithm));
58 register(&mut algorithms, Arc::new(BellmanFordAlgorithm));
59 register(&mut algorithms, Arc::new(FloydWarshallAlgorithm));
60
61 register(&mut algorithms, Arc::new(ClusteringCoefficientAlgorithm));
63
64 register(&mut algorithms, Arc::new(LabelPropagationAlgorithm));
66 register(&mut algorithms, Arc::new(LouvainAlgorithm));
67
68 register(&mut algorithms, Arc::new(KruskalAlgorithm));
70 register(&mut algorithms, Arc::new(PrimAlgorithm));
71
72 register(&mut algorithms, Arc::new(MaxFlowAlgorithm));
74 register(&mut algorithms, Arc::new(MinCostFlowAlgorithm));
75
76 register(&mut algorithms, Arc::new(ArticulationPointsAlgorithm));
78 register(&mut algorithms, Arc::new(BridgesAlgorithm));
79 register(&mut algorithms, Arc::new(KCoreAlgorithm));
80
81 Self { algorithms }
82 }
83
84 pub fn get(&self, name: &[String]) -> Option<Arc<dyn GraphAlgorithm>> {
90 let key = resolve_name(name);
91 self.algorithms.get(key).cloned()
92 }
93
94 pub fn list(&self) -> Vec<ProcedureInfo> {
96 let mut result: Vec<ProcedureInfo> = self
97 .algorithms
98 .values()
99 .map(|algo| ProcedureInfo {
100 name: format!("grafeo.{}", algo.name()),
101 description: algo.description().to_string(),
102 parameters: algo.parameters().to_vec(),
103 output_columns: output_columns_for(algo.as_ref()),
104 })
105 .collect();
106 result.sort_by(|a, b| a.name.cmp(&b.name));
107 result
108 }
109}
110
111impl Default for BuiltinProcedures {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117pub struct ProcedureInfo {
119 pub name: String,
121 pub description: String,
123 pub parameters: Vec<ParameterDef>,
125 pub output_columns: Vec<String>,
127}
128
129fn resolve_name(parts: &[String]) -> &str {
133 match parts {
134 [_, name] if parts[0].eq_ignore_ascii_case("grafeo") => name.as_str(),
135 [name] => name.as_str(),
136 _ => parts.last().map_or("", String::as_str),
137 }
138}
139
140pub fn output_columns_for_name(algo: &dyn GraphAlgorithm) -> Vec<String> {
144 output_columns_for(algo)
145}
146
147fn output_columns_for(algo: &dyn GraphAlgorithm) -> Vec<String> {
152 match algo.name() {
153 "pagerank" => vec!["node_id".into(), "score".into()],
154 "betweenness_centrality" => vec!["node_id".into(), "centrality".into()],
155 "closeness_centrality" => vec!["node_id".into(), "centrality".into()],
156 "degree_centrality" => {
157 vec![
158 "node_id".into(),
159 "in_degree".into(),
160 "out_degree".into(),
161 "total_degree".into(),
162 ]
163 }
164 "bfs" => vec!["node_id".into(), "depth".into()],
165 "dfs" => vec!["node_id".into(), "depth".into()],
166 "connected_components" | "strongly_connected_components" => {
167 vec!["node_id".into(), "component_id".into()]
168 }
169 "topological_sort" => vec!["node_id".into(), "order".into()],
170 "dijkstra" => vec!["node_id".into(), "distance".into()],
171 "bellman_ford" => vec![
172 "node_id".into(),
173 "distance".into(),
174 "has_negative_cycle".into(),
175 ],
176 "floyd_warshall" => vec!["source".into(), "target".into(), "distance".into()],
177 "clustering_coefficient" => {
178 vec![
179 "node_id".into(),
180 "coefficient".into(),
181 "triangle_count".into(),
182 ]
183 }
184 "label_propagation" => vec!["node_id".into(), "community_id".into()],
185 "louvain" => vec!["node_id".into(), "community_id".into(), "modularity".into()],
186 "kruskal" | "prim" => vec!["source".into(), "target".into(), "weight".into()],
187 "max_flow" => {
188 vec![
189 "source".into(),
190 "target".into(),
191 "flow".into(),
192 "max_flow".into(),
193 ]
194 }
195 "min_cost_max_flow" => {
196 vec![
197 "source".into(),
198 "target".into(),
199 "flow".into(),
200 "cost".into(),
201 "max_flow".into(),
202 ]
203 }
204 "articulation_points" => vec!["node_id".into()],
205 "bridges" => vec!["source".into(), "target".into()],
206 "k_core" => vec!["node_id".into(), "core_number".into(), "max_core".into()],
207 _ => vec!["node_id".into(), "value".into()],
208 }
209}
210
211pub fn evaluate_arguments(args: &[LogicalExpression], param_defs: &[ParameterDef]) -> Parameters {
217 let mut params = Parameters::new();
218
219 if args.len() == 1
220 && let LogicalExpression::Map(entries) = &args[0]
221 {
222 for (key, value_expr) in entries {
224 set_param_from_expression(&mut params, key, value_expr);
225 }
226 return params;
227 }
228
229 for (i, arg) in args.iter().enumerate() {
231 if let Some(def) = param_defs.get(i) {
232 set_param_from_expression(&mut params, &def.name, arg);
233 }
234 }
235
236 params
237}
238
239fn set_param_from_expression(params: &mut Parameters, name: &str, expr: &LogicalExpression) {
241 match expr {
242 LogicalExpression::Literal(Value::Int64(v)) => params.set_int(name, *v),
243 LogicalExpression::Literal(Value::Float64(v)) => params.set_float(name, *v),
244 LogicalExpression::Literal(Value::String(v)) => {
245 params.set_string(name, AsRef::<str>::as_ref(v));
246 }
247 LogicalExpression::Literal(Value::Bool(v)) => params.set_bool(name, *v),
248 _ => {} }
250}
251
252pub fn procedures_result(registry: &BuiltinProcedures) -> AlgorithmResult {
254 let procedures = registry.list();
255 let mut result = AlgorithmResult::new(vec![
256 "name".into(),
257 "description".into(),
258 "parameters".into(),
259 "output_columns".into(),
260 ]);
261 for proc in procedures {
262 let param_desc: String = proc
263 .parameters
264 .iter()
265 .map(|p| {
266 if p.required {
267 format!("{} ({:?})", p.name, p.param_type)
268 } else if let Some(ref default) = p.default {
269 format!("{} ({:?}, default={})", p.name, p.param_type, default)
270 } else {
271 format!("{} ({:?}, optional)", p.name, p.param_type)
272 }
273 })
274 .collect::<Vec<_>>()
275 .join(", ");
276
277 let columns_desc = proc.output_columns.join(", ");
278
279 result.add_row(vec![
280 Value::from(proc.name.as_str()),
281 Value::from(proc.description.as_str()),
282 Value::from(param_desc.as_str()),
283 Value::from(columns_desc.as_str()),
284 ]);
285 }
286 result
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_registry_has_all_algorithms() {
295 let registry = BuiltinProcedures::new();
296 let list = registry.list();
297 assert!(
298 list.len() >= 22,
299 "Expected at least 22 algorithms, got {}",
300 list.len()
301 );
302 }
303
304 #[test]
305 fn test_resolve_with_namespace() {
306 let registry = BuiltinProcedures::new();
307 let name = vec!["grafeo".to_string(), "pagerank".to_string()];
308 assert!(registry.get(&name).is_some());
309 }
310
311 #[test]
312 fn test_resolve_without_namespace() {
313 let registry = BuiltinProcedures::new();
314 let name = vec!["pagerank".to_string()];
315 assert!(registry.get(&name).is_some());
316 }
317
318 #[test]
319 fn test_resolve_unknown() {
320 let registry = BuiltinProcedures::new();
321 let name = vec!["grafeo".to_string(), "nonexistent".to_string()];
322 assert!(registry.get(&name).is_none());
323 }
324
325 #[test]
326 fn test_evaluate_map_arguments() {
327 let args = vec![LogicalExpression::Map(vec![
328 (
329 "damping".to_string(),
330 LogicalExpression::Literal(Value::Float64(0.85)),
331 ),
332 (
333 "max_iterations".to_string(),
334 LogicalExpression::Literal(Value::Int64(20)),
335 ),
336 ])];
337 let params = evaluate_arguments(&args, &[]);
338 assert_eq!(params.get_float("damping"), Some(0.85));
339 assert_eq!(params.get_int("max_iterations"), Some(20));
340 }
341
342 #[test]
343 fn test_evaluate_empty_arguments() {
344 let params = evaluate_arguments(&[], &[]);
345 assert_eq!(params.get_float("damping"), None);
346 }
347
348 #[test]
349 fn test_procedures_result() {
350 let registry = BuiltinProcedures::new();
351 let result = procedures_result(®istry);
352 assert_eq!(
353 result.columns,
354 vec!["name", "description", "parameters", "output_columns"]
355 );
356 assert!(result.rows.len() >= 22);
357 }
358}