use super::hungarian::{
hungarian_match, hungarian_match_instrumented, CostGraph, HungarianStats, Matching,
};
use super::ScalingInfo;
use crate::error::FeralError;
use crate::sparse::csc::CscMatrix;
const LOG_HUGE: f64 = 709.0;
pub(crate) fn matching_perm(matrix: &CscMatrix) -> Result<(Vec<usize>, usize), FeralError> {
let cache = compute_matching(matrix)?;
Ok((cache.perm, cache.n_matched))
}
#[derive(Debug, Clone)]
pub(crate) struct Mc64Cache {
pub perm: Vec<usize>,
pub u: Vec<f64>,
pub v: Vec<f64>,
pub cmax: Vec<f64>,
pub n_matched: usize,
}
pub static MC64_RECOMPUTE_COUNT: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
pub(crate) fn compute_matching(matrix: &CscMatrix) -> Result<Mc64Cache, FeralError> {
let n = matrix.n;
if n == 0 {
return Ok(Mc64Cache {
perm: Vec::new(),
u: Vec::new(),
v: Vec::new(),
cmax: Vec::new(),
n_matched: 0,
});
}
MC64_RECOMPUTE_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let trace = matches!(
std::env::var("FERAL_MC64_TRACE").as_deref(),
Ok("1") | Ok("on")
);
let t0 = if trace {
Some(std::time::Instant::now())
} else {
None
};
let (cost_graph, cmax) = build_cost_graph(matrix)?;
let Matching {
perm,
u,
v,
n_matched,
} = hungarian_match(&cost_graph);
if let Some(t0) = t0 {
let ms = t0.elapsed().as_secs_f64() * 1e3;
let count = MC64_RECOMPUTE_COUNT.load(std::sync::atomic::Ordering::Relaxed);
eprintln!(
"[feral mc64] call #{} n={} nnz={} matching {:.1} ms",
count,
n,
matrix.row_idx.len(),
ms,
);
}
Ok(Mc64Cache {
perm,
u,
v,
cmax,
n_matched,
})
}
pub(crate) fn compute_symmetric(matrix: &CscMatrix) -> Result<(Vec<f64>, ScalingInfo), FeralError> {
let cache = compute_matching(matrix)?;
Ok(scaling_from_cache(&cache))
}
pub(crate) fn compute_matching_stats(
matrix: &CscMatrix,
) -> Result<(HungarianStats, usize), FeralError> {
if matrix.n == 0 {
return Ok((HungarianStats::default(), 0));
}
let (cost_graph, _cmax) = build_cost_graph(matrix)?;
let cost_nnz = cost_graph.row_idx.len();
let (_m, stats) = hungarian_match_instrumented(&cost_graph);
Ok((stats, cost_nnz))
}
pub(crate) fn scaling_from_cache(cache: &Mc64Cache) -> (Vec<f64>, ScalingInfo) {
let n = cache.perm.len();
if n == 0 {
return (Vec::new(), ScalingInfo::Applied);
}
let Mc64Cache {
perm,
u,
v,
cmax,
n_matched,
} = cache;
let mut row_matched = vec![false; n];
for &r in perm.iter() {
if r != usize::MAX {
row_matched[r] = true;
}
}
let mut scaling = vec![1.0_f64; n];
for i in 0..n {
if !cmax[i].is_finite() {
scaling[i] = 1.0;
continue;
}
if perm[i] == usize::MAX || !row_matched[i] {
scaling[i] = 1.0;
continue;
}
let mut arg = (u[i] + v[i] - cmax[i]) / 2.0;
if !arg.is_finite() {
scaling[i] = 1.0;
continue;
}
arg = arg.clamp(-LOG_HUGE, LOG_HUGE);
let s = arg.exp();
if s == 0.0 || !s.is_finite() {
scaling[i] = 1.0;
} else {
scaling[i] = s;
}
}
let info = if *n_matched == n {
ScalingInfo::Applied
} else {
ScalingInfo::PartialSingular {
n_unmatched: n - *n_matched,
}
};
(scaling, info)
}
fn build_cost_graph(matrix: &CscMatrix) -> Result<(CostGraph, Vec<f64>), FeralError> {
let n = matrix.n;
let mut col_counts = vec![0usize; n];
for j in 0..n {
for k in matrix.col_ptr[j]..matrix.col_ptr[j + 1] {
let i = matrix.row_idx[k];
let val = matrix.values[k];
if val == 0.0 {
continue;
}
col_counts[j] += 1;
if i != j {
col_counts[i] += 1;
}
}
}
let mut col_ptr = vec![0usize; n + 1];
for j in 0..n {
col_ptr[j + 1] = col_ptr[j] + col_counts[j];
}
let nnz_full = col_ptr[n];
let mut row_idx = vec![0usize; nnz_full];
let mut cost = vec![0.0_f64; nnz_full];
let mut offsets: Vec<usize> = col_ptr[..n].to_vec();
for j in 0..n {
for k in matrix.col_ptr[j]..matrix.col_ptr[j + 1] {
let i = matrix.row_idx[k];
let val = matrix.values[k];
if val == 0.0 {
continue;
}
let logabs = val.abs().ln();
let p = offsets[j];
row_idx[p] = i;
cost[p] = logabs;
offsets[j] += 1;
if i != j {
let q = offsets[i];
row_idx[q] = j;
cost[q] = logabs;
offsets[i] += 1;
}
}
}
for j in 0..n {
let start = col_ptr[j];
let end = col_ptr[j + 1];
let mut pairs: Vec<(usize, f64)> = (start..end).map(|k| (row_idx[k], cost[k])).collect();
pairs.sort_by_key(|&(r, _)| r);
for (k, (r, c)) in (start..end).zip(pairs) {
row_idx[k] = r;
cost[k] = c;
}
}
let mut cmax = vec![f64::NEG_INFINITY; n];
for j in 0..n {
let start = col_ptr[j];
let end = col_ptr[j + 1];
if start == end {
continue;
}
let mut m = cost[start];
for &c in &cost[(start + 1)..end] {
if c > m {
m = c;
}
}
cmax[j] = m;
for c in &mut cost[start..end] {
*c = m - *c;
}
}
let graph = CostGraph {
n,
col_ptr,
row_idx,
cost,
};
Ok((graph, cmax))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn diagonal_matrix_produces_inverse_sqrt_scaling() {
let csc = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[2.0, 3.0, 5.0]).unwrap();
let (s, info) = compute_symmetric(&csc).unwrap();
assert_eq!(info, ScalingInfo::Applied);
let expected = [
1.0 / 2.0_f64.sqrt(),
1.0 / 3.0_f64.sqrt(),
1.0 / 5.0_f64.sqrt(),
];
for i in 0..3 {
assert!(
(s[i] - expected[i]).abs() < 1e-12,
"s[{}] = {}, expected {}",
i,
s[i],
expected[i]
);
}
}
#[test]
fn unmatched_row_with_matched_column_falls_back_to_identity() {
let cache = Mc64Cache {
perm: vec![1, usize::MAX],
u: vec![0.0, 0.0],
v: vec![2.0, 0.0],
cmax: vec![1.0, 1.0],
n_matched: 1,
};
let (s, info) = scaling_from_cache(&cache);
assert_eq!(info, ScalingInfo::PartialSingular { n_unmatched: 1 });
assert!(
(s[0] - 1.0).abs() < 1e-12,
"index 0's ROW is unmatched; the contract requires identity \
scaling, got {} (X4)",
s[0]
);
assert!(
(s[1] - 1.0).abs() < 1e-12,
"index 1's column is unmatched; must be identity, got {}",
s[1]
);
}
#[test]
fn empty_matrix_returns_empty_scaling() {
let csc = CscMatrix {
n: 0,
col_ptr: vec![0],
row_idx: vec![],
values: vec![],
};
let (s, info) = compute_symmetric(&csc).unwrap();
assert!(s.is_empty());
assert_eq!(info, ScalingInfo::Applied);
}
}