use num_traits::{Bounded, Signed, Zero};
use fixedbitset::FixedBitSet;
use square_matrix::SquareMatrix;
use std::iter::Sum;
pub fn kuhn_munkres<C>(weights: &SquareMatrix<C>) -> (C, Vec<usize>)
where
C: Bounded + Sum<C> + Zero + Signed + Ord + Copy,
{
let n = weights.size;
let mut xy: Vec<Option<usize>> = vec![None; n];
let mut yx: Vec<Option<usize>> = vec![None; n];
let mut lx: Vec<C> = (0..n)
.map(|row| (0..n).map(|col| weights[&(row, col)]).max().unwrap())
.collect::<Vec<_>>();
let mut ly: Vec<C> = vec![Zero::zero(); n];
let mut s = FixedBitSet::with_capacity(n);
let mut alternating = Vec::with_capacity(n);
let mut slack = vec![Zero::zero(); n];
let mut slackx = Vec::with_capacity(n);
for root in 0..n {
alternating.clear();
alternating.resize(n, None);
let mut y = {
s.clear();
s.insert(root);
for y in 0..n {
slack[y] = lx[root] + ly[y] - weights[&(root, y)];
}
slackx.clear();
slackx.resize(n, root);
Some(loop {
let mut delta = Bounded::max_value();
let mut x = 0;
let mut y = 0;
for yy in 0..n {
if alternating[yy].is_none() && slack[yy] < delta {
delta = slack[yy];
x = slackx[yy];
y = yy;
}
}
debug_assert!(s.contains(x));
if delta > Zero::zero() {
for x in s.ones() {
lx[x] = lx[x] - delta;
}
for y in 0..n {
if alternating[y].is_some() {
ly[y] = ly[y] + delta;
} else {
slack[y] = slack[y] - delta;
}
}
}
debug_assert!(lx[x] + ly[y] == weights[&(x, y)]);
alternating[y] = Some(x);
if yx[y].is_none() {
break y;
}
let x = yx[y].unwrap();
debug_assert!(!s.contains(x));
s.insert(x);
for y in 0..n {
if alternating[y].is_none() {
let alternate_slack = lx[x] + ly[y] - weights[&(x, y)];
if slack[y] > alternate_slack {
slack[y] = alternate_slack;
slackx[y] = x;
}
}
}
})
};
while y.is_some() {
let x = alternating[y.unwrap()].unwrap();
let prec = xy[x];
yx[y.unwrap()] = Some(x);
xy[x] = y;
y = prec;
}
}
(
lx.into_iter().sum::<C>() + ly.into_iter().sum(),
xy.into_iter().map(|v| v.unwrap()).collect::<Vec<_>>(),
)
}
pub fn kuhn_munkres_min<C>(weights: &SquareMatrix<C>) -> (C, Vec<usize>)
where
C: Bounded + Sum<C> + Zero + Signed + Ord + Copy,
{
let (total, assignments) = kuhn_munkres(&-weights.clone());
(-total, assignments)
}