use super::EquationInfo;
use std::collections::{HashMap, HashSet, VecDeque};
const NIL: usize = usize::MAX;
pub(super) struct HopcroftKarp {
n_equations: usize,
adj: Vec<Vec<usize>>,
pair_eq: Vec<usize>,
pair_var: Vec<usize>,
dist: Vec<usize>,
}
impl HopcroftKarp {
pub fn new(n_equations: usize, n_variables: usize, adj: Vec<Vec<usize>>) -> Self {
Self {
n_equations,
adj,
pair_eq: vec![NIL; n_equations],
pair_var: vec![NIL; n_variables],
dist: vec![0; n_equations + 1],
}
}
pub fn max_matching(&mut self) -> usize {
let mut matching = 0;
while self.bfs() {
for eq in 0..self.n_equations {
if self.pair_eq[eq] == NIL && self.dfs(eq) {
matching += 1;
}
}
}
matching
}
fn bfs(&mut self) -> bool {
let mut queue = VecDeque::new();
for eq in 0..self.n_equations {
if self.pair_eq[eq] == NIL {
self.dist[eq] = 0;
queue.push_back(eq);
} else {
self.dist[eq] = usize::MAX;
}
}
self.dist[self.n_equations] = usize::MAX;
while let Some(eq) = queue.pop_front() {
if self.dist[eq] < self.dist[self.n_equations] {
for &var in &self.adj[eq] {
let next_eq = self.pair_var[var];
let next_idx = if next_eq == NIL {
self.n_equations
} else {
next_eq
};
if self.dist[next_idx] == usize::MAX {
self.dist[next_idx] = self.dist[eq] + 1;
if next_eq != NIL {
queue.push_back(next_eq);
}
}
}
}
}
self.dist[self.n_equations] != usize::MAX
}
fn dfs(&mut self, eq: usize) -> bool {
if eq == NIL {
return true;
}
for i in 0..self.adj[eq].len() {
let var = self.adj[eq][i];
let next_eq = self.pair_var[var];
let next_idx = if next_eq == NIL {
self.n_equations
} else {
next_eq
};
if self.dist[next_idx] == self.dist[eq] + 1 && self.dfs(next_eq) {
self.pair_var[var] = eq;
self.pair_eq[eq] = var;
return true;
}
}
self.dist[eq] = usize::MAX;
false
}
pub fn get_equation_matching(&self) -> Vec<Option<usize>> {
self.pair_eq
.iter()
.map(|&v| if v == NIL { None } else { Some(v) })
.collect()
}
}
pub(super) fn find_maximum_matching(
eq_infos: &[EquationInfo],
all_variables: &[String],
exclude_from_matching: &HashSet<String>,
) -> HashMap<usize, String> {
let n_equations = eq_infos.len();
let n_variables = all_variables.len();
let var_to_idx: HashMap<&String, usize> = all_variables
.iter()
.enumerate()
.map(|(i, v)| (v, i))
.collect();
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n_equations];
let mut reverse_adj: Vec<Vec<usize>> = vec![Vec::new(); n_variables];
for (eq_idx, info) in eq_infos.iter().enumerate() {
let mut candidates: Vec<usize> = Vec::new();
for var in &info.all_variables {
if !exclude_from_matching.contains(var)
&& let Some(&var_idx) = var_to_idx.get(var)
&& !candidates.contains(&var_idx)
{
candidates.push(var_idx);
reverse_adj[var_idx].push(eq_idx);
}
}
adj[eq_idx] = candidates;
}
let mut forced_eq_to_var: HashMap<usize, usize> = HashMap::new();
let mut forced_var_to_eq: HashMap<usize, usize> = HashMap::new();
let mut changed = true;
while changed {
changed = false;
for (var_idx, var_eqs) in reverse_adj.iter().enumerate() {
if forced_var_to_eq.contains_key(&var_idx) {
continue; }
let available_eqs: Vec<usize> = var_eqs
.iter()
.filter(|&&eq_idx| !forced_eq_to_var.contains_key(&eq_idx))
.copied()
.collect();
if available_eqs.len() == 1 {
let eq_idx = available_eqs[0];
forced_eq_to_var.insert(eq_idx, var_idx);
forced_var_to_eq.insert(var_idx, eq_idx);
changed = true;
}
}
}
let mut adj_modified: Vec<Vec<usize>> = vec![Vec::new(); n_equations];
for eq_idx in 0..n_equations {
if let Some(&forced_var) = forced_eq_to_var.get(&eq_idx) {
adj_modified[eq_idx] = vec![forced_var];
} else {
let info = &eq_infos[eq_idx];
let mut candidates: Vec<usize> = Vec::new();
if let Some(ref lhs_var) = info.lhs_variable
&& !exclude_from_matching.contains(lhs_var)
&& let Some(&var_idx) = var_to_idx.get(lhs_var)
&& !forced_var_to_eq.contains_key(&var_idx)
{
candidates.push(var_idx);
}
for var in &info.all_variables {
if !exclude_from_matching.contains(var)
&& let Some(&var_idx) = var_to_idx.get(var)
&& !forced_var_to_eq.contains_key(&var_idx)
&& !candidates.contains(&var_idx)
{
candidates.push(var_idx);
}
}
adj_modified[eq_idx] = candidates;
}
}
let mut hk = HopcroftKarp::new(n_equations, n_variables, adj_modified);
let _matching_size = hk.max_matching();
let matching = hk.get_equation_matching();
let mut result = HashMap::new();
for (eq_idx, var_idx_opt) in matching.iter().enumerate() {
if let Some(var_idx) = var_idx_opt {
result.insert(eq_idx, all_variables[*var_idx].clone());
}
}
let matched_vars: HashSet<_> = result.values().cloned().collect();
let all_vars_set: HashSet<_> = all_variables.iter().cloned().collect();
let unmatched_vars: Vec<_> = all_vars_set.difference(&matched_vars).cloned().collect();
if !unmatched_vars.is_empty() {
result = fix_unmatched_variables(&result, &unmatched_vars, &reverse_adj, all_variables);
}
result
}
fn fix_unmatched_variables(
initial_matching: &HashMap<usize, String>,
unmatched_vars: &[String],
reverse_adj: &[Vec<usize>],
all_variables: &[String],
) -> HashMap<usize, String> {
let mut result = initial_matching.clone();
let var_to_idx: HashMap<&String, usize> = all_variables
.iter()
.enumerate()
.map(|(i, v)| (v, i))
.collect();
let mut var_to_eq: HashMap<usize, usize> = HashMap::new();
for (&eq_idx, var_name) in &result {
if let Some(&var_idx) = var_to_idx.get(var_name) {
var_to_eq.insert(var_idx, eq_idx);
}
}
for unmatched_var in unmatched_vars {
let Some(&unmatched_var_idx) = var_to_idx.get(unmatched_var) else {
continue;
};
let candidate_eqs: Vec<usize> = reverse_adj[unmatched_var_idx].clone();
for candidate_eq in candidate_eqs {
let current_var_name = match result.get(&candidate_eq) {
None => {
result.insert(candidate_eq, unmatched_var.clone());
var_to_eq.insert(unmatched_var_idx, candidate_eq);
break;
}
Some(name) => name.clone(),
};
let Some(¤t_var_idx) = var_to_idx.get(¤t_var_name) else {
continue;
};
let matched_eqs: HashSet<usize> = result.keys().copied().collect();
let other_eqs: Vec<usize> = reverse_adj[current_var_idx]
.iter()
.filter(|&&eq| eq != candidate_eq && !matched_eqs.contains(&eq))
.copied()
.collect();
if !other_eqs.is_empty() {
let other_eq = other_eqs[0];
result.insert(candidate_eq, unmatched_var.clone());
result.insert(other_eq, current_var_name.clone());
var_to_eq.insert(unmatched_var_idx, candidate_eq);
var_to_eq.insert(current_var_idx, other_eq);
break;
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::ast::{ComponentRefPart, ComponentReference, Equation, Expression, Token};
fn make_var(name: &str) -> Expression {
Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.to_string(),
..Default::default()
},
subs: None,
}],
})
}
#[test]
fn test_hopcroft_karp_simple_matching() {
let adj = vec![vec![0], vec![1], vec![2]];
let mut hk = HopcroftKarp::new(3, 3, adj);
let matching_size = hk.max_matching();
assert_eq!(matching_size, 3, "Should find perfect matching of size 3");
let matching = hk.get_equation_matching();
assert_eq!(matching[0], Some(0));
assert_eq!(matching[1], Some(1));
assert_eq!(matching[2], Some(2));
}
#[test]
fn test_hopcroft_karp_requires_augmenting_path() {
let adj = vec![vec![0, 1], vec![0]];
let mut hk = HopcroftKarp::new(2, 2, adj);
let matching_size = hk.max_matching();
assert_eq!(matching_size, 2, "Should find perfect matching of size 2");
let matching = hk.get_equation_matching();
assert_eq!(matching[0], Some(1));
assert_eq!(matching[1], Some(0));
}
#[test]
fn test_hopcroft_karp_incomplete_matching() {
let adj = vec![vec![0], vec![0], vec![1]];
let mut hk = HopcroftKarp::new(3, 2, adj);
let matching_size = hk.max_matching();
assert_eq!(matching_size, 2, "Should find matching of size 2");
}
#[test]
fn test_hopcroft_karp_complex_augmenting() {
let adj = vec![vec![0, 1], vec![0, 2], vec![1, 2]];
let mut hk = HopcroftKarp::new(3, 3, adj);
let matching_size = hk.max_matching();
assert_eq!(matching_size, 3, "Should find perfect matching of size 3");
}
#[test]
fn test_hopcroft_karp_empty() {
let adj: Vec<Vec<usize>> = vec![];
let mut hk = HopcroftKarp::new(0, 0, adj);
let matching_size = hk.max_matching();
assert_eq!(matching_size, 0);
}
#[test]
fn test_hopcroft_karp_no_edges() {
let adj = vec![vec![], vec![], vec![]];
let mut hk = HopcroftKarp::new(3, 3, adj);
let matching_size = hk.max_matching();
assert_eq!(matching_size, 0, "No matching possible without edges");
}
#[test]
fn test_find_maximum_matching_integration() {
let eq_infos = vec![
EquationInfo {
equation: Equation::Simple {
lhs: make_var("x"),
rhs: make_var("y"),
},
all_variables: ["x".to_string(), "y".to_string()].into_iter().collect(),
lhs_variable: Some("x".to_string()),
is_derivative: false,
matched_variable: None,
},
EquationInfo {
equation: Equation::Simple {
lhs: make_var("y"),
rhs: make_var("z"),
},
all_variables: ["y".to_string(), "z".to_string()].into_iter().collect(),
lhs_variable: Some("y".to_string()),
is_derivative: false,
matched_variable: None,
},
];
let all_variables = vec!["x".to_string(), "y".to_string(), "z".to_string()];
let matching = find_maximum_matching(&eq_infos, &all_variables, &HashSet::new());
assert_eq!(matching.len(), 2, "Both equations should be matched");
assert!(matching.contains_key(&0));
assert!(matching.contains_key(&1));
}
}