use super::common::{GraphView, NodeId};
use std::collections::{HashSet, BinaryHeap};
use std::cmp::Ordering;
pub struct MSTResult {
pub total_weight: f64,
pub edges: Vec<(NodeId, NodeId, f64)>, }
#[derive(Copy, Clone, PartialEq)]
struct EdgeState {
weight: f64,
source: usize,
target: usize,
}
impl Eq for EdgeState {}
impl Ord for EdgeState {
fn cmp(&self, other: &Self) -> Ordering {
other.weight.partial_cmp(&self.weight).unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for EdgeState {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub fn prim_mst(view: &GraphView) -> MSTResult {
if view.node_count == 0 {
return MSTResult { total_weight: 0.0, edges: Vec::new() };
}
let start_idx = 0; let mut visited = HashSet::new();
let mut heap = BinaryHeap::new();
let mut mst_edges = Vec::new();
let mut total_weight = 0.0;
visited.insert(start_idx);
add_edges(view, start_idx, &mut heap, &visited);
while let Some(EdgeState { weight, source, target }) = heap.pop() {
if visited.contains(&target) {
continue;
}
visited.insert(target);
mst_edges.push((
view.index_to_node[source],
view.index_to_node[target],
weight
));
total_weight += weight;
add_edges(view, target, &mut heap, &visited);
}
MSTResult {
total_weight,
edges: mst_edges,
}
}
fn add_edges(view: &GraphView, u: usize, heap: &mut BinaryHeap<EdgeState>, visited: &HashSet<usize>) {
let u_out = view.successors(u);
for (i, &v) in u_out.iter().enumerate() {
if !visited.contains(&v) {
let weight = view.weights(u).map(|w| w[i]).unwrap_or(1.0);
heap.push(EdgeState { weight, source: u, target: v });
}
}
let u_in = view.predecessors(u);
for &_v in u_in.iter() {
let v = _v; if !visited.contains(&v) {
let v_out = view.successors(v);
if let Some(idx) = v_out.iter().position(|&x| x == u) {
let weight = view.weights(v).map(|w| w[idx]).unwrap_or(1.0);
heap.push(EdgeState { weight, source: u, target: v }); }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_prim_mst() {
let node_count = 3;
let index_to_node = vec![1, 2, 3];
let mut node_to_index = HashMap::new();
node_to_index.insert(1, 0); node_to_index.insert(2, 1); node_to_index.insert(3, 2);
let mut outgoing = vec![vec![]; 3];
let mut incoming = vec![vec![]; 3];
let mut weights = vec![vec![]; 3];
outgoing[0].push(1); incoming[1].push(0); weights[0].push(1.0);
outgoing[1].push(0); incoming[0].push(1); weights[1].push(1.0);
outgoing[1].push(2); incoming[2].push(1); weights[1].push(2.0);
outgoing[2].push(1); incoming[1].push(2); weights[2].push(2.0);
outgoing[0].push(2); incoming[2].push(0); weights[0].push(10.0);
outgoing[2].push(0); incoming[0].push(2); weights[2].push(10.0);
let view = GraphView::from_adjacency_list(
node_count,
index_to_node,
node_to_index,
outgoing,
incoming,
Some(weights),
);
let result = prim_mst(&view);
assert_eq!(result.total_weight, 3.0);
assert_eq!(result.edges.len(), 2);
}
}