use alloc::collections::BinaryHeap;
use hashbrown::{HashMap, HashSet};
use crate::data::Element;
use crate::prelude::*;
use crate::scored::MinScored;
use crate::unionfind::UnionFind;
use crate::visit::{Data, IntoEdges, IntoNodeReferences, NodeRef};
use crate::visit::{IntoEdgeReferences, NodeIndexable};
pub fn min_spanning_tree<G>(g: G) -> MinSpanningTree<G>
where
G::NodeWeight: Clone,
G::EdgeWeight: Clone + PartialOrd,
G: IntoNodeReferences + IntoEdgeReferences + NodeIndexable,
{
let subgraphs = UnionFind::new(g.node_bound());
let edges = g.edge_references();
let mut sort_edges = BinaryHeap::with_capacity(edges.size_hint().0);
for edge in edges {
sort_edges.push(MinScored(
edge.weight().clone(),
(edge.source(), edge.target()),
));
}
MinSpanningTree {
graph: g,
node_ids: Some(g.node_references()),
subgraphs,
sort_edges,
node_map: HashMap::new(),
node_count: 0,
}
}
#[derive(Debug, Clone)]
pub struct MinSpanningTree<G>
where
G: Data + IntoNodeReferences,
{
graph: G,
node_ids: Option<G::NodeReferences>,
subgraphs: UnionFind<usize>,
#[allow(clippy::type_complexity)]
sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
node_map: HashMap<usize, usize>,
node_count: usize,
}
impl<G> Iterator for MinSpanningTree<G>
where
G: IntoNodeReferences + NodeIndexable,
G::NodeWeight: Clone,
G::EdgeWeight: PartialOrd,
{
type Item = Element<G::NodeWeight, G::EdgeWeight>;
fn next(&mut self) -> Option<Self::Item> {
let g = self.graph;
if let Some(ref mut iter) = self.node_ids {
if let Some(node) = iter.next() {
self.node_map.insert(g.to_index(node.id()), self.node_count);
self.node_count += 1;
return Some(Element::Node {
weight: node.weight().clone(),
});
}
}
self.node_ids = None;
while let Some(MinScored(score, (a, b))) = self.sort_edges.pop() {
let (a_index, b_index) = (g.to_index(a), g.to_index(b));
if self.subgraphs.union(a_index, b_index) {
let (&a_order, &b_order) =
match (self.node_map.get(&a_index), self.node_map.get(&b_index)) {
(Some(a_id), Some(b_id)) => (a_id, b_id),
_ => panic!("Edge references unknown node"),
};
return Some(Element::Edge {
source: a_order,
target: b_order,
weight: score,
});
}
}
None
}
}
pub fn min_spanning_tree_prim<G>(g: G) -> MinSpanningTreePrim<G>
where
G::EdgeWeight: PartialOrd,
G: IntoNodeReferences + IntoEdgeReferences,
{
let sort_edges = BinaryHeap::with_capacity(g.edge_references().size_hint().0);
let nodes_taken = HashSet::with_capacity(g.node_references().size_hint().0);
let initial_node = g.node_references().next();
MinSpanningTreePrim {
graph: g,
node_ids: Some(g.node_references()),
node_map: HashMap::new(),
node_count: 0,
sort_edges,
nodes_taken,
initial_node,
}
}
#[derive(Debug, Clone)]
pub struct MinSpanningTreePrim<G>
where
G: IntoNodeReferences,
{
graph: G,
node_ids: Option<G::NodeReferences>,
node_map: HashMap<usize, usize>,
node_count: usize,
#[allow(clippy::type_complexity)]
sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
nodes_taken: HashSet<usize>,
initial_node: Option<G::NodeRef>,
}
impl<G> Iterator for MinSpanningTreePrim<G>
where
G: IntoNodeReferences + IntoEdges + NodeIndexable,
G::NodeWeight: Clone,
G::EdgeWeight: Clone + PartialOrd,
{
type Item = Element<G::NodeWeight, G::EdgeWeight>;
fn next(&mut self) -> Option<Self::Item> {
let g = self.graph;
if let Some(ref mut iter) = self.node_ids {
if let Some(node) = iter.next() {
self.node_map.insert(g.to_index(node.id()), self.node_count);
self.node_count += 1;
return Some(Element::Node {
weight: node.weight().clone(),
});
}
}
self.node_ids = None;
if let Some(initial_node) = self.initial_node {
let initial_node_index = g.to_index(initial_node.id());
self.nodes_taken.insert(initial_node_index);
let initial_edges = g.edges(initial_node.id());
for edge in initial_edges {
self.sort_edges.push(MinScored(
edge.weight().clone(),
(edge.source(), edge.target()),
));
}
};
self.initial_node = None;
if self.nodes_taken.len() == self.node_count {
self.sort_edges.clear();
};
while let Some(MinScored(score, (source, target))) = self.sort_edges.pop() {
let (source_index, target_index) = (g.to_index(source), g.to_index(target));
if self.nodes_taken.contains(&target_index) {
continue;
}
self.nodes_taken.insert(target_index);
for edge in g.edges(target) {
self.sort_edges.push(MinScored(
edge.weight().clone(),
(edge.source(), edge.target()),
));
}
let (&source_order, &target_order) = match (
self.node_map.get(&source_index),
self.node_map.get(&target_index),
) {
(Some(source_order), Some(target_order)) => (source_order, target_order),
_ => panic!("Edge references unknown node"),
};
return Some(Element::Edge {
source: source_order,
target: target_order,
weight: score,
});
}
None
}
}