use foldhash::fast::RandomState;
use foldhash::HashSet;
use hashbrown::HashMap;
use indexmap::IndexSet;
use petgraph::data::Build;
use petgraph::prelude::*;
use petgraph::visit::{
Data, EdgeCount, EdgeRef, GraphBase, IntoEdgeReferences, IntoNodeIdentifiers,
};
use rustworkx_core::err::ContractError;
use rustworkx_core::graph_ext::contraction::can_contract;
use rustworkx_core::graph_ext::*;
use std::convert::Infallible;
use std::fmt::Debug;
use std::hash::Hash;
mod graph_map {
use petgraph::prelude::*;
type G = DiGraphMap<char, usize>;
common_test!(test_cycle_check_enabled, G);
common_test!(test_cycle_check_disabled, G);
common_test!(test_empty_nodes, G);
common_test!(test_unknown_nodes, G);
common_test!(test_cycle_path_len_gt_1, G);
common_test!(test_multiple_paths_would_cycle, G);
common_test!(test_replace_node_no_neighbors, G);
common_test!(test_keep_edges_multigraph, G);
common_test!(test_collapse_parallel_edges, G);
common_test!(test_replace_all_nodes, G);
}
mod stable_graph {
use petgraph::prelude::*;
type G = StableDiGraph<char, usize>;
common_test!(test_cycle_check_enabled, G);
common_test!(test_cycle_check_disabled, G);
common_test!(test_empty_nodes, G);
common_test!(test_unknown_nodes, G);
common_test!(test_cycle_path_len_gt_1, G);
common_test!(test_multiple_paths_would_cycle, G);
common_test!(test_replace_node_no_neighbors, G);
common_test!(test_keep_edges_multigraph, G);
common_test!(test_collapse_parallel_edges, G);
common_test!(test_replace_all_nodes, G);
}
pub fn test_cycle_check_enabled<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>,
G::NodeId: Debug,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
dag.add_edge(a, b, 1);
dag.add_edge(b, c, 2);
let result = dag.contract_nodes([a, c], 'm', true);
match result.expect_err("Cycle should cause return error.") {
ContractError::DAGWouldCycle => (),
}
}
fn test_cycle_check_disabled<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>,
G::NodeId: Debug,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
dag.add_edge(a, b, 1);
dag.add_edge(b, c, 2);
let result = dag.contract_nodes([a, c], 'm', false);
result.expect("No error should be raised for a cycle when cycle check is disabled.");
}
fn test_empty_nodes<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>,
G::NodeId: Debug,
{
let mut dag = G::default();
dag.contract_nodes([], 'm', false).unwrap();
assert_eq!(dag.node_count(), 1);
}
fn test_unknown_nodes<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>
+ NodeRemovable,
G::NodeId: Debug + Copy,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
dag.add_edge(a, b, 1);
dag.add_edge(b, c, 2);
dag.remove_node(b);
dag.remove_node(c);
dag.contract_nodes([b, c], 'm', false).unwrap();
assert_eq!(dag.node_count(), 2);
}
fn test_cycle_path_len_gt_1<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>
+ NodeRemovable,
G::NodeId: Debug + Copy,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
let d = dag.add_node('d');
dag.add_edge(a, b, 1);
dag.add_edge(b, c, 2);
dag.add_edge(c, d, 3);
dag.add_edge(a, d, 4);
dag.contract_nodes([a, d], 'm', true)
.expect_err("Cycle should be detected.");
}
fn test_multiple_paths_would_cycle<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>,
for<'b> &'b G: GraphBase<NodeId = G::NodeId>
+ Data<EdgeWeight = G::EdgeWeight>
+ IntoEdgeReferences
+ IntoNodeIdentifiers,
G::NodeId: Eq + Hash + Debug + Copy,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
let d = dag.add_node('d');
let e = dag.add_node('e');
let f = dag.add_node('f');
dag.add_edge(a, b, 1);
dag.add_edge(b, c, 2);
dag.add_edge(c, d, 3);
dag.add_edge(b, e, 4);
dag.add_edge(e, f, 5);
let result = dag.contract_nodes([a, d, f], 'm', true);
match result.expect_err("Cycle should cause return error.") {
ContractError::DAGWouldCycle => (),
}
dag.contract_nodes([a, d, f], 'm', false)
.expect("Contraction should be allowed without cycle check.");
let edge_refs: Vec<_> = dag.edge_references().collect();
assert_eq!(edge_refs.len(), 5, "Missing expected edge!");
let mut seen = HashMap::new();
for edge_ref in edge_refs.into_iter() {
match (edge_ref.source(), edge_ref.target(), edge_ref.weight()) {
(m, b, 1) => {
assert_eq!(*seen.entry('m').or_insert(m), m);
assert_eq!(*seen.entry('b').or_insert(b), b);
}
(b, c, 2) => {
assert_eq!(*seen.entry('b').or_insert(b), b);
assert_eq!(*seen.entry('c').or_insert(c), c);
}
(c, m, 3) => {
assert_eq!(*seen.entry('c').or_insert(c), c);
assert_eq!(*seen.entry('m').or_insert(m), m);
}
(b, e, 4) => {
assert_eq!(*seen.entry('b').or_insert(b), b);
assert_eq!(*seen.entry('e').or_insert(e), e);
}
(e, m, 5) => {
assert_eq!(*seen.entry('e').or_insert(e), e);
assert_eq!(*seen.entry('m').or_insert(m), m);
}
(_, _, w) => panic!("Unexpected edge weight: {w}"),
}
}
assert_eq!(seen.len(), 4, "Missing expected node!");
}
fn test_replace_node_no_neighbors<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>,
G::NodeId: Debug,
{
let mut dag = G::default();
let a = dag.add_node('a');
dag.contract_nodes([a], 'm', true).unwrap();
assert_eq!(dag.node_count(), 1);
}
fn test_keep_edges_multigraph<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected<Error = ContractError>,
for<'b> &'b G: GraphBase<NodeId = G::NodeId>
+ Data<EdgeWeight = G::EdgeWeight>
+ IntoEdgeReferences
+ IntoNodeIdentifiers,
G::NodeId: Eq + Hash + Debug + Copy,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
dag.add_edge(a, b, 1);
dag.add_edge(c, a, 2);
let result = dag.contract_nodes([b, c], 'm', true);
match result.expect_err("Cycle should cause return error.") {
ContractError::DAGWouldCycle => (),
}
let m = dag
.contract_nodes([b, c], 'm', false)
.expect("Contraction should be allowed without cycle check.");
assert_eq!(dag.node_count(), 2);
let edges: HashSet<_> = dag
.edge_references()
.map(|e| (e.source(), e.target(), *e.weight()))
.collect();
let expected = HashSet::from_iter([(a, m, 1), (m, a, 2)]);
assert_eq!(edges, expected);
}
fn test_collapse_parallel_edges<G>()
where
G: Default + Data<NodeWeight = char, EdgeWeight = usize> + Build + ContractNodesSimpleDirected,
for<'b> &'b G: GraphBase<NodeId = G::NodeId>
+ Data<EdgeWeight = G::EdgeWeight>
+ IntoEdgeReferences
+ IntoNodeIdentifiers,
G::NodeId: Eq + Hash + Debug + Copy,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
let d = dag.add_node('d');
let e = dag.add_node('e');
dag.add_edge(a, b, 1);
dag.add_edge(a, c, 2);
dag.add_edge(a, d, 3);
dag.add_edge(b, e, 4);
dag.add_edge(c, e, 5);
dag.add_edge(d, e, 6);
let m = dag
.contract_nodes_simple([b, c, d], 'm', true, |w1, w2| {
Ok::<usize, Infallible>(w1 + w2)
})
.unwrap();
assert_eq!(dag.node_count(), 3);
let edges: HashSet<_> = dag
.edge_references()
.map(|e| (e.source(), e.target(), *e.weight()))
.collect();
let expected = HashSet::from_iter([(a, m, 6), (m, e, 15)]);
assert_eq!(edges, expected);
}
fn test_replace_all_nodes<G>()
where
G: Default
+ Data<NodeWeight = char, EdgeWeight = usize>
+ Build
+ ContractNodesDirected
+ EdgeCount,
for<'b> &'b G: GraphBase<NodeId = G::NodeId>
+ Data<EdgeWeight = G::EdgeWeight>
+ IntoEdgeReferences
+ IntoNodeIdentifiers,
G::NodeId: Eq + Hash + Debug + Copy,
{
let mut dag = G::default();
let a = dag.add_node('a');
let b = dag.add_node('b');
let c = dag.add_node('c');
let d = dag.add_node('d');
let e = dag.add_node('e');
dag.add_edge(a, b, 1);
dag.add_edge(a, c, 2);
dag.add_edge(a, d, 3);
dag.add_edge(b, e, 4);
dag.add_edge(c, e, 5);
dag.add_edge(d, e, 6);
dag.contract_nodes(dag.node_identifiers().collect::<Vec<_>>(), 'm', true)
.unwrap();
assert_eq!(dag.node_count(), 1);
assert_eq!(dag.edge_count(), 0);
}
#[test]
fn test_can_contract_without_cycle_true() {
let mut graph = StableDiGraph::<&str, ()>::default();
let a = graph.add_node("a");
let b = graph.add_node("b");
let c = graph.add_node("c");
graph.add_edge(a, b, ());
graph.add_edge(b, c, ());
let mut nodes: IndexSet<_, RandomState> = IndexSet::with_hasher(RandomState::default());
nodes.insert(b);
nodes.insert(c);
assert!(can_contract(&graph, &nodes));
}
#[test]
fn test_can_contract_without_cycle_false() {
let mut graph = StableDiGraph::<&str, ()>::default();
let a = graph.add_node("a");
let b = graph.add_node("b");
let c = graph.add_node("c");
graph.add_edge(a, b, ());
graph.add_edge(b, c, ());
graph.add_edge(c, a, ());
let mut nodes: IndexSet<_, RandomState> = IndexSet::with_hasher(RandomState::default());
nodes.insert(a);
nodes.insert(c);
assert!(!can_contract(&graph, &nodes));
}