use crate::algorithms::connectivity::is_bipartite;
use crate::base::{EdgeWeight, Graph, IndexType, Node};
use crate::error::{GraphError, Result};
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
#[derive(Debug, Clone)]
pub struct BipartiteMatching<N: Node> {
pub matching: HashMap<N, N>,
pub size: usize,
}
#[allow(dead_code)]
pub fn maximum_bipartite_matching<N, E, Ix>(
graph: &Graph<N, E, Ix>,
coloring: &HashMap<N, u8>,
) -> BipartiteMatching<N>
where
N: Node + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
let mut node_to_idx: HashMap<N, petgraph::graph::NodeIndex<Ix>> = HashMap::new();
for node_idx in graph.inner().node_indices() {
node_to_idx.insert(graph.inner()[node_idx].clone(), node_idx);
}
let mut left_nodes = Vec::new();
let mut right_nodes = Vec::new();
for (node, &color) in coloring {
if color == 0 {
left_nodes.push(node.clone());
} else {
right_nodes.push(node.clone());
}
}
let mut matching: HashMap<N, N> = HashMap::new();
let mut reverse_matching: HashMap<N, N> = HashMap::new();
for left_node in &left_nodes {
if !matching.contains_key(left_node) {
let mut visited = HashSet::new();
augment_path(
graph,
left_node,
&mut matching,
&mut reverse_matching,
&mut visited,
coloring,
);
}
}
BipartiteMatching {
size: matching.len(),
matching,
}
}
#[allow(dead_code)]
fn augment_path<N, E, Ix>(
graph: &Graph<N, E, Ix>,
node: &N,
matching: &mut HashMap<N, N>,
reverse_matching: &mut HashMap<N, N>,
visited: &mut HashSet<N>,
coloring: &HashMap<N, u8>,
) -> bool
where
N: Node + std::fmt::Debug,
E: EdgeWeight,
Ix: petgraph::graph::IndexType,
{
visited.insert(node.clone());
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in neighbors {
if coloring.get(node) == coloring.get(&neighbor) {
continue;
}
if let std::collections::hash_map::Entry::Vacant(e) =
reverse_matching.entry(neighbor.clone())
{
matching.insert(node.clone(), neighbor.clone());
e.insert(node.clone());
return true;
}
let matched_node = reverse_matching[&neighbor].clone();
if !visited.contains(&matched_node)
&& augment_path(
graph,
&matched_node,
matching,
reverse_matching,
visited,
coloring,
)
{
matching.insert(node.clone(), neighbor.clone());
reverse_matching.insert(neighbor, node.clone());
return true;
}
}
}
false
}
#[allow(dead_code)]
pub fn minimum_weight_bipartite_matching<N, E, Ix>(
graph: &Graph<N, E, Ix>,
) -> Result<(f64, Vec<(N, N)>)>
where
N: Node + Clone + Hash + Eq + std::fmt::Debug,
E: EdgeWeight + Into<f64> + Clone,
Ix: IndexType,
{
let bipartite_result = is_bipartite(graph);
if !bipartite_result.is_bipartite {
return Err(GraphError::InvalidGraph(
"Graph is not bipartite".to_string(),
));
}
let coloring = bipartite_result.coloring;
let mut left_nodes = Vec::new();
let mut right_nodes = Vec::new();
for (node, &color) in &coloring {
if color == 0 {
left_nodes.push(node.clone());
} else {
right_nodes.push(node.clone());
}
}
let n_left = left_nodes.len();
let n_right = right_nodes.len();
if n_left != n_right {
return Err(GraphError::InvalidGraph(
"Bipartite graph must have equal number of nodes in each partition for perfect matching".to_string()
));
}
if n_left == 0 {
return Ok((0.0, vec![]));
}
let mut cost_matrix = vec![vec![f64::INFINITY; n_right]; n_left];
for (i, left_node) in left_nodes.iter().enumerate() {
for (j, right_node) in right_nodes.iter().enumerate() {
if let Ok(weight) = graph.edge_weight(left_node, right_node) {
cost_matrix[i][j] = weight.into();
}
}
}
if n_left <= 6 {
minimum_weight_matching_bruteforce(&left_nodes, &right_nodes, &cost_matrix)
} else {
minimum_weight_matching_greedy(&left_nodes, &right_nodes, &cost_matrix)
}
}
#[allow(dead_code)]
fn minimum_weight_matching_bruteforce<N>(
left_nodes: &[N],
right_nodes: &[N],
cost_matrix: &[Vec<f64>],
) -> Result<(f64, Vec<(N, N)>)>
where
N: Node + Clone + std::fmt::Debug,
{
let n = left_nodes.len();
let mut best_cost = f64::INFINITY;
let mut best_matching = Vec::new();
let mut perm: Vec<usize> = (0..n).collect();
loop {
let mut cost = 0.0;
for i in 0..n {
cost += cost_matrix[i][perm[i]];
}
if cost < best_cost {
best_cost = cost;
best_matching = (0..n)
.map(|i| (left_nodes[i].clone(), right_nodes[perm[i]].clone()))
.collect();
}
if !next_permutation(&mut perm) {
break;
}
}
Ok((best_cost, best_matching))
}
#[allow(dead_code)]
fn minimum_weight_matching_greedy<N>(
left_nodes: &[N],
right_nodes: &[N],
cost_matrix: &[Vec<f64>],
) -> Result<(f64, Vec<(N, N)>)>
where
N: Node + Clone + std::fmt::Debug,
{
let n = left_nodes.len();
let mut matching = Vec::new();
let mut used_right = vec![false; n];
let mut total_cost = 0.0;
for i in 0..n {
let mut best_j = None;
let mut best_cost = f64::INFINITY;
for (j, &used) in used_right.iter().enumerate().take(n) {
if !used && cost_matrix[i][j] < best_cost {
best_cost = cost_matrix[i][j];
best_j = Some(j);
}
}
if let Some(j) = best_j {
used_right[j] = true;
total_cost += best_cost;
matching.push((left_nodes[i].clone(), right_nodes[j].clone()));
}
}
Ok((total_cost, matching))
}
#[allow(dead_code)]
fn next_permutation(perm: &mut [usize]) -> bool {
let n = perm.len();
let mut k = None;
for i in 0..n - 1 {
if perm[i] < perm[i + 1] {
k = Some(i);
}
}
let k = match k {
Some(k) => k,
None => return false, };
let mut l = k + 1;
for i in k + 1..n {
if perm[k] < perm[i] {
l = i;
}
}
perm.swap(k, l);
perm[k + 1..].reverse();
true
}
#[derive(Debug, Clone)]
pub struct MaximumMatching<N: Node> {
pub matching: Vec<(N, N)>,
pub size: usize,
}
#[allow(dead_code)]
pub fn maximum_cardinality_matching<N, E, Ix>(graph: &Graph<N, E, Ix>) -> MaximumMatching<N>
where
N: Node + Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
let n = nodes.len();
if n == 0 {
return MaximumMatching {
matching: Vec::new(),
size: 0,
};
}
let mut matching = Vec::new();
let mut matched = vec![false; n];
let node_to_idx: HashMap<N, usize> = nodes
.iter()
.enumerate()
.map(|(i, n)| (n.clone(), i))
.collect();
for (i, node) in nodes.iter().enumerate() {
if matched[i] {
continue;
}
if let Ok(neighbors) = graph.neighbors(node) {
for neighbor in neighbors {
if let Some(&j) = node_to_idx.get(&neighbor) {
if !matched[j] {
matching.push((node.clone(), neighbor));
matched[i] = true;
matched[j] = true;
break;
}
}
}
}
}
MaximumMatching {
size: matching.len(),
matching,
}
}
#[allow(dead_code)]
pub fn maximal_matching<N, E, Ix>(graph: &Graph<N, E, Ix>) -> MaximumMatching<N>
where
N: Node + Clone + std::fmt::Debug,
E: EdgeWeight,
Ix: IndexType,
{
let mut matching = Vec::new();
let mut matched_nodes = HashSet::new();
let edges = graph.edges();
for edge in edges {
if !matched_nodes.contains(&edge.source) && !matched_nodes.contains(&edge.target) {
matching.push((edge.source.clone(), edge.target.clone()));
matched_nodes.insert(edge.source);
matched_nodes.insert(edge.target);
}
}
MaximumMatching {
size: matching.len(),
matching,
}
}
#[allow(dead_code)]
pub fn stable_marriage(
left_prefs: &[Vec<usize>],
right_prefs: &[Vec<usize>],
) -> Result<Vec<(usize, usize)>> {
let n = left_prefs.len();
if n != right_prefs.len() {
return Err(GraphError::InvalidGraph(
"Left and right sets must have equal size".to_string(),
));
}
if n == 0 {
return Ok(Vec::new());
}
for (i, prefs) in left_prefs.iter().enumerate() {
if prefs.len() != n {
return Err(GraphError::InvalidGraph(format!(
"Left preference list {i} has wrong length"
)));
}
let mut sorted_prefs = prefs.clone();
sorted_prefs.sort_unstable();
if sorted_prefs != (0..n).collect::<Vec<_>>() {
return Err(GraphError::InvalidGraph(format!(
"Left preference list {i} is not a valid permutation"
)));
}
}
for (i, prefs) in right_prefs.iter().enumerate() {
if prefs.len() != n {
return Err(GraphError::InvalidGraph(format!(
"Right preference list {i} has wrong length"
)));
}
let mut sorted_prefs = prefs.clone();
sorted_prefs.sort_unstable();
if sorted_prefs != (0..n).collect::<Vec<_>>() {
return Err(GraphError::InvalidGraph(format!(
"Right preference list {i} is not a valid permutation"
)));
}
}
let mut right_inv_prefs = vec![vec![0; n]; n];
for (i, prefs) in right_prefs.iter().enumerate() {
for (rank, &person) in prefs.iter().enumerate() {
right_inv_prefs[i][person] = rank;
}
}
let mut left_partner = vec![None; n];
let mut right_partner = vec![None; n];
let mut left_next_proposal = vec![0; n];
let mut free_left: std::collections::VecDeque<usize> = (0..n).collect();
while let Some(left) = free_left.pop_front() {
if left_next_proposal[left] >= n {
continue; }
let right = left_prefs[left][left_next_proposal[left]];
left_next_proposal[left] += 1;
match right_partner[right] {
None => {
left_partner[left] = Some(right);
right_partner[right] = Some(left);
}
Some(current_left) => {
if right_inv_prefs[right][left] < right_inv_prefs[right][current_left] {
left_partner[left] = Some(right);
right_partner[right] = Some(left);
left_partner[current_left] = None;
free_left.push_back(current_left);
} else {
free_left.push_back(left);
}
}
}
}
let mut result = Vec::new();
for (left, partner) in left_partner.iter().enumerate() {
if let Some(right) = partner {
result.push((left, *right));
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result as GraphResult;
use crate::generators::create_graph;
#[test]
fn test_maximum_bipartite_matching() -> GraphResult<()> {
let mut graph = create_graph::<&str, ()>();
graph.add_edge("A", "1", ())?;
graph.add_edge("A", "2", ())?;
graph.add_edge("B", "2", ())?;
graph.add_edge("B", "3", ())?;
graph.add_edge("C", "3", ())?;
let mut coloring = HashMap::new();
coloring.insert("A", 0);
coloring.insert("B", 0);
coloring.insert("C", 0);
coloring.insert("1", 1);
coloring.insert("2", 1);
coloring.insert("3", 1);
let matching = maximum_bipartite_matching(&graph, &coloring);
assert_eq!(matching.size, 3);
let mut used_right = HashSet::new();
for right in matching.matching.values() {
assert!(!used_right.contains(right));
used_right.insert(right);
}
Ok(())
}
#[test]
fn test_minimum_weight_bipartite_matching() -> GraphResult<()> {
let mut graph = create_graph::<&str, f64>();
graph.add_edge("A", "1", 1.0)?;
graph.add_edge("A", "2", 3.0)?;
graph.add_edge("B", "1", 2.0)?;
graph.add_edge("B", "2", 1.0)?;
let (total_weight, matching) = minimum_weight_bipartite_matching(&graph)?;
assert_eq!(total_weight, 2.0);
assert_eq!(matching.len(), 2);
Ok(())
}
#[test]
fn test_maximum_cardinality_matching() {
let mut graph = create_graph::<&str, ()>();
graph.add_edge("A", "B", ()).expect("Operation failed");
graph.add_edge("C", "D", ()).expect("Operation failed");
graph.add_edge("E", "F", ()).expect("Operation failed");
let matching = maximum_cardinality_matching(&graph);
assert_eq!(matching.size, 3);
assert_eq!(matching.matching.len(), 3);
let mut matched_nodes = HashSet::new();
for (u, v) in &matching.matching {
assert!(!matched_nodes.contains(u));
assert!(!matched_nodes.contains(v));
matched_nodes.insert(u);
matched_nodes.insert(v);
}
}
#[test]
fn test_maximal_matching() {
let mut graph = create_graph::<i32, ()>();
graph.add_edge(1, 2, ()).expect("Operation failed");
graph.add_edge(2, 3, ()).expect("Operation failed");
graph.add_edge(3, 1, ()).expect("Operation failed");
let matching = maximal_matching(&graph);
assert_eq!(matching.size, 1);
assert_eq!(matching.matching.len(), 1);
let mut matched_nodes = HashSet::new();
for (u, v) in &matching.matching {
assert!(!matched_nodes.contains(u));
assert!(!matched_nodes.contains(v));
matched_nodes.insert(u);
matched_nodes.insert(v);
}
}
#[test]
fn test_stable_marriage() -> GraphResult<()> {
let left_prefs = vec![
vec![0, 1, 2], vec![1, 0, 2], vec![0, 1, 2], ];
let right_prefs = vec![
vec![2, 1, 0], vec![0, 2, 1], vec![0, 1, 2], ];
let matching = stable_marriage(&left_prefs, &right_prefs)?;
assert_eq!(matching.len(), 3);
let mut matched_left = HashSet::new();
let mut matched_right = HashSet::new();
for (left, right) in &matching {
assert!(!matched_left.contains(left));
assert!(!matched_right.contains(right));
matched_left.insert(*left);
matched_right.insert(*right);
}
Ok(())
}
#[test]
fn test_stable_marriage_empty() -> GraphResult<()> {
let left_prefs: Vec<Vec<usize>> = vec![];
let right_prefs: Vec<Vec<usize>> = vec![];
let matching = stable_marriage(&left_prefs, &right_prefs)?;
assert_eq!(matching.len(), 0);
Ok(())
}
#[test]
fn test_stable_marriage_invalid_input() {
let left_prefs = vec![vec![0]];
let right_prefs = vec![vec![0], vec![1]];
assert!(stable_marriage(&left_prefs, &right_prefs).is_err());
let left_prefs = vec![vec![0, 0]]; let right_prefs = vec![vec![0, 1]];
assert!(stable_marriage(&left_prefs, &right_prefs).is_err());
}
}