use crate::error::{VisionError, VisionResult};
pub fn hungarian(cost: &[f32], n_workers: usize, n_jobs: usize) -> VisionResult<Vec<usize>> {
if n_workers == 0 {
return Err(VisionError::EmptyInput("hungarian: n_workers=0"));
}
if n_jobs == 0 {
return Err(VisionError::EmptyInput("hungarian: n_jobs=0"));
}
let expected = n_workers * n_jobs;
if cost.len() != expected {
return Err(VisionError::DimensionMismatch {
expected,
got: cost.len(),
});
}
for &c in cost {
if c.is_nan() || c == f32::INFINITY {
return Err(VisionError::NonFinite("hungarian: cost contains NaN/+inf"));
}
}
let n = n_workers.max(n_jobs);
let max_real = cost
.iter()
.copied()
.fold(0.0f64, |acc, c| acc.max(c as f64));
let dummy = max_real.abs().max(1.0) * 1.0e6 + 1.0e6;
let mut square = vec![dummy; n * n];
for i in 0..n_workers {
for j in 0..n_jobs {
let c = cost[i * n_jobs + j] as f64;
square[i * n + j] = c;
}
}
let assign_square = solve_square_kuhn_munkres(&square, n)?;
let mut assignment = vec![usize::MAX; n_workers];
for w in 0..n_workers {
let j = assign_square[w];
if j < n_jobs {
assignment[w] = j;
}
}
Ok(assignment)
}
pub fn exact_bipartite_match(
cost: &[f32],
n_queries: usize,
n_targets: usize,
) -> VisionResult<Vec<(usize, usize)>> {
let assign = hungarian(cost, n_queries, n_targets)?;
let mut pairs = Vec::with_capacity(assign.len().min(n_targets));
for (q, &t) in assign.iter().enumerate() {
if t != usize::MAX {
pairs.push((q, t));
}
}
Ok(pairs)
}
fn solve_square_kuhn_munkres(cost: &[f64], n: usize) -> VisionResult<Vec<usize>> {
if cost.len() != n * n {
return Err(VisionError::DimensionMismatch {
expected: n * n,
got: cost.len(),
});
}
if n == 0 {
return Ok(Vec::new());
}
let inf = f64::INFINITY;
let mut u = vec![0.0f64; n + 1];
let mut v = vec![0.0f64; n + 1];
let mut p = vec![0usize; n + 1];
let mut way = vec![0usize; n + 1];
for i in 1..=n {
p[0] = i;
let mut j0 = 0usize;
let mut minv = vec![inf; n + 1];
let mut used = vec![false; n + 1];
loop {
used[j0] = true;
let i0 = p[j0];
let mut delta = inf;
let mut j1 = 0usize;
for j in 1..=n {
if !used[j] {
let row = i0 - 1;
let col = j - 1;
let cur = cost[row * n + col] - u[i0] - v[j];
if cur < minv[j] {
minv[j] = cur;
way[j] = j0;
}
if minv[j] < delta {
delta = minv[j];
j1 = j;
}
}
}
if delta == inf {
return Err(VisionError::Internal(
"hungarian: no augmenting path (algorithm bug)".into(),
));
}
for j in 0..=n {
if used[j] {
u[p[j]] += delta;
v[j] -= delta;
} else {
minv[j] -= delta;
}
}
j0 = j1;
if p[j0] == 0 {
break;
}
}
loop {
let j1 = way[j0];
p[j0] = p[j1];
j0 = j1;
if j0 == 0 {
break;
}
}
}
let mut assignment = vec![0usize; n];
for j in 1..=n {
if p[j] >= 1 && p[j] <= n {
assignment[p[j] - 1] = j - 1;
}
}
Ok(assignment)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::detection::set_match::bipartite_match;
use crate::handle::LcgRng;
fn total_cost(cost: &[f32], n_jobs: usize, assignment: &[usize]) -> f32 {
let mut s = 0.0f64;
for (w, &j) in assignment.iter().enumerate() {
if j != usize::MAX {
s += cost[w * n_jobs + j] as f64;
}
}
s as f32
}
#[test]
fn identity_diagonal_3x3() {
#[rustfmt::skip]
let cost = vec![
0.0f32, 1.0, 1.0,
1.0, 0.0, 1.0,
1.0, 1.0, 0.0,
];
let a = hungarian(&cost, 3, 3).expect("ok");
assert_eq!(a, vec![0, 1, 2]);
assert!(total_cost(&cost, 3, &a).abs() < 1e-6);
}
#[test]
fn permutation_assignment() {
#[rustfmt::skip]
let cost = vec![
1.0f32, 0.0, 1.0,
1.0, 1.0, 0.0,
0.0, 1.0, 1.0,
];
let a = hungarian(&cost, 3, 3).expect("ok");
assert_eq!(a, vec![1, 2, 0]);
assert!(total_cost(&cost, 3, &a).abs() < 1e-6);
}
#[test]
fn uniform_cost_full_assignment() {
let cost = vec![3.0f32; 16]; let a = hungarian(&cost, 4, 4).expect("ok");
assert_eq!(a.len(), 4);
let cols: std::collections::HashSet<usize> = a.iter().copied().collect();
assert_eq!(cols.len(), 4, "all cols distinct: {a:?}");
assert!((total_cost(&cost, 4, &a) - 12.0).abs() < 1e-5);
}
#[test]
fn tiny_1x1() {
let cost = vec![3.25f32];
let a = hungarian(&cost, 1, 1).expect("ok");
assert_eq!(a, vec![0]);
}
#[test]
fn hand_2x2_known_optimum() {
#[rustfmt::skip]
let cost = vec![
4.0f32, 1.0,
2.0, 5.0,
];
let a = hungarian(&cost, 2, 2).expect("ok");
assert_eq!(a, vec![1, 0]);
assert!((total_cost(&cost, 2, &a) - 3.0).abs() < 1e-6);
}
#[test]
fn hand_3x3_known_optimum() {
#[rustfmt::skip]
let cost = vec![
2.0f32, 3.0, 3.0,
3.0, 2.0, 3.0,
3.0, 3.0, 2.0,
];
let a = hungarian(&cost, 3, 3).expect("ok");
assert!((total_cost(&cost, 3, &a) - 6.0).abs() < 1e-6);
}
#[test]
fn beats_greedy_on_adversarial_3x3() {
let mut rng = LcgRng::new(20_240_521);
let n = 5;
let mut cost = vec![0.0f32; n * n];
for c in cost.iter_mut() {
*c = (rng.next_u32() as f32) / 4_294_967_296.0 * 10.0;
}
let hung = hungarian(&cost, n, n).expect("ok");
let h_cost = total_cost(&cost, n, &hung);
let greedy = bipartite_match(&cost, n, n).expect("ok");
let mut g_assign = vec![usize::MAX; n];
for &(q, t) in &greedy {
g_assign[q] = t;
}
let g_cost = total_cost(&cost, n, &g_assign);
assert!(
h_cost <= g_cost + 1e-5,
"Hungarian {h_cost} must beat or equal greedy {g_cost}"
);
let mut perm: Vec<usize> = (0..n).collect();
let mut best = f64::INFINITY;
permute(&mut perm, 0, &cost, n, &mut best);
assert!(
(h_cost as f64 - best).abs() < 1e-4,
"Hungarian {h_cost} did not match brute-force {best}"
);
}
fn permute(perm: &mut [usize], idx: usize, cost: &[f32], n: usize, best: &mut f64) {
if idx == perm.len() {
let s: f64 = (0..n).map(|i| cost[i * n + perm[i]] as f64).sum();
if s < *best {
*best = s;
}
return;
}
for k in idx..perm.len() {
perm.swap(idx, k);
permute(perm, idx + 1, cost, n, best);
perm.swap(idx, k);
}
}
#[test]
fn rectangular_more_jobs_than_workers() {
#[rustfmt::skip]
let cost = vec![
1.0f32, 5.0, 10.0,
10.0, 1.0, 5.0,
];
let a = hungarian(&cost, 2, 3).expect("ok");
assert_eq!(a.len(), 2);
assert!(a.iter().all(|&j| j != usize::MAX));
let cols: std::collections::HashSet<usize> = a.iter().copied().collect();
assert_eq!(cols.len(), 2);
assert!((total_cost(&cost, 3, &a) - 2.0).abs() < 1e-5);
}
#[test]
fn rectangular_more_workers_than_jobs() {
#[rustfmt::skip]
let cost = vec![
1.0f32, 10.0,
10.0, 1.0,
5.0, 5.0,
];
let a = hungarian(&cost, 3, 2).expect("ok");
assert_eq!(a.len(), 3);
let unmatched = a.iter().filter(|&&j| j == usize::MAX).count();
assert_eq!(unmatched, 1);
let matched: Vec<(usize, usize)> = a
.iter()
.enumerate()
.filter_map(|(w, &j)| if j == usize::MAX { None } else { Some((w, j)) })
.collect();
assert_eq!(matched.len(), 2);
let c: f64 = matched.iter().map(|&(w, j)| cost[w * 2 + j] as f64).sum();
assert!((c - 2.0).abs() < 1e-5, "expected cost 2, got {c}");
}
#[test]
fn deterministic_results() {
let mut rng = LcgRng::new(7);
let n = 6;
let mut cost = vec![0.0f32; n * n];
for c in cost.iter_mut() {
*c = (rng.next_u32() as f32) / 4_294_967_296.0 * 5.0;
}
let a1 = hungarian(&cost, n, n).expect("ok");
let a2 = hungarian(&cost, n, n).expect("ok");
assert_eq!(a1, a2);
}
#[test]
fn permutation_invariance_of_total_cost() {
let mut rng = LcgRng::new(9);
let n = 5;
let mut cost = vec![0.0f32; n * n];
for c in cost.iter_mut() {
*c = (rng.next_u32() as f32) / 4_294_967_296.0 * 7.0;
}
let base = hungarian(&cost, n, n).expect("ok");
let base_cost = total_cost(&cost, n, &base);
let perm = [2usize, 4, 1, 0, 3];
let mut permuted = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
permuted[perm[i] * n + perm[j]] = cost[i * n + j];
}
}
let permuted_assign = hungarian(&permuted, n, n).expect("ok");
let permuted_cost = total_cost(&permuted, n, &permuted_assign);
assert!(
(base_cost - permuted_cost).abs() < 1e-4,
"permutation invariance: {base_cost} vs {permuted_cost}"
);
}
#[test]
fn hungarian_le_greedy_random() {
let mut rng = LcgRng::new(11);
for trial in 0..20 {
let n = 3 + (trial % 4);
let mut cost = vec![0.0f32; n * n];
for c in cost.iter_mut() {
*c = (rng.next_u32() as f32) / 4_294_967_296.0 * 10.0;
}
let h = hungarian(&cost, n, n).expect("ok");
let hc = total_cost(&cost, n, &h);
let g = bipartite_match(&cost, n, n).expect("ok");
let mut gv = vec![usize::MAX; n];
for &(q, t) in &g {
gv[q] = t;
}
let gc = total_cost(&cost, n, &gv);
assert!(
hc <= gc + 1e-4,
"trial {trial}: Hungarian {hc} > greedy {gc}"
);
}
}
#[test]
fn empty_workers_errors() {
let c: Vec<f32> = vec![];
assert!(matches!(
hungarian(&c, 0, 3),
Err(VisionError::EmptyInput(_))
));
}
#[test]
fn empty_jobs_errors() {
let c: Vec<f32> = vec![];
assert!(matches!(
hungarian(&c, 3, 0),
Err(VisionError::EmptyInput(_))
));
}
#[test]
fn cost_length_mismatch_errors() {
let c = vec![0.0f32; 5]; let r = hungarian(&c, 3, 3);
assert!(matches!(
r,
Err(VisionError::DimensionMismatch {
expected: 9,
got: 5
})
));
}
#[test]
fn nan_cost_errors() {
let mut c = vec![0.0f32; 9];
c[4] = f32::NAN;
let r = hungarian(&c, 3, 3);
assert!(matches!(r, Err(VisionError::NonFinite(_))));
}
#[test]
fn inf_cost_errors() {
let mut c = vec![0.0f32; 9];
c[4] = f32::INFINITY;
let r = hungarian(&c, 3, 3);
assert!(matches!(r, Err(VisionError::NonFinite(_))));
}
#[test]
fn exact_bipartite_match_returns_pairs() {
#[rustfmt::skip]
let cost = vec![
0.0f32, 1.0,
1.0, 0.0,
];
let pairs = exact_bipartite_match(&cost, 2, 2).expect("ok");
assert_eq!(pairs.len(), 2);
let mut sorted = pairs.clone();
sorted.sort_unstable();
assert_eq!(sorted, vec![(0, 0), (1, 1)]);
}
#[test]
fn exact_bipartite_match_drops_unmatched() {
#[rustfmt::skip]
let cost = vec![
0.0f32, 5.0,
5.0, 0.0,
2.5, 2.5,
];
let pairs = exact_bipartite_match(&cost, 3, 2).expect("ok");
assert_eq!(pairs.len(), 2);
for &(q, t) in &pairs {
assert!(q < 3);
assert!(t < 2);
}
}
#[test]
fn negative_costs_handled() {
#[rustfmt::skip]
let cost = vec![
-5.0f32, -1.0,
-1.0, -5.0,
];
let a = hungarian(&cost, 2, 2).expect("ok");
assert_eq!(a, vec![0, 1]);
assert!((total_cost(&cost, 2, &a) + 10.0).abs() < 1e-5);
}
#[test]
fn large_8x8_random_optimal() {
let mut rng = LcgRng::new(2026);
let n = 8;
let mut cost = vec![0.0f32; n * n];
for c in cost.iter_mut() {
*c = (rng.next_u32() as f32) / 4_294_967_296.0 * 100.0;
}
let h = hungarian(&cost, n, n).expect("ok");
let hc = total_cost(&cost, n, &h);
let cols: std::collections::HashSet<usize> = h.iter().copied().collect();
assert_eq!(cols.len(), n);
let g = bipartite_match(&cost, n, n).expect("ok");
let mut gv = vec![usize::MAX; n];
for &(q, t) in &g {
gv[q] = t;
}
let gc = total_cost(&cost, n, &gv);
assert!(hc <= gc + 1e-4, "8x8: hungarian {hc} > greedy {gc}");
}
}