Skip to main content

rivrs_sparse/ordering/
matching.rs

1//! MC64 weighted bipartite matching and symmetric scaling.
2//!
3//! Implements the Duff & Koster (2001) Algorithm MPD — Dijkstra-based shortest
4//! augmenting path matching on a logarithmic cost graph — with the MC64SYM
5//! symmetric scaling from Duff & Pralet (2005).
6//!
7//! MC64 preprocessing improves diagonal dominance of sparse symmetric indefinite
8//! matrices, reducing delayed pivots in subsequent APTP factorization. The algorithm
9//! produces:
10//! - A matching permutation decomposing into singletons and 2-cycles
11//! - Symmetric scaling factors such that the scaled matrix has unit matched diagonal
12//!   and off-diagonal entries bounded by 1 in absolute value
13//!
14//! # Algorithm
15//!
16//! 1. Build a bipartite cost graph with `c[i,j] = log(col_max_j) - log|a[i,j]|`
17//! 2. Compute initial dual variables and greedy matching (~80% cardinality)
18//! 3. For each unmatched column, find shortest augmenting path via Dijkstra
19//! 4. Symmetrize scaling: `s[i] = exp(-(u[i] + v[i]) / 2)`
20//!
21//! # References
22//!
23//! - Duff & Koster (2001), "On Algorithms for Permuting Large Entries to the
24//!   Diagonal of a Sparse Matrix", SIAM J. Matrix Anal. Appl. 22(4)
25//! - Duff & Pralet (2005), "Strategies for Scaling and Pivoting for Sparse
26//!   Symmetric Indefinite Problems", RAL Technical Report
27
28use faer::perm::Perm;
29use faer::sparse::SparseColMat;
30
31use crate::error::SparseError;
32
33/// Result of MC64 matching-based preprocessing.
34pub struct Mc64Result {
35    /// Matching permutation: row i is matched to column σ(i).
36    /// For symmetric matrices: decomposes into singletons (σ(i)=i)
37    /// and 2-cycles (σ(i)=j, σ(j)=i).
38    /// NOTE: This is NOT a fill-reducing ordering. Do not use directly
39    /// with `SymmetricOrdering::Custom`.
40    pub matching: Perm<usize>,
41
42    /// Symmetric scaling factors (linear domain).
43    /// `A_scaled[i,j] = scaling[i] * A[i,j] * scaling[j]`
44    /// The scaled matrix has unit diagonal and off-diagonals <= 1
45    /// for matched entries.
46    pub scaling: Vec<f64>,
47
48    /// Number of matched entries. Equals n for structurally nonsingular
49    /// matrices. Less than n indicates structural singularity.
50    pub matched: usize,
51
52    /// Per-index matched status: `is_matched[i]` is true if index i
53    /// participates in the matching (as row OR column for symmetric matrices).
54    /// For structurally nonsingular matrices, all entries are true.
55    /// Used by downstream cycle decomposition to distinguish singletons
56    /// (matched to self) from unmatched indices.
57    pub is_matched: Vec<bool>,
58}
59
60/// Optimization objective for the matching.
61#[non_exhaustive]
62pub enum Mc64Job {
63    /// Maximize the product of diagonal entry magnitudes.
64    /// Equivalent to minimizing sum of `-log|a_ij|` costs.
65    /// Default for APTP preprocessing.
66    MaximumProduct,
67}
68
69/// Bipartite cost graph with logarithmic edge costs.
70///
71/// Built from the input matrix by expanding the upper triangle to full
72/// symmetric storage and computing `c[i,j] = log(col_max_j) - log|a[i,j]|`.
73struct CostGraph {
74    /// CSC column pointers for the full (symmetrized) matrix.
75    col_ptr: Vec<usize>,
76    /// CSC row indices.
77    row_idx: Vec<usize>,
78    /// Edge costs: `c[i,j] = col_max_log[j] - log|a[i,j]|` (non-negative).
79    cost: Vec<f64>,
80    /// Column maxima in log domain: `log(max_k |a[k,j]|)`.
81    col_max_log: Vec<f64>,
82    /// Matrix dimension.
83    n: usize,
84}
85
86/// Working state during the matching algorithm.
87struct MatchingState {
88    /// For each row i, the column it is matched to (usize::MAX if unmatched).
89    row_match: Vec<usize>,
90    /// For each column j, the row it is matched to (usize::MAX if unmatched).
91    col_match: Vec<usize>,
92    /// Row dual variables (log domain).
93    u: Vec<f64>,
94}
95
96const UNMATCHED: usize = usize::MAX;
97
98/// Clamp bound for log-domain scaling to prevent overflow/underflow in `exp()`.
99const LOG_SCALE_CLAMP: f64 = 500.0;
100
101/// Compute MC64 weighted bipartite matching and symmetric scaling for a sparse
102/// symmetric matrix.
103///
104/// # Arguments
105///
106/// * `matrix` — Sparse symmetric matrix in upper-triangular CSC format (faer
107///   convention). Must be square. Numeric values are required.
108/// * `job` — Optimization objective. Currently only `Mc64Job::MaximumProduct`.
109///
110/// # Returns
111///
112/// * `Ok(Mc64Result)` — Matching permutation, scaling factors, and match count.
113/// * `Err(SparseError::NotSquare)` — Matrix is not square.
114/// * `Err(SparseError::InvalidInput)` — Zero dimension or non-finite entries.
115/// * `Err(SparseError::AnalysisFailure)` — Internal algorithm failure.
116///
117/// # Algorithm References
118///
119/// - Duff & Koster (2001), Algorithm MPD
120/// - Duff & Pralet (2005), MC64SYM
121pub fn mc64_matching(
122    matrix: &SparseColMat<usize, f64>,
123    _job: Mc64Job,
124) -> Result<Mc64Result, SparseError> {
125    let (nrows, ncols) = (matrix.nrows(), matrix.ncols());
126
127    // Input validation
128    if nrows != ncols {
129        return Err(SparseError::NotSquare {
130            dims: (nrows, ncols),
131        });
132    }
133    let n = nrows;
134
135    if n == 0 {
136        return Err(SparseError::InvalidInput {
137            reason: "MC64 requires non-empty matrix".to_string(),
138        });
139    }
140
141    // Check for non-finite entries
142    let symbolic = matrix.symbolic();
143    let values = matrix.val();
144    for j in 0..n {
145        let start = symbolic.col_ptr()[j];
146        let end = symbolic.col_ptr()[j + 1];
147        for &val in &values[start..end] {
148            if !val.is_finite() {
149                return Err(SparseError::InvalidInput {
150                    reason: "MC64 requires finite matrix entries".to_string(),
151                });
152            }
153        }
154    }
155
156    // Trivial case: n=1
157    if n == 1 {
158        let has_entry = symbolic.col_ptr()[1] > symbolic.col_ptr()[0];
159        let scale = if has_entry {
160            let val = values[symbolic.col_ptr()[0]];
161            if val.abs() > 0.0 {
162                1.0 / val.abs().sqrt()
163            } else {
164                1.0
165            }
166        } else {
167            1.0
168        };
169        let fwd: Box<[usize]> = vec![0].into_boxed_slice();
170        let inv: Box<[usize]> = vec![0].into_boxed_slice();
171        return Ok(Mc64Result {
172            matching: Perm::new_checked(fwd, inv, 1),
173            scaling: vec![scale],
174            matched: if has_entry { 1 } else { 0 },
175            is_matched: vec![has_entry],
176        });
177    }
178
179    // Build cost graph
180    let graph = build_cost_graph(matrix);
181
182    // Greedy initial matching
183    let mut state = greedy_initial_matching(&graph);
184
185    // Persistent Dijkstra state (reused across augmentations; avoids O(n) re-init)
186    let mut ds = DijkstraState::new(n);
187    ds.init_jperm(&graph, &state);
188
189    // Augment unmatched columns via Dijkstra
190    for j in 0..n {
191        if state.col_match[j] != UNMATCHED {
192            continue;
193        }
194        dijkstra_augment(j, &graph, &mut state, &mut ds);
195    }
196
197    // Count matched
198    let matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
199
200    if matched == n {
201        // Full matching — compute scaling directly from dual variables
202        #[cfg(debug_assertions)]
203        assert_dual_feasibility(&graph, &state);
204        let (scaling, fwd, inv) = build_full_match_result(&graph, &state);
205        return Ok(Mc64Result {
206            matching: Perm::new_checked(fwd.into_boxed_slice(), inv.into_boxed_slice(), n),
207            scaling,
208            matched,
209            is_matched: vec![true; n],
210        });
211    }
212
213    // Structural singularity: matched < n
214    let is_row_matched: Vec<bool> = (0..n).map(|i| state.row_match[i] != UNMATCHED).collect();
215
216    // Check dual feasibility for matched rows before modifying duals
217    #[cfg(debug_assertions)]
218    assert_dual_feasibility(&graph, &state);
219
220    // Zero unmatched row duals, then compute scaling using the full-match formula.
221    //
222    // For partial matching, the duals from the Hungarian algorithm are not globally
223    // feasible (u[i] + v[j] <= c[i,j] may be violated for unmatched rows). Zeroing
224    // unmatched duals is a standard heuristic from Duff & Koster (2001) §4 that
225    // produces a well-defined scaling even for structurally singular matrices.
226    // The |s*a*s| <= 1 property only holds for entries incident to matched rows.
227    for i in 0..n {
228        if state.row_match[i] == UNMATCHED {
229            state.u[i] = 0.0;
230        }
231    }
232    let v = compute_column_duals(&graph, &state);
233    let scaling = symmetrize_scaling(&state.u, &v, &graph.col_max_log);
234
235    // is_matched uses row-only: for condensation pipeline, what matters is whether
236    // row i has a real matching edge (not a fake assignment from build_singular_permutation)
237    let is_matched = is_row_matched;
238
239    // Build permutation: matched rows keep their matched column, unmatched get remaining
240    let (fwd, inv) = build_singular_permutation(n, &state, &is_matched);
241
242    Ok(Mc64Result {
243        matching: Perm::new_checked(fwd.into_boxed_slice(), inv.into_boxed_slice(), n),
244        scaling,
245        matched,
246        is_matched,
247    })
248}
249
250/// Compute column dual variables via complementary slackness.
251///
252/// For each matched column j: `v[j] = c[matched_row, j] - u[matched_row]`.
253/// Unmatched columns get `v[j] = 0.0`.
254/// This is the standard post-matching dual computation from complementary slackness.
255fn compute_column_duals(graph: &CostGraph, state: &MatchingState) -> Vec<f64> {
256    let n = graph.n;
257    let mut v = vec![0.0_f64; n];
258    for (j, v_j) in v.iter_mut().enumerate() {
259        let i = state.col_match[j];
260        if i != UNMATCHED {
261            let col_start = graph.col_ptr[j];
262            let col_end = graph.col_ptr[j + 1];
263            for idx in col_start..col_end {
264                if graph.row_idx[idx] == i {
265                    *v_j = graph.cost[idx] - state.u[i];
266                    break;
267                }
268            }
269        }
270    }
271    v
272}
273
274/// Verify dual feasibility: `u[i] + v[j] <= c[i,j] + eps` for all edges.
275///
276/// This is a mathematical invariant of the Hungarian algorithm. Violations
277/// indicate a bug in the Dijkstra augmentation or dual update logic.
278/// Only runs in debug builds to avoid overhead in release.
279#[cfg(debug_assertions)]
280fn assert_dual_feasibility(graph: &CostGraph, state: &MatchingState) {
281    let eps = 1e-10;
282    let v = compute_column_duals(graph, state);
283    let n = graph.n;
284
285    for (j, &vj) in v.iter().enumerate().take(n) {
286        let col_start = graph.col_ptr[j];
287        let col_end = graph.col_ptr[j + 1];
288        for idx in col_start..col_end {
289            let i = graph.row_idx[idx];
290            // Only check matched rows — unmatched rows have uninitialized duals
291            if state.row_match[i] == UNMATCHED {
292                continue;
293            }
294            let slack = graph.cost[idx] - state.u[i] - vj;
295            debug_assert!(
296                slack >= -eps,
297                "dual infeasibility: u[{}] + v[{}] - c[{},{}] = {:.6e} > eps",
298                i,
299                j,
300                i,
301                j,
302                -slack,
303            );
304        }
305    }
306}
307
308/// Build scaling, forward and inverse permutation for a full matching result.
309fn build_full_match_result(
310    graph: &CostGraph,
311    state: &MatchingState,
312) -> (Vec<f64>, Vec<usize>, Vec<usize>) {
313    let n = graph.n;
314    let v = compute_column_duals(graph, state);
315    let scaling = symmetrize_scaling(&state.u, &v, &graph.col_max_log);
316
317    let mut fwd = vec![0usize; n];
318    for (i, fwd_i) in fwd.iter_mut().enumerate() {
319        *fwd_i = state.row_match[i];
320    }
321
322    let mut inv = vec![0usize; n];
323    for (i, &f) in fwd.iter().enumerate() {
324        inv[f] = i;
325    }
326
327    (scaling, fwd, inv)
328}
329
330/// Build permutation arrays for structurally singular case.
331/// Matched rows keep their matched column; unmatched rows get remaining columns.
332fn build_singular_permutation(
333    n: usize,
334    state: &MatchingState,
335    is_matched: &[bool],
336) -> (Vec<usize>, Vec<usize>) {
337    let mut fwd = vec![0usize; n];
338    let mut unmatched_rows: Vec<usize> = Vec::new();
339
340    for (i, fwd_i) in fwd.iter_mut().enumerate() {
341        if state.row_match[i] != UNMATCHED {
342            *fwd_i = state.row_match[i];
343        } else {
344            unmatched_rows.push(i);
345        }
346    }
347
348    let mut used_cols = vec![false; n];
349    for (i, &matched) in is_matched.iter().enumerate() {
350        if matched {
351            used_cols[state.row_match[i]] = true;
352        }
353    }
354    let free_cols: Vec<usize> = (0..n).filter(|&j| !used_cols[j]).collect();
355    for (idx, &i) in unmatched_rows.iter().enumerate() {
356        fwd[i] = free_cols[idx];
357    }
358
359    let mut inv = vec![0usize; n];
360    for (i, &f) in fwd.iter().enumerate() {
361        inv[f] = i;
362    }
363
364    (fwd, inv)
365}
366
367/// Build the bipartite cost graph from a symmetric sparse matrix.
368///
369/// Expands the upper-triangular CSC input to full symmetric storage and computes
370/// logarithmic costs: `c[i,j] = log(col_max_j) - log|a[i,j]|`.
371fn build_cost_graph(matrix: &SparseColMat<usize, f64>) -> CostGraph {
372    let n = matrix.nrows();
373    let symbolic = matrix.symbolic();
374    let values = matrix.val();
375    let col_ptrs = symbolic.col_ptr();
376    let row_indices = symbolic.row_idx();
377
378    // Step 1: Collect all (row, col, |value|) entries, expanding upper triangle to full.
379    // We store entries as (col, row, abs_val) grouped by column.
380    let mut col_entries: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
381
382    for j in 0..n {
383        let start = col_ptrs[j];
384        let end = col_ptrs[j + 1];
385        for k in start..end {
386            let i = row_indices[k];
387            let abs_val = values[k].abs();
388            if abs_val == 0.0 {
389                continue; // Skip explicit zeros
390            }
391            // Add (i, j) entry
392            col_entries[j].push((i, abs_val));
393            // If off-diagonal, also add (j, i) for symmetry
394            if i != j {
395                col_entries[i].push((j, abs_val));
396            }
397        }
398    }
399
400    // Sort each column's entries by row index and dedup.
401    // Dedup is needed when the input is full symmetric CSC (both triangles stored):
402    // the mirroring loop above would double-count off-diagonal entries.
403    // For upper-triangular input, dedup is a no-op.
404    for entries in &mut col_entries {
405        entries.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.total_cmp(&b.1)));
406        entries.dedup_by_key(|entry| entry.0);
407    }
408
409    // Step 2: Compute column maxima in log domain
410    let mut col_max_log = vec![f64::NEG_INFINITY; n];
411    for j in 0..n {
412        for &(_, abs_val) in &col_entries[j] {
413            let log_val = abs_val.ln();
414            if log_val > col_max_log[j] {
415                col_max_log[j] = log_val;
416            }
417        }
418    }
419
420    // Step 3: Build CSC with costs = col_max_log[j] - log|a[i,j]|
421    let mut col_ptr = Vec::with_capacity(n + 1);
422    let mut row_idx = Vec::new();
423    let mut cost = Vec::new();
424
425    col_ptr.push(0);
426    for j in 0..n {
427        for &(i, abs_val) in &col_entries[j] {
428            let c = col_max_log[j] - abs_val.ln();
429            row_idx.push(i);
430            cost.push(c);
431        }
432        col_ptr.push(row_idx.len());
433    }
434
435    CostGraph {
436        col_ptr,
437        row_idx,
438        cost,
439        col_max_log,
440        n,
441    }
442}
443
444/// Compute initial matching and dual variables using the greedy heuristic.
445///
446/// Following Duff & Koster (2001, Section 4): compute initial dual variables
447/// from column-minimum and row-minimum passes, then build greedy matching
448/// from zero-reduced-cost edges with a secondary 2-pass rearrangement.
449fn greedy_initial_matching(graph: &CostGraph) -> MatchingState {
450    let n = graph.n;
451    let mut row_match = vec![UNMATCHED; n];
452    let mut col_match = vec![UNMATCHED; n];
453    let mut u = vec![f64::INFINITY; n];
454
455    // Pass 1: For each row i, find minimum cost across all columns
456    // u[i] = min_j c[i,j]
457    // Also record which column j achieves the minimum, and the position
458    let mut best_col_for_row = vec![UNMATCHED; n]; // column with min cost for row i
459    let mut best_cost_pos = vec![0usize; n]; // index into graph arrays
460
461    for j in 0..n {
462        let col_start = graph.col_ptr[j];
463        let col_end = graph.col_ptr[j + 1];
464        for idx in col_start..col_end {
465            let i = graph.row_idx[idx];
466            let c = graph.cost[idx];
467            if c < u[i] {
468                u[i] = c;
469                best_col_for_row[i] = j;
470                best_cost_pos[i] = idx;
471            }
472        }
473    }
474
475    // For rows with no entries, set u[i] = 0
476    for u_i in &mut u {
477        if *u_i == f64::INFINITY {
478            *u_i = 0.0;
479        }
480    }
481
482    // Pass 1 matching: match row i to its best column if that column is unmatched
483    // Avoid matching on very dense columns (dense columns are better handled by Dijkstra)
484    let dense_threshold = if n > 50 { n / 10 } else { n };
485    for i in 0..n {
486        let j = best_col_for_row[i];
487        if j == UNMATCHED {
488            continue;
489        }
490        if col_match[j] != UNMATCHED {
491            continue;
492        }
493        let col_degree = graph.col_ptr[j + 1] - graph.col_ptr[j];
494        if col_degree > dense_threshold {
495            continue;
496        }
497        row_match[i] = j;
498        col_match[j] = i;
499    }
500
501    // Pass 2: For each unmatched column, find cheapest reduced-cost assignment
502    // d[j] = min_i (c[i,j] - u[i]) for column j
503    // If cheapest row is unmatched, assign directly
504    // If cheapest row is matched to column jj, try length-2 augmentation
505    let mut d_col = vec![0.0_f64; n]; // current min reduced cost for column j
506    let mut search_from = vec![0usize; n]; // track search position per column
507    search_from[..n].copy_from_slice(&graph.col_ptr[..n]);
508
509    'col_loop: for j in 0..n {
510        if col_match[j] != UNMATCHED {
511            continue;
512        }
513        let col_start = graph.col_ptr[j];
514        let col_end = graph.col_ptr[j + 1];
515        if col_start >= col_end {
516            continue; // empty column
517        }
518
519        // Find row with smallest reduced cost in column j
520        let mut best_i = graph.row_idx[col_start];
521        let mut best_rc = graph.cost[col_start] - u[best_i];
522        let mut best_k = col_start;
523
524        for idx in (col_start + 1)..col_end {
525            let i = graph.row_idx[idx];
526            let rc = graph.cost[idx] - u[i];
527            if rc < best_rc
528                || (rc == best_rc && row_match[i] == UNMATCHED && row_match[best_i] != UNMATCHED)
529            {
530                best_rc = rc;
531                best_i = i;
532                best_k = idx;
533            }
534        }
535
536        d_col[j] = best_rc;
537
538        // If best row is unmatched, assign directly
539        if row_match[best_i] == UNMATCHED {
540            row_match[best_i] = j;
541            col_match[j] = best_i;
542            search_from[j] = best_k + 1;
543            continue;
544        }
545
546        // Try 2-augmentation: for each tied row i matched to column jj,
547        // scan column jj for an unmatched row ii
548        for idx in best_k..col_end {
549            let i = graph.row_idx[idx];
550            let rc = graph.cost[idx] - u[i];
551            if rc > best_rc {
552                continue;
553            }
554            let jj = row_match[i];
555            if jj == UNMATCHED {
556                continue;
557            }
558
559            // Scan column jj for unmatched row
560            let jj_end = graph.col_ptr[jj + 1];
561            for kk in search_from[jj]..jj_end {
562                let ii = graph.row_idx[kk];
563                if row_match[ii] != UNMATCHED {
564                    continue;
565                }
566                let rc_ii = graph.cost[kk] - u[ii];
567                if rc_ii <= d_col[jj] {
568                    // Augment: (i,j) and (ii,jj)
569                    col_match[jj] = ii;
570                    row_match[ii] = jj;
571                    search_from[jj] = kk + 1;
572                    col_match[j] = i;
573                    row_match[i] = j;
574                    search_from[j] = idx + 1;
575                    continue 'col_loop;
576                }
577            }
578            search_from[jj] = jj_end;
579        }
580    }
581
582    MatchingState {
583        row_match,
584        col_match,
585        u,
586    }
587}
588
589/// Persistent state across Dijkstra augmentations.
590///
591/// Arrays are allocated once and selectively reset for touched elements after each
592/// augmentation. This avoids O(n) re-initialization per augmentation and, critically,
593/// ensures `jperm` is maintained incrementally across successive shortest-path searches.
594struct DijkstraState {
595    /// d[i] = shortest distance from current root column to row i (INFINITY = untouched)
596    d: Vec<f64>,
597    /// l[i] = position tracking (0 = not touched, 1..qlen = heap position,
598    /// low..up-1 = in Q1, up.. = finalized)
599    l: Vec<usize>,
600    /// jperm[j] = edge index of matched edge in column j (for O(1) vj computation)
601    jperm: Vec<usize>,
602    /// pr[j] = parent column along shortest path tree
603    pr: Vec<usize>,
604    /// out[j] = edge index of the discovery edge for column j's traceback
605    out: Vec<usize>,
606    /// q[0..m] = shared array for heap (0..qlen), Q1 (low..up), finalized (up..m)
607    q: Vec<usize>,
608    /// Scratch buffer for root column edge indices (avoids per-augmentation allocation).
609    root_edges: Vec<usize>,
610}
611
612impl DijkstraState {
613    fn new(n: usize) -> Self {
614        Self {
615            d: vec![f64::INFINITY; n],
616            l: vec![0; n],
617            jperm: vec![UNMATCHED; n],
618            pr: vec![UNMATCHED; n],
619            out: vec![0; n],
620            q: vec![0; n],
621            root_edges: Vec::new(),
622        }
623    }
624
625    /// Reset touched rows: Q1/finalized region [low-1..n) and heap region [0..qlen).
626    fn cleanup_touched(&mut self, low: usize, qlen: usize, n: usize) {
627        for k in (low - 1)..n {
628            let i = self.q[k];
629            self.d[i] = f64::INFINITY;
630            self.l[i] = 0;
631        }
632        for k in 0..qlen {
633            let i = self.q[k];
634            self.d[i] = f64::INFINITY;
635            self.l[i] = 0;
636        }
637    }
638
639    /// Initialize jperm from current matching state.
640    fn init_jperm(&mut self, graph: &CostGraph, state: &MatchingState) {
641        let n = graph.n;
642        for j in 0..n {
643            let matched_row = state.col_match[j];
644            if matched_row == UNMATCHED {
645                self.jperm[j] = UNMATCHED;
646                continue;
647            }
648            let col_start = graph.col_ptr[j];
649            let col_end = graph.col_ptr[j + 1];
650            for idx in col_start..col_end {
651                if graph.row_idx[idx] == matched_row {
652                    self.jperm[j] = idx;
653                    break;
654                }
655            }
656        }
657    }
658}
659
660/// Find shortest augmenting path from unmatched column `root_col` using Dijkstra
661/// on reduced costs and augment the matching if a path is found.
662///
663/// Returns `true` if augmenting path found, `false` if no path exists.
664///
665/// Uses Dijkstra on reduced costs with an indexed binary min-heap for
666/// O((m + n) log n) augmentation (Duff & Koster 2001, Algorithm MPD).
667///
668/// Key data structures:
669/// - `ds.d[i]`: shortest distance from root column to row i
670/// - `ds.l[i]`: position tracking (0=unseen, 1..qlen=heap, low..up-1=Q1, up..=finalized)
671/// - `ds.jperm[j]`: edge index of matched edge in column j
672/// - `ds.out[j]`: edge index of the discovery edge for column j's traceback
673/// - `ds.pr[j]`: parent column in shortest-path tree
674fn dijkstra_augment(
675    root_col: usize,
676    graph: &CostGraph,
677    state: &mut MatchingState,
678    ds: &mut DijkstraState,
679) -> bool {
680    let n = graph.n;
681
682    // csp = cost of shortest augmenting path found so far
683    let mut csp = f64::INFINITY;
684    let mut isp: usize = 0; // edge index at end of shortest augmenting path
685    let mut jsp = UNMATCHED; // column entering that row
686
687    // Heap region: q[0..qlen] (indexed min-heap by d[q[i]])
688    let mut qlen: usize = 0;
689    // Q1 region: q[low-1..up-1] (rows with d == dmin, 0-indexed)
690    // Finalized region: q[up-1..m-1]
691    // Using 1-based indices for Q1/finalized region tracking, convert at array access
692    let mut low: usize = n + 1; // 1-based, initially empty Q1
693    let mut up: usize = n + 1; // 1-based, initially empty finalized
694    let mut dmin = f64::INFINITY;
695
696    // Scan root column: compute initial reduced costs for adjacent rows
697    ds.pr[root_col] = UNMATCHED; // sentinel
698    let col_start = graph.col_ptr[root_col];
699    let col_end = graph.col_ptr[root_col + 1];
700
701    // First pass: compute d[i] for all rows in root column, collect edge refs
702    ds.root_edges.clear();
703    for idx in col_start..col_end {
704        let i = graph.row_idx[idx];
705        let dnew = graph.cost[idx] - state.u[i];
706        if dnew >= csp {
707            continue;
708        }
709        if state.row_match[i] == UNMATCHED {
710            csp = dnew;
711            isp = idx;
712            jsp = root_col;
713        } else {
714            if dnew < dmin {
715                dmin = dnew;
716            }
717            ds.d[i] = dnew;
718            ds.root_edges.push(idx);
719        }
720    }
721
722    // Second pass: partition root-column rows into Q1 (d == dmin) and heap
723    for k in 0..ds.root_edges.len() {
724        let idx = ds.root_edges[k];
725        let i = graph.row_idx[idx];
726        if csp <= ds.d[i] {
727            ds.d[i] = f64::INFINITY;
728            continue;
729        }
730        if ds.d[i] <= dmin {
731            // Add to Q1
732            low -= 1;
733            ds.q[low - 1] = i; // 0-indexed array access
734            ds.l[i] = low; // 1-based position in Q1/finalized region
735        } else {
736            // Add to heap (q is pre-allocated to size n, so qlen never exceeds q.len())
737            qlen += 1;
738            ds.l[i] = qlen; // 1-based heap position
739            ds.q[qlen - 1] = i;
740            // Sift up in heap
741            heap_update_inline(i, &mut ds.q, &ds.d, &mut ds.l);
742        }
743        // Update tree
744        let jj = state.row_match[i];
745        ds.out[jj] = idx;
746        ds.pr[jj] = root_col;
747    }
748
749    // Main Dijkstra loop: expand frontier until augmenting path found or exhausted
750    for _jdum in 0..n {
751        // If Q1 is empty, extract from heap
752        if low == up {
753            if qlen == 0 {
754                break;
755            }
756            let top_i = ds.q[0];
757            if ds.d[top_i] >= csp {
758                break;
759            }
760            dmin = ds.d[top_i];
761            // Extract all rows with d == dmin into Q1
762            while qlen > 0 {
763                let top_i = ds.q[0];
764                if ds.d[top_i] > dmin {
765                    break;
766                }
767                // Pop from heap
768                let popped = heap_pop_inline(&mut ds.q, &ds.d, &mut ds.l, &mut qlen);
769                low -= 1;
770                ds.q[low - 1] = popped;
771                ds.l[popped] = low;
772            }
773        }
774
775        // q0 is row from Q1 with distance dmin
776        let q0 = ds.q[up - 1 - 1]; // up-1 in 1-based = index (up-2) in 0-based
777        let dq0 = ds.d[q0];
778        if dq0 >= csp {
779            break;
780        }
781        up -= 1; // Move q0 from Q1 to finalized
782
783        // Scan column matched with row q0
784        let j = state.row_match[q0];
785        // q0 must be matched (greedy ensures only matched rows enter the heap)
786
787        // Compute vj using jperm for O(1) matched-edge cost lookup
788        debug_assert!(
789            ds.jperm[j] != UNMATCHED,
790            "jperm[{}] not set for matched column",
791            j
792        );
793        let vj = dq0 - graph.cost[ds.jperm[j]] + state.u[q0];
794
795        let col_start_j = graph.col_ptr[j];
796        let col_end_j = graph.col_ptr[j + 1];
797        for idx in col_start_j..col_end_j {
798            let i = graph.row_idx[idx];
799
800            // Skip finalized rows (l[i] >= up)
801            if ds.l[i] >= up {
802                continue;
803            }
804
805            let dnew = vj + graph.cost[idx] - state.u[i];
806
807            if dnew >= csp {
808                continue;
809            }
810
811            if state.row_match[i] == UNMATCHED {
812                csp = dnew;
813                isp = idx;
814                jsp = j;
815            } else {
816                // Skip if not improving
817                let di = ds.d[i];
818                if di <= dnew {
819                    continue;
820                }
821                // Skip if already in Q1 (l[i] >= low)
822                if ds.l[i] >= low {
823                    continue;
824                }
825
826                ds.d[i] = dnew;
827                if dnew <= dmin {
828                    // Move to Q1
829                    let lpos = ds.l[i];
830                    if lpos != 0 {
831                        // Delete from heap first
832                        heap_delete_inline(lpos, &mut ds.q, &ds.d, &mut ds.l, &mut qlen);
833                    }
834                    low -= 1;
835                    ds.q[low - 1] = i;
836                    ds.l[i] = low;
837                } else {
838                    if ds.l[i] == 0 {
839                        // New entry in heap (q is pre-allocated to size n)
840                        qlen += 1;
841                        ds.l[i] = qlen;
842                        ds.q[qlen - 1] = i;
843                    }
844                    // d[i] decreased — sift up
845                    heap_update_inline(i, &mut ds.q, &ds.d, &mut ds.l);
846                }
847                // Update tree
848                let jj = state.row_match[i];
849                ds.out[jj] = idx;
850                ds.pr[jj] = j;
851            }
852        }
853    }
854
855    // If csp = INFINITY, no augmenting path found
856    if csp == f64::INFINITY {
857        // Reset d[] and l[] for rows visited this augmentation
858        ds.cleanup_touched(low, qlen, n);
859        return false;
860    }
861
862    // Augment matching along path
863    let mut i = graph.row_idx[isp];
864    let mut j = jsp;
865    state.row_match[i] = j;
866    state.col_match[j] = i;
867    ds.jperm[j] = isp;
868
869    loop {
870        let jj = ds.pr[j];
871        if jj == UNMATCHED {
872            break;
873        }
874        let k = ds.out[j];
875        i = graph.row_idx[k];
876        state.row_match[i] = jj;
877        state.col_match[jj] = i;
878        ds.jperm[jj] = k;
879        j = jj;
880    }
881
882    // Update dual variables for finalized rows (in q[up-1..n-1])
883    for k in (up - 1)..n {
884        let i = ds.q[k];
885        state.u[i] = state.u[i] + ds.d[i] - csp;
886    }
887
888    // Reset state for all touched rows: Q1 + finalized + heap
889    ds.cleanup_touched(low, qlen, n);
890
891    true
892}
893
894/// Sift row `idx` up after distance decreased.
895/// `pos` values are 1-based heap positions.
896fn heap_update_inline(idx: usize, q: &mut [usize], d: &[f64], pos: &mut [usize]) {
897    let mut p = pos[idx]; // 1-based
898    if p <= 1 {
899        q[0] = idx; // ensure root is set
900        return;
901    }
902    let v = d[idx];
903    while p > 1 {
904        let parent = p / 2;
905        let parent_idx = q[parent - 1];
906        if v >= d[parent_idx] {
907            break;
908        }
909        q[p - 1] = parent_idx;
910        pos[parent_idx] = p;
911        p = parent;
912    }
913    q[p - 1] = idx;
914    pos[idx] = p;
915}
916
917/// Inline heap_pop: extract minimum element from heap.
918/// Returns the row index of the minimum element.
919fn heap_pop_inline(q: &mut [usize], d: &[f64], pos: &mut [usize], qlen: &mut usize) -> usize {
920    let result = q[0];
921    heap_delete_inline(1, q, d, pos, qlen);
922    result
923}
924
925/// Delete element at 1-based position `pos0` and restore heap property.
926fn heap_delete_inline(
927    pos0: usize,
928    q: &mut [usize],
929    d: &[f64],
930    pos: &mut [usize],
931    qlen: &mut usize,
932) {
933    if *qlen == pos0 {
934        *qlen -= 1;
935        return;
936    }
937
938    let last_idx = q[*qlen - 1];
939    let v = d[last_idx];
940    *qlen -= 1;
941    let mut p = pos0;
942
943    // Try to move up
944    if p > 1 {
945        loop {
946            let parent = p / 2;
947            let parent_idx = q[parent - 1];
948            if v >= d[parent_idx] {
949                break;
950            }
951            q[p - 1] = parent_idx;
952            pos[parent_idx] = p;
953            p = parent;
954            if p <= 1 {
955                break;
956            }
957        }
958    }
959    q[p - 1] = last_idx;
960    pos[last_idx] = p;
961    if p != pos0 {
962        return; // Moved up
963    }
964
965    // Sift down
966    loop {
967        let child = 2 * p;
968        if child > *qlen {
969            break;
970        }
971        let mut child_d = d[q[child - 1]];
972        let mut best_child = child;
973        if child < *qlen {
974            let right_d = d[q[child]]; // q[child+1-1] = q[child]
975            if child_d > right_d {
976                best_child = child + 1;
977                child_d = right_d;
978            }
979        }
980        if v <= child_d {
981            break;
982        }
983        let child_idx = q[best_child - 1];
984        q[p - 1] = child_idx;
985        pos[child_idx] = p;
986        p = best_child;
987    }
988    q[p - 1] = last_idx;
989    pos[last_idx] = p;
990}
991
992/// Compute symmetric scaling factors from dual variables.
993///
994/// Following Duff & Pralet (2005), MC64SYM:
995/// `scaling[i] = exp((u[i] + v[i] - col_max_log[i]) / 2)`
996///
997/// where u, v are row/column dual variables in cost domain and col_max_log
998/// are the column maxima used in cost construction.
999fn symmetrize_scaling(u: &[f64], v: &[f64], col_max_log: &[f64]) -> Vec<f64> {
1000    let n = u.len();
1001    let mut scaling = Vec::with_capacity(n);
1002
1003    for i in 0..n {
1004        // row contribution u[i], column contribution (v[i] - col_max_log[i]),
1005        // symmetrized: exp((u[i] + v[i] - col_max_log[i]) / 2)
1006        let log_scale = (u[i] + v[i] - col_max_log[i]) / 2.0;
1007
1008        // Clamp to avoid overflow/underflow in exp
1009        let clamped = log_scale.clamp(-LOG_SCALE_CLAMP, LOG_SCALE_CLAMP);
1010        scaling.push(clamped.exp());
1011    }
1012
1013    scaling
1014}
1015
1016/// Apply Duff-Pralet correction for unmatched indices in structurally singular matrices.
1017///
1018/// For unmatched index i: `scaling[i] = 1.0 / max_k |a[i,k] * scaling[k]|` over
1019/// matched k. Convention: `1/0 = 1.0`.
1020///
1021/// Retained for unit tests of the correction logic in linear-space.
1022#[cfg(test)]
1023fn duff_pralet_correction(
1024    matrix: &SparseColMat<usize, f64>,
1025    scaling: &mut [f64],
1026    is_matched: &[bool],
1027) {
1028    let n = matrix.nrows();
1029    let symbolic = matrix.symbolic();
1030    let values = matrix.val();
1031    let col_ptrs = symbolic.col_ptr();
1032    let row_indices = symbolic.row_idx();
1033
1034    // For each unmatched index, compute max_k |a[i,k] * scaling[k]| over matched k
1035    // We need the original (stored) scaling values, so clone first
1036    let orig_scaling = scaling.to_vec();
1037
1038    // Build log-domain scaling for unmatched rows
1039    let mut log_max = vec![f64::NEG_INFINITY; n];
1040
1041    for j in 0..n {
1042        let start = col_ptrs[j];
1043        let end = col_ptrs[j + 1];
1044        for k in start..end {
1045            let i = row_indices[k];
1046            let abs_val = values[k].abs();
1047            if abs_val == 0.0 {
1048                continue;
1049            }
1050            // Entry (i, j) in upper triangle
1051            // If i is unmatched and j is matched:
1052            if !is_matched[i] && is_matched[j] {
1053                let contrib = abs_val.ln() + orig_scaling[j].ln();
1054                if contrib > log_max[i] {
1055                    log_max[i] = contrib;
1056                }
1057            }
1058            // If j is unmatched and i is matched (symmetric entry):
1059            if i != j && !is_matched[j] && is_matched[i] {
1060                let contrib = abs_val.ln() + orig_scaling[i].ln();
1061                if contrib > log_max[j] {
1062                    log_max[j] = contrib;
1063                }
1064            }
1065        }
1066    }
1067
1068    // Apply correction
1069    for i in 0..n {
1070        if is_matched[i] {
1071            continue;
1072        }
1073        if log_max[i] == f64::NEG_INFINITY {
1074            // Isolated row: no matched neighbors
1075            scaling[i] = 1.0;
1076        } else {
1077            scaling[i] = (-log_max[i]).exp();
1078        }
1079    }
1080}
1081
1082/// Count singletons, 2-cycles, and longer cycles in a matching permutation.
1083///
1084/// For symmetric matrices, the optimal matching decomposes into singletons
1085/// (σ(i)=i) and 2-cycles (σ(i)=j, σ(j)=i, i≠j). However, the asymmetric
1086/// cost graph can produce longer cycles that are genuinely optimal (see
1087/// `dev/mc64-scaling-notes.md`). Returns `(singletons, two_cycles, longer_cycles)`.
1088pub fn count_cycles(matching: &[usize]) -> (usize, usize, usize) {
1089    let n = matching.len();
1090    let mut visited = vec![false; n];
1091    let mut singletons = 0;
1092    let mut two_cycles = 0;
1093    let mut longer_cycles = 0;
1094
1095    for i in 0..n {
1096        if visited[i] {
1097            continue;
1098        }
1099        let j = matching[i];
1100        if j == i {
1101            singletons += 1;
1102            visited[i] = true;
1103        } else if matching[j] == i {
1104            two_cycles += 1;
1105            visited[i] = true;
1106            visited[j] = true;
1107        } else {
1108            // Longer cycle — trace it
1109            longer_cycles += 1;
1110            let mut k = i;
1111            loop {
1112                visited[k] = true;
1113                k = matching[k];
1114                if k == i {
1115                    break;
1116                }
1117            }
1118        }
1119    }
1120
1121    (singletons, two_cycles, longer_cycles)
1122}
1123
1124#[cfg(test)]
1125mod tests {
1126    use super::*;
1127    use faer::sparse::Triplet;
1128
1129    /// Helper: create a symmetric upper-triangular matrix from entries.
1130    /// Only entries with i <= j should be provided.
1131    fn make_upper_tri(n: usize, entries: &[(usize, usize, f64)]) -> SparseColMat<usize, f64> {
1132        let triplets: Vec<_> = entries
1133            .iter()
1134            .map(|&(i, j, v)| Triplet::new(i, j, v))
1135            .collect();
1136        SparseColMat::try_new_from_triplets(n, n, &triplets).unwrap()
1137    }
1138
1139    /// 3x3 symmetric matrix (upper triangle):
1140    /// [4  2  0]
1141    /// [2  5  1]
1142    /// [0  1  3]
1143    fn make_3x3_test() -> SparseColMat<usize, f64> {
1144        make_upper_tri(
1145            3,
1146            &[
1147                (0, 0, 4.0),
1148                (0, 1, 2.0),
1149                (1, 1, 5.0),
1150                (1, 2, 1.0),
1151                (2, 2, 3.0),
1152            ],
1153        )
1154    }
1155
1156    // ---- build_cost_graph tests ----
1157
1158    #[test]
1159    fn test_build_cost_graph_3x3() {
1160        let matrix = make_3x3_test();
1161        let graph = build_cost_graph(&matrix);
1162
1163        assert_eq!(graph.n, 3);
1164
1165        // Full symmetric matrix has entries:
1166        // col 0: rows [0, 1] (diagonal + (0,1))
1167        // col 1: rows [0, 1, 2] (symmetric (0,1) + diagonal + (1,2))
1168        // col 2: rows [1, 2] (symmetric (1,2) + diagonal)
1169
1170        // Check column counts
1171        let col_count = |j: usize| graph.col_ptr[j + 1] - graph.col_ptr[j];
1172        assert_eq!(col_count(0), 2, "col 0 should have 2 entries");
1173        assert_eq!(col_count(1), 3, "col 1 should have 3 entries");
1174        assert_eq!(col_count(2), 2, "col 2 should have 2 entries");
1175
1176        // Column maxima in log domain:
1177        // col 0: max(|4|, |2|) = 4, log(4) = 1.386...
1178        // col 1: max(|2|, |5|, |1|) = 5, log(5) = 1.609...
1179        // col 2: max(|1|, |3|) = 3, log(3) = 1.099...
1180        assert!((graph.col_max_log[0] - 4.0_f64.ln()).abs() < 1e-12);
1181        assert!((graph.col_max_log[1] - 5.0_f64.ln()).abs() < 1e-12);
1182        assert!((graph.col_max_log[2] - 3.0_f64.ln()).abs() < 1e-12);
1183
1184        // All costs should be non-negative
1185        for &c in &graph.cost {
1186            assert!(c >= -1e-14, "cost {} should be non-negative", c);
1187        }
1188
1189        // Diagonal entries should have cost = col_max_log - log|diag|
1190        // (0,0): log(4) - log(4) = 0
1191        // (1,1): log(5) - log(5) = 0
1192        // (2,2): log(3) - log(3) = 0
1193        for j in 0..3 {
1194            let col_start = graph.col_ptr[j];
1195            let col_end = graph.col_ptr[j + 1];
1196            for idx in col_start..col_end {
1197                if graph.row_idx[idx] == j {
1198                    assert!(
1199                        graph.cost[idx].abs() < 1e-12,
1200                        "diagonal ({},{}) cost should be ~0, got {}",
1201                        j,
1202                        j,
1203                        graph.cost[idx]
1204                    );
1205                }
1206            }
1207        }
1208    }
1209
1210    #[test]
1211    fn test_build_cost_graph_includes_diagonal() {
1212        let matrix = make_upper_tri(2, &[(0, 0, 3.0), (0, 1, 1.0), (1, 1, 2.0)]);
1213        let graph = build_cost_graph(&matrix);
1214
1215        // Both diagonals should be present
1216        let mut has_diag = [false; 2];
1217        for (j, diag) in has_diag.iter_mut().enumerate() {
1218            let col_start = graph.col_ptr[j];
1219            let col_end = graph.col_ptr[j + 1];
1220            for idx in col_start..col_end {
1221                if graph.row_idx[idx] == j {
1222                    *diag = true;
1223                }
1224            }
1225        }
1226        assert!(has_diag[0], "diagonal (0,0) missing");
1227        assert!(has_diag[1], "diagonal (1,1) missing");
1228    }
1229
1230    #[test]
1231    fn test_build_cost_graph_symmetric_expansion() {
1232        // Upper triangle only: entries (0,1) and (1,2)
1233        let matrix = make_upper_tri(
1234            3,
1235            &[
1236                (0, 0, 1.0),
1237                (0, 1, 2.0),
1238                (1, 1, 3.0),
1239                (1, 2, 4.0),
1240                (2, 2, 5.0),
1241            ],
1242        );
1243        let graph = build_cost_graph(&matrix);
1244
1245        // Check that (1,0) and (2,1) appear in the expanded graph
1246        let has_entry = |col: usize, row: usize| -> bool {
1247            let start = graph.col_ptr[col];
1248            let end = graph.col_ptr[col + 1];
1249            graph.row_idx[start..end].contains(&row)
1250        };
1251
1252        assert!(has_entry(0, 1), "symmetric entry (1,0) should exist");
1253        assert!(has_entry(1, 0), "entry (0,1) should exist");
1254        assert!(
1255            has_entry(2, 1),
1256            "symmetric entry (1,2) should exist in col 2"
1257        );
1258        assert!(has_entry(1, 2), "entry (2,1) should exist in col 1");
1259    }
1260
1261    // ---- greedy_initial_matching tests ----
1262
1263    #[test]
1264    fn test_greedy_matching_4x4() {
1265        // 4x4 matrix where greedy can find at least 3 matches
1266        let matrix = make_upper_tri(
1267            4,
1268            &[
1269                (0, 0, 10.0),
1270                (0, 1, 1.0),
1271                (1, 1, 8.0),
1272                (1, 2, 2.0),
1273                (2, 2, 6.0),
1274                (2, 3, 3.0),
1275                (3, 3, 5.0),
1276            ],
1277        );
1278        let graph = build_cost_graph(&matrix);
1279        let state = greedy_initial_matching(&graph);
1280
1281        // Count matched
1282        let matched_count = state.row_match.iter().filter(|&&m| m != UNMATCHED).count();
1283        assert!(
1284            matched_count >= 3,
1285            "greedy should match at least 3 of 4, got {}",
1286            matched_count
1287        );
1288
1289        // Verify dual feasibility: u[i] + v[j] <= c[i,j] for all edges
1290        // Note: v is not yet computed in greedy, but u should be valid
1291        // u[i] should be non-negative for cost graphs
1292        for &ui in &state.u {
1293            assert!(ui.is_finite(), "dual u should be finite");
1294        }
1295
1296        // Verify matched edges have zero reduced cost (approximately)
1297        for i in 0..4 {
1298            let j = state.row_match[i];
1299            if j == UNMATCHED {
1300                continue;
1301            }
1302            let col_start = graph.col_ptr[j];
1303            let col_end = graph.col_ptr[j + 1];
1304            for idx in col_start..col_end {
1305                if graph.row_idx[idx] == i {
1306                    break;
1307                }
1308            }
1309        }
1310    }
1311
1312    // ---- dijkstra_augment tests ----
1313
1314    #[test]
1315    fn test_dijkstra_augment_3x3() {
1316        // 3x3 with one unmatched column; verify augmenting path
1317        let matrix = make_upper_tri(
1318            3,
1319            &[
1320                (0, 0, 5.0),
1321                (0, 1, 3.0),
1322                (0, 2, 1.0),
1323                (1, 1, 4.0),
1324                (1, 2, 2.0),
1325                (2, 2, 6.0),
1326            ],
1327        );
1328        let graph = build_cost_graph(&matrix);
1329        let mut state = greedy_initial_matching(&graph);
1330        let mut ds = DijkstraState::new(3);
1331        ds.init_jperm(&graph, &state);
1332
1333        let initial_matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
1334
1335        // Try to augment each unmatched column
1336        let mut augmented = false;
1337        for j in 0..3 {
1338            if state.col_match[j] == UNMATCHED && dijkstra_augment(j, &graph, &mut state, &mut ds) {
1339                augmented = true;
1340            }
1341        }
1342
1343        let final_matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
1344
1345        // Should have found augmenting path if initial matching wasn't perfect
1346        if initial_matched < 3 {
1347            assert!(augmented, "should find augmenting path");
1348            assert!(
1349                final_matched > initial_matched,
1350                "matching size should increase"
1351            );
1352        }
1353
1354        // Dual feasibility: u[i] should remain finite
1355        for &ui in &state.u {
1356            assert!(ui.is_finite(), "dual u should be finite after augmentation");
1357        }
1358    }
1359
1360    // ---- symmetrize_scaling tests ----
1361
1362    #[test]
1363    fn test_symmetrize_scaling_known_duals() {
1364        // Known dual values; verify scaling = exp((u + v - col_max_log) / 2)
1365        let u = vec![0.5, 1.0, 0.0];
1366        let v = vec![0.2, 0.3, 0.8];
1367        let col_max_log = vec![1.0, 1.5, 0.5];
1368
1369        let scaling = symmetrize_scaling(&u, &v, &col_max_log);
1370
1371        for i in 0..3 {
1372            let expected = ((u[i] + v[i] - col_max_log[i]) / 2.0).exp();
1373            assert!(
1374                (scaling[i] - expected).abs() < 1e-12,
1375                "scaling[{}] = {}, expected {}",
1376                i,
1377                scaling[i],
1378                expected
1379            );
1380        }
1381    }
1382
1383    #[test]
1384    fn test_symmetrize_scaling_positive() {
1385        let u = vec![1.0, -0.5, 2.0];
1386        let v = vec![0.5, 1.5, -1.0];
1387        let col_max_log = vec![0.0, 0.0, 0.0];
1388
1389        let scaling = symmetrize_scaling(&u, &v, &col_max_log);
1390
1391        for (i, &s) in scaling.iter().enumerate() {
1392            assert!(s > 0.0, "scaling[{}] = {} should be positive", i, s);
1393            assert!(s.is_finite(), "scaling[{}] should be finite", i);
1394        }
1395    }
1396
1397    // ---- end-to-end mc64_matching tests ----
1398
1399    #[test]
1400    fn test_mc64_diagonal_identity() {
1401        // Diagonal matrix: identity matching, scaling = 1/sqrt(diag)
1402        let matrix = make_upper_tri(3, &[(0, 0, 4.0), (1, 1, 9.0), (2, 2, 1.0)]);
1403
1404        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1405        assert_eq!(result.matched, 3);
1406
1407        // Matching should be identity
1408        let (fwd, _) = result.matching.as_ref().arrays();
1409        for (i, &f) in fwd.iter().enumerate() {
1410            assert_eq!(f, i, "diagonal matrix matching should be identity");
1411        }
1412
1413        // All scaling factors positive and finite
1414        for (i, &s) in result.scaling.iter().enumerate() {
1415            assert!(s > 0.0, "scaling[{}] should be positive", i);
1416            assert!(s.is_finite(), "scaling[{}] should be finite", i);
1417        }
1418    }
1419
1420    #[test]
1421    fn test_mc64_tridiagonal_indefinite() {
1422        // 4x4 tridiagonal indefinite:
1423        // [ 2  -1   0   0]
1424        // [-1  -3   2   0]
1425        // [ 0   2   1  -1]
1426        // [ 0   0  -1  -4]
1427        let matrix = make_upper_tri(
1428            4,
1429            &[
1430                (0, 0, 2.0),
1431                (0, 1, -1.0),
1432                (1, 1, -3.0),
1433                (1, 2, 2.0),
1434                (2, 2, 1.0),
1435                (2, 3, -1.0),
1436                (3, 3, -4.0),
1437            ],
1438        );
1439
1440        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1441        assert_eq!(result.matched, 4);
1442
1443        // Verify SPRAL scaling properties
1444        verify_scaling_properties(&matrix, &result);
1445    }
1446
1447    #[test]
1448    fn test_mc64_arrow_indefinite() {
1449        // 5x5 arrow: dense first row/col, indefinite
1450        // [10  1  1  1  1]
1451        // [ 1 -3  0  0  0]
1452        // [ 1  0  5  0  0]
1453        // [ 1  0  0 -2  0]
1454        // [ 1  0  0  0  4]
1455        let matrix = make_upper_tri(
1456            5,
1457            &[
1458                (0, 0, 10.0),
1459                (0, 1, 1.0),
1460                (0, 2, 1.0),
1461                (0, 3, 1.0),
1462                (0, 4, 1.0),
1463                (1, 1, -3.0),
1464                (2, 2, 5.0),
1465                (3, 3, -2.0),
1466                (4, 4, 4.0),
1467            ],
1468        );
1469
1470        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1471        assert_eq!(result.matched, 5);
1472
1473        verify_scaling_properties(&matrix, &result);
1474    }
1475
1476    #[test]
1477    fn test_mc64_trivial_1x1() {
1478        let matrix = make_upper_tri(1, &[(0, 0, 7.0)]);
1479        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1480        assert_eq!(result.matched, 1);
1481        assert_eq!(result.scaling.len(), 1);
1482        assert!(result.scaling[0] > 0.0);
1483    }
1484
1485    #[test]
1486    fn test_mc64_trivial_2x2() {
1487        let matrix = make_upper_tri(2, &[(0, 0, 3.0), (0, 1, 1.0), (1, 1, 5.0)]);
1488        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1489        assert_eq!(result.matched, 2);
1490
1491        verify_scaling_properties(&matrix, &result);
1492    }
1493
1494    #[test]
1495    fn test_mc64_not_square_error() {
1496        // Non-square matrix should error
1497        let triplets = vec![Triplet::new(0, 0, 1.0), Triplet::new(0, 1, 2.0)];
1498        let matrix = SparseColMat::try_new_from_triplets(2, 3, &triplets).unwrap();
1499        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1500        assert!(matches!(result, Err(SparseError::NotSquare { .. })));
1501    }
1502
1503    #[test]
1504    fn test_mc64_zero_dim_error() {
1505        let triplets: Vec<Triplet<usize, usize, f64>> = vec![];
1506        let matrix = SparseColMat::try_new_from_triplets(0, 0, &triplets).unwrap();
1507        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1508        assert!(matches!(result, Err(SparseError::InvalidInput { .. })));
1509    }
1510
1511    #[test]
1512    fn test_count_cycles_identity() {
1513        let matching = vec![0, 1, 2, 3];
1514        let (s, c, l) = count_cycles(&matching);
1515        assert_eq!(s, 4);
1516        assert_eq!(c, 0);
1517        assert_eq!(l, 0);
1518    }
1519
1520    #[test]
1521    fn test_count_cycles_two_swaps() {
1522        let matching = vec![1, 0, 3, 2];
1523        let (s, c, l) = count_cycles(&matching);
1524        assert_eq!(s, 0);
1525        assert_eq!(c, 2);
1526        assert_eq!(l, 0);
1527    }
1528
1529    #[test]
1530    fn test_count_cycles_mixed() {
1531        let matching = vec![0, 2, 1, 3, 4];
1532        let (s, c, l) = count_cycles(&matching);
1533        assert_eq!(s, 3); // 0, 3, 4 are singletons
1534        assert_eq!(c, 1); // (1,2) is a 2-cycle
1535        assert_eq!(l, 0);
1536    }
1537
1538    #[test]
1539    fn test_count_cycles_longer_cycle() {
1540        // 3-cycle: 0→1→2→0, plus singleton 3
1541        let matching = vec![1, 2, 0, 3];
1542        let (s, c, l) = count_cycles(&matching);
1543        assert_eq!(s, 1); // 3 is singleton
1544        assert_eq!(c, 0);
1545        assert_eq!(l, 1); // one 3-cycle
1546    }
1547
1548    /// Verify SPRAL scaling properties for a matching result.
1549    /// Delegates to the shared helper in testing::mc64_validation.
1550    fn verify_scaling_properties(matrix: &SparseColMat<usize, f64>, result: &Mc64Result) {
1551        use crate::testing::verify_spral_scaling_properties;
1552        verify_spral_scaling_properties("unit_test", matrix, result);
1553    }
1554
1555    // ---- duff_pralet_correction tests ----
1556
1557    #[test]
1558    fn test_duff_pralet_4x4_singular() {
1559        // 4x4 structurally singular: row/col 3 has no matching
1560        // [4  2  0  1]
1561        // [2  5  1  0]
1562        // [0  1  3  0]
1563        // [1  0  0  0]  <- no diagonal, sparse connections
1564        let matrix = make_upper_tri(
1565            4,
1566            &[
1567                (0, 0, 4.0),
1568                (0, 1, 2.0),
1569                (0, 3, 1.0),
1570                (1, 1, 5.0),
1571                (1, 2, 1.0),
1572                (2, 2, 3.0),
1573            ],
1574        );
1575
1576        let mut scaling = vec![0.5, 0.4, 0.6, 0.0]; // pre-set for matched
1577        let is_matched = vec![true, true, true, false]; // row 3 unmatched
1578
1579        duff_pralet_correction(&matrix, &mut scaling, &is_matched);
1580
1581        // Row 3 connects to row 0 via entry (0,3)=1.0
1582        // scaling[3] = 1.0 / max_k |a[3,k] * scaling[k]| over matched k
1583        // Only connection is (0,3)=1.0, scaling[0]=0.5
1584        // So scaling[3] = 1.0 / |1.0 * 0.5| = 2.0
1585        assert!(scaling[3] > 0.0, "unmatched scaling should be positive");
1586        assert!(scaling[3].is_finite(), "unmatched scaling should be finite");
1587
1588        // Matched scaling should be unchanged
1589        assert!((scaling[0] - 0.5).abs() < 1e-12);
1590        assert!((scaling[1] - 0.4).abs() < 1e-12);
1591        assert!((scaling[2] - 0.6).abs() < 1e-12);
1592    }
1593
1594    #[test]
1595    fn test_duff_pralet_isolated_row() {
1596        // Row with no connections to matched set → scaling = 1.0
1597        let matrix = make_upper_tri(
1598            3,
1599            &[
1600                (0, 0, 4.0),
1601                (1, 1, 5.0),
1602                // No entries connecting row 2 to rows 0 or 1
1603                (2, 2, 3.0),
1604            ],
1605        );
1606
1607        let mut scaling = vec![0.5, 0.4, 0.0];
1608        // Only diagonal entries; if row 2 is unmatched but its only entry is (2,2),
1609        // and column 2 is unmatched too, then no matched connections
1610        let is_matched = vec![true, true, false];
1611
1612        duff_pralet_correction(&matrix, &mut scaling, &is_matched);
1613
1614        // Row 2 has no connections to matched indices (only diagonal which is unmatched)
1615        assert_eq!(
1616            scaling[2], 1.0,
1617            "isolated unmatched row should get scaling 1.0"
1618        );
1619    }
1620
1621    #[test]
1622    fn test_duff_pralet_all_positive() {
1623        let matrix = make_upper_tri(
1624            4,
1625            &[
1626                (0, 0, 4.0),
1627                (0, 1, 2.0),
1628                (0, 3, 1.0),
1629                (1, 1, 5.0),
1630                (1, 2, 1.0),
1631                (2, 2, 3.0),
1632            ],
1633        );
1634
1635        let mut scaling = vec![0.5, 0.4, 0.6, 0.0];
1636        let is_matched = vec![true, true, true, false];
1637
1638        duff_pralet_correction(&matrix, &mut scaling, &is_matched);
1639
1640        for (i, &s) in scaling.iter().enumerate() {
1641            assert!(s > 0.0, "scaling[{}] = {} should be positive", i, s);
1642            assert!(s.is_finite(), "scaling[{}] = {} should be finite", i, s);
1643        }
1644    }
1645
1646    // ---- structurally singular mc64_matching ----
1647
1648    #[test]
1649    fn test_mc64_singular_zero_diagonal() {
1650        // Structurally singular: no diagonal entries, forcing partial matching
1651        // [0  5  0  0]
1652        // [5  0  0  0]
1653        // [0  0  0  3]
1654        // [0  0  3  0]
1655        let matrix = make_upper_tri(4, &[(0, 1, 5.0), (2, 3, 3.0)]);
1656
1657        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1658
1659        // Should find matching (even if not perfect)
1660        for (i, &s) in result.scaling.iter().enumerate() {
1661            assert!(s > 0.0, "scaling[{}] should be positive", i);
1662            assert!(s.is_finite(), "scaling[{}] should be finite", i);
1663        }
1664
1665        // Matching should be valid permutation
1666        let (fwd, _) = result.matching.as_ref().arrays();
1667        let mut seen = [false; 4];
1668        for &f in fwd {
1669            assert!(!seen[f], "duplicate in matching");
1670            seen[f] = true;
1671        }
1672    }
1673
1674    // ---- Issue #8: NaN/Inf input rejection ----
1675
1676    #[test]
1677    fn test_mc64_nan_entry_error() {
1678        let triplets = vec![
1679            Triplet::new(0, 0, 4.0),
1680            Triplet::new(0, 1, f64::NAN),
1681            Triplet::new(1, 1, 5.0),
1682        ];
1683        let matrix = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
1684        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1685        assert!(
1686            matches!(result, Err(SparseError::InvalidInput { .. })),
1687            "NaN entry should produce InvalidInput error"
1688        );
1689    }
1690
1691    #[test]
1692    fn test_mc64_inf_entry_error() {
1693        let triplets = vec![
1694            Triplet::new(0, 0, 4.0),
1695            Triplet::new(0, 1, f64::INFINITY),
1696            Triplet::new(1, 1, 5.0),
1697        ];
1698        let matrix = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
1699        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1700        assert!(
1701            matches!(result, Err(SparseError::InvalidInput { .. })),
1702            "Inf entry should produce InvalidInput error"
1703        );
1704    }
1705
1706    // ---- Issue #10: greedy matching quality on diagonal matrix ----
1707
1708    #[test]
1709    fn test_greedy_matching_diagonal_perfect() {
1710        // Diagonal matrix: greedy should achieve perfect matching without Dijkstra
1711        let matrix = make_upper_tri(4, &[(0, 0, 10.0), (1, 1, 20.0), (2, 2, 5.0), (3, 3, 15.0)]);
1712        let graph = build_cost_graph(&matrix);
1713        let state = greedy_initial_matching(&graph);
1714
1715        let matched_count = state.row_match.iter().filter(|&&m| m != UNMATCHED).count();
1716        assert_eq!(
1717            matched_count, 4,
1718            "greedy should perfectly match a diagonal matrix"
1719        );
1720
1721        // All should be identity matching (row i matched to col i)
1722        for (i, &j) in state.row_match.iter().enumerate() {
1723            assert_eq!(
1724                j, i,
1725                "diagonal greedy: row {} should match col {}, got {}",
1726                i, i, j
1727            );
1728        }
1729    }
1730
1731    // ---- Issue #11: negative diagonal matrix ----
1732
1733    #[test]
1734    fn test_mc64_negative_diagonal() {
1735        // All-negative diagonal: tests abs() path in cost computation
1736        let matrix = make_upper_tri(3, &[(0, 0, -10.0), (1, 1, -20.0), (2, 2, -5.0)]);
1737        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1738        assert_eq!(result.matched, 3);
1739
1740        // Should produce identity matching (diagonals are the only entries)
1741        let (fwd, _) = result.matching.as_ref().arrays();
1742        for (i, &f) in fwd.iter().enumerate() {
1743            assert_eq!(f, i, "negative diagonal should give identity matching");
1744        }
1745
1746        verify_scaling_properties(&matrix, &result);
1747    }
1748
1749    // ---- Issue #12: unmatched index placement in singular permutation ----
1750
1751    #[test]
1752    fn test_singular_unmatched_permutation_valid() {
1753        // 3x3 matrix where only rows 0,1 can match (via off-diagonal)
1754        // Row 2 has no connections except diagonal (which doesn't help with bipartite matching
1755        // when all entries are off-diagonal for the matched pairs)
1756        // [0  5  0]
1757        // [5  0  0]
1758        // [0  0  0]  <- isolated, can only be "placed" in remaining slot
1759        let matrix = make_upper_tri(3, &[(0, 1, 5.0)]);
1760        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1761
1762        // Matching must be a valid permutation regardless of match count
1763        let (fwd, inv) = result.matching.as_ref().arrays();
1764        let mut seen = [false; 3];
1765        for &f in fwd {
1766            assert!(f < 3, "fwd index out of range");
1767            assert!(!seen[f], "duplicate in fwd");
1768            seen[f] = true;
1769        }
1770        // fwd and inv should be consistent
1771        for i in 0..3 {
1772            assert_eq!(fwd[inv[i]], i, "fwd[inv[{}]] != {}", i, i);
1773        }
1774    }
1775
1776    #[test]
1777    fn test_second_matching_improves_scaling() {
1778        // Structurally singular: 6x6, row 5 unmatched
1779        // The second matching should produce better duals than the first partial matching
1780        let matrix = make_upper_tri(
1781            6,
1782            &[
1783                (0, 0, 10.0),
1784                (0, 1, 1.0),
1785                (1, 1, 8.0),
1786                (1, 2, 2.0),
1787                (2, 2, 6.0),
1788                (2, 3, 3.0),
1789                (3, 3, 5.0),
1790                (3, 4, 1.0),
1791                (4, 4, 7.0),
1792                (0, 5, 0.1), // Weak connection to row 5
1793            ],
1794        );
1795
1796        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1797
1798        // All scaling factors should be positive and finite
1799        for (i, &s) in result.scaling.iter().enumerate() {
1800            assert!(s > 0.0, "scaling[{}] should be positive, got {}", i, s);
1801            assert!(s.is_finite(), "scaling[{}] should be finite, got {}", i, s);
1802        }
1803
1804        // Matched entries should have good scaling (scaled diagonal close to 1)
1805        // Check diagonal entries that exist: |s_i * a_ii * s_i| should be close to 1
1806        let symbolic = matrix.symbolic();
1807        let values = matrix.val();
1808        for j in 0..5 {
1809            // Only check matched indices 0..4
1810            let start = symbolic.col_ptr()[j];
1811            let end = symbolic.col_ptr()[j + 1];
1812            for (k, &row) in symbolic.row_idx()[start..end].iter().enumerate() {
1813                let i = row;
1814                if i == j {
1815                    let scaled = result.scaling[i] * values[start + k].abs() * result.scaling[j];
1816                    assert!(
1817                        scaled <= 1.0 + 1e-10,
1818                        "scaled diagonal ({},{}) = {:.6e} should be <= 1",
1819                        i,
1820                        j,
1821                        scaled
1822                    );
1823                }
1824            }
1825        }
1826    }
1827
1828    #[test]
1829    fn test_is_matched_uses_row_only() {
1830        // Test that is_matched is based on row matching, not OR of row+col.
1831        // Create a matrix where row 3 is unmatched but column 3 is matched.
1832        // [0  5  0  1]
1833        // [5  0  0  0]
1834        // [0  0  4  0]
1835        // [1  0  0  0]  <- row 3 may be unmatched, but col 3 matched by row 0
1836        let matrix = make_upper_tri(4, &[(0, 1, 5.0), (0, 3, 1.0), (2, 2, 4.0)]);
1837
1838        let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1839
1840        // Verify is_matched consistency: if is_matched[i] is true, row i must
1841        // have a real matching edge
1842        let (fwd, _) = result.matching.as_ref().arrays();
1843        for (i, &fi) in fwd.iter().enumerate().take(4) {
1844            if result.is_matched[i] {
1845                // Row i claims to be matched — fwd[i] should point to a column
1846                // that was actually matched to row i (not an arbitrary assignment)
1847                let j = fi;
1848                assert!(
1849                    j < 4,
1850                    "matched row {} should map to valid column, got {}",
1851                    i,
1852                    j
1853                );
1854            }
1855        }
1856
1857        // All scaling should be positive and finite
1858        for (i, &s) in result.scaling.iter().enumerate() {
1859            assert!(s > 0.0, "scaling[{}] positive", i);
1860            assert!(s.is_finite(), "scaling[{}] finite", i);
1861        }
1862    }
1863}