use super::{AssignmentProblem, AssignmentSolution, AssignmentSolver};
use crate::{Cost, Error, Result, SolverParams, SolverStats, SolverStatus};
use std::time::Instant;
pub struct HungarianSolver;
impl AssignmentSolver for HungarianSolver {
fn solve(&self, problem: &AssignmentProblem, params: &SolverParams) -> Result<AssignmentSolution> {
solve_hungarian(problem, params)
}
fn name(&self) -> &'static str {
"hungarian"
}
}
pub fn solve(problem: &AssignmentProblem) -> Result<AssignmentSolution> {
solve_hungarian(problem, &SolverParams::default())
}
fn solve_hungarian(problem: &AssignmentProblem, params: &SolverParams) -> Result<AssignmentSolution> {
let start = Instant::now();
let n = problem.num_agents;
let m = problem.num_tasks;
if n == 0 || m == 0 {
return Err(Error::invalid_input("empty problem"));
}
let size = n.max(m);
let mut cost = vec![vec![0i64; size]; size];
let large = problem.costs.iter()
.flat_map(|row| row.iter())
.max()
.copied()
.unwrap_or(0)
.saturating_add(1);
for i in 0..size {
for j in 0..size {
cost[i][j] = if i < n && j < m {
problem.costs[i][j]
} else {
large };
}
}
let mut u = vec![0i64; size + 1];
let mut v = vec![0i64; size + 1];
let mut p = vec![0usize; size + 1];
let mut way = vec![0usize; size + 1];
let mut iterations = 0;
for i in 1..=size {
if params.has_time_limit() && start.elapsed().as_secs_f64() > params.time_limit_seconds {
return Err(Error::timeout(params.time_limit_seconds));
}
p[0] = i;
let mut j0 = 0usize;
let mut minv = vec![i64::MAX; size + 1];
let mut used = vec![false; size + 1];
loop {
iterations += 1;
used[j0] = true;
let i0 = p[j0];
let mut delta = i64::MAX;
let mut j1 = 0usize;
for j in 1..=size {
if !used[j] {
let cur = cost[i0 - 1][j - 1] - u[i0] - v[j];
if cur < minv[j] {
minv[j] = cur;
way[j] = j0;
}
if minv[j] < delta {
delta = minv[j];
j1 = j;
}
}
}
for j in 0..=size {
if used[j] {
u[p[j]] += delta;
v[j] -= delta;
} else {
minv[j] -= delta;
}
}
j0 = j1;
if p[j0] == 0 {
break;
}
}
loop {
let j1 = way[j0];
p[j0] = p[j1];
j0 = j1;
if j0 == 0 {
break;
}
}
}
let mut assignments = vec![0usize; n];
let mut total_cost: Cost = 0;
for j in 1..=size {
if p[j] != 0 && p[j] <= n && j <= m {
let agent = p[j] - 1;
let task = j - 1;
assignments[agent] = task;
total_cost += problem.costs[agent][task];
}
}
let elapsed = start.elapsed().as_secs_f64();
Ok(AssignmentSolution {
assignments,
total_cost,
status: SolverStatus::Optimal,
stats: SolverStats {
solve_time_seconds: elapsed,
iterations,
objective_value: Some(total_cost as f64),
..Default::default()
},
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_3x3() {
let problem = AssignmentProblem::from_costs(vec![
vec![10, 5, 13],
vec![3, 9, 18],
vec![14, 8, 7],
]);
let solution = solve(&problem).unwrap();
assert_eq!(solution.total_cost, 15);
assert_eq!(solution.status, SolverStatus::Optimal);
}
#[test]
fn test_1x1() {
let problem = AssignmentProblem::from_costs(vec![vec![42]]);
let solution = solve(&problem).unwrap();
assert_eq!(solution.total_cost, 42);
assert_eq!(solution.assignments, vec![0]);
}
#[test]
fn test_2x2() {
let problem = AssignmentProblem::from_costs(vec![
vec![1, 2],
vec![3, 4],
]);
let solution = solve(&problem).unwrap();
assert_eq!(solution.total_cost, 5);
}
#[test]
fn test_negative_costs() {
let problem = AssignmentProblem::from_costs(vec![
vec![-1, -2],
vec![-3, -4],
]);
let solution = solve(&problem).unwrap();
assert_eq!(solution.total_cost, -5);
}
#[test]
fn test_larger() {
let problem = AssignmentProblem::from_costs(vec![
vec![7, 53, 183, 439],
vec![497, 383, 563, 79],
vec![627, 343, 773, 959],
vec![447, 283, 463, 29],
]);
let solution = solve(&problem).unwrap();
assert!(solution.total_cost <= 892);
}
}