1use faer::perm::Perm;
29use faer::sparse::SparseColMat;
30
31use crate::error::SparseError;
32
33pub struct Mc64Result {
35 pub matching: Perm<usize>,
41
42 pub scaling: Vec<f64>,
47
48 pub matched: usize,
51
52 pub is_matched: Vec<bool>,
58}
59
60#[non_exhaustive]
62pub enum Mc64Job {
63 MaximumProduct,
67}
68
69struct CostGraph {
74 col_ptr: Vec<usize>,
76 row_idx: Vec<usize>,
78 cost: Vec<f64>,
80 col_max_log: Vec<f64>,
82 n: usize,
84}
85
86struct MatchingState {
88 row_match: Vec<usize>,
90 col_match: Vec<usize>,
92 u: Vec<f64>,
94}
95
96const UNMATCHED: usize = usize::MAX;
97
98const LOG_SCALE_CLAMP: f64 = 500.0;
100
101pub fn mc64_matching(
122 matrix: &SparseColMat<usize, f64>,
123 _job: Mc64Job,
124) -> Result<Mc64Result, SparseError> {
125 let (nrows, ncols) = (matrix.nrows(), matrix.ncols());
126
127 if nrows != ncols {
129 return Err(SparseError::NotSquare {
130 dims: (nrows, ncols),
131 });
132 }
133 let n = nrows;
134
135 if n == 0 {
136 return Err(SparseError::InvalidInput {
137 reason: "MC64 requires non-empty matrix".to_string(),
138 });
139 }
140
141 let symbolic = matrix.symbolic();
143 let values = matrix.val();
144 for j in 0..n {
145 let start = symbolic.col_ptr()[j];
146 let end = symbolic.col_ptr()[j + 1];
147 for &val in &values[start..end] {
148 if !val.is_finite() {
149 return Err(SparseError::InvalidInput {
150 reason: "MC64 requires finite matrix entries".to_string(),
151 });
152 }
153 }
154 }
155
156 if n == 1 {
158 let has_entry = symbolic.col_ptr()[1] > symbolic.col_ptr()[0];
159 let scale = if has_entry {
160 let val = values[symbolic.col_ptr()[0]];
161 if val.abs() > 0.0 {
162 1.0 / val.abs().sqrt()
163 } else {
164 1.0
165 }
166 } else {
167 1.0
168 };
169 let fwd: Box<[usize]> = vec![0].into_boxed_slice();
170 let inv: Box<[usize]> = vec![0].into_boxed_slice();
171 return Ok(Mc64Result {
172 matching: Perm::new_checked(fwd, inv, 1),
173 scaling: vec![scale],
174 matched: if has_entry { 1 } else { 0 },
175 is_matched: vec![has_entry],
176 });
177 }
178
179 let graph = build_cost_graph(matrix);
181
182 let mut state = greedy_initial_matching(&graph);
184
185 let mut ds = DijkstraState::new(n);
187 ds.init_jperm(&graph, &state);
188
189 for j in 0..n {
191 if state.col_match[j] != UNMATCHED {
192 continue;
193 }
194 dijkstra_augment(j, &graph, &mut state, &mut ds);
195 }
196
197 let matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
199
200 if matched == n {
201 #[cfg(debug_assertions)]
203 assert_dual_feasibility(&graph, &state);
204 let (scaling, fwd, inv) = build_full_match_result(&graph, &state);
205 return Ok(Mc64Result {
206 matching: Perm::new_checked(fwd.into_boxed_slice(), inv.into_boxed_slice(), n),
207 scaling,
208 matched,
209 is_matched: vec![true; n],
210 });
211 }
212
213 let is_row_matched: Vec<bool> = (0..n).map(|i| state.row_match[i] != UNMATCHED).collect();
215
216 #[cfg(debug_assertions)]
218 assert_dual_feasibility(&graph, &state);
219
220 for i in 0..n {
228 if state.row_match[i] == UNMATCHED {
229 state.u[i] = 0.0;
230 }
231 }
232 let v = compute_column_duals(&graph, &state);
233 let scaling = symmetrize_scaling(&state.u, &v, &graph.col_max_log);
234
235 let is_matched = is_row_matched;
238
239 let (fwd, inv) = build_singular_permutation(n, &state, &is_matched);
241
242 Ok(Mc64Result {
243 matching: Perm::new_checked(fwd.into_boxed_slice(), inv.into_boxed_slice(), n),
244 scaling,
245 matched,
246 is_matched,
247 })
248}
249
250fn compute_column_duals(graph: &CostGraph, state: &MatchingState) -> Vec<f64> {
256 let n = graph.n;
257 let mut v = vec![0.0_f64; n];
258 for (j, v_j) in v.iter_mut().enumerate() {
259 let i = state.col_match[j];
260 if i != UNMATCHED {
261 let col_start = graph.col_ptr[j];
262 let col_end = graph.col_ptr[j + 1];
263 for idx in col_start..col_end {
264 if graph.row_idx[idx] == i {
265 *v_j = graph.cost[idx] - state.u[i];
266 break;
267 }
268 }
269 }
270 }
271 v
272}
273
274#[cfg(debug_assertions)]
280fn assert_dual_feasibility(graph: &CostGraph, state: &MatchingState) {
281 let eps = 1e-10;
282 let v = compute_column_duals(graph, state);
283 let n = graph.n;
284
285 for (j, &vj) in v.iter().enumerate().take(n) {
286 let col_start = graph.col_ptr[j];
287 let col_end = graph.col_ptr[j + 1];
288 for idx in col_start..col_end {
289 let i = graph.row_idx[idx];
290 if state.row_match[i] == UNMATCHED {
292 continue;
293 }
294 let slack = graph.cost[idx] - state.u[i] - vj;
295 debug_assert!(
296 slack >= -eps,
297 "dual infeasibility: u[{}] + v[{}] - c[{},{}] = {:.6e} > eps",
298 i,
299 j,
300 i,
301 j,
302 -slack,
303 );
304 }
305 }
306}
307
308fn build_full_match_result(
310 graph: &CostGraph,
311 state: &MatchingState,
312) -> (Vec<f64>, Vec<usize>, Vec<usize>) {
313 let n = graph.n;
314 let v = compute_column_duals(graph, state);
315 let scaling = symmetrize_scaling(&state.u, &v, &graph.col_max_log);
316
317 let mut fwd = vec![0usize; n];
318 for (i, fwd_i) in fwd.iter_mut().enumerate() {
319 *fwd_i = state.row_match[i];
320 }
321
322 let mut inv = vec![0usize; n];
323 for (i, &f) in fwd.iter().enumerate() {
324 inv[f] = i;
325 }
326
327 (scaling, fwd, inv)
328}
329
330fn build_singular_permutation(
333 n: usize,
334 state: &MatchingState,
335 is_matched: &[bool],
336) -> (Vec<usize>, Vec<usize>) {
337 let mut fwd = vec![0usize; n];
338 let mut unmatched_rows: Vec<usize> = Vec::new();
339
340 for (i, fwd_i) in fwd.iter_mut().enumerate() {
341 if state.row_match[i] != UNMATCHED {
342 *fwd_i = state.row_match[i];
343 } else {
344 unmatched_rows.push(i);
345 }
346 }
347
348 let mut used_cols = vec![false; n];
349 for (i, &matched) in is_matched.iter().enumerate() {
350 if matched {
351 used_cols[state.row_match[i]] = true;
352 }
353 }
354 let free_cols: Vec<usize> = (0..n).filter(|&j| !used_cols[j]).collect();
355 for (idx, &i) in unmatched_rows.iter().enumerate() {
356 fwd[i] = free_cols[idx];
357 }
358
359 let mut inv = vec![0usize; n];
360 for (i, &f) in fwd.iter().enumerate() {
361 inv[f] = i;
362 }
363
364 (fwd, inv)
365}
366
367fn build_cost_graph(matrix: &SparseColMat<usize, f64>) -> CostGraph {
372 let n = matrix.nrows();
373 let symbolic = matrix.symbolic();
374 let values = matrix.val();
375 let col_ptrs = symbolic.col_ptr();
376 let row_indices = symbolic.row_idx();
377
378 let mut col_entries: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
381
382 for j in 0..n {
383 let start = col_ptrs[j];
384 let end = col_ptrs[j + 1];
385 for k in start..end {
386 let i = row_indices[k];
387 let abs_val = values[k].abs();
388 if abs_val == 0.0 {
389 continue; }
391 col_entries[j].push((i, abs_val));
393 if i != j {
395 col_entries[i].push((j, abs_val));
396 }
397 }
398 }
399
400 for entries in &mut col_entries {
405 entries.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.total_cmp(&b.1)));
406 entries.dedup_by_key(|entry| entry.0);
407 }
408
409 let mut col_max_log = vec![f64::NEG_INFINITY; n];
411 for j in 0..n {
412 for &(_, abs_val) in &col_entries[j] {
413 let log_val = abs_val.ln();
414 if log_val > col_max_log[j] {
415 col_max_log[j] = log_val;
416 }
417 }
418 }
419
420 let mut col_ptr = Vec::with_capacity(n + 1);
422 let mut row_idx = Vec::new();
423 let mut cost = Vec::new();
424
425 col_ptr.push(0);
426 for j in 0..n {
427 for &(i, abs_val) in &col_entries[j] {
428 let c = col_max_log[j] - abs_val.ln();
429 row_idx.push(i);
430 cost.push(c);
431 }
432 col_ptr.push(row_idx.len());
433 }
434
435 CostGraph {
436 col_ptr,
437 row_idx,
438 cost,
439 col_max_log,
440 n,
441 }
442}
443
444fn greedy_initial_matching(graph: &CostGraph) -> MatchingState {
450 let n = graph.n;
451 let mut row_match = vec![UNMATCHED; n];
452 let mut col_match = vec![UNMATCHED; n];
453 let mut u = vec![f64::INFINITY; n];
454
455 let mut best_col_for_row = vec![UNMATCHED; n]; let mut best_cost_pos = vec![0usize; n]; for j in 0..n {
462 let col_start = graph.col_ptr[j];
463 let col_end = graph.col_ptr[j + 1];
464 for idx in col_start..col_end {
465 let i = graph.row_idx[idx];
466 let c = graph.cost[idx];
467 if c < u[i] {
468 u[i] = c;
469 best_col_for_row[i] = j;
470 best_cost_pos[i] = idx;
471 }
472 }
473 }
474
475 for u_i in &mut u {
477 if *u_i == f64::INFINITY {
478 *u_i = 0.0;
479 }
480 }
481
482 let dense_threshold = if n > 50 { n / 10 } else { n };
485 for i in 0..n {
486 let j = best_col_for_row[i];
487 if j == UNMATCHED {
488 continue;
489 }
490 if col_match[j] != UNMATCHED {
491 continue;
492 }
493 let col_degree = graph.col_ptr[j + 1] - graph.col_ptr[j];
494 if col_degree > dense_threshold {
495 continue;
496 }
497 row_match[i] = j;
498 col_match[j] = i;
499 }
500
501 let mut d_col = vec![0.0_f64; n]; let mut search_from = vec![0usize; n]; search_from[..n].copy_from_slice(&graph.col_ptr[..n]);
508
509 'col_loop: for j in 0..n {
510 if col_match[j] != UNMATCHED {
511 continue;
512 }
513 let col_start = graph.col_ptr[j];
514 let col_end = graph.col_ptr[j + 1];
515 if col_start >= col_end {
516 continue; }
518
519 let mut best_i = graph.row_idx[col_start];
521 let mut best_rc = graph.cost[col_start] - u[best_i];
522 let mut best_k = col_start;
523
524 for idx in (col_start + 1)..col_end {
525 let i = graph.row_idx[idx];
526 let rc = graph.cost[idx] - u[i];
527 if rc < best_rc
528 || (rc == best_rc && row_match[i] == UNMATCHED && row_match[best_i] != UNMATCHED)
529 {
530 best_rc = rc;
531 best_i = i;
532 best_k = idx;
533 }
534 }
535
536 d_col[j] = best_rc;
537
538 if row_match[best_i] == UNMATCHED {
540 row_match[best_i] = j;
541 col_match[j] = best_i;
542 search_from[j] = best_k + 1;
543 continue;
544 }
545
546 for idx in best_k..col_end {
549 let i = graph.row_idx[idx];
550 let rc = graph.cost[idx] - u[i];
551 if rc > best_rc {
552 continue;
553 }
554 let jj = row_match[i];
555 if jj == UNMATCHED {
556 continue;
557 }
558
559 let jj_end = graph.col_ptr[jj + 1];
561 for kk in search_from[jj]..jj_end {
562 let ii = graph.row_idx[kk];
563 if row_match[ii] != UNMATCHED {
564 continue;
565 }
566 let rc_ii = graph.cost[kk] - u[ii];
567 if rc_ii <= d_col[jj] {
568 col_match[jj] = ii;
570 row_match[ii] = jj;
571 search_from[jj] = kk + 1;
572 col_match[j] = i;
573 row_match[i] = j;
574 search_from[j] = idx + 1;
575 continue 'col_loop;
576 }
577 }
578 search_from[jj] = jj_end;
579 }
580 }
581
582 MatchingState {
583 row_match,
584 col_match,
585 u,
586 }
587}
588
589struct DijkstraState {
595 d: Vec<f64>,
597 l: Vec<usize>,
600 jperm: Vec<usize>,
602 pr: Vec<usize>,
604 out: Vec<usize>,
606 q: Vec<usize>,
608 root_edges: Vec<usize>,
610}
611
612impl DijkstraState {
613 fn new(n: usize) -> Self {
614 Self {
615 d: vec![f64::INFINITY; n],
616 l: vec![0; n],
617 jperm: vec![UNMATCHED; n],
618 pr: vec![UNMATCHED; n],
619 out: vec![0; n],
620 q: vec![0; n],
621 root_edges: Vec::new(),
622 }
623 }
624
625 fn cleanup_touched(&mut self, low: usize, qlen: usize, n: usize) {
627 for k in (low - 1)..n {
628 let i = self.q[k];
629 self.d[i] = f64::INFINITY;
630 self.l[i] = 0;
631 }
632 for k in 0..qlen {
633 let i = self.q[k];
634 self.d[i] = f64::INFINITY;
635 self.l[i] = 0;
636 }
637 }
638
639 fn init_jperm(&mut self, graph: &CostGraph, state: &MatchingState) {
641 let n = graph.n;
642 for j in 0..n {
643 let matched_row = state.col_match[j];
644 if matched_row == UNMATCHED {
645 self.jperm[j] = UNMATCHED;
646 continue;
647 }
648 let col_start = graph.col_ptr[j];
649 let col_end = graph.col_ptr[j + 1];
650 for idx in col_start..col_end {
651 if graph.row_idx[idx] == matched_row {
652 self.jperm[j] = idx;
653 break;
654 }
655 }
656 }
657 }
658}
659
660fn dijkstra_augment(
675 root_col: usize,
676 graph: &CostGraph,
677 state: &mut MatchingState,
678 ds: &mut DijkstraState,
679) -> bool {
680 let n = graph.n;
681
682 let mut csp = f64::INFINITY;
684 let mut isp: usize = 0; let mut jsp = UNMATCHED; let mut qlen: usize = 0;
689 let mut low: usize = n + 1; let mut up: usize = n + 1; let mut dmin = f64::INFINITY;
695
696 ds.pr[root_col] = UNMATCHED; let col_start = graph.col_ptr[root_col];
699 let col_end = graph.col_ptr[root_col + 1];
700
701 ds.root_edges.clear();
703 for idx in col_start..col_end {
704 let i = graph.row_idx[idx];
705 let dnew = graph.cost[idx] - state.u[i];
706 if dnew >= csp {
707 continue;
708 }
709 if state.row_match[i] == UNMATCHED {
710 csp = dnew;
711 isp = idx;
712 jsp = root_col;
713 } else {
714 if dnew < dmin {
715 dmin = dnew;
716 }
717 ds.d[i] = dnew;
718 ds.root_edges.push(idx);
719 }
720 }
721
722 for k in 0..ds.root_edges.len() {
724 let idx = ds.root_edges[k];
725 let i = graph.row_idx[idx];
726 if csp <= ds.d[i] {
727 ds.d[i] = f64::INFINITY;
728 continue;
729 }
730 if ds.d[i] <= dmin {
731 low -= 1;
733 ds.q[low - 1] = i; ds.l[i] = low; } else {
736 qlen += 1;
738 ds.l[i] = qlen; ds.q[qlen - 1] = i;
740 heap_update_inline(i, &mut ds.q, &ds.d, &mut ds.l);
742 }
743 let jj = state.row_match[i];
745 ds.out[jj] = idx;
746 ds.pr[jj] = root_col;
747 }
748
749 for _jdum in 0..n {
751 if low == up {
753 if qlen == 0 {
754 break;
755 }
756 let top_i = ds.q[0];
757 if ds.d[top_i] >= csp {
758 break;
759 }
760 dmin = ds.d[top_i];
761 while qlen > 0 {
763 let top_i = ds.q[0];
764 if ds.d[top_i] > dmin {
765 break;
766 }
767 let popped = heap_pop_inline(&mut ds.q, &ds.d, &mut ds.l, &mut qlen);
769 low -= 1;
770 ds.q[low - 1] = popped;
771 ds.l[popped] = low;
772 }
773 }
774
775 let q0 = ds.q[up - 1 - 1]; let dq0 = ds.d[q0];
778 if dq0 >= csp {
779 break;
780 }
781 up -= 1; let j = state.row_match[q0];
785 debug_assert!(
789 ds.jperm[j] != UNMATCHED,
790 "jperm[{}] not set for matched column",
791 j
792 );
793 let vj = dq0 - graph.cost[ds.jperm[j]] + state.u[q0];
794
795 let col_start_j = graph.col_ptr[j];
796 let col_end_j = graph.col_ptr[j + 1];
797 for idx in col_start_j..col_end_j {
798 let i = graph.row_idx[idx];
799
800 if ds.l[i] >= up {
802 continue;
803 }
804
805 let dnew = vj + graph.cost[idx] - state.u[i];
806
807 if dnew >= csp {
808 continue;
809 }
810
811 if state.row_match[i] == UNMATCHED {
812 csp = dnew;
813 isp = idx;
814 jsp = j;
815 } else {
816 let di = ds.d[i];
818 if di <= dnew {
819 continue;
820 }
821 if ds.l[i] >= low {
823 continue;
824 }
825
826 ds.d[i] = dnew;
827 if dnew <= dmin {
828 let lpos = ds.l[i];
830 if lpos != 0 {
831 heap_delete_inline(lpos, &mut ds.q, &ds.d, &mut ds.l, &mut qlen);
833 }
834 low -= 1;
835 ds.q[low - 1] = i;
836 ds.l[i] = low;
837 } else {
838 if ds.l[i] == 0 {
839 qlen += 1;
841 ds.l[i] = qlen;
842 ds.q[qlen - 1] = i;
843 }
844 heap_update_inline(i, &mut ds.q, &ds.d, &mut ds.l);
846 }
847 let jj = state.row_match[i];
849 ds.out[jj] = idx;
850 ds.pr[jj] = j;
851 }
852 }
853 }
854
855 if csp == f64::INFINITY {
857 ds.cleanup_touched(low, qlen, n);
859 return false;
860 }
861
862 let mut i = graph.row_idx[isp];
864 let mut j = jsp;
865 state.row_match[i] = j;
866 state.col_match[j] = i;
867 ds.jperm[j] = isp;
868
869 loop {
870 let jj = ds.pr[j];
871 if jj == UNMATCHED {
872 break;
873 }
874 let k = ds.out[j];
875 i = graph.row_idx[k];
876 state.row_match[i] = jj;
877 state.col_match[jj] = i;
878 ds.jperm[jj] = k;
879 j = jj;
880 }
881
882 for k in (up - 1)..n {
884 let i = ds.q[k];
885 state.u[i] = state.u[i] + ds.d[i] - csp;
886 }
887
888 ds.cleanup_touched(low, qlen, n);
890
891 true
892}
893
894fn heap_update_inline(idx: usize, q: &mut [usize], d: &[f64], pos: &mut [usize]) {
897 let mut p = pos[idx]; if p <= 1 {
899 q[0] = idx; return;
901 }
902 let v = d[idx];
903 while p > 1 {
904 let parent = p / 2;
905 let parent_idx = q[parent - 1];
906 if v >= d[parent_idx] {
907 break;
908 }
909 q[p - 1] = parent_idx;
910 pos[parent_idx] = p;
911 p = parent;
912 }
913 q[p - 1] = idx;
914 pos[idx] = p;
915}
916
917fn heap_pop_inline(q: &mut [usize], d: &[f64], pos: &mut [usize], qlen: &mut usize) -> usize {
920 let result = q[0];
921 heap_delete_inline(1, q, d, pos, qlen);
922 result
923}
924
925fn heap_delete_inline(
927 pos0: usize,
928 q: &mut [usize],
929 d: &[f64],
930 pos: &mut [usize],
931 qlen: &mut usize,
932) {
933 if *qlen == pos0 {
934 *qlen -= 1;
935 return;
936 }
937
938 let last_idx = q[*qlen - 1];
939 let v = d[last_idx];
940 *qlen -= 1;
941 let mut p = pos0;
942
943 if p > 1 {
945 loop {
946 let parent = p / 2;
947 let parent_idx = q[parent - 1];
948 if v >= d[parent_idx] {
949 break;
950 }
951 q[p - 1] = parent_idx;
952 pos[parent_idx] = p;
953 p = parent;
954 if p <= 1 {
955 break;
956 }
957 }
958 }
959 q[p - 1] = last_idx;
960 pos[last_idx] = p;
961 if p != pos0 {
962 return; }
964
965 loop {
967 let child = 2 * p;
968 if child > *qlen {
969 break;
970 }
971 let mut child_d = d[q[child - 1]];
972 let mut best_child = child;
973 if child < *qlen {
974 let right_d = d[q[child]]; if child_d > right_d {
976 best_child = child + 1;
977 child_d = right_d;
978 }
979 }
980 if v <= child_d {
981 break;
982 }
983 let child_idx = q[best_child - 1];
984 q[p - 1] = child_idx;
985 pos[child_idx] = p;
986 p = best_child;
987 }
988 q[p - 1] = last_idx;
989 pos[last_idx] = p;
990}
991
992fn symmetrize_scaling(u: &[f64], v: &[f64], col_max_log: &[f64]) -> Vec<f64> {
1000 let n = u.len();
1001 let mut scaling = Vec::with_capacity(n);
1002
1003 for i in 0..n {
1004 let log_scale = (u[i] + v[i] - col_max_log[i]) / 2.0;
1007
1008 let clamped = log_scale.clamp(-LOG_SCALE_CLAMP, LOG_SCALE_CLAMP);
1010 scaling.push(clamped.exp());
1011 }
1012
1013 scaling
1014}
1015
1016#[cfg(test)]
1023fn duff_pralet_correction(
1024 matrix: &SparseColMat<usize, f64>,
1025 scaling: &mut [f64],
1026 is_matched: &[bool],
1027) {
1028 let n = matrix.nrows();
1029 let symbolic = matrix.symbolic();
1030 let values = matrix.val();
1031 let col_ptrs = symbolic.col_ptr();
1032 let row_indices = symbolic.row_idx();
1033
1034 let orig_scaling = scaling.to_vec();
1037
1038 let mut log_max = vec![f64::NEG_INFINITY; n];
1040
1041 for j in 0..n {
1042 let start = col_ptrs[j];
1043 let end = col_ptrs[j + 1];
1044 for k in start..end {
1045 let i = row_indices[k];
1046 let abs_val = values[k].abs();
1047 if abs_val == 0.0 {
1048 continue;
1049 }
1050 if !is_matched[i] && is_matched[j] {
1053 let contrib = abs_val.ln() + orig_scaling[j].ln();
1054 if contrib > log_max[i] {
1055 log_max[i] = contrib;
1056 }
1057 }
1058 if i != j && !is_matched[j] && is_matched[i] {
1060 let contrib = abs_val.ln() + orig_scaling[i].ln();
1061 if contrib > log_max[j] {
1062 log_max[j] = contrib;
1063 }
1064 }
1065 }
1066 }
1067
1068 for i in 0..n {
1070 if is_matched[i] {
1071 continue;
1072 }
1073 if log_max[i] == f64::NEG_INFINITY {
1074 scaling[i] = 1.0;
1076 } else {
1077 scaling[i] = (-log_max[i]).exp();
1078 }
1079 }
1080}
1081
1082pub fn count_cycles(matching: &[usize]) -> (usize, usize, usize) {
1089 let n = matching.len();
1090 let mut visited = vec![false; n];
1091 let mut singletons = 0;
1092 let mut two_cycles = 0;
1093 let mut longer_cycles = 0;
1094
1095 for i in 0..n {
1096 if visited[i] {
1097 continue;
1098 }
1099 let j = matching[i];
1100 if j == i {
1101 singletons += 1;
1102 visited[i] = true;
1103 } else if matching[j] == i {
1104 two_cycles += 1;
1105 visited[i] = true;
1106 visited[j] = true;
1107 } else {
1108 longer_cycles += 1;
1110 let mut k = i;
1111 loop {
1112 visited[k] = true;
1113 k = matching[k];
1114 if k == i {
1115 break;
1116 }
1117 }
1118 }
1119 }
1120
1121 (singletons, two_cycles, longer_cycles)
1122}
1123
1124#[cfg(test)]
1125mod tests {
1126 use super::*;
1127 use faer::sparse::Triplet;
1128
1129 fn make_upper_tri(n: usize, entries: &[(usize, usize, f64)]) -> SparseColMat<usize, f64> {
1132 let triplets: Vec<_> = entries
1133 .iter()
1134 .map(|&(i, j, v)| Triplet::new(i, j, v))
1135 .collect();
1136 SparseColMat::try_new_from_triplets(n, n, &triplets).unwrap()
1137 }
1138
1139 fn make_3x3_test() -> SparseColMat<usize, f64> {
1144 make_upper_tri(
1145 3,
1146 &[
1147 (0, 0, 4.0),
1148 (0, 1, 2.0),
1149 (1, 1, 5.0),
1150 (1, 2, 1.0),
1151 (2, 2, 3.0),
1152 ],
1153 )
1154 }
1155
1156 #[test]
1159 fn test_build_cost_graph_3x3() {
1160 let matrix = make_3x3_test();
1161 let graph = build_cost_graph(&matrix);
1162
1163 assert_eq!(graph.n, 3);
1164
1165 let col_count = |j: usize| graph.col_ptr[j + 1] - graph.col_ptr[j];
1172 assert_eq!(col_count(0), 2, "col 0 should have 2 entries");
1173 assert_eq!(col_count(1), 3, "col 1 should have 3 entries");
1174 assert_eq!(col_count(2), 2, "col 2 should have 2 entries");
1175
1176 assert!((graph.col_max_log[0] - 4.0_f64.ln()).abs() < 1e-12);
1181 assert!((graph.col_max_log[1] - 5.0_f64.ln()).abs() < 1e-12);
1182 assert!((graph.col_max_log[2] - 3.0_f64.ln()).abs() < 1e-12);
1183
1184 for &c in &graph.cost {
1186 assert!(c >= -1e-14, "cost {} should be non-negative", c);
1187 }
1188
1189 for j in 0..3 {
1194 let col_start = graph.col_ptr[j];
1195 let col_end = graph.col_ptr[j + 1];
1196 for idx in col_start..col_end {
1197 if graph.row_idx[idx] == j {
1198 assert!(
1199 graph.cost[idx].abs() < 1e-12,
1200 "diagonal ({},{}) cost should be ~0, got {}",
1201 j,
1202 j,
1203 graph.cost[idx]
1204 );
1205 }
1206 }
1207 }
1208 }
1209
1210 #[test]
1211 fn test_build_cost_graph_includes_diagonal() {
1212 let matrix = make_upper_tri(2, &[(0, 0, 3.0), (0, 1, 1.0), (1, 1, 2.0)]);
1213 let graph = build_cost_graph(&matrix);
1214
1215 let mut has_diag = [false; 2];
1217 for (j, diag) in has_diag.iter_mut().enumerate() {
1218 let col_start = graph.col_ptr[j];
1219 let col_end = graph.col_ptr[j + 1];
1220 for idx in col_start..col_end {
1221 if graph.row_idx[idx] == j {
1222 *diag = true;
1223 }
1224 }
1225 }
1226 assert!(has_diag[0], "diagonal (0,0) missing");
1227 assert!(has_diag[1], "diagonal (1,1) missing");
1228 }
1229
1230 #[test]
1231 fn test_build_cost_graph_symmetric_expansion() {
1232 let matrix = make_upper_tri(
1234 3,
1235 &[
1236 (0, 0, 1.0),
1237 (0, 1, 2.0),
1238 (1, 1, 3.0),
1239 (1, 2, 4.0),
1240 (2, 2, 5.0),
1241 ],
1242 );
1243 let graph = build_cost_graph(&matrix);
1244
1245 let has_entry = |col: usize, row: usize| -> bool {
1247 let start = graph.col_ptr[col];
1248 let end = graph.col_ptr[col + 1];
1249 graph.row_idx[start..end].contains(&row)
1250 };
1251
1252 assert!(has_entry(0, 1), "symmetric entry (1,0) should exist");
1253 assert!(has_entry(1, 0), "entry (0,1) should exist");
1254 assert!(
1255 has_entry(2, 1),
1256 "symmetric entry (1,2) should exist in col 2"
1257 );
1258 assert!(has_entry(1, 2), "entry (2,1) should exist in col 1");
1259 }
1260
1261 #[test]
1264 fn test_greedy_matching_4x4() {
1265 let matrix = make_upper_tri(
1267 4,
1268 &[
1269 (0, 0, 10.0),
1270 (0, 1, 1.0),
1271 (1, 1, 8.0),
1272 (1, 2, 2.0),
1273 (2, 2, 6.0),
1274 (2, 3, 3.0),
1275 (3, 3, 5.0),
1276 ],
1277 );
1278 let graph = build_cost_graph(&matrix);
1279 let state = greedy_initial_matching(&graph);
1280
1281 let matched_count = state.row_match.iter().filter(|&&m| m != UNMATCHED).count();
1283 assert!(
1284 matched_count >= 3,
1285 "greedy should match at least 3 of 4, got {}",
1286 matched_count
1287 );
1288
1289 for &ui in &state.u {
1293 assert!(ui.is_finite(), "dual u should be finite");
1294 }
1295
1296 for i in 0..4 {
1298 let j = state.row_match[i];
1299 if j == UNMATCHED {
1300 continue;
1301 }
1302 let col_start = graph.col_ptr[j];
1303 let col_end = graph.col_ptr[j + 1];
1304 for idx in col_start..col_end {
1305 if graph.row_idx[idx] == i {
1306 break;
1307 }
1308 }
1309 }
1310 }
1311
1312 #[test]
1315 fn test_dijkstra_augment_3x3() {
1316 let matrix = make_upper_tri(
1318 3,
1319 &[
1320 (0, 0, 5.0),
1321 (0, 1, 3.0),
1322 (0, 2, 1.0),
1323 (1, 1, 4.0),
1324 (1, 2, 2.0),
1325 (2, 2, 6.0),
1326 ],
1327 );
1328 let graph = build_cost_graph(&matrix);
1329 let mut state = greedy_initial_matching(&graph);
1330 let mut ds = DijkstraState::new(3);
1331 ds.init_jperm(&graph, &state);
1332
1333 let initial_matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
1334
1335 let mut augmented = false;
1337 for j in 0..3 {
1338 if state.col_match[j] == UNMATCHED && dijkstra_augment(j, &graph, &mut state, &mut ds) {
1339 augmented = true;
1340 }
1341 }
1342
1343 let final_matched = state.col_match.iter().filter(|&&m| m != UNMATCHED).count();
1344
1345 if initial_matched < 3 {
1347 assert!(augmented, "should find augmenting path");
1348 assert!(
1349 final_matched > initial_matched,
1350 "matching size should increase"
1351 );
1352 }
1353
1354 for &ui in &state.u {
1356 assert!(ui.is_finite(), "dual u should be finite after augmentation");
1357 }
1358 }
1359
1360 #[test]
1363 fn test_symmetrize_scaling_known_duals() {
1364 let u = vec![0.5, 1.0, 0.0];
1366 let v = vec![0.2, 0.3, 0.8];
1367 let col_max_log = vec![1.0, 1.5, 0.5];
1368
1369 let scaling = symmetrize_scaling(&u, &v, &col_max_log);
1370
1371 for i in 0..3 {
1372 let expected = ((u[i] + v[i] - col_max_log[i]) / 2.0).exp();
1373 assert!(
1374 (scaling[i] - expected).abs() < 1e-12,
1375 "scaling[{}] = {}, expected {}",
1376 i,
1377 scaling[i],
1378 expected
1379 );
1380 }
1381 }
1382
1383 #[test]
1384 fn test_symmetrize_scaling_positive() {
1385 let u = vec![1.0, -0.5, 2.0];
1386 let v = vec![0.5, 1.5, -1.0];
1387 let col_max_log = vec![0.0, 0.0, 0.0];
1388
1389 let scaling = symmetrize_scaling(&u, &v, &col_max_log);
1390
1391 for (i, &s) in scaling.iter().enumerate() {
1392 assert!(s > 0.0, "scaling[{}] = {} should be positive", i, s);
1393 assert!(s.is_finite(), "scaling[{}] should be finite", i);
1394 }
1395 }
1396
1397 #[test]
1400 fn test_mc64_diagonal_identity() {
1401 let matrix = make_upper_tri(3, &[(0, 0, 4.0), (1, 1, 9.0), (2, 2, 1.0)]);
1403
1404 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1405 assert_eq!(result.matched, 3);
1406
1407 let (fwd, _) = result.matching.as_ref().arrays();
1409 for (i, &f) in fwd.iter().enumerate() {
1410 assert_eq!(f, i, "diagonal matrix matching should be identity");
1411 }
1412
1413 for (i, &s) in result.scaling.iter().enumerate() {
1415 assert!(s > 0.0, "scaling[{}] should be positive", i);
1416 assert!(s.is_finite(), "scaling[{}] should be finite", i);
1417 }
1418 }
1419
1420 #[test]
1421 fn test_mc64_tridiagonal_indefinite() {
1422 let matrix = make_upper_tri(
1428 4,
1429 &[
1430 (0, 0, 2.0),
1431 (0, 1, -1.0),
1432 (1, 1, -3.0),
1433 (1, 2, 2.0),
1434 (2, 2, 1.0),
1435 (2, 3, -1.0),
1436 (3, 3, -4.0),
1437 ],
1438 );
1439
1440 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1441 assert_eq!(result.matched, 4);
1442
1443 verify_scaling_properties(&matrix, &result);
1445 }
1446
1447 #[test]
1448 fn test_mc64_arrow_indefinite() {
1449 let matrix = make_upper_tri(
1456 5,
1457 &[
1458 (0, 0, 10.0),
1459 (0, 1, 1.0),
1460 (0, 2, 1.0),
1461 (0, 3, 1.0),
1462 (0, 4, 1.0),
1463 (1, 1, -3.0),
1464 (2, 2, 5.0),
1465 (3, 3, -2.0),
1466 (4, 4, 4.0),
1467 ],
1468 );
1469
1470 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1471 assert_eq!(result.matched, 5);
1472
1473 verify_scaling_properties(&matrix, &result);
1474 }
1475
1476 #[test]
1477 fn test_mc64_trivial_1x1() {
1478 let matrix = make_upper_tri(1, &[(0, 0, 7.0)]);
1479 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1480 assert_eq!(result.matched, 1);
1481 assert_eq!(result.scaling.len(), 1);
1482 assert!(result.scaling[0] > 0.0);
1483 }
1484
1485 #[test]
1486 fn test_mc64_trivial_2x2() {
1487 let matrix = make_upper_tri(2, &[(0, 0, 3.0), (0, 1, 1.0), (1, 1, 5.0)]);
1488 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1489 assert_eq!(result.matched, 2);
1490
1491 verify_scaling_properties(&matrix, &result);
1492 }
1493
1494 #[test]
1495 fn test_mc64_not_square_error() {
1496 let triplets = vec![Triplet::new(0, 0, 1.0), Triplet::new(0, 1, 2.0)];
1498 let matrix = SparseColMat::try_new_from_triplets(2, 3, &triplets).unwrap();
1499 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1500 assert!(matches!(result, Err(SparseError::NotSquare { .. })));
1501 }
1502
1503 #[test]
1504 fn test_mc64_zero_dim_error() {
1505 let triplets: Vec<Triplet<usize, usize, f64>> = vec![];
1506 let matrix = SparseColMat::try_new_from_triplets(0, 0, &triplets).unwrap();
1507 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1508 assert!(matches!(result, Err(SparseError::InvalidInput { .. })));
1509 }
1510
1511 #[test]
1512 fn test_count_cycles_identity() {
1513 let matching = vec![0, 1, 2, 3];
1514 let (s, c, l) = count_cycles(&matching);
1515 assert_eq!(s, 4);
1516 assert_eq!(c, 0);
1517 assert_eq!(l, 0);
1518 }
1519
1520 #[test]
1521 fn test_count_cycles_two_swaps() {
1522 let matching = vec![1, 0, 3, 2];
1523 let (s, c, l) = count_cycles(&matching);
1524 assert_eq!(s, 0);
1525 assert_eq!(c, 2);
1526 assert_eq!(l, 0);
1527 }
1528
1529 #[test]
1530 fn test_count_cycles_mixed() {
1531 let matching = vec![0, 2, 1, 3, 4];
1532 let (s, c, l) = count_cycles(&matching);
1533 assert_eq!(s, 3); assert_eq!(c, 1); assert_eq!(l, 0);
1536 }
1537
1538 #[test]
1539 fn test_count_cycles_longer_cycle() {
1540 let matching = vec![1, 2, 0, 3];
1542 let (s, c, l) = count_cycles(&matching);
1543 assert_eq!(s, 1); assert_eq!(c, 0);
1545 assert_eq!(l, 1); }
1547
1548 fn verify_scaling_properties(matrix: &SparseColMat<usize, f64>, result: &Mc64Result) {
1551 use crate::testing::verify_spral_scaling_properties;
1552 verify_spral_scaling_properties("unit_test", matrix, result);
1553 }
1554
1555 #[test]
1558 fn test_duff_pralet_4x4_singular() {
1559 let matrix = make_upper_tri(
1565 4,
1566 &[
1567 (0, 0, 4.0),
1568 (0, 1, 2.0),
1569 (0, 3, 1.0),
1570 (1, 1, 5.0),
1571 (1, 2, 1.0),
1572 (2, 2, 3.0),
1573 ],
1574 );
1575
1576 let mut scaling = vec![0.5, 0.4, 0.6, 0.0]; let is_matched = vec![true, true, true, false]; duff_pralet_correction(&matrix, &mut scaling, &is_matched);
1580
1581 assert!(scaling[3] > 0.0, "unmatched scaling should be positive");
1586 assert!(scaling[3].is_finite(), "unmatched scaling should be finite");
1587
1588 assert!((scaling[0] - 0.5).abs() < 1e-12);
1590 assert!((scaling[1] - 0.4).abs() < 1e-12);
1591 assert!((scaling[2] - 0.6).abs() < 1e-12);
1592 }
1593
1594 #[test]
1595 fn test_duff_pralet_isolated_row() {
1596 let matrix = make_upper_tri(
1598 3,
1599 &[
1600 (0, 0, 4.0),
1601 (1, 1, 5.0),
1602 (2, 2, 3.0),
1604 ],
1605 );
1606
1607 let mut scaling = vec![0.5, 0.4, 0.0];
1608 let is_matched = vec![true, true, false];
1611
1612 duff_pralet_correction(&matrix, &mut scaling, &is_matched);
1613
1614 assert_eq!(
1616 scaling[2], 1.0,
1617 "isolated unmatched row should get scaling 1.0"
1618 );
1619 }
1620
1621 #[test]
1622 fn test_duff_pralet_all_positive() {
1623 let matrix = make_upper_tri(
1624 4,
1625 &[
1626 (0, 0, 4.0),
1627 (0, 1, 2.0),
1628 (0, 3, 1.0),
1629 (1, 1, 5.0),
1630 (1, 2, 1.0),
1631 (2, 2, 3.0),
1632 ],
1633 );
1634
1635 let mut scaling = vec![0.5, 0.4, 0.6, 0.0];
1636 let is_matched = vec![true, true, true, false];
1637
1638 duff_pralet_correction(&matrix, &mut scaling, &is_matched);
1639
1640 for (i, &s) in scaling.iter().enumerate() {
1641 assert!(s > 0.0, "scaling[{}] = {} should be positive", i, s);
1642 assert!(s.is_finite(), "scaling[{}] = {} should be finite", i, s);
1643 }
1644 }
1645
1646 #[test]
1649 fn test_mc64_singular_zero_diagonal() {
1650 let matrix = make_upper_tri(4, &[(0, 1, 5.0), (2, 3, 3.0)]);
1656
1657 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1658
1659 for (i, &s) in result.scaling.iter().enumerate() {
1661 assert!(s > 0.0, "scaling[{}] should be positive", i);
1662 assert!(s.is_finite(), "scaling[{}] should be finite", i);
1663 }
1664
1665 let (fwd, _) = result.matching.as_ref().arrays();
1667 let mut seen = [false; 4];
1668 for &f in fwd {
1669 assert!(!seen[f], "duplicate in matching");
1670 seen[f] = true;
1671 }
1672 }
1673
1674 #[test]
1677 fn test_mc64_nan_entry_error() {
1678 let triplets = vec![
1679 Triplet::new(0, 0, 4.0),
1680 Triplet::new(0, 1, f64::NAN),
1681 Triplet::new(1, 1, 5.0),
1682 ];
1683 let matrix = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
1684 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1685 assert!(
1686 matches!(result, Err(SparseError::InvalidInput { .. })),
1687 "NaN entry should produce InvalidInput error"
1688 );
1689 }
1690
1691 #[test]
1692 fn test_mc64_inf_entry_error() {
1693 let triplets = vec![
1694 Triplet::new(0, 0, 4.0),
1695 Triplet::new(0, 1, f64::INFINITY),
1696 Triplet::new(1, 1, 5.0),
1697 ];
1698 let matrix = SparseColMat::try_new_from_triplets(2, 2, &triplets).unwrap();
1699 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct);
1700 assert!(
1701 matches!(result, Err(SparseError::InvalidInput { .. })),
1702 "Inf entry should produce InvalidInput error"
1703 );
1704 }
1705
1706 #[test]
1709 fn test_greedy_matching_diagonal_perfect() {
1710 let matrix = make_upper_tri(4, &[(0, 0, 10.0), (1, 1, 20.0), (2, 2, 5.0), (3, 3, 15.0)]);
1712 let graph = build_cost_graph(&matrix);
1713 let state = greedy_initial_matching(&graph);
1714
1715 let matched_count = state.row_match.iter().filter(|&&m| m != UNMATCHED).count();
1716 assert_eq!(
1717 matched_count, 4,
1718 "greedy should perfectly match a diagonal matrix"
1719 );
1720
1721 for (i, &j) in state.row_match.iter().enumerate() {
1723 assert_eq!(
1724 j, i,
1725 "diagonal greedy: row {} should match col {}, got {}",
1726 i, i, j
1727 );
1728 }
1729 }
1730
1731 #[test]
1734 fn test_mc64_negative_diagonal() {
1735 let matrix = make_upper_tri(3, &[(0, 0, -10.0), (1, 1, -20.0), (2, 2, -5.0)]);
1737 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1738 assert_eq!(result.matched, 3);
1739
1740 let (fwd, _) = result.matching.as_ref().arrays();
1742 for (i, &f) in fwd.iter().enumerate() {
1743 assert_eq!(f, i, "negative diagonal should give identity matching");
1744 }
1745
1746 verify_scaling_properties(&matrix, &result);
1747 }
1748
1749 #[test]
1752 fn test_singular_unmatched_permutation_valid() {
1753 let matrix = make_upper_tri(3, &[(0, 1, 5.0)]);
1760 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1761
1762 let (fwd, inv) = result.matching.as_ref().arrays();
1764 let mut seen = [false; 3];
1765 for &f in fwd {
1766 assert!(f < 3, "fwd index out of range");
1767 assert!(!seen[f], "duplicate in fwd");
1768 seen[f] = true;
1769 }
1770 for i in 0..3 {
1772 assert_eq!(fwd[inv[i]], i, "fwd[inv[{}]] != {}", i, i);
1773 }
1774 }
1775
1776 #[test]
1777 fn test_second_matching_improves_scaling() {
1778 let matrix = make_upper_tri(
1781 6,
1782 &[
1783 (0, 0, 10.0),
1784 (0, 1, 1.0),
1785 (1, 1, 8.0),
1786 (1, 2, 2.0),
1787 (2, 2, 6.0),
1788 (2, 3, 3.0),
1789 (3, 3, 5.0),
1790 (3, 4, 1.0),
1791 (4, 4, 7.0),
1792 (0, 5, 0.1), ],
1794 );
1795
1796 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1797
1798 for (i, &s) in result.scaling.iter().enumerate() {
1800 assert!(s > 0.0, "scaling[{}] should be positive, got {}", i, s);
1801 assert!(s.is_finite(), "scaling[{}] should be finite, got {}", i, s);
1802 }
1803
1804 let symbolic = matrix.symbolic();
1807 let values = matrix.val();
1808 for j in 0..5 {
1809 let start = symbolic.col_ptr()[j];
1811 let end = symbolic.col_ptr()[j + 1];
1812 for (k, &row) in symbolic.row_idx()[start..end].iter().enumerate() {
1813 let i = row;
1814 if i == j {
1815 let scaled = result.scaling[i] * values[start + k].abs() * result.scaling[j];
1816 assert!(
1817 scaled <= 1.0 + 1e-10,
1818 "scaled diagonal ({},{}) = {:.6e} should be <= 1",
1819 i,
1820 j,
1821 scaled
1822 );
1823 }
1824 }
1825 }
1826 }
1827
1828 #[test]
1829 fn test_is_matched_uses_row_only() {
1830 let matrix = make_upper_tri(4, &[(0, 1, 5.0), (0, 3, 1.0), (2, 2, 4.0)]);
1837
1838 let result = mc64_matching(&matrix, Mc64Job::MaximumProduct).unwrap();
1839
1840 let (fwd, _) = result.matching.as_ref().arrays();
1843 for (i, &fi) in fwd.iter().enumerate().take(4) {
1844 if result.is_matched[i] {
1845 let j = fi;
1848 assert!(
1849 j < 4,
1850 "matched row {} should map to valid column, got {}",
1851 i,
1852 j
1853 );
1854 }
1855 }
1856
1857 for (i, &s) in result.scaling.iter().enumerate() {
1859 assert!(s > 0.0, "scaling[{}] positive", i);
1860 assert!(s.is_finite(), "scaling[{}] finite", i);
1861 }
1862 }
1863}