use alloc::vec::Vec;
use core::{fmt::Debug, hash::Hash};
use hashbrown::{HashMap, HashSet};
use crate::algo::floyd_warshall::floyd_warshall_path;
use crate::algo::{dijkstra, min_spanning_tree, BoundedMeasure, Measure};
use crate::data::FromElements;
use crate::graph::{IndexType, NodeIndex, UnGraph};
use crate::visit::{
Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges, IntoNeighbors,
IntoNodeIdentifiers, IntoNodeReferences, NodeCompactIndexable, NodeIndexable, Visitable,
};
use crate::Undirected;
#[cfg(feature = "stable_graph")]
use crate::stable_graph::StableGraph;
type Edge<G> = (<G as GraphBase>::NodeId, <G as GraphBase>::NodeId);
type Subgraph<G> = HashSet<<G as GraphBase>::NodeId>;
fn compute_shortest_path_length<G>(graph: G, source: G::NodeId, target: G::NodeId) -> G::EdgeWeight
where
G: Visitable + IntoEdges,
G::NodeId: Eq + Hash,
G::EdgeWeight: Measure + Copy,
{
let output = dijkstra(graph, source, Some(target), |e| *e.weight());
output[&target]
}
fn compute_metric_closure<G>(
graph: G,
terminals: &[G::NodeId],
) -> HashMap<(usize, usize), G::EdgeWeight>
where
G: Data + IntoNodeReferences + NodeIndexable + Visitable + IntoEdges,
G::EdgeWeight: Copy + Measure,
G::NodeId: PartialOrd + Eq + Hash,
{
let mut closure = HashMap::new();
for (i, node_id_1) in terminals.iter().enumerate() {
for node_id_2 in terminals.iter().skip(i + 1) {
closure.insert(
(graph.to_index(*node_id_1), graph.to_index(*node_id_2)),
compute_shortest_path_length(graph, *node_id_1, *node_id_2),
);
}
}
closure
}
fn subgraph_edges_from_metric_closure<G>(
graph: G,
minimum_spanning_closure: G,
) -> (Vec<Edge<G>>, Subgraph<G>)
where
G: GraphBase
+ NodeCompactIndexable
+ IntoEdgeReferences
+ IntoNodeIdentifiers
+ GraphProp
+ IntoNodeReferences,
G::EdgeWeight: BoundedMeasure + Copy,
G::NodeId: Eq + Hash + Ord + Debug,
{
let mut retained_nodes = HashSet::new();
let mut retained_edges = Vec::new();
let (_, prev) = floyd_warshall_path(graph, |e| *e.weight()).unwrap();
for edge in minimum_spanning_closure.edge_references() {
let target = graph.to_index(edge.target());
let source = graph.to_index(edge.source());
let mut current = target;
while current != source {
if let Some(prev_node) = prev[source][current] {
retained_nodes.insert(graph.from_index(prev_node));
retained_nodes.insert(graph.from_index(current));
retained_edges.push((graph.from_index(prev_node), graph.from_index(current)));
current = prev_node;
}
}
}
(retained_edges, retained_nodes)
}
fn non_terminal_leaves<G>(graph: G, terminals: &[G::NodeId]) -> HashSet<G::NodeId>
where
G: GraphBase + IntoNodeReferences + IntoNodeIdentifiers + IntoNeighbors,
G::NodeId: Hash + Eq + Debug,
G::NodeRef: Eq + Hash,
{
let mut removed_leaves = HashSet::new();
let mut remaining_leaves = graph
.node_identifiers()
.filter(|node_id| {
graph.neighbors(*node_id).collect::<HashSet<_>>().len() == 1
&& !terminals.contains(node_id)
})
.collect::<HashSet<_>>();
while !remaining_leaves.is_empty() {
remaining_leaves = graph
.node_identifiers()
.filter(|node_id| {
!terminals.contains(node_id)
&& !removed_leaves.contains(node_id)
&& (graph
.neighbors(*node_id)
.collect::<HashSet<_>>()
.difference(&removed_leaves))
.collect::<Vec<_>>()
.len()
== 1
})
.collect::<HashSet<_>>();
removed_leaves = removed_leaves
.union(&remaining_leaves)
.cloned()
.collect::<HashSet<_>>();
}
removed_leaves
}
#[cfg(feature = "stable_graph")]
pub fn steiner_tree<N, E, Ix>(
graph: &UnGraph<N, E, Ix>,
terminals: &[NodeIndex<Ix>],
) -> StableGraph<N, E, Undirected, Ix>
where
N: Default + Clone + Eq + Hash + Debug,
E: Copy + Eq + Ord + Measure + BoundedMeasure,
Ix: IndexType,
{
let metric_closure = compute_metric_closure(&graph, terminals);
let metric_closure_graph: UnGraph<N, E, _> = UnGraph::from_edges(
metric_closure
.iter()
.map(|((node1, node2), &weight)| (*node1, *node2, weight)),
);
let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
let (subgraph_edges, subgraph_nodes) =
subgraph_edges_from_metric_closure(graph, &minimum_spanning);
let mut graph = StableGraph::from(graph.clone());
graph.retain_edges(|graph, e| {
let edge = graph.edge_endpoints(e).unwrap();
subgraph_edges.contains(&(edge.0, edge.1)) || subgraph_edges.contains(&(edge.1, edge.0))
});
graph.retain_nodes(|_, n| subgraph_nodes.contains(&n));
let non_terminal_nodes = non_terminal_leaves(&graph, terminals);
graph.retain_nodes(|_, n| !non_terminal_nodes.contains(&n));
graph
}
#[cfg(test)]
mod test {
use alloc::vec;
use hashbrown::{HashMap, HashSet};
use super::{compute_metric_closure, non_terminal_leaves, subgraph_edges_from_metric_closure};
use crate::graph::NodeIndex;
use crate::{
algo::{min_spanning_tree, EdgeRef, UnGraph},
data::FromElements,
Graph, Undirected,
};
#[test]
fn test_compute_metric_closure() {
let mut graph = Graph::<(), i32, Undirected>::new_undirected();
let a = graph.add_node(());
let b = graph.add_node(());
let c = graph.add_node(());
let d = graph.add_node(());
let e = graph.add_node(());
let f = graph.add_node(());
graph.extend_with_edges([
(a, b, 7),
(a, f, 6),
(b, c, 1),
(b, f, 5),
(c, d, 1),
(c, e, 3),
(d, e, 1),
(d, f, 4),
(e, f, 10),
]);
let terminals = vec![a, c, e, f];
let metric_closure = compute_metric_closure(&graph, &terminals);
let metric_closure_graph: UnGraph<&str, _, _> = UnGraph::from_edges(
metric_closure
.iter()
.map(|((node1, node2), &weight)| (*node1, *node2, weight)),
);
let ref_weights = HashMap::<_, _>::from([
((0, 2), 8),
((0, 4), 10),
((0, 5), 6),
((2, 4), 2),
((2, 5), 5),
((4, 5), 5),
]);
for ((node1, node2), ref_weight) in ref_weights {
assert_eq!(metric_closure[&(node1, node2)], ref_weight);
assert_eq!(
*metric_closure_graph
.edge_weight(
metric_closure_graph
.find_edge(NodeIndex::new(node1), NodeIndex::new(node2))
.unwrap()
)
.unwrap(),
ref_weight
);
}
}
#[test]
fn test_subgraph_from_metric_closure() {
let mut graph = Graph::<(), i32, _>::new_undirected();
let a = graph.add_node(());
let b = graph.add_node(());
let c = graph.add_node(());
let d = graph.add_node(());
let e = graph.add_node(());
let f = graph.add_node(());
graph.extend_with_edges([
(a, b, 7),
(a, f, 6),
(b, c, 1),
(b, f, 5),
(c, d, 1),
(c, e, 3),
(d, e, 1),
(d, f, 4),
(e, f, 10),
]);
let terminals = vec![a, c, e, f];
let metric_closure = compute_metric_closure(&graph, &terminals);
let metric_closure_graph: UnGraph<(), _, _> = UnGraph::from_edges(
metric_closure
.iter()
.map(|((node1, node2), &weight)| (*node1 as u32, *node2 as u32, weight)),
);
let minimum_spanning = UnGraph::from_elements(min_spanning_tree(&metric_closure_graph));
let (subgraph_edges, _subgraph_nodes) =
subgraph_edges_from_metric_closure(&graph, &minimum_spanning);
graph.retain_edges(|graph, e| {
let edge = graph.edge_endpoints(e).unwrap();
subgraph_edges.contains(&(edge.0, edge.1))
});
let mut ref_graph = UnGraph::<(), _>::new_undirected();
let ref_a = ref_graph.add_node(());
let _ = ref_graph.add_node(());
let ref_c = ref_graph.add_node(());
let ref_d = ref_graph.add_node(());
let ref_e = ref_graph.add_node(());
let ref_f = ref_graph.add_node(());
ref_graph.extend_with_edges([
(ref_c, ref_d, 1),
(ref_d, ref_e, 1),
(ref_d, ref_f, 4),
(ref_a, ref_f, 6),
]);
for ref_edge in ref_graph.edge_references() {
let (edge_index, _) = graph
.find_edge_undirected(ref_edge.source(), ref_edge.target())
.unwrap();
let edge_endpoints = graph.edge_endpoints(edge_index).unwrap();
assert_eq!(graph.edge_weight(edge_index).unwrap(), ref_edge.weight());
assert_eq!(edge_endpoints.0, ref_edge.source());
assert_eq!(edge_endpoints.1, ref_edge.target());
}
}
#[test]
fn test_remove_non_terminal_nodes() {
let mut graph = Graph::<(), i32, _>::new_undirected();
let a = graph.add_node(());
let b = graph.add_node(());
let c = graph.add_node(());
let d = graph.add_node(());
let e = graph.add_node(());
let f = graph.add_node(());
graph.extend_with_edges([(a, b, 7), (b, c, 6), (c, d, 1), (d, e, 5), (e, f, 1)]);
let terminals = vec![a, c];
let non_terminal_nodes = non_terminal_leaves(&graph, &terminals);
let non_terminal_refs = HashSet::from([d, e, f]);
assert_eq!(non_terminal_refs, non_terminal_nodes);
}
}