use crate::errcode::{ERROR_SINGULAR, WARNING_SINGULAR};
use crate::matrix_util::half_to_full;
use crate::postproc::match_postproc;
use std::iter::zip;
#[derive(Default, Debug, Clone, Copy)]
pub struct HungarianOptions {
pub array_base: usize, pub scale_if_singular: bool,
}
#[derive(Default, Debug, Clone, Copy)]
pub struct HungarianInform {
pub flag: i32,
pub stat: i32,
pub matched: usize,
}
pub fn hungarian_scale_sym(
n: usize,
ptr: &[usize],
row: &[usize],
val: &[f64],
scaling: &mut [f64],
options: &HungarianOptions,
inform: &mut HungarianInform,
match_result: Option<&mut [i32]>,
) {
inform.flag = 0;
let mut rscaling = vec![0.0; n];
let mut cscaling = vec![0.0; n];
match match_result {
Some(match_vec) => {
hungarian_wrapper(
true,
n,
n,
ptr,
row,
val,
match_vec,
&mut rscaling,
&mut cscaling,
options,
inform,
);
}
None => {
let mut perm = vec![0; n];
hungarian_wrapper(
true,
n,
n,
ptr,
row,
val,
&mut perm,
&mut rscaling,
&mut cscaling,
options,
inform,
);
}
}
for i in 0..n {
scaling[i] = ((rscaling[i] + cscaling[i]) / 2.0).exp();
}
}
pub fn hungarian_scale_unsym(
m: usize,
n: usize,
ptr: &[usize],
row: &[usize],
val: &[f64],
rscaling: &mut [f64],
cscaling: &mut [f64],
options: &HungarianOptions,
inform: &mut HungarianInform,
match_out: Option<&mut [i32]>,
) {
inform.flag = 0;
match match_out {
Some(match_result) => {
hungarian_wrapper(
false,
m,
n,
ptr,
row,
val,
match_result,
rscaling,
cscaling,
options,
inform,
);
}
None => {
let mut perm = vec![0; m];
hungarian_wrapper(
false, m, n, ptr, row, val, &mut perm, rscaling, cscaling, options, inform,
);
}
}
for r in rscaling.iter_mut() {
*r = r.exp();
}
for c in cscaling.iter_mut() {
*c = c.exp();
}
}
fn hungarian_wrapper(
sym: bool,
m: usize,
n: usize,
ptr: &[usize],
row: &[usize],
val: &[f64],
match_result: &mut [i32],
rscaling: &mut [f64],
cscaling: &mut [f64],
options: &HungarianOptions,
inform: &mut HungarianInform,
) {
assert_eq!(ptr.len(), n + 1);
inform.flag = 0;
inform.stat = 0;
let mut ne = ptr[n] - 1;
ne = 2 * ne;
let mut ptr2 = vec![0; n + 1];
let mut row2 = vec![0; ne];
let mut val2 = vec![0.0; ne];
let mut iw = vec![0; 5 * n];
let mut dualu = vec![0.0; m];
let mut dualv = vec![0.0; n];
let mut cmax = vec![0.0; n];
let mut klong = 0;
for i in 0..n {
ptr2[i] = klong;
for jlong in ptr[i]..ptr[i + 1] {
if val[jlong] == 0.0 {
continue;
}
row2[klong] = row[jlong];
val2[klong] = val[jlong].abs();
klong += 1;
}
for v in &mut val2[ptr2[i]..klong] {
*v = v.ln();
}
}
ptr2[n] = klong;
if sym {
half_to_full(n, &mut row2, &mut ptr2, &mut iw, Some(&mut val2), true);
}
for i in 0..n {
let colmax = val2[ptr2[i]..ptr2[i + 1]]
.iter()
.fold(f64::NEG_INFINITY, |a, &b| f64::max(a, b));
cmax[i] = colmax;
for v in &mut val2[ptr2[i]..ptr2[i + 1]] {
*v = colmax - *v;
}
}
hungarian_match(
m,
n,
&ptr2,
&row2,
&val2,
match_result,
&mut inform.matched,
&mut dualu,
&mut dualv,
);
if inform.matched != usize::min(m, n) {
if options.scale_if_singular {
inform.flag = WARNING_SINGULAR;
} else {
inform.flag = ERROR_SINGULAR;
rscaling.fill(0.0);
cscaling.fill(0.0);
return;
}
}
if !sym || inform.matched == n {
rscaling.copy_from_slice(&dualu[..m]);
for (c, (dv, &max)) in zip(cscaling.iter_mut(), zip(dualv, &cmax)) {
*c = dv - max;
}
match_postproc(
m,
n,
ptr,
row,
val,
rscaling,
cscaling,
inform.matched,
match_result,
&mut inform.flag,
);
return;
}
let mut old_to_new = vec![0; n];
let mut new_to_old = vec![0; n];
let mut cperm = vec![0; n];
let mut j = inform.matched; let mut k = 0;
for i in 0..m {
if match_result[i] < 0 {
old_to_new[i] = -(j as i32);
j += 1;
} else {
old_to_new[i] = k as i32;
new_to_old[k] = i;
k += 1; }
}
ne = 0;
k = 0;
ptr2[0] = 0;
let mut j2 = 0;
for i in 0..n {
let j1 = j2;
j2 = ptr2[i + 1];
if match_result[i] < 0 {
continue;
}
k += 1;
for jlong in j1..j2 {
let jj = row2[jlong];
if match_result[jj] < 0 {
continue;
}
row2[ne] = old_to_new[jj] as usize;
val2[ne] = val2[jlong];
ne += 1;
}
ptr2[k] = ne; }
let nn = k;
hungarian_match(
nn,
nn,
&ptr2,
&row2,
&val2,
&mut cperm,
&mut inform.matched,
&mut dualu,
&mut dualv,
);
for i in 0..n {
let j = old_to_new[i];
if j < 0 {
rscaling[i] = f64::NEG_INFINITY;
} else {
rscaling[i] = (dualu[j as usize] + dualv[j as usize] - cmax[i]) / 2.0;
}
}
match_result.fill(-1);
for i in 0..nn {
let j = cperm[i];
match_result[new_to_old[i]] = j;
}
for i in 0..n {
if match_result[i] == -1 {
match_result[i] = old_to_new[i];
}
}
let mut cscale = vec![0.0; n];
cscale.copy_from_slice(&rscaling[..n]);
for i in 0..n {
for jlong in ptr[i]..ptr[i + 1] {
let k = row[jlong];
if cscale[i] == f64::NEG_INFINITY && cscale[k] != f64::NEG_INFINITY {
rscaling[i] = f64::max(rscaling[i], val[jlong].abs().ln() + rscaling[k]);
}
if cscale[k] == f64::NEG_INFINITY && cscale[i] != f64::NEG_INFINITY {
rscaling[k] = f64::max(rscaling[k], val[jlong].abs().ln() + rscaling[i]);
}
}
}
for i in 0..n {
if cscale[i] != f64::NEG_INFINITY {
continue; }
if rscaling[i] == f64::NEG_INFINITY {
rscaling[i] = 0.0;
} else {
rscaling[i] = -rscaling[i];
}
}
cscaling.copy_from_slice(&rscaling[..n]);
}
fn hungarian_init_heuristic(
m: usize,
n: usize,
ptr: &[usize],
row: &[usize],
val: &[f64],
num: &mut usize,
iperm: &mut [i32],
jperm: &mut [usize],
dualu: &mut [f64],
d: &mut [f64],
l: &mut [usize],
search_from: &mut [usize],
) {
assert_eq!(ptr.len(), n + 1);
assert_eq!(row.len(), n);
assert_eq!(val.len(), n);
assert_eq!(dualu.len(), m);
assert_eq!(d.len(), n);
assert_eq!(l.len(), m);
assert_eq!(search_from.len(), n);
dualu.fill(f64::INFINITY);
l.fill(0);
for j in 0..n {
for k in ptr[j]..ptr[j + 1] {
let i = row[k];
if val[k] > dualu[i] {
continue;
}
dualu[i] = val[k]; iperm[i] = j as i32; l[i] = k; }
}
for i in 0..m {
let j = iperm[i] as usize; if j == 0 {
continue; }
iperm[i] = 0;
if jperm[j] != 0 {
continue; }
if (ptr[j + 1] - ptr[j] > m / 10) && (m > 50) {
continue;
}
*num += 1;
iperm[i] = j as i32;
jperm[j] = l[i];
}
if *num == usize::min(m, n) {
return;
}
d.fill(0.0);
search_from.copy_from_slice(&ptr[..n]);
'improve_assign: for j in 0..n {
if jperm[j] != 0 {
continue; }
if ptr[j] >= ptr[j + 1] {
continue; }
let mut i0 = row[ptr[j]];
let mut vj = val[ptr[j]] - dualu[i0];
let mut k0 = ptr[j];
for k in (ptr[j] + 1)..ptr[j + 1] {
let i = row[k];
let di = val[k] - dualu[i];
if di > vj {
continue;
}
if (di == vj) && (di != f64::INFINITY) {
if (iperm[i] != 0) || (iperm[i0] == 0) {
continue;
}
}
vj = di;
i0 = i;
k0 = k;
}
d[j] = vj;
if iperm[i0] == 0 {
*num += 1;
jperm[j] = k0;
iperm[i0] = j as i32;
search_from[j] = k0 + 1;
continue;
}
for k in k0..ptr[j + 1] {
let i = row[k];
if (val[k] - dualu[i]) > vj {
continue; }
let jj = iperm[i] as usize;
for kk in search_from[jj]..ptr[jj + 1] {
let ii = row[kk];
if iperm[ii] > 0 {
continue; }
if (val[kk] - dualu[ii]) <= d[jj] {
jperm[jj] = kk;
iperm[ii] = jj as i32;
search_from[jj] = kk + 1;
*num += 1;
jperm[j] = k;
iperm[i] = j as i32;
search_from[j] = k + 1;
continue 'improve_assign;
}
}
search_from[jj] = ptr[jj + 1];
}
}
}
fn hungarian_match(
m: usize, n: usize, ptr: &[usize], row: &[usize], val: &[f64], iperm: &mut [i32], num: &mut usize, dualu: &mut [f64], dualv: &mut [f64], ) {
let mut jperm = vec![0; n];
let mut out = vec![0; n];
let mut pr = vec![0; n];
let mut q = vec![0; m];
let mut longwork = vec![0; m];
let mut l = vec![0; m];
let mut d = vec![f64::INFINITY; usize::max(m, n)];
hungarian_init_heuristic(
m,
n,
ptr,
row,
val,
num,
iperm,
&mut jperm,
dualu,
&mut d,
&mut longwork,
&mut out,
);
if *num == usize::min(m, n) {
return; }
for jord in 0..n {
if jperm[jord] != 0 {
continue;
}
let mut dmin = f64::INFINITY;
let mut qlen = 0;
let mut low = m;
let mut up = m;
let mut csp = f64::INFINITY;
let j = jord;
let mut isp = 0;
let mut jsp = 0;
pr[j] = !0;
for klong in ptr[j]..ptr[j + 1] {
let i = row[klong];
let dnew = val[klong] - dualu[i];
if dnew >= csp {
continue;
}
if iperm[i] == -1 {
csp = dnew;
isp = klong;
jsp = j;
} else {
if dnew < dmin {
dmin = dnew;
}
d[i] = dnew;
qlen += 1;
longwork[qlen - 1] = klong;
}
}
let q0 = qlen;
qlen = 0;
for kk in 0..q0 {
let klong = longwork[kk];
let i = row[klong];
if csp <= d[i] {
d[i] = f64::INFINITY;
continue;
}
if d[i] <= dmin {
low -= 1;
q[low] = i;
l[i] = low;
} else {
qlen += 1;
l[i] = qlen;
heap_update(i, m, &mut q, &mut d, &mut l);
}
let jj = iperm[i] as usize;
out[jj] = klong;
pr[jj] = j;
}
for _ in 0..*num {
if low == up {
if qlen == 0 {
break;
}
let i = q[0]; if d[i] >= csp {
break;
}
dmin = d[i];
while qlen > 0 {
let i = q[0]; if d[i] > dmin {
break;
}
let i = heap_pop(&mut qlen, m, &mut q, &mut d, &mut l);
low -= 1;
q[low] = i;
l[i] = low;
}
}
let q0 = q[up - 1];
let dq0 = d[q0];
if dq0 >= csp {
break;
}
up -= 1;
let j = iperm[q0] as usize;
let vj = dq0 - val[jperm[j]] + dualu[q0];
for klong in ptr[j]..ptr[j + 1] {
let i = row[klong];
if l[i] >= up {
continue;
}
let dnew = vj + val[klong] - dualu[i];
if dnew >= csp {
continue;
}
if iperm[i] == -1 {
csp = dnew;
isp = klong;
jsp = j;
} else {
let di = d[i];
if di <= dnew {
continue;
}
if l[i] >= low {
continue;
}
d[i] = dnew;
if dnew <= dmin {
let lpos = l[i];
if lpos != 0 {
heap_delete(lpos, &mut qlen, m, &mut q, &mut d, &mut l);
}
low -= 1;
q[low] = i;
l[i] = low;
} else {
if l[i] == 0 {
qlen += 1;
l[i] = qlen;
}
heap_update(i, m, &mut q, &mut d, &mut l);
}
let jj = iperm[i] as usize;
out[jj] = klong;
pr[jj] = j;
}
}
}
if csp == f64::INFINITY {
continue; }
*num += 1;
let mut i = row[isp];
iperm[i] = jsp as i32;
jperm[jsp] = isp;
let mut j = jsp;
for _ in 0..*num {
let jj = pr[j];
if jj == !0 {
break;
}
let klong = out[j];
i = row[klong];
iperm[i] = jj as i32;
jperm[jj] = klong;
j = jj;
}
for kk in up..m {
let i = q[kk];
dualu[i] = dualu[i] + d[i] - csp;
}
for kk in low..m {
let i = q[kk];
d[i] = f64::INFINITY;
l[i] = 0;
}
for kk in 0..qlen {
let i = q[kk];
d[i] = f64::INFINITY;
l[i] = 0;
}
}
for j in 0..n {
let klong = jperm[j];
if klong != 0 {
dualv[j] = val[klong] - dualu[row[klong]];
} else {
dualv[j] = 0.0;
}
}
for i in 0..m {
if iperm[i] == -1 {
dualu[i] = 0.0;
}
}
if *num != usize::min(m, n) {
jperm.fill(0);
let mut k = 0;
for i in 0..m {
if iperm[i] == -1 {
k += 1;
out[k - 1] = i;
} else {
let j = iperm[i] as usize;
jperm[j] = i;
}
}
k = 0;
for j in 0..n {
if jperm[j] != 0 {
continue;
}
k += 1;
let jdum = out[k - 1];
iperm[jdum] = -(j as i32) - 1;
}
}
}
fn heap_update(idx: usize, _n: usize, q: &mut [usize], val: &[f64], l: &mut [usize]) {
let mut pos = l[idx];
if pos <= 1 {
q[pos - 1] = idx;
return;
}
let v = val[idx];
while pos > 1 {
let parent_pos = pos / 2;
let parent_idx = q[parent_pos - 1];
if v >= val[parent_idx] {
break;
}
q[pos - 1] = parent_idx;
l[parent_idx] = pos;
pos = parent_pos;
}
q[pos - 1] = idx;
l[idx] = pos;
}
fn heap_pop(qlen: &mut usize, n: usize, q: &mut [usize], val: &[f64], l: &mut [usize]) -> usize {
let root = q[0];
heap_delete(1, qlen, n, q, val, l);
root
}
fn heap_delete(
pos0: usize,
qlen: &mut usize,
_n: usize,
q: &mut [usize],
d: &[f64],
l: &mut [usize],
) {
if *qlen == pos0 {
*qlen -= 1;
return;
}
let idx = q[*qlen - 1];
let v = d[idx];
*qlen -= 1; let mut pos = pos0;
if pos > 1 {
loop {
let parent = pos / 2;
let qk = q[parent - 1];
if v >= d[qk] {
break;
}
q[pos - 1] = qk;
l[qk] = pos;
pos = parent;
if pos <= 1 {
break;
}
}
}
q[pos - 1] = idx;
l[idx] = pos;
if pos != pos0 {
return; }
loop {
let mut child = 2 * pos;
if child > *qlen {
break;
}
let mut dk = d[q[child - 1]];
if child < *qlen {
let dr = d[q[child]];
if dk > dr {
child += 1;
dk = dr;
}
}
if v <= dk {
break;
}
let qk = q[child - 1];
q[pos - 1] = qk;
l[qk] = pos;
pos = child;
}
q[pos - 1] = idx;
l[idx] = pos;
}