use faer::perm::Perm;
use faer::sparse::SparseColMat;
use crate::error::SparseError;
pub struct Mc64Result {
pub matching: Perm<usize>,
pub scaling: Vec<f64>,
pub matched: usize,
pub is_matched: Vec<bool>,
}
#[non_exhaustive]
pub enum Mc64Job {
MaximumProduct,
}
struct CostGraph {
col_ptr: Vec<usize>,
row_idx: Vec<usize>,
cost: Vec<f64>,
col_max_log: Vec<f64>,
n: usize,
}
struct MatchingState {
row_match: Vec<usize>,
col_match: Vec<usize>,
u: Vec<f64>,
}
const UNMATCHED: usize = usize::MAX;
const LOG_SCALE_CLAMP: f64 = 500.0;
pub fn mc64_matching(
matrix: &SparseColMat<usize, f64>,
_job: Mc64Job,
) -> Result<Mc64Result, SparseError> {
let (nrows, ncols) = (matrix.nrows(), matrix.ncols());
if nrows != ncols {
return Err(SparseError::NotSquare {
dims: (nrows, ncols),
});
}
let n = nrows;
if n == 0 {
return Err(SparseError::InvalidInput {
reason: "MC64 requires non-empty matrix".to_string(),
});
}
let symbolic = matrix.symbolic();
let values = matrix.val();
for j in 0..n {
let start = symbolic.col_ptr()[j];
let end = symbolic.col_ptr()[j + 1];
for &val in &values[start..end] {
if !val.is_finite() {
return Err(SparseError::InvalidInput {
reason: "MC64 requires finite matrix entries".to_string(),
});
}
}
}
if n == 1 {
let has_entry = symbolic.col_ptr()[1] > symbolic.col_ptr()[0];
let scale = if has_entry {
let val = values[symbolic.col_ptr()[0]];
if val.abs() > 0.0 {
1.0 / val.abs().sqrt()
} else {
1.0
}
} else {
1.0
};
let fwd: Box<[usize]> = vec![0].into_boxed_slice();
let inv: Box<[usize]> = vec![0].into_boxed_slice();
return Ok(Mc64Result {
matching: Perm::new_checked(fwd, inv, 1),
scaling: vec![scale],
matched: if has_entry { 1 } else { 0 },
is_matched: vec![has_entry],
});
}
let graph = build_cost_graph(matrix);
let mut state = greedy_initial_matching(&graph);
let mut ds = DijkstraState::new(n);
ds.init_jperm(&graph, &state);
for j in 0..n {
if state.col_match[j] != UNMATCHED {
continue;
}
dijkstra_augment(j, &graph, &mut state, &mut ds);
}
let matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
if matched == n {
#[cfg(debug_assertions)]
assert_dual_feasibility(&graph, &state);
let (scaling, fwd, inv) = build_full_match_result(&graph, &state);
return Ok(Mc64Result {
matching: Perm::new_checked(fwd.into_boxed_slice(), inv.into_boxed_slice(), n),
scaling,
matched,
is_matched: vec![true; n],
});
}
let is_row_matched: Vec<bool> = (0..n).map(|i| state.row_match[i] != UNMATCHED).collect();
#[cfg(debug_assertions)]
assert_dual_feasibility(&graph, &state);
for i in 0..n {
if state.row_match[i] == UNMATCHED {
state.u[i] = 0.0;
}
}
let v = compute_column_duals(&graph, &state);
let scaling = symmetrize_scaling(&state.u, &v, &graph.col_max_log);
let is_matched = is_row_matched;
let (fwd, inv) = build_singular_permutation(n, &state, &is_matched);
Ok(Mc64Result {
matching: Perm::new_checked(fwd.into_boxed_slice(), inv.into_boxed_slice(), n),
scaling,
matched,
is_matched,
})
}
fn compute_column_duals(graph: &CostGraph, state: &MatchingState) -> Vec<f64> {
let n = graph.n;
let mut v = vec![0.0_f64; n];
for (j, v_j) in v.iter_mut().enumerate() {
let i = state.col_match[j];
if i != UNMATCHED {
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
for idx in col_start..col_end {
if graph.row_idx[idx] == i {
*v_j = graph.cost[idx] - state.u[i];
break;
}
}
}
}
v
}
#[cfg(debug_assertions)]
fn assert_dual_feasibility(graph: &CostGraph, state: &MatchingState) {
let eps = 1e-10;
let v = compute_column_duals(graph, state);
let n = graph.n;
for (j, &vj) in v.iter().enumerate().take(n) {
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
for idx in col_start..col_end {
let i = graph.row_idx[idx];
if state.row_match[i] == UNMATCHED {
continue;
}
let slack = graph.cost[idx] - state.u[i] - vj;
debug_assert!(
slack >= -eps,
"dual infeasibility: u[{}] + v[{}] - c[{},{}] = {:.6e} > eps",
i,
j,
i,
j,
-slack,
);
}
}
}
fn build_full_match_result(
graph: &CostGraph,
state: &MatchingState,
) -> (Vec<f64>, Vec<usize>, Vec<usize>) {
let n = graph.n;
let v = compute_column_duals(graph, state);
let scaling = symmetrize_scaling(&state.u, &v, &graph.col_max_log);
let mut fwd = vec![0usize; n];
for (i, fwd_i) in fwd.iter_mut().enumerate() {
*fwd_i = state.row_match[i];
}
let mut inv = vec![0usize; n];
for (i, &f) in fwd.iter().enumerate() {
inv[f] = i;
}
(scaling, fwd, inv)
}
fn build_singular_permutation(
n: usize,
state: &MatchingState,
is_matched: &[bool],
) -> (Vec<usize>, Vec<usize>) {
let mut fwd = vec![0usize; n];
let mut unmatched_rows: Vec<usize> = Vec::new();
for (i, fwd_i) in fwd.iter_mut().enumerate() {
if state.row_match[i] != UNMATCHED {
*fwd_i = state.row_match[i];
} else {
unmatched_rows.push(i);
}
}
let mut used_cols = vec![false; n];
for (i, &matched) in is_matched.iter().enumerate() {
if matched {
used_cols[state.row_match[i]] = true;
}
}
let free_cols: Vec<usize> = (0..n).filter(|&j| !used_cols[j]).collect();
for (idx, &i) in unmatched_rows.iter().enumerate() {
fwd[i] = free_cols[idx];
}
let mut inv = vec![0usize; n];
for (i, &f) in fwd.iter().enumerate() {
inv[f] = i;
}
(fwd, inv)
}
fn build_cost_graph(matrix: &SparseColMat<usize, f64>) -> CostGraph {
let n = matrix.nrows();
let symbolic = matrix.symbolic();
let values = matrix.val();
let col_ptrs = symbolic.col_ptr();
let row_indices = symbolic.row_idx();
let mut col_entries: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
for j in 0..n {
let start = col_ptrs[j];
let end = col_ptrs[j + 1];
for k in start..end {
let i = row_indices[k];
let abs_val = values[k].abs();
if abs_val == 0.0 {
continue; }
col_entries[j].push((i, abs_val));
if i != j {
col_entries[i].push((j, abs_val));
}
}
}
for entries in &mut col_entries {
entries.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.total_cmp(&b.1)));
entries.dedup_by_key(|entry| entry.0);
}
let mut col_max_log = vec![f64::NEG_INFINITY; n];
for j in 0..n {
for &(_, abs_val) in &col_entries[j] {
let log_val = abs_val.ln();
if log_val > col_max_log[j] {
col_max_log[j] = log_val;
}
}
}
let mut col_ptr = Vec::with_capacity(n + 1);
let mut row_idx = Vec::new();
let mut cost = Vec::new();
col_ptr.push(0);
for j in 0..n {
for &(i, abs_val) in &col_entries[j] {
let c = col_max_log[j] - abs_val.ln();
row_idx.push(i);
cost.push(c);
}
col_ptr.push(row_idx.len());
}
CostGraph {
col_ptr,
row_idx,
cost,
col_max_log,
n,
}
}
fn greedy_initial_matching(graph: &CostGraph) -> MatchingState {
let n = graph.n;
let mut row_match = vec![UNMATCHED; n];
let mut col_match = vec![UNMATCHED; n];
let mut u = vec![f64::INFINITY; n];
let mut best_col_for_row = vec![UNMATCHED; n]; let mut best_cost_pos = vec![0usize; n];
for j in 0..n {
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
for idx in col_start..col_end {
let i = graph.row_idx[idx];
let c = graph.cost[idx];
if c < u[i] {
u[i] = c;
best_col_for_row[i] = j;
best_cost_pos[i] = idx;
}
}
}
for u_i in &mut u {
if *u_i == f64::INFINITY {
*u_i = 0.0;
}
}
let dense_threshold = if n > 50 { n / 10 } else { n };
for i in 0..n {
let j = best_col_for_row[i];
if j == UNMATCHED {
continue;
}
if col_match[j] != UNMATCHED {
continue;
}
let col_degree = graph.col_ptr[j + 1] - graph.col_ptr[j];
if col_degree > dense_threshold {
continue;
}
row_match[i] = j;
col_match[j] = i;
}
let mut d_col = vec![0.0_f64; n]; let mut search_from = vec![0usize; n]; search_from[..n].copy_from_slice(&graph.col_ptr[..n]);
'col_loop: for j in 0..n {
if col_match[j] != UNMATCHED {
continue;
}
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
if col_start >= col_end {
continue; }
let mut best_i = graph.row_idx[col_start];
let mut best_rc = graph.cost[col_start] - u[best_i];
let mut best_k = col_start;
for idx in (col_start + 1)..col_end {
let i = graph.row_idx[idx];
let rc = graph.cost[idx] - u[i];
if rc < best_rc
|| (rc == best_rc && row_match[i] == UNMATCHED && row_match[best_i] != UNMATCHED)
{
best_rc = rc;
best_i = i;
best_k = idx;
}
}
d_col[j] = best_rc;
if row_match[best_i] == UNMATCHED {
row_match[best_i] = j;
col_match[j] = best_i;
search_from[j] = best_k + 1;
continue;
}
for idx in best_k..col_end {
let i = graph.row_idx[idx];
let rc = graph.cost[idx] - u[i];
if rc > best_rc {
continue;
}
let jj = row_match[i];
if jj == UNMATCHED {
continue;
}
let jj_end = graph.col_ptr[jj + 1];
for kk in search_from[jj]..jj_end {
let ii = graph.row_idx[kk];
if row_match[ii] != UNMATCHED {
continue;
}
let rc_ii = graph.cost[kk] - u[ii];
if rc_ii <= d_col[jj] {
col_match[jj] = ii;
row_match[ii] = jj;
search_from[jj] = kk + 1;
col_match[j] = i;
row_match[i] = j;
search_from[j] = idx + 1;
continue 'col_loop;
}
}
search_from[jj] = jj_end;
}
}
MatchingState {
row_match,
col_match,
u,
}
}
struct DijkstraState {
d: Vec<f64>,
l: Vec<usize>,
jperm: Vec<usize>,
pr: Vec<usize>,
out: Vec<usize>,
q: Vec<usize>,
root_edges: Vec<usize>,
}
impl DijkstraState {
fn new(n: usize) -> Self {
Self {
d: vec![f64::INFINITY; n],
l: vec![0; n],
jperm: vec![UNMATCHED; n],
pr: vec![UNMATCHED; n],
out: vec![0; n],
q: vec![0; n],
root_edges: Vec::new(),
}
}
fn cleanup_touched(&mut self, low: usize, qlen: usize, n: usize) {
for k in (low - 1)..n {
let i = self.q[k];
self.d[i] = f64::INFINITY;
self.l[i] = 0;
}
for k in 0..qlen {
let i = self.q[k];
self.d[i] = f64::INFINITY;
self.l[i] = 0;
}
}
fn init_jperm(&mut self, graph: &CostGraph, state: &MatchingState) {
let n = graph.n;
for j in 0..n {
let matched_row = state.col_match[j];
if matched_row == UNMATCHED {
self.jperm[j] = UNMATCHED;
continue;
}
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
for idx in col_start..col_end {
if graph.row_idx[idx] == matched_row {
self.jperm[j] = idx;
break;
}
}
}
}
}
fn dijkstra_augment(
root_col: usize,
graph: &CostGraph,
state: &mut MatchingState,
ds: &mut DijkstraState,
) -> bool {
let n = graph.n;
let mut csp = f64::INFINITY;
let mut isp: usize = 0; let mut jsp = UNMATCHED;
let mut qlen: usize = 0;
let mut low: usize = n + 1; let mut up: usize = n + 1; let mut dmin = f64::INFINITY;
ds.pr[root_col] = UNMATCHED; let col_start = graph.col_ptr[root_col];
let col_end = graph.col_ptr[root_col + 1];
ds.root_edges.clear();
for idx in col_start..col_end {
let i = graph.row_idx[idx];
let dnew = graph.cost[idx] - state.u[i];
if dnew >= csp {
continue;
}
if state.row_match[i] == UNMATCHED {
csp = dnew;
isp = idx;
jsp = root_col;
} else {
if dnew < dmin {
dmin = dnew;
}
ds.d[i] = dnew;
ds.root_edges.push(idx);
}
}
for k in 0..ds.root_edges.len() {
let idx = ds.root_edges[k];
let i = graph.row_idx[idx];
if csp <= ds.d[i] {
ds.d[i] = f64::INFINITY;
continue;
}
if ds.d[i] <= dmin {
low -= 1;
ds.q[low - 1] = i; ds.l[i] = low; } else {
qlen += 1;
ds.l[i] = qlen; ds.q[qlen - 1] = i;
heap_update_inline(i, &mut ds.q, &ds.d, &mut ds.l);
}
let jj = state.row_match[i];
ds.out[jj] = idx;
ds.pr[jj] = root_col;
}
for _jdum in 0..n {
if low == up {
if qlen == 0 {
break;
}
let top_i = ds.q[0];
if ds.d[top_i] >= csp {
break;
}
dmin = ds.d[top_i];
while qlen > 0 {
let top_i = ds.q[0];
if ds.d[top_i] > dmin {
break;
}
let popped = heap_pop_inline(&mut ds.q, &ds.d, &mut ds.l, &mut qlen);
low -= 1;
ds.q[low - 1] = popped;
ds.l[popped] = low;
}
}
let q0 = ds.q[up - 1 - 1]; let dq0 = ds.d[q0];
if dq0 >= csp {
break;
}
up -= 1;
let j = state.row_match[q0];
debug_assert!(
ds.jperm[j] != UNMATCHED,
"jperm[{}] not set for matched column",
j
);
let vj = dq0 - graph.cost[ds.jperm[j]] + state.u[q0];
let col_start_j = graph.col_ptr[j];
let col_end_j = graph.col_ptr[j + 1];
for idx in col_start_j..col_end_j {
let i = graph.row_idx[idx];
if ds.l[i] >= up {
continue;
}
let dnew = vj + graph.cost[idx] - state.u[i];
if dnew >= csp {
continue;
}
if state.row_match[i] == UNMATCHED {
csp = dnew;
isp = idx;
jsp = j;
} else {
let di = ds.d[i];
if di <= dnew {
continue;
}
if ds.l[i] >= low {
continue;
}
ds.d[i] = dnew;
if dnew <= dmin {
let lpos = ds.l[i];
if lpos != 0 {
heap_delete_inline(lpos, &mut ds.q, &ds.d, &mut ds.l, &mut qlen);
}
low -= 1;
ds.q[low - 1] = i;
ds.l[i] = low;
} else {
if ds.l[i] == 0 {
qlen += 1;
ds.l[i] = qlen;
ds.q[qlen - 1] = i;
}
heap_update_inline(i, &mut ds.q, &ds.d, &mut ds.l);
}
let jj = state.row_match[i];
ds.out[jj] = idx;
ds.pr[jj] = j;
}
}
}
if csp == f64::INFINITY {
ds.cleanup_touched(low, qlen, n);
return false;
}
let mut i = graph.row_idx[isp];
let mut j = jsp;
state.row_match[i] = j;
state.col_match[j] = i;
ds.jperm[j] = isp;
loop {
let jj = ds.pr[j];
if jj == UNMATCHED {
break;
}
let k = ds.out[j];
i = graph.row_idx[k];
state.row_match[i] = jj;
state.col_match[jj] = i;
ds.jperm[jj] = k;
j = jj;
}
for k in (up - 1)..n {
let i = ds.q[k];
state.u[i] = state.u[i] + ds.d[i] - csp;
}
ds.cleanup_touched(low, qlen, n);
true
}
fn heap_update_inline(idx: usize, q: &mut [usize], d: &[f64], pos: &mut [usize]) {
let mut p = pos[idx]; if p <= 1 {
q[0] = idx; return;
}
let v = d[idx];
while p > 1 {
let parent = p / 2;
let parent_idx = q[parent - 1];
if v >= d[parent_idx] {
break;
}
q[p - 1] = parent_idx;
pos[parent_idx] = p;
p = parent;
}
q[p - 1] = idx;
pos[idx] = p;
}
fn heap_pop_inline(q: &mut [usize], d: &[f64], pos: &mut [usize], qlen: &mut usize) -> usize {
let result = q[0];
heap_delete_inline(1, q, d, pos, qlen);
result
}
fn heap_delete_inline(
pos0: usize,
q: &mut [usize],
d: &[f64],
pos: &mut [usize],
qlen: &mut usize,
) {
if *qlen == pos0 {
*qlen -= 1;
return;
}
let last_idx = q[*qlen - 1];
let v = d[last_idx];
*qlen -= 1;
let mut p = pos0;
if p > 1 {
loop {
let parent = p / 2;
let parent_idx = q[parent - 1];
if v >= d[parent_idx] {
break;
}
q[p - 1] = parent_idx;
pos[parent_idx] = p;
p = parent;
if p <= 1 {
break;
}
}
}
q[p - 1] = last_idx;
pos[last_idx] = p;
if p != pos0 {
return; }
loop {
let child = 2 * p;
if child > *qlen {
break;
}
let mut child_d = d[q[child - 1]];
let mut best_child = child;
if child < *qlen {
let right_d = d[q[child]]; if child_d > right_d {
best_child = child + 1;
child_d = right_d;
}
}
if v <= child_d {
break;
}
let child_idx = q[best_child - 1];
q[p - 1] = child_idx;
pos[child_idx] = p;
p = best_child;
}
q[p - 1] = last_idx;
pos[last_idx] = p;
}
fn symmetrize_scaling(u: &[f64], v: &[f64], col_max_log: &[f64]) -> Vec<f64> {
let n = u.len();
let mut scaling = Vec::with_capacity(n);
for i in 0..n {
let log_scale = (u[i] + v[i] - col_max_log[i]) / 2.0;
let clamped = log_scale.clamp(-LOG_SCALE_CLAMP, LOG_SCALE_CLAMP);
scaling.push(clamped.exp());
}
scaling
}
#[cfg(test)]
fn duff_pralet_correction(
matrix: &SparseColMat<usize, f64>,
scaling: &mut [f64],
is_matched: &[bool],
) {
let n = matrix.nrows();
let symbolic = matrix.symbolic();
let values = matrix.val();
let col_ptrs = symbolic.col_ptr();
let row_indices = symbolic.row_idx();
let orig_scaling = scaling.to_vec();
let mut log_max = vec![f64::NEG_INFINITY; n];
for j in 0..n {
let start = col_ptrs[j];
let end = col_ptrs[j + 1];
for k in start..end {
let i = row_indices[k];
let abs_val = values[k].abs();
if abs_val == 0.0 {
continue;
}
if !is_matched[i] && is_matched[j] {
let contrib = abs_val.ln() + orig_scaling[j].ln();
if contrib > log_max[i] {
log_max[i] = contrib;
}
}
if i != j && !is_matched[j] && is_matched[i] {
let contrib = abs_val.ln() + orig_scaling[i].ln();
if contrib > log_max[j] {
log_max[j] = contrib;
}
}
}
}
for i in 0..n {
if is_matched[i] {
continue;
}
if log_max[i] == f64::NEG_INFINITY {
scaling[i] = 1.0;
} else {
scaling[i] = (-log_max[i]).exp();
}
}
}
pub fn count_cycles(matching: &[usize]) -> (usize, usize, usize) {
let n = matching.len();
let mut visited = vec![false; n];
let mut singletons = 0;
let mut two_cycles = 0;
let mut longer_cycles = 0;
for i in 0..n {
if visited[i] {
continue;
}
let j = matching[i];
if j == i {
singletons += 1;
visited[i] = true;
} else if matching[j] == i {
two_cycles += 1;
visited[i] = true;
visited[j] = true;
} else {
longer_cycles += 1;
let mut k = i;
loop {
visited[k] = true;
k = matching[k];
if k == i {
break;
}
}
}
}
(singletons, two_cycles, longer_cycles)
}
#[cfg(test)]
mod tests {
use super::*;
use faer::sparse::Triplet;
fn make_upper_tri(n: usize, entries: &[(usize, usize, f64)]) -> SparseColMat<usize, f64> {
let triplets: Vec<_> = entries
.iter()
.map(|&(i, j, v)| Triplet::new(i, j, v))
.collect();
SparseColMat::try_new_from_triplets(n, n, &triplets).unwrap()
}
fn make_3x3_test() -> SparseColMat<usize, f64> {
make_upper_tri(
3,
&[
(0, 0, 4.0),
(0, 1, 2.0),
(1, 1, 5.0),
(1, 2, 1.0),
(2, 2, 3.0),
],
)
}
#[test]
fn test_build_cost_graph_3x3() {
let matrix = make_3x3_test();
let graph = build_cost_graph(&matrix);
assert_eq!(graph.n, 3);
let col_count = |j: usize| graph.col_ptr[j + 1] - graph.col_ptr[j];
assert_eq!(col_count(0), 2, "col 0 should have 2 entries");
assert_eq!(col_count(1), 3, "col 1 should have 3 entries");
assert_eq!(col_count(2), 2, "col 2 should have 2 entries");
assert!((graph.col_max_log[0] - 4.0_f64.ln()).abs() < 1e-12);
assert!((graph.col_max_log[1] - 5.0_f64.ln()).abs() < 1e-12);
assert!((graph.col_max_log[2] - 3.0_f64.ln()).abs() < 1e-12);
for &c in &graph.cost {
assert!(c >= -1e-14, "cost {} should be non-negative", c);
}
for j in 0..3 {
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
for idx in col_start..col_end {
if graph.row_idx[idx] == j {
assert!(
graph.cost[idx].abs() < 1e-12,
"diagonal ({},{}) cost should be ~0, got {}",
j,
j,
graph.cost[idx]
);
}
}
}
}
#[test]
fn test_build_cost_graph_includes_diagonal() {
let matrix = make_upper_tri(2, &[(0, 0, 3.0), (0, 1, 1.0), (1, 1, 2.0)]);
let graph = build_cost_graph(&matrix);
let mut has_diag = [false; 2];
for (j, diag) in has_diag.iter_mut().enumerate() {
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
for idx in col_start..col_end {
if graph.row_idx[idx] == j {
*diag = true;
}
}
}
assert!(has_diag[0], "diagonal (0,0) missing");
assert!(has_diag[1], "diagonal (1,1) missing");
}
#[test]
fn test_build_cost_graph_symmetric_expansion() {
let matrix = make_upper_tri(
3,
&[
(0, 0, 1.0),
(0, 1, 2.0),
(1, 1, 3.0),
(1, 2, 4.0),
(2, 2, 5.0),
],
);
let graph = build_cost_graph(&matrix);
let has_entry = |col: usize, row: usize| -> bool {
let start = graph.col_ptr[col];
let end = graph.col_ptr[col + 1];
graph.row_idx[start..end].contains(&row)
};
assert!(has_entry(0, 1), "symmetric entry (1,0) should exist");
assert!(has_entry(1, 0), "entry (0,1) should exist");
assert!(
has_entry(2, 1),
"symmetric entry (1,2) should exist in col 2"
);
assert!(has_entry(1, 2), "entry (2,1) should exist in col 1");
}
#[test]
fn test_greedy_matching_4x4() {
let matrix = make_upper_tri(
4,
&[
(0, 0, 10.0),
(0, 1, 1.0),
(1, 1, 8.0),
(1, 2, 2.0),
(2, 2, 6.0),
(2, 3, 3.0),
(3, 3, 5.0),
],
);
let graph = build_cost_graph(&matrix);
let state = greedy_initial_matching(&graph);
let matched_count = state.row_match.iter().filter(|&&m| m != UNMATCHED).count();
assert!(
matched_count >= 3,
"greedy should match at least 3 of 4, got {}",
matched_count
);
for &ui in &state.u {
assert!(ui.is_finite(), "dual u should be finite");
}
for i in 0..4 {
let j = state.row_match[i];
if j == UNMATCHED {
continue;
}
let col_start = graph.col_ptr[j];
let col_end = graph.col_ptr[j + 1];
for idx in col_start..col_end {
if graph.row_idx[idx] == i {
break;
}
}
}
}
#[test]
fn test_dijkstra_augment_3x3() {
let matrix = make_upper_tri(
3,
&[
(0, 0, 5.0),
(0, 1, 3.0),
(0, 2, 1.0),
(1, 1, 4.0),
(1, 2, 2.0),
(2, 2, 6.0),
],
);
let graph = build_cost_graph(&matrix);
let mut state = greedy_initial_matching(&graph);
let mut ds = DijkstraState::new(3);
ds.init_jperm(&graph, &state);
let initial_matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
let mut augmented = false;
for j in 0..3 {
if state.col_match[j] == UNMATCHED && dijkstra_augment(j, &graph, &mut state, &mut ds) {
augmented = true;
}
}
let final_matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
if initial_matched < 3 {
assert!(augmented, "should find augmenting path");
assert!(
final_matched > initial_matched,
"matching size should increase"
);
}
for &ui in &state.u {
assert!(ui.is_finite(), "dual u should be finite after augmentation");
}
}
#[test]
fn test_symmetrize_scaling_known_duals() {
let u = vec![0.5, 1.0, 0.0];
let v = vec![0.2, 0.3, 0.8];
let col_max_log = vec![1.0, 1.5, 0.5];
let scaling = symmetrize_scaling(&u, &v, &col_max_log);
for i in 0..3 {
let expected = ((u[i] + v[i] - col_max_log[i]) / 2.0).exp();
assert!(
(scaling[i] - expected).abs() < 1e-12,
"scaling[{}] = {}, expected {}",
i,
scaling[i],
expected
);
}
}
#[test]
fn test_symmetrize_scaling_positive() {
let u = vec![1.0, -0.5, 2.0];
let v = vec![0.5, 1.5, -1.0];
let col_max_log = vec![0.0, 0.0, 0.0];
let scaling = symmetrize_scaling(&u, &v, &col_max_log);
for (i, &s) in scaling.iter().enumerate() {
assert!(s > 0.0, "scaling[{}] = {} should be positive", i, s);
assert!(s.is_finite(), "scaling[{}] should be finite", i);
}
}
#[test]
fn test_mc64_diagonal_identity() {
let matrix = make_upper_tri(3, &[(0, 0, 4.0), (1, 1, 9.0), (2, 2, 1.0)]);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
assert_eq!(result.matched, 3);
let (fwd, _) = result.matching.as_ref().arrays();
for (i, &f) in fwd.iter().enumerate() {
assert_eq!(f, i, "diagonal matrix matching should be identity");
}
for (i, &s) in result.scaling.iter().enumerate() {
assert!(s > 0.0, "scaling[{}] should be positive", i);
assert!(s.is_finite(), "scaling[{}] should be finite", i);
}
}
#[test]
fn test_mc64_tridiagonal_indefinite() {
let matrix = make_upper_tri(
4,
&[
(0, 0, 2.0),
(0, 1, -1.0),
(1, 1, -3.0),
(1, 2, 2.0),
(2, 2, 1.0),
(2, 3, -1.0),
(3, 3, -4.0),
],
);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
assert_eq!(result.matched, 4);
verify_scaling_properties(&matrix, &result);
}
#[test]
fn test_mc64_arrow_indefinite() {
let matrix = make_upper_tri(
5,
&[
(0, 0, 10.0),
(0, 1, 1.0),
(0, 2, 1.0),
(0, 3, 1.0),
(0, 4, 1.0),
(1, 1, -3.0),
(2, 2, 5.0),
(3, 3, -2.0),
(4, 4, 4.0),
],
);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
assert_eq!(result.matched, 5);
verify_scaling_properties(&matrix, &result);
}
#[test]
fn test_mc64_trivial_1x1() {
let matrix = make_upper_tri(1, &[(0, 0, 7.0)]);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
assert_eq!(result.matched, 1);
assert_eq!(result.scaling.len(), 1);
assert!(result.scaling[0] > 0.0);
}
#[test]
fn test_mc64_trivial_2x2() {
let matrix = make_upper_tri(2, &[(0, 0, 3.0), (0, 1, 1.0), (1, 1, 5.0)]);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
assert_eq!(result.matched, 2);
verify_scaling_properties(&matrix, &result);
}
#[test]
fn test_mc64_not_square_error() {
let triplets = vec![Triplet::new(0, 0, 1.0), Triplet::new(0, 1, 2.0)];
let matrix = SparseColMat::try_new_from_triplets(2, 3, &triplets).unwrap();
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
assert!(matches!(result, Err(SparseError::NotSquare { .. })));
}
#[test]
fn test_mc64_zero_dim_error() {
let triplets: Vec<Triplet<usize, usize, f64>> = vec![];
let matrix = SparseColMat::try_new_from_triplets(0, 0, &triplets).unwrap();
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
assert!(matches!(result, Err(SparseError::InvalidInput { .. })));
}
#[test]
fn test_count_cycles_identity() {
let matching = vec![0, 1, 2, 3];
let (s, c, l) = count_cycles(&matching);
assert_eq!(s, 4);
assert_eq!(c, 0);
assert_eq!(l, 0);
}
#[test]
fn test_count_cycles_two_swaps() {
let matching = vec![1, 0, 3, 2];
let (s, c, l) = count_cycles(&matching);
assert_eq!(s, 0);
assert_eq!(c, 2);
assert_eq!(l, 0);
}
#[test]
fn test_count_cycles_mixed() {
let matching = vec![0, 2, 1, 3, 4];
let (s, c, l) = count_cycles(&matching);
assert_eq!(s, 3); assert_eq!(c, 1); assert_eq!(l, 0);
}
#[test]
fn test_count_cycles_longer_cycle() {
let matching = vec![1, 2, 0, 3];
let (s, c, l) = count_cycles(&matching);
assert_eq!(s, 1); assert_eq!(c, 0);
assert_eq!(l, 1); }
fn verify_scaling_properties(matrix: &SparseColMat<usize, f64>, result: &Mc64Result) {
use crate::testing::verify_spral_scaling_properties;
verify_spral_scaling_properties("unit_test", matrix, result);
}
#[test]
fn test_duff_pralet_4x4_singular() {
let matrix = make_upper_tri(
4,
&[
(0, 0, 4.0),
(0, 1, 2.0),
(0, 3, 1.0),
(1, 1, 5.0),
(1, 2, 1.0),
(2, 2, 3.0),
],
);
let mut scaling = vec![0.5, 0.4, 0.6, 0.0]; let is_matched = vec![true, true, true, false];
duff_pralet_correction(&matrix, &mut scaling, &is_matched);
assert!(scaling[3] > 0.0, "unmatched scaling should be positive");
assert!(scaling[3].is_finite(), "unmatched scaling should be finite");
assert!((scaling[0] - 0.5).abs() < 1e-12);
assert!((scaling[1] - 0.4).abs() < 1e-12);
assert!((scaling[2] - 0.6).abs() < 1e-12);
}
#[test]
fn test_duff_pralet_isolated_row() {
let matrix = make_upper_tri(
3,
&[
(0, 0, 4.0),
(1, 1, 5.0),
(2, 2, 3.0),
],
);
let mut scaling = vec![0.5, 0.4, 0.0];
let is_matched = vec![true, true, false];
duff_pralet_correction(&matrix, &mut scaling, &is_matched);
assert_eq!(
scaling[2], 1.0,
"isolated unmatched row should get scaling 1.0"
);
}
#[test]
fn test_duff_pralet_all_positive() {
let matrix = make_upper_tri(
4,
&[
(0, 0, 4.0),
(0, 1, 2.0),
(0, 3, 1.0),
(1, 1, 5.0),
(1, 2, 1.0),
(2, 2, 3.0),
],
);
let mut scaling = vec![0.5, 0.4, 0.6, 0.0];
let is_matched = vec![true, true, true, false];
duff_pralet_correction(&matrix, &mut scaling, &is_matched);
for (i, &s) in scaling.iter().enumerate() {
assert!(s > 0.0, "scaling[{}] = {} should be positive", i, s);
assert!(s.is_finite(), "scaling[{}] = {} should be finite", i, s);
}
}
#[test]
fn test_mc64_singular_zero_diagonal() {
let matrix = make_upper_tri(4, &[(0, 1, 5.0), (2, 3, 3.0)]);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
for (i, &s) in result.scaling.iter().enumerate() {
assert!(s > 0.0, "scaling[{}] should be positive", i);
assert!(s.is_finite(), "scaling[{}] should be finite", i);
}
let (fwd, _) = result.matching.as_ref().arrays();
let mut seen = [false; 4];
for &f in fwd {
assert!(!seen[f], "duplicate in matching");
seen[f] = true;
}
}
#[test]
fn test_mc64_nan_entry_error() {
let triplets = vec![
Triplet::new(0, 0, 4.0),
Triplet::new(0, 1, f64::NAN),
Triplet::new(1, 1, 5.0),
];
let matrix = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
assert!(
matches!(result, Err(SparseError::InvalidInput { .. })),
"NaN entry should produce InvalidInput error"
);
}
#[test]
fn test_mc64_inf_entry_error() {
let triplets = vec![
Triplet::new(0, 0, 4.0),
Triplet::new(0, 1, f64::INFINITY),
Triplet::new(1, 1, 5.0),
];
let matrix = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
assert!(
matches!(result, Err(SparseError::InvalidInput { .. })),
"Inf entry should produce InvalidInput error"
);
}
#[test]
fn test_greedy_matching_diagonal_perfect() {
let matrix = make_upper_tri(4, &[(0, 0, 10.0), (1, 1, 20.0), (2, 2, 5.0), (3, 3, 15.0)]);
let graph = build_cost_graph(&matrix);
let state = greedy_initial_matching(&graph);
let matched_count = state.row_match.iter().filter(|&&m| m != UNMATCHED).count();
assert_eq!(
matched_count, 4,
"greedy should perfectly match a diagonal matrix"
);
for (i, &j) in state.row_match.iter().enumerate() {
assert_eq!(
j, i,
"diagonal greedy: row {} should match col {}, got {}",
i, i, j
);
}
}
#[test]
fn test_mc64_negative_diagonal() {
let matrix = make_upper_tri(3, &[(0, 0, -10.0), (1, 1, -20.0), (2, 2, -5.0)]);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
assert_eq!(result.matched, 3);
let (fwd, _) = result.matching.as_ref().arrays();
for (i, &f) in fwd.iter().enumerate() {
assert_eq!(f, i, "negative diagonal should give identity matching");
}
verify_scaling_properties(&matrix, &result);
}
#[test]
fn test_singular_unmatched_permutation_valid() {
let matrix = make_upper_tri(3, &[(0, 1, 5.0)]);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
let (fwd, inv) = result.matching.as_ref().arrays();
let mut seen = [false; 3];
for &f in fwd {
assert!(f < 3, "fwd index out of range");
assert!(!seen[f], "duplicate in fwd");
seen[f] = true;
}
for i in 0..3 {
assert_eq!(fwd[inv[i]], i, "fwd[inv[{}]] != {}", i, i);
}
}
#[test]
fn test_second_matching_improves_scaling() {
let matrix = make_upper_tri(
6,
&[
(0, 0, 10.0),
(0, 1, 1.0),
(1, 1, 8.0),
(1, 2, 2.0),
(2, 2, 6.0),
(2, 3, 3.0),
(3, 3, 5.0),
(3, 4, 1.0),
(4, 4, 7.0),
(0, 5, 0.1), ],
);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
for (i, &s) in result.scaling.iter().enumerate() {
assert!(s > 0.0, "scaling[{}] should be positive, got {}", i, s);
assert!(s.is_finite(), "scaling[{}] should be finite, got {}", i, s);
}
let symbolic = matrix.symbolic();
let values = matrix.val();
for j in 0..5 {
let start = symbolic.col_ptr()[j];
let end = symbolic.col_ptr()[j + 1];
for (k, &row) in symbolic.row_idx()[start..end].iter().enumerate() {
let i = row;
if i == j {
let scaled = result.scaling[i] * values[start + k].abs() * result.scaling[j];
assert!(
scaled <= 1.0 + 1e-10,
"scaled diagonal ({},{}) = {:.6e} should be <= 1",
i,
j,
scaled
);
}
}
}
}
#[test]
fn test_is_matched_uses_row_only() {
let matrix = make_upper_tri(4, &[(0, 1, 5.0), (0, 3, 1.0), (2, 2, 4.0)]);
let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
let (fwd, _) = result.matching.as_ref().arrays();
for (i, &fi) in fwd.iter().enumerate().take(4) {
if result.is_matched[i] {
let j = fi;
assert!(
j < 4,
"matched row {} should map to valid column, got {}",
i,
j
);
}
}
for (i, &s) in result.scaling.iter().enumerate() {
assert!(s > 0.0, "scaling[{}] positive", i);
assert!(s.is_finite(), "scaling[{}] finite", i);
}
}
}