1use std::collections::BinaryHeap;
7use std::sync::OnceLock;
8
9use grafeo_common::types::{EdgeId, NodeId, Value};
10use grafeo_common::utils::error::Result;
11use grafeo_common::utils::hash::FxHashMap;
12use grafeo_core::graph::Direction;
13use grafeo_core::graph::lpg::LpgStore;
14
15use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
16use super::components::UnionFind;
17use super::traits::{GraphAlgorithm, MinScored};
18
19fn extract_weight(store: &LpgStore, edge_id: EdgeId, weight_prop: Option<&str>) -> f64 {
25 if let Some(prop_name) = weight_prop
26 && let Some(edge) = store.get_edge(edge_id)
27 && let Some(value) = edge.get_property(prop_name)
28 {
29 return match value {
30 Value::Int64(i) => *i as f64,
31 Value::Float64(f) => *f,
32 _ => 1.0,
33 };
34 }
35 1.0
36}
37
38#[derive(Debug, Clone)]
44pub struct MstResult {
45 pub edges: Vec<(NodeId, NodeId, EdgeId, f64)>,
47 pub total_weight: f64,
49}
50
51impl MstResult {
52 pub fn edge_count(&self) -> usize {
54 self.edges.len()
55 }
56
57 pub fn is_spanning_tree(&self, node_count: usize) -> bool {
59 if node_count == 0 {
60 return self.edges.is_empty();
61 }
62 self.edges.len() == node_count - 1
63 }
64}
65
66pub fn kruskal(store: &LpgStore, weight_property: Option<&str>) -> MstResult {
88 let nodes = store.node_ids();
89 let n = nodes.len();
90
91 if n == 0 {
92 return MstResult {
93 edges: Vec::new(),
94 total_weight: 0.0,
95 };
96 }
97
98 let mut node_to_idx: FxHashMap<NodeId, usize> = FxHashMap::default();
100 for (idx, &node) in nodes.iter().enumerate() {
101 node_to_idx.insert(node, idx);
102 }
103
104 let mut edges: Vec<(f64, NodeId, NodeId, EdgeId)> = Vec::new();
106 let mut seen_edges: std::collections::HashSet<(usize, usize)> =
107 std::collections::HashSet::new();
108
109 for &node in &nodes {
110 let i = *node_to_idx.get(&node).unwrap();
111 for (neighbor, edge_id) in store.edges_from(node, Direction::Outgoing) {
112 if let Some(&j) = node_to_idx.get(&neighbor) {
113 let key = if i < j { (i, j) } else { (j, i) };
115 if !seen_edges.contains(&key) {
116 seen_edges.insert(key);
117 let weight = extract_weight(store, edge_id, weight_property);
118 edges.push((weight, node, neighbor, edge_id));
119 }
120 }
121 }
122 }
123
124 edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
126
127 let mut uf = UnionFind::new(n);
129
130 let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
131 let mut total_weight = 0.0;
132
133 for (weight, src, dst, edge_id) in edges {
134 let i = *node_to_idx.get(&src).unwrap();
135 let j = *node_to_idx.get(&dst).unwrap();
136
137 if uf.find(i) != uf.find(j) {
138 uf.union(i, j);
139 mst_edges.push((src, dst, edge_id, weight));
140 total_weight += weight;
141
142 if mst_edges.len() == n - 1 {
144 break;
145 }
146 }
147 }
148
149 MstResult {
150 edges: mst_edges,
151 total_weight,
152 }
153}
154
155pub fn prim(store: &LpgStore, weight_property: Option<&str>, start: Option<NodeId>) -> MstResult {
178 let nodes = store.node_ids();
179 let n = nodes.len();
180
181 if n == 0 {
182 return MstResult {
183 edges: Vec::new(),
184 total_weight: 0.0,
185 };
186 }
187
188 let start_node = start.unwrap_or(nodes[0]);
190
191 if store.get_node(start_node).is_none() {
193 return MstResult {
194 edges: Vec::new(),
195 total_weight: 0.0,
196 };
197 }
198
199 let mut in_tree: FxHashMap<NodeId, bool> = FxHashMap::default();
200 let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
201 let mut total_weight = 0.0;
202
203 let mut heap: BinaryHeap<MinScored<f64, (NodeId, NodeId, EdgeId)>> = BinaryHeap::new();
205
206 in_tree.insert(start_node, true);
208
209 for (neighbor, edge_id) in store.edges_from(start_node, Direction::Outgoing) {
211 let weight = extract_weight(store, edge_id, weight_property);
212 heap.push(MinScored::new(weight, (start_node, neighbor, edge_id)));
213 }
214
215 for &other in &nodes {
217 for (neighbor, edge_id) in store.edges_from(other, Direction::Outgoing) {
218 if neighbor == start_node {
219 let weight = extract_weight(store, edge_id, weight_property);
220 heap.push(MinScored::new(weight, (other, start_node, edge_id)));
221 }
222 }
223 }
224
225 while let Some(MinScored(weight, (src, dst, edge_id))) = heap.pop() {
226 if *in_tree.get(&dst).unwrap_or(&false) {
228 continue;
229 }
230
231 in_tree.insert(dst, true);
233 mst_edges.push((src, dst, edge_id, weight));
234 total_weight += weight;
235
236 for (neighbor, new_edge_id) in store.edges_from(dst, Direction::Outgoing) {
238 if !*in_tree.get(&neighbor).unwrap_or(&false) {
239 let new_weight = extract_weight(store, new_edge_id, weight_property);
240 heap.push(MinScored::new(new_weight, (dst, neighbor, new_edge_id)));
241 }
242 }
243
244 for &other in &nodes {
246 if !*in_tree.get(&other).unwrap_or(&false) {
247 for (neighbor, new_edge_id) in store.edges_from(other, Direction::Outgoing) {
248 if neighbor == dst {
249 let new_weight = extract_weight(store, new_edge_id, weight_property);
250 heap.push(MinScored::new(new_weight, (other, dst, new_edge_id)));
251 }
252 }
253 }
254 }
255
256 if mst_edges.len() == n - 1 {
258 break;
259 }
260 }
261
262 MstResult {
263 edges: mst_edges,
264 total_weight,
265 }
266}
267
268static KRUSKAL_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
274
275fn kruskal_params() -> &'static [ParameterDef] {
276 KRUSKAL_PARAMS.get_or_init(|| {
277 vec![ParameterDef {
278 name: "weight".to_string(),
279 description: "Edge property name for weights (default: 1.0)".to_string(),
280 param_type: ParameterType::String,
281 required: false,
282 default: None,
283 }]
284 })
285}
286
287pub struct KruskalAlgorithm;
289
290impl GraphAlgorithm for KruskalAlgorithm {
291 fn name(&self) -> &str {
292 "kruskal"
293 }
294
295 fn description(&self) -> &str {
296 "Kruskal's Minimum Spanning Tree algorithm"
297 }
298
299 fn parameters(&self) -> &[ParameterDef] {
300 kruskal_params()
301 }
302
303 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
304 let weight_prop = params.get_string("weight");
305
306 let result = kruskal(store, weight_prop);
307
308 let mut output = AlgorithmResult::new(vec![
309 "source".to_string(),
310 "target".to_string(),
311 "weight".to_string(),
312 "total_weight".to_string(),
313 ]);
314
315 for (src, dst, _edge_id, weight) in result.edges {
316 output.add_row(vec![
317 Value::Int64(src.0 as i64),
318 Value::Int64(dst.0 as i64),
319 Value::Float64(weight),
320 Value::Float64(result.total_weight),
321 ]);
322 }
323
324 Ok(output)
325 }
326}
327
328static PRIM_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
330
331fn prim_params() -> &'static [ParameterDef] {
332 PRIM_PARAMS.get_or_init(|| {
333 vec![
334 ParameterDef {
335 name: "weight".to_string(),
336 description: "Edge property name for weights (default: 1.0)".to_string(),
337 param_type: ParameterType::String,
338 required: false,
339 default: None,
340 },
341 ParameterDef {
342 name: "start".to_string(),
343 description: "Starting node ID (optional)".to_string(),
344 param_type: ParameterType::NodeId,
345 required: false,
346 default: None,
347 },
348 ]
349 })
350}
351
352pub struct PrimAlgorithm;
354
355impl GraphAlgorithm for PrimAlgorithm {
356 fn name(&self) -> &str {
357 "prim"
358 }
359
360 fn description(&self) -> &str {
361 "Prim's Minimum Spanning Tree algorithm"
362 }
363
364 fn parameters(&self) -> &[ParameterDef] {
365 prim_params()
366 }
367
368 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
369 let weight_prop = params.get_string("weight");
370 let start = params.get_int("start").map(|id| NodeId::new(id as u64));
371
372 let result = prim(store, weight_prop, start);
373
374 let mut output = AlgorithmResult::new(vec![
375 "source".to_string(),
376 "target".to_string(),
377 "weight".to_string(),
378 "total_weight".to_string(),
379 ]);
380
381 for (src, dst, _edge_id, weight) in result.edges {
382 output.add_row(vec![
383 Value::Int64(src.0 as i64),
384 Value::Int64(dst.0 as i64),
385 Value::Float64(weight),
386 Value::Float64(result.total_weight),
387 ]);
388 }
389
390 Ok(output)
391 }
392}
393
394#[cfg(test)]
399mod tests {
400 use super::*;
401
402 fn create_weighted_triangle() -> LpgStore {
403 let store = LpgStore::new();
408
409 let n0 = store.create_node(&["Node"]);
410 let n1 = store.create_node(&["Node"]);
411 let n2 = store.create_node(&["Node"]);
412
413 store.create_edge_with_props(n0, n1, "EDGE", [("weight", Value::Float64(1.0))]);
414 store.create_edge_with_props(n1, n0, "EDGE", [("weight", Value::Float64(1.0))]);
415 store.create_edge_with_props(n1, n2, "EDGE", [("weight", Value::Float64(2.0))]);
416 store.create_edge_with_props(n2, n1, "EDGE", [("weight", Value::Float64(2.0))]);
417 store.create_edge_with_props(n0, n2, "EDGE", [("weight", Value::Float64(3.0))]);
418 store.create_edge_with_props(n2, n0, "EDGE", [("weight", Value::Float64(3.0))]);
419
420 store
421 }
422
423 fn create_simple_chain() -> LpgStore {
424 let store = LpgStore::new();
426
427 let n0 = store.create_node(&["Node"]);
428 let n1 = store.create_node(&["Node"]);
429 let n2 = store.create_node(&["Node"]);
430 let n3 = store.create_node(&["Node"]);
431
432 store.create_edge(n0, n1, "EDGE");
433 store.create_edge(n1, n0, "EDGE");
434 store.create_edge(n1, n2, "EDGE");
435 store.create_edge(n2, n1, "EDGE");
436 store.create_edge(n2, n3, "EDGE");
437 store.create_edge(n3, n2, "EDGE");
438
439 store
440 }
441
442 #[test]
443 fn test_kruskal_triangle() {
444 let store = create_weighted_triangle();
445 let result = kruskal(&store, Some("weight"));
446
447 assert_eq!(result.edges.len(), 2);
449
450 assert!((result.total_weight - 3.0).abs() < 0.001);
452 }
453
454 #[test]
455 fn test_kruskal_chain() {
456 let store = create_simple_chain();
457 let result = kruskal(&store, None);
458
459 assert_eq!(result.edges.len(), 3);
461
462 assert!((result.total_weight - 3.0).abs() < 0.001);
464 }
465
466 #[test]
467 fn test_kruskal_empty() {
468 let store = LpgStore::new();
469 let result = kruskal(&store, None);
470
471 assert!(result.edges.is_empty());
472 assert_eq!(result.total_weight, 0.0);
473 }
474
475 #[test]
476 fn test_kruskal_single_node() {
477 let store = LpgStore::new();
478 store.create_node(&["Node"]);
479
480 let result = kruskal(&store, None);
481
482 assert!(result.edges.is_empty());
483 assert!(result.is_spanning_tree(1));
484 }
485
486 #[test]
487 fn test_prim_triangle() {
488 let store = create_weighted_triangle();
489 let result = prim(&store, Some("weight"), None);
490
491 assert_eq!(result.edges.len(), 2);
493
494 assert!((result.total_weight - 3.0).abs() < 0.001);
496 }
497
498 #[test]
499 fn test_prim_chain() {
500 let store = create_simple_chain();
501 let result = prim(&store, None, None);
502
503 assert_eq!(result.edges.len(), 3);
505 }
506
507 #[test]
508 fn test_prim_with_start() {
509 let store = create_simple_chain();
510 let result = prim(&store, None, Some(NodeId::new(2)));
511
512 assert_eq!(result.edges.len(), 3);
514 }
515
516 #[test]
517 fn test_prim_empty() {
518 let store = LpgStore::new();
519 let result = prim(&store, None, None);
520
521 assert!(result.edges.is_empty());
522 }
523
524 #[test]
525 fn test_kruskal_prim_same_weight() {
526 let store = create_weighted_triangle();
527
528 let kruskal_result = kruskal(&store, Some("weight"));
529 let prim_result = prim(&store, Some("weight"), None);
530
531 assert!((kruskal_result.total_weight - prim_result.total_weight).abs() < 0.001);
533 }
534
535 #[test]
536 fn test_mst_is_spanning_tree() {
537 let store = create_simple_chain();
538 let result = kruskal(&store, None);
539
540 assert!(result.is_spanning_tree(4));
541 }
542}