#[allow(dead_code)]
pub fn solve(cost: &[Vec<f32>]) -> Option<Vec<usize>> {
let n = cost.len();
if n == 0 {
return Some(Vec::new());
}
if cost.iter().any(|row| row.len() != n) {
return None;
}
let inf = f32::INFINITY;
let mut u = vec![0.0_f32; n + 1];
let mut v = vec![0.0_f32; n + 1];
let mut p = vec![0_usize; n + 1];
let mut way = vec![0_usize; n + 1];
for i in 1..=n {
p[0] = i;
let mut j0 = 0_usize;
let mut minv = vec![inf; n + 1];
let mut used = vec![false; n + 1];
loop {
used[j0] = true;
let i0 = p[j0];
let mut delta = inf;
let mut j1 = 0_usize;
for j in 1..=n {
if !used[j] {
let cur = cost[i0 - 1][j - 1] - u[i0] - v[j];
if cur < minv[j] {
minv[j] = cur;
way[j] = j0;
}
if minv[j] < delta {
delta = minv[j];
j1 = j;
}
}
}
for j in 0..=n {
if used[j] {
u[p[j]] += delta;
v[j] -= delta;
} else {
minv[j] -= delta;
}
}
j0 = j1;
if p[j0] == 0 {
break;
}
}
loop {
let j1 = way[j0];
p[j0] = p[j1];
j0 = j1;
if j0 == 0 {
break;
}
}
}
let mut result = vec![0_usize; n];
for j in 1..=n {
if p[j] > 0 {
result[p[j] - 1] = j - 1;
}
}
Some(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_matrix_returns_empty_assignment() {
let cost: Vec<Vec<f32>> = Vec::new();
let assignment = solve(&cost).expect("empty matrix is valid");
assert!(assignment.is_empty());
}
#[test]
fn one_by_one_matrix_returns_self() {
let cost = vec![vec![3.5_f32]];
let assignment = solve(&cost).expect("1x1 valid");
assert_eq!(assignment, vec![0]);
}
#[test]
fn diagonal_zero_matrix_returns_identity() {
let n = 3;
let mut cost = vec![vec![10.0_f32; n]; n];
for (i, row) in cost.iter_mut().enumerate() {
row[i] = 0.0;
}
let assignment = solve(&cost).expect("3x3 valid");
assert_eq!(assignment, vec![0, 1, 2]);
}
#[test]
fn anti_diagonal_zero_matrix_returns_reverse_permutation() {
let cost = vec![
vec![10.0_f32, 10.0, 0.0],
vec![10.0, 0.0, 10.0],
vec![0.0, 10.0, 10.0],
];
let assignment = solve(&cost).expect("3x3 valid");
assert_eq!(assignment, vec![2, 1, 0]);
}
#[test]
fn permutation_matrix_recovered() {
let cost = vec![
vec![5.0_f32, 0.0, 5.0],
vec![5.0, 5.0, 0.0],
vec![0.0, 5.0, 5.0],
];
let assignment = solve(&cost).expect("3x3 valid");
assert_eq!(assignment, vec![1, 2, 0]);
}
#[test]
fn rejects_non_square_matrix() {
let cost = vec![vec![1.0_f32, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
assert!(solve(&cost).is_none());
}
#[test]
fn handles_negative_costs() {
let cost = vec![vec![-1.0_f32, -3.0], vec![-2.0, -5.0]];
let assignment = solve(&cost).expect("2x2 valid");
assert_eq!(assignment, vec![0, 1]);
}
#[test]
fn cost_matrix_with_repeated_rows_still_assigns_unique_columns() {
let cost = vec![
vec![1.0_f32, 2.0, 3.0],
vec![1.0, 2.0, 3.0],
vec![1.0, 2.0, 3.0],
];
let assignment = solve(&cost).expect("3x3 valid");
let mut sorted = assignment.clone();
sorted.sort();
assert_eq!(sorted, vec![0, 1, 2], "must be a permutation");
}
}