1use std::collections::BinaryHeap;
7use std::sync::OnceLock;
8
9use graphos_common::types::{EdgeId, NodeId, Value};
10use graphos_common::utils::error::Result;
11use graphos_common::utils::hash::FxHashMap;
12use graphos_core::graph::Direction;
13use graphos_core::graph::lpg::LpgStore;
14
15use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
16use super::components::UnionFind;
17use super::traits::{GraphAlgorithm, MinScored};
18
19fn extract_weight(store: &LpgStore, edge_id: EdgeId, weight_prop: Option<&str>) -> f64 {
25 if let Some(prop_name) = weight_prop {
26 if let Some(edge) = store.get_edge(edge_id) {
27 if let Some(value) = edge.get_property(prop_name) {
28 return match value {
29 Value::Int64(i) => *i as f64,
30 Value::Float64(f) => *f,
31 _ => 1.0,
32 };
33 }
34 }
35 }
36 1.0
37}
38
39#[derive(Debug, Clone)]
45pub struct MstResult {
46 pub edges: Vec<(NodeId, NodeId, EdgeId, f64)>,
48 pub total_weight: f64,
50}
51
52impl MstResult {
53 pub fn edge_count(&self) -> usize {
55 self.edges.len()
56 }
57
58 pub fn is_spanning_tree(&self, node_count: usize) -> bool {
60 if node_count == 0 {
61 return self.edges.is_empty();
62 }
63 self.edges.len() == node_count - 1
64 }
65}
66
67pub fn kruskal(store: &LpgStore, weight_property: Option<&str>) -> MstResult {
89 let nodes = store.node_ids();
90 let n = nodes.len();
91
92 if n == 0 {
93 return MstResult {
94 edges: Vec::new(),
95 total_weight: 0.0,
96 };
97 }
98
99 let mut node_to_idx: FxHashMap<NodeId, usize> = FxHashMap::default();
101 for (idx, &node) in nodes.iter().enumerate() {
102 node_to_idx.insert(node, idx);
103 }
104
105 let mut edges: Vec<(f64, NodeId, NodeId, EdgeId)> = Vec::new();
107 let mut seen_edges: std::collections::HashSet<(usize, usize)> =
108 std::collections::HashSet::new();
109
110 for &node in &nodes {
111 let i = *node_to_idx.get(&node).unwrap();
112 for (neighbor, edge_id) in store.edges_from(node, Direction::Outgoing) {
113 if let Some(&j) = node_to_idx.get(&neighbor) {
114 let key = if i < j { (i, j) } else { (j, i) };
116 if !seen_edges.contains(&key) {
117 seen_edges.insert(key);
118 let weight = extract_weight(store, edge_id, weight_property);
119 edges.push((weight, node, neighbor, edge_id));
120 }
121 }
122 }
123 }
124
125 edges.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
127
128 let mut uf = UnionFind::new(n);
130
131 let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
132 let mut total_weight = 0.0;
133
134 for (weight, src, dst, edge_id) in edges {
135 let i = *node_to_idx.get(&src).unwrap();
136 let j = *node_to_idx.get(&dst).unwrap();
137
138 if uf.find(i) != uf.find(j) {
139 uf.union(i, j);
140 mst_edges.push((src, dst, edge_id, weight));
141 total_weight += weight;
142
143 if mst_edges.len() == n - 1 {
145 break;
146 }
147 }
148 }
149
150 MstResult {
151 edges: mst_edges,
152 total_weight,
153 }
154}
155
156pub fn prim(store: &LpgStore, weight_property: Option<&str>, start: Option<NodeId>) -> MstResult {
179 let nodes = store.node_ids();
180 let n = nodes.len();
181
182 if n == 0 {
183 return MstResult {
184 edges: Vec::new(),
185 total_weight: 0.0,
186 };
187 }
188
189 let start_node = start.unwrap_or(nodes[0]);
191
192 if store.get_node(start_node).is_none() {
194 return MstResult {
195 edges: Vec::new(),
196 total_weight: 0.0,
197 };
198 }
199
200 let mut in_tree: FxHashMap<NodeId, bool> = FxHashMap::default();
201 let mut mst_edges: Vec<(NodeId, NodeId, EdgeId, f64)> = Vec::new();
202 let mut total_weight = 0.0;
203
204 let mut heap: BinaryHeap<MinScored<f64, (NodeId, NodeId, EdgeId)>> = BinaryHeap::new();
206
207 in_tree.insert(start_node, true);
209
210 for (neighbor, edge_id) in store.edges_from(start_node, Direction::Outgoing) {
212 let weight = extract_weight(store, edge_id, weight_property);
213 heap.push(MinScored::new(weight, (start_node, neighbor, edge_id)));
214 }
215
216 for &other in &nodes {
218 for (neighbor, edge_id) in store.edges_from(other, Direction::Outgoing) {
219 if neighbor == start_node {
220 let weight = extract_weight(store, edge_id, weight_property);
221 heap.push(MinScored::new(weight, (other, start_node, edge_id)));
222 }
223 }
224 }
225
226 while let Some(MinScored(weight, (src, dst, edge_id))) = heap.pop() {
227 if *in_tree.get(&dst).unwrap_or(&false) {
229 continue;
230 }
231
232 in_tree.insert(dst, true);
234 mst_edges.push((src, dst, edge_id, weight));
235 total_weight += weight;
236
237 for (neighbor, new_edge_id) in store.edges_from(dst, Direction::Outgoing) {
239 if !*in_tree.get(&neighbor).unwrap_or(&false) {
240 let new_weight = extract_weight(store, new_edge_id, weight_property);
241 heap.push(MinScored::new(new_weight, (dst, neighbor, new_edge_id)));
242 }
243 }
244
245 for &other in &nodes {
247 if !*in_tree.get(&other).unwrap_or(&false) {
248 for (neighbor, new_edge_id) in store.edges_from(other, Direction::Outgoing) {
249 if neighbor == dst {
250 let new_weight = extract_weight(store, new_edge_id, weight_property);
251 heap.push(MinScored::new(new_weight, (other, dst, new_edge_id)));
252 }
253 }
254 }
255 }
256
257 if mst_edges.len() == n - 1 {
259 break;
260 }
261 }
262
263 MstResult {
264 edges: mst_edges,
265 total_weight,
266 }
267}
268
269static KRUSKAL_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
275
276fn kruskal_params() -> &'static [ParameterDef] {
277 KRUSKAL_PARAMS.get_or_init(|| {
278 vec![ParameterDef {
279 name: "weight".to_string(),
280 description: "Edge property name for weights (default: 1.0)".to_string(),
281 param_type: ParameterType::String,
282 required: false,
283 default: None,
284 }]
285 })
286}
287
288pub struct KruskalAlgorithm;
290
291impl GraphAlgorithm for KruskalAlgorithm {
292 fn name(&self) -> &str {
293 "kruskal"
294 }
295
296 fn description(&self) -> &str {
297 "Kruskal's Minimum Spanning Tree algorithm"
298 }
299
300 fn parameters(&self) -> &[ParameterDef] {
301 kruskal_params()
302 }
303
304 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
305 let weight_prop = params.get_string("weight");
306
307 let result = kruskal(store, weight_prop);
308
309 let mut output = AlgorithmResult::new(vec![
310 "source".to_string(),
311 "target".to_string(),
312 "weight".to_string(),
313 "total_weight".to_string(),
314 ]);
315
316 for (src, dst, _edge_id, weight) in result.edges {
317 output.add_row(vec![
318 Value::Int64(src.0 as i64),
319 Value::Int64(dst.0 as i64),
320 Value::Float64(weight),
321 Value::Float64(result.total_weight),
322 ]);
323 }
324
325 Ok(output)
326 }
327}
328
329static PRIM_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
331
332fn prim_params() -> &'static [ParameterDef] {
333 PRIM_PARAMS.get_or_init(|| {
334 vec![
335 ParameterDef {
336 name: "weight".to_string(),
337 description: "Edge property name for weights (default: 1.0)".to_string(),
338 param_type: ParameterType::String,
339 required: false,
340 default: None,
341 },
342 ParameterDef {
343 name: "start".to_string(),
344 description: "Starting node ID (optional)".to_string(),
345 param_type: ParameterType::NodeId,
346 required: false,
347 default: None,
348 },
349 ]
350 })
351}
352
353pub struct PrimAlgorithm;
355
356impl GraphAlgorithm for PrimAlgorithm {
357 fn name(&self) -> &str {
358 "prim"
359 }
360
361 fn description(&self) -> &str {
362 "Prim's Minimum Spanning Tree algorithm"
363 }
364
365 fn parameters(&self) -> &[ParameterDef] {
366 prim_params()
367 }
368
369 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
370 let weight_prop = params.get_string("weight");
371 let start = params.get_int("start").map(|id| NodeId::new(id as u64));
372
373 let result = prim(store, weight_prop, start);
374
375 let mut output = AlgorithmResult::new(vec![
376 "source".to_string(),
377 "target".to_string(),
378 "weight".to_string(),
379 "total_weight".to_string(),
380 ]);
381
382 for (src, dst, _edge_id, weight) in result.edges {
383 output.add_row(vec![
384 Value::Int64(src.0 as i64),
385 Value::Int64(dst.0 as i64),
386 Value::Float64(weight),
387 Value::Float64(result.total_weight),
388 ]);
389 }
390
391 Ok(output)
392 }
393}
394
395#[cfg(test)]
400mod tests {
401 use super::*;
402
403 fn create_weighted_triangle() -> LpgStore {
404 let store = LpgStore::new();
409
410 let n0 = store.create_node(&["Node"]);
411 let n1 = store.create_node(&["Node"]);
412 let n2 = store.create_node(&["Node"]);
413
414 store.create_edge_with_props(n0, n1, "EDGE", [("weight", Value::Float64(1.0))]);
415 store.create_edge_with_props(n1, n0, "EDGE", [("weight", Value::Float64(1.0))]);
416 store.create_edge_with_props(n1, n2, "EDGE", [("weight", Value::Float64(2.0))]);
417 store.create_edge_with_props(n2, n1, "EDGE", [("weight", Value::Float64(2.0))]);
418 store.create_edge_with_props(n0, n2, "EDGE", [("weight", Value::Float64(3.0))]);
419 store.create_edge_with_props(n2, n0, "EDGE", [("weight", Value::Float64(3.0))]);
420
421 store
422 }
423
424 fn create_simple_chain() -> LpgStore {
425 let store = LpgStore::new();
427
428 let n0 = store.create_node(&["Node"]);
429 let n1 = store.create_node(&["Node"]);
430 let n2 = store.create_node(&["Node"]);
431 let n3 = store.create_node(&["Node"]);
432
433 store.create_edge(n0, n1, "EDGE");
434 store.create_edge(n1, n0, "EDGE");
435 store.create_edge(n1, n2, "EDGE");
436 store.create_edge(n2, n1, "EDGE");
437 store.create_edge(n2, n3, "EDGE");
438 store.create_edge(n3, n2, "EDGE");
439
440 store
441 }
442
443 #[test]
444 fn test_kruskal_triangle() {
445 let store = create_weighted_triangle();
446 let result = kruskal(&store, Some("weight"));
447
448 assert_eq!(result.edges.len(), 2);
450
451 assert!((result.total_weight - 3.0).abs() < 0.001);
453 }
454
455 #[test]
456 fn test_kruskal_chain() {
457 let store = create_simple_chain();
458 let result = kruskal(&store, None);
459
460 assert_eq!(result.edges.len(), 3);
462
463 assert!((result.total_weight - 3.0).abs() < 0.001);
465 }
466
467 #[test]
468 fn test_kruskal_empty() {
469 let store = LpgStore::new();
470 let result = kruskal(&store, None);
471
472 assert!(result.edges.is_empty());
473 assert_eq!(result.total_weight, 0.0);
474 }
475
476 #[test]
477 fn test_kruskal_single_node() {
478 let store = LpgStore::new();
479 store.create_node(&["Node"]);
480
481 let result = kruskal(&store, None);
482
483 assert!(result.edges.is_empty());
484 assert!(result.is_spanning_tree(1));
485 }
486
487 #[test]
488 fn test_prim_triangle() {
489 let store = create_weighted_triangle();
490 let result = prim(&store, Some("weight"), None);
491
492 assert_eq!(result.edges.len(), 2);
494
495 assert!((result.total_weight - 3.0).abs() < 0.001);
497 }
498
499 #[test]
500 fn test_prim_chain() {
501 let store = create_simple_chain();
502 let result = prim(&store, None, None);
503
504 assert_eq!(result.edges.len(), 3);
506 }
507
508 #[test]
509 fn test_prim_with_start() {
510 let store = create_simple_chain();
511 let result = prim(&store, None, Some(NodeId::new(2)));
512
513 assert_eq!(result.edges.len(), 3);
515 }
516
517 #[test]
518 fn test_prim_empty() {
519 let store = LpgStore::new();
520 let result = prim(&store, None, None);
521
522 assert!(result.edges.is_empty());
523 }
524
525 #[test]
526 fn test_kruskal_prim_same_weight() {
527 let store = create_weighted_triangle();
528
529 let kruskal_result = kruskal(&store, Some("weight"));
530 let prim_result = prim(&store, Some("weight"), None);
531
532 assert!((kruskal_result.total_weight - prim_result.total_weight).abs() < 0.001);
534 }
535
536 #[test]
537 fn test_mst_is_spanning_tree() {
538 let store = create_simple_chain();
539 let result = kruskal(&store, None);
540
541 assert!(result.is_spanning_tree(4));
542 }
543}