pub mod hungarian;
pub mod auction;
use crate::{Cost, Error, Result, SolverParams, SolverStats, SolverStatus};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssignmentProblem {
pub costs: Vec<Vec<Cost>>,
pub num_agents: usize,
pub num_tasks: usize,
}
impl AssignmentProblem {
pub fn from_costs(costs: Vec<Vec<Cost>>) -> Self {
let num_agents = costs.len();
let num_tasks = costs.first().map_or(0, Vec::len);
Self {
costs,
num_agents,
num_tasks,
}
}
pub fn from_flat(costs: Vec<Cost>, n: usize) -> Result<Self> {
if costs.len() != n * n {
return Err(Error::dimension_mismatch(n * n, costs.len()));
}
let matrix: Vec<Vec<Cost>> = costs.chunks(n).map(|c| c.to_vec()).collect();
Ok(Self::from_costs(matrix))
}
pub fn is_square(&self) -> bool {
self.num_agents == self.num_tasks
}
pub fn cost(&self, agent: usize, task: usize) -> Cost {
self.costs[agent][task]
}
pub fn validate(&self) -> Result<()> {
if self.num_agents == 0 {
return Err(Error::invalid_input("no agents"));
}
if self.num_tasks == 0 {
return Err(Error::invalid_input("no tasks"));
}
for row in &self.costs {
if row.len() != self.num_tasks {
return Err(Error::dimension_mismatch(self.num_tasks, row.len()));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AssignmentSolution {
pub assignments: Vec<usize>,
pub total_cost: Cost,
pub status: SolverStatus,
pub stats: SolverStats,
}
impl AssignmentSolution {
pub fn task_for_agent(&self, agent: usize) -> Option<usize> {
self.assignments.get(agent).copied()
}
pub fn iter(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
self.assignments.iter().enumerate().map(|(a, &t)| (a, t))
}
}
pub trait AssignmentSolver {
fn solve(&self, problem: &AssignmentProblem, params: &SolverParams) -> Result<AssignmentSolution>;
fn name(&self) -> &'static str;
}
pub fn solve(problem: &AssignmentProblem) -> Result<AssignmentSolution> {
solve_with_params(problem, &SolverParams::default())
}
pub fn solve_with_params(problem: &AssignmentProblem, params: &SolverParams) -> Result<AssignmentSolution> {
problem.validate()?;
hungarian::HungarianSolver.solve(problem, params)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_assignment() {
let problem = AssignmentProblem::from_costs(vec![
vec![10, 5, 13],
vec![3, 9, 18],
vec![14, 8, 7],
]);
let solution = solve(&problem).unwrap();
assert_eq!(solution.status, SolverStatus::Optimal);
assert_eq!(solution.total_cost, 15);
}
#[test]
fn test_from_flat() {
let costs = vec![1, 2, 3, 4];
let problem = AssignmentProblem::from_flat(costs, 2).unwrap();
assert_eq!(problem.num_agents, 2);
assert_eq!(problem.num_tasks, 2);
assert_eq!(problem.cost(0, 0), 1);
assert_eq!(problem.cost(1, 1), 4);
}
}