1use std::sync::Arc;
8
9use grafeo_adapters::plugins::algorithms::{
10 ArticulationPointsAlgorithm, BellmanFordAlgorithm, BetweennessCentralityAlgorithm,
11 BfsAlgorithm, BridgesAlgorithm, ClosenessCentralityAlgorithm, ClusteringCoefficientAlgorithm,
12 ConnectedComponentsAlgorithm, DegreeCentralityAlgorithm, DfsAlgorithm, DijkstraAlgorithm,
13 FloydWarshallAlgorithm, GraphAlgorithm, KCoreAlgorithm, KruskalAlgorithm,
14 LabelPropagationAlgorithm, LouvainAlgorithm, MaxFlowAlgorithm, MinCostFlowAlgorithm,
15 PageRankAlgorithm, PrimAlgorithm, SsspAlgorithm, StronglyConnectedComponentsAlgorithm,
16 TopologicalSortAlgorithm,
17};
18use grafeo_adapters::plugins::{AlgorithmResult, ParameterDef, Parameters};
19use grafeo_common::types::Value;
20use hashbrown::HashMap;
21
22use crate::query::plan::LogicalExpression;
23
24pub struct BuiltinProcedures {
26 algorithms: HashMap<String, Arc<dyn GraphAlgorithm>>,
27}
28
29impl BuiltinProcedures {
30 pub fn new() -> Self {
32 let mut algorithms: HashMap<String, Arc<dyn GraphAlgorithm>> = HashMap::new();
33 let register = |map: &mut HashMap<String, Arc<dyn GraphAlgorithm>>,
34 algo: Arc<dyn GraphAlgorithm>| {
35 map.insert(algo.name().to_string(), algo);
36 };
37
38 register(&mut algorithms, Arc::new(PageRankAlgorithm));
40 register(&mut algorithms, Arc::new(BetweennessCentralityAlgorithm));
41 register(&mut algorithms, Arc::new(ClosenessCentralityAlgorithm));
42 register(&mut algorithms, Arc::new(DegreeCentralityAlgorithm));
43
44 register(&mut algorithms, Arc::new(BfsAlgorithm));
46 register(&mut algorithms, Arc::new(DfsAlgorithm));
47
48 register(&mut algorithms, Arc::new(ConnectedComponentsAlgorithm));
50 register(
51 &mut algorithms,
52 Arc::new(StronglyConnectedComponentsAlgorithm),
53 );
54 register(&mut algorithms, Arc::new(TopologicalSortAlgorithm));
55
56 register(&mut algorithms, Arc::new(DijkstraAlgorithm));
58 register(&mut algorithms, Arc::new(SsspAlgorithm));
59 register(&mut algorithms, Arc::new(BellmanFordAlgorithm));
60 register(&mut algorithms, Arc::new(FloydWarshallAlgorithm));
61
62 register(&mut algorithms, Arc::new(ClusteringCoefficientAlgorithm));
64
65 register(&mut algorithms, Arc::new(LabelPropagationAlgorithm));
67 register(&mut algorithms, Arc::new(LouvainAlgorithm));
68
69 register(&mut algorithms, Arc::new(KruskalAlgorithm));
71 register(&mut algorithms, Arc::new(PrimAlgorithm));
72
73 register(&mut algorithms, Arc::new(MaxFlowAlgorithm));
75 register(&mut algorithms, Arc::new(MinCostFlowAlgorithm));
76
77 register(&mut algorithms, Arc::new(ArticulationPointsAlgorithm));
79 register(&mut algorithms, Arc::new(BridgesAlgorithm));
80 register(&mut algorithms, Arc::new(KCoreAlgorithm));
81
82 Self { algorithms }
83 }
84
85 pub fn get(&self, name: &[String]) -> Option<Arc<dyn GraphAlgorithm>> {
91 let key = resolve_name(name);
92 self.algorithms.get(key).cloned()
93 }
94
95 pub fn list(&self) -> Vec<ProcedureInfo> {
97 let mut result: Vec<ProcedureInfo> = self
98 .algorithms
99 .values()
100 .map(|algo| ProcedureInfo {
101 name: format!("grafeo.{}", algo.name()),
102 description: algo.description().to_string(),
103 parameters: algo.parameters().to_vec(),
104 output_columns: output_columns_for(algo.as_ref()),
105 })
106 .collect();
107 result.sort_by(|a, b| a.name.cmp(&b.name));
108 result
109 }
110}
111
112impl Default for BuiltinProcedures {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118pub struct ProcedureInfo {
120 pub name: String,
122 pub description: String,
124 pub parameters: Vec<ParameterDef>,
126 pub output_columns: Vec<String>,
128}
129
130fn resolve_name(parts: &[String]) -> &str {
134 match parts {
135 [_, name] if parts[0].eq_ignore_ascii_case("grafeo") => name.as_str(),
136 [name] => name.as_str(),
137 _ => parts.last().map_or("", String::as_str),
138 }
139}
140
141pub fn output_columns_for_name(algo: &dyn GraphAlgorithm) -> Vec<String> {
145 output_columns_for(algo)
146}
147
148fn output_columns_for(algo: &dyn GraphAlgorithm) -> Vec<String> {
153 match algo.name() {
154 "pagerank" => vec!["node_id".into(), "score".into()],
155 "betweenness_centrality" => vec!["node_id".into(), "centrality".into()],
156 "closeness_centrality" => vec!["node_id".into(), "centrality".into()],
157 "degree_centrality" => {
158 vec![
159 "node_id".into(),
160 "in_degree".into(),
161 "out_degree".into(),
162 "total_degree".into(),
163 ]
164 }
165 "bfs" => vec!["node_id".into(), "depth".into()],
166 "dfs" => vec!["node_id".into(), "depth".into()],
167 "connected_components" | "strongly_connected_components" => {
168 vec!["node_id".into(), "component_id".into()]
169 }
170 "topological_sort" => vec!["node_id".into(), "order".into()],
171 "dijkstra" | "sssp" => vec!["node_id".into(), "distance".into()],
172 "bellman_ford" => vec![
173 "node_id".into(),
174 "distance".into(),
175 "has_negative_cycle".into(),
176 ],
177 "floyd_warshall" => vec!["source".into(), "target".into(), "distance".into()],
178 "clustering_coefficient" => {
179 vec![
180 "node_id".into(),
181 "coefficient".into(),
182 "triangle_count".into(),
183 ]
184 }
185 "label_propagation" => vec!["node_id".into(), "community_id".into()],
186 "louvain" => vec!["node_id".into(), "community_id".into(), "modularity".into()],
187 "kruskal" | "prim" => vec!["source".into(), "target".into(), "weight".into()],
188 "max_flow" => {
189 vec![
190 "source".into(),
191 "target".into(),
192 "flow".into(),
193 "max_flow".into(),
194 ]
195 }
196 "min_cost_max_flow" => {
197 vec![
198 "source".into(),
199 "target".into(),
200 "flow".into(),
201 "cost".into(),
202 "max_flow".into(),
203 ]
204 }
205 "articulation_points" => vec!["node_id".into()],
206 "bridges" => vec!["source".into(), "target".into()],
207 "k_core" => vec!["node_id".into(), "core_number".into(), "max_core".into()],
208 _ => vec!["node_id".into(), "value".into()],
209 }
210}
211
212pub fn evaluate_arguments(args: &[LogicalExpression], param_defs: &[ParameterDef]) -> Parameters {
218 let mut params = Parameters::new();
219
220 if args.len() == 1
221 && let LogicalExpression::Map(entries) = &args[0]
222 {
223 for (key, value_expr) in entries {
225 set_param_from_expression(&mut params, key, value_expr);
226 }
227 return params;
228 }
229
230 for (i, arg) in args.iter().enumerate() {
232 if let Some(def) = param_defs.get(i) {
233 set_param_from_expression(&mut params, &def.name, arg);
234 }
235 }
236
237 params
238}
239
240fn set_param_from_expression(params: &mut Parameters, name: &str, expr: &LogicalExpression) {
242 match expr {
243 LogicalExpression::Literal(Value::Int64(v)) => params.set_int(name, *v),
244 LogicalExpression::Literal(Value::Float64(v)) => params.set_float(name, *v),
245 LogicalExpression::Literal(Value::String(v)) => {
246 params.set_string(name, AsRef::<str>::as_ref(v));
247 }
248 LogicalExpression::Literal(Value::Bool(v)) => params.set_bool(name, *v),
249 _ => {} }
251}
252
253pub fn procedures_result(registry: &BuiltinProcedures) -> AlgorithmResult {
255 let procedures = registry.list();
256 let mut result = AlgorithmResult::new(vec![
257 "name".into(),
258 "description".into(),
259 "parameters".into(),
260 "output_columns".into(),
261 ]);
262 for proc in procedures {
263 let param_desc: String = proc
264 .parameters
265 .iter()
266 .map(|p| {
267 if p.required {
268 format!("{} ({:?})", p.name, p.param_type)
269 } else if let Some(ref default) = p.default {
270 format!("{} ({:?}, default={})", p.name, p.param_type, default)
271 } else {
272 format!("{} ({:?}, optional)", p.name, p.param_type)
273 }
274 })
275 .collect::<Vec<_>>()
276 .join(", ");
277
278 let columns_desc = proc.output_columns.join(", ");
279
280 result.add_row(vec![
281 Value::from(proc.name.as_str()),
282 Value::from(proc.description.as_str()),
283 Value::from(param_desc.as_str()),
284 Value::from(columns_desc.as_str()),
285 ]);
286 }
287 result
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_registry_has_all_algorithms() {
296 let registry = BuiltinProcedures::new();
297 let list = registry.list();
298 assert!(
299 list.len() >= 22,
300 "Expected at least 22 algorithms, got {}",
301 list.len()
302 );
303 }
304
305 #[test]
306 fn test_resolve_with_namespace() {
307 let registry = BuiltinProcedures::new();
308 let name = vec!["grafeo".to_string(), "pagerank".to_string()];
309 assert!(registry.get(&name).is_some());
310 }
311
312 #[test]
313 fn test_resolve_without_namespace() {
314 let registry = BuiltinProcedures::new();
315 let name = vec!["pagerank".to_string()];
316 assert!(registry.get(&name).is_some());
317 }
318
319 #[test]
320 fn test_resolve_unknown() {
321 let registry = BuiltinProcedures::new();
322 let name = vec!["grafeo".to_string(), "nonexistent".to_string()];
323 assert!(registry.get(&name).is_none());
324 }
325
326 #[test]
327 fn test_evaluate_map_arguments() {
328 let args = vec![LogicalExpression::Map(vec![
329 (
330 "damping".to_string(),
331 LogicalExpression::Literal(Value::Float64(0.85)),
332 ),
333 (
334 "max_iterations".to_string(),
335 LogicalExpression::Literal(Value::Int64(20)),
336 ),
337 ])];
338 let params = evaluate_arguments(&args, &[]);
339 assert_eq!(params.get_float("damping"), Some(0.85));
340 assert_eq!(params.get_int("max_iterations"), Some(20));
341 }
342
343 #[test]
344 fn test_evaluate_empty_arguments() {
345 let params = evaluate_arguments(&[], &[]);
346 assert_eq!(params.get_float("damping"), None);
347 }
348
349 #[test]
350 fn test_procedures_result() {
351 let registry = BuiltinProcedures::new();
352 let result = procedures_result(®istry);
353 assert_eq!(
354 result.columns,
355 vec!["name", "description", "parameters", "output_columns"]
356 );
357 assert!(result.rows.len() >= 22);
358 }
359}