converge_optimization/assignment/
mod.rs1pub mod auction;
41pub mod hungarian;
42
43use crate::{Cost, Error, Result, SolverParams, SolverStats, SolverStatus};
44use serde::{Deserialize, Serialize};
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct AssignmentProblem {
49 pub costs: Vec<Vec<Cost>>,
51 pub num_agents: usize,
53 pub num_tasks: usize,
55}
56
57impl AssignmentProblem {
58 pub fn from_costs(costs: Vec<Vec<Cost>>) -> Self {
60 let num_agents = costs.len();
61 let num_tasks = costs.first().map_or(0, Vec::len);
62 Self {
63 costs,
64 num_agents,
65 num_tasks,
66 }
67 }
68
69 pub fn from_flat(costs: Vec<Cost>, n: usize) -> Result<Self> {
71 if costs.len() != n * n {
72 return Err(Error::dimension_mismatch(n * n, costs.len()));
73 }
74 let matrix: Vec<Vec<Cost>> = costs.chunks(n).map(|c| c.to_vec()).collect();
75 Ok(Self::from_costs(matrix))
76 }
77
78 pub fn is_square(&self) -> bool {
80 self.num_agents == self.num_tasks
81 }
82
83 pub fn cost(&self, agent: usize, task: usize) -> Cost {
85 self.costs[agent][task]
86 }
87
88 pub fn validate(&self) -> Result<()> {
90 if self.num_agents == 0 {
91 return Err(Error::invalid_input("no agents"));
92 }
93 if self.num_tasks == 0 {
94 return Err(Error::invalid_input("no tasks"));
95 }
96 for row in &self.costs {
97 if row.len() != self.num_tasks {
98 return Err(Error::dimension_mismatch(self.num_tasks, row.len()));
99 }
100 }
101 Ok(())
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct AssignmentSolution {
108 pub assignments: Vec<usize>,
110 pub total_cost: Cost,
112 pub status: SolverStatus,
114 pub stats: SolverStats,
116}
117
118impl AssignmentSolution {
119 pub fn task_for_agent(&self, agent: usize) -> Option<usize> {
121 self.assignments.get(agent).copied()
122 }
123
124 pub fn iter(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
126 self.assignments.iter().enumerate().map(|(a, &t)| (a, t))
127 }
128}
129
130pub trait AssignmentSolver {
132 fn solve(
134 &self,
135 problem: &AssignmentProblem,
136 params: &SolverParams,
137 ) -> Result<AssignmentSolution>;
138
139 fn name(&self) -> &'static str;
141}
142
143pub fn solve(problem: &AssignmentProblem) -> Result<AssignmentSolution> {
145 solve_with_params(problem, &SolverParams::default())
146}
147
148pub fn solve_with_params(
150 problem: &AssignmentProblem,
151 params: &SolverParams,
152) -> Result<AssignmentSolution> {
153 problem.validate()?;
154 hungarian::HungarianSolver.solve(problem, params)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_simple_assignment() {
163 let problem =
164 AssignmentProblem::from_costs(vec![vec![10, 5, 13], vec![3, 9, 18], vec![14, 8, 7]]);
165
166 let solution = solve(&problem).unwrap();
167 assert_eq!(solution.status, SolverStatus::Optimal);
168 assert_eq!(solution.total_cost, 15);
170 }
171
172 #[test]
173 fn test_from_flat() {
174 let costs = vec![1, 2, 3, 4];
175 let problem = AssignmentProblem::from_flat(costs, 2).unwrap();
176 assert_eq!(problem.num_agents, 2);
177 assert_eq!(problem.num_tasks, 2);
178 assert_eq!(problem.cost(0, 0), 1);
179 assert_eq!(problem.cost(1, 1), 4);
180 }
181}