use crate::csc::CscMatrix;
fn augment(
dual_idx: usize,
dual_neighbors: &[Vec<usize>],
match_dual_to_primal: &mut [Option<usize>],
match_primal_to_dual: &mut [Option<usize>],
visited: &mut [bool],
) -> bool {
for &primal_idx in &dual_neighbors[dual_idx] {
if visited[primal_idx] {
continue;
}
visited[primal_idx] = true;
if match_primal_to_dual[primal_idx].is_none()
|| augment(
match_primal_to_dual[primal_idx].unwrap(),
dual_neighbors,
match_dual_to_primal,
match_primal_to_dual,
visited,
)
{
match_dual_to_primal[dual_idx] = Some(primal_idx);
match_primal_to_dual[primal_idx] = Some(dual_idx);
return true;
}
}
false
}
pub fn kkt_matching_ordering(csc: &CscMatrix, n_primal: usize) -> (Vec<usize>, Vec<usize>) {
let dim = csc.n;
let m_dual = dim - n_primal;
if m_dual == 0 || n_primal == 0 {
return super::amd::amd_ordering(csc);
}
let mut dual_neighbors: Vec<Vec<usize>> = vec![Vec::new(); m_dual];
for dual_idx in 0..m_dual {
let col = n_primal + dual_idx;
let mut neighbors: Vec<(usize, f64)> = Vec::new();
for idx in csc.col_ptr[col]..csc.col_ptr[col + 1] {
let row = csc.row_idx[idx];
if row < n_primal {
let weight = csc.vals[idx].abs();
if weight > 0.0 {
neighbors.push((row, weight));
}
}
}
neighbors.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
dual_neighbors[dual_idx] = neighbors.into_iter().map(|(p, _)| p).collect();
}
let mut match_dual_to_primal: Vec<Option<usize>> = vec![None; m_dual];
let mut match_primal_to_dual: Vec<Option<usize>> = vec![None; n_primal];
for dual_idx in 0..m_dual {
let mut visited = vec![false; n_primal];
augment(dual_idx, &dual_neighbors, &mut match_dual_to_primal,
&mut match_primal_to_dual, &mut visited);
}
let num_matched = match_dual_to_primal.iter().filter(|m| m.is_some()).count();
let num_unmatched_primal = n_primal - num_matched;
let num_unmatched_dual = m_dual - num_matched;
let compressed_n = num_matched + num_unmatched_primal + num_unmatched_dual;
let mut orig_to_compressed = vec![0usize; dim];
let mut compressed_to_orig: Vec<Vec<usize>> = Vec::with_capacity(compressed_n);
let mut pair_idx = 0;
for dual_idx in 0..m_dual {
if let Some(primal_idx) = match_dual_to_primal[dual_idx] {
orig_to_compressed[primal_idx] = pair_idx;
orig_to_compressed[n_primal + dual_idx] = pair_idx;
compressed_to_orig.push(vec![primal_idx, n_primal + dual_idx]);
pair_idx += 1;
}
}
let mut unmatched_idx = num_matched;
for j in 0..n_primal {
if match_primal_to_dual[j].is_none() {
orig_to_compressed[j] = unmatched_idx;
compressed_to_orig.push(vec![j]);
unmatched_idx += 1;
}
}
for i in 0..m_dual {
if match_dual_to_primal[i].is_none() {
orig_to_compressed[n_primal + i] = unmatched_idx;
compressed_to_orig.push(vec![n_primal + i]);
unmatched_idx += 1;
}
}
assert_eq!(compressed_to_orig.len(), compressed_n);
let mut adj: Vec<std::collections::BTreeSet<usize>> = vec![std::collections::BTreeSet::new(); compressed_n];
for j in 0..dim {
let cj = orig_to_compressed[j];
for idx in csc.col_ptr[j]..csc.col_ptr[j + 1] {
let i = csc.row_idx[idx];
let ci = orig_to_compressed[i];
if ci != cj {
adj[cj].insert(ci);
adj[ci].insert(cj);
}
}
}
let mut comp_col_ptr = vec![0i64; compressed_n + 1];
for j in 0..compressed_n {
comp_col_ptr[j + 1] = comp_col_ptr[j] + adj[j].len() as i64;
}
let comp_nnz = comp_col_ptr[compressed_n] as usize;
let mut comp_row_idx = vec![0i64; comp_nnz];
for j in 0..compressed_n {
let start = comp_col_ptr[j] as usize;
for (k, &neighbor) in adj[j].iter().enumerate() {
comp_row_idx[start + k] = neighbor as i64;
}
}
let comp_perm = if compressed_n <= 2 || comp_nnz == 0 {
(0..compressed_n).collect::<Vec<_>>()
} else {
let control = amd::Control::default();
match amd::order::<i64>(
compressed_n as i64, &comp_col_ptr, &comp_row_idx, &control,
) {
Ok((perm_i64, _, _)) => {
perm_i64.iter().map(|&x| x as usize).collect::<Vec<_>>()
}
Err(_) => (0..compressed_n).collect(),
}
};
let mut perm = Vec::with_capacity(dim);
for &comp_idx in &comp_perm {
let orig_cols = &compressed_to_orig[comp_idx];
perm.extend_from_slice(orig_cols);
}
assert_eq!(perm.len(), dim);
let mut perm_inv = vec![0usize; dim];
for (new, &old) in perm.iter().enumerate() {
perm_inv[old] = new;
}
(perm, perm_inv)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::coo::CooMatrix;
#[test]
fn test_kkt_matching_simple() {
let coo = CooMatrix::new(3,
vec![0, 1, 0, 1],
vec![0, 1, 2, 2],
vec![2.0, 2.0, 1.0, 1.0],
).unwrap();
let csc = CscMatrix::from_coo(&coo);
let (perm, perm_inv) = kkt_matching_ordering(&csc, 2);
assert_eq!(perm.len(), 3);
let mut sorted = perm.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2]);
for i in 0..3 {
assert_eq!(perm_inv[perm[i]], i);
}
}
#[test]
fn test_kkt_matching_preserves_adjacency() {
let coo = CooMatrix::new(5,
vec![0, 1, 2, 0, 1],
vec![0, 1, 2, 3, 4],
vec![2.0, 2.0, 2.0, 3.0, 4.0],
).unwrap();
let csc = CscMatrix::from_coo(&coo);
let (perm, perm_inv) = kkt_matching_ordering(&csc, 3);
assert_eq!(perm.len(), 5);
let mut sorted = perm.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2, 3, 4]);
let pos0 = perm_inv[0];
let pos3 = perm_inv[3];
assert!((pos0 as isize - pos3 as isize).unsigned_abs() == 1,
"Primal 0 and dual 0 should be adjacent: pos0={}, pos3={}", pos0, pos3);
let pos1 = perm_inv[1];
let pos4 = perm_inv[4];
assert!((pos1 as isize - pos4 as isize).unsigned_abs() == 1,
"Primal 1 and dual 1 should be adjacent: pos1={}, pos4={}", pos1, pos4);
}
}