1use derive_more::{Display, Error};
2
3#[derive(Debug, Display, Error)]
4pub enum LSAPError {
5 Invalid,
6 Infeasible,
7}
8
9pub fn get_assigned_cost(
21 nr: usize,
22 nc: usize,
23 cost: &Vec<f64>,
24 maximize: bool,
25) -> Result<f64, LSAPError> {
26 let (rows, cols) = solve(nr, nc, cost, maximize)?;
27 let mut score = 0.0;
28 for i in 0..rows.len() {
29 score += cost[rows[i] * nc + cols[i]];
30 }
31 return Ok(score);
32}
33
34pub fn solve(
45 mut nr: usize,
46 mut nc: usize,
47 cost: &Vec<f64>,
48 maximize: bool,
49) -> Result<(Vec<usize>, Vec<usize>), LSAPError> {
50 if nr == 0 || nc == 0 {
52 return Ok((vec![], vec![]));
53 }
54
55 let transpose = nc < nr;
57
58 let mut temp: Vec<f64>;
60 let surrogated_cost = if transpose || maximize {
61 if transpose {
62 temp = vec![0.0; nc * nr];
63 for i in 0..nr {
64 for j in 0..nc {
65 temp[j * nr + i] = cost[i * nc + j];
66 }
67 }
68
69 std::mem::swap(&mut nr, &mut nc);
70 } else {
71 temp = cost.clone();
72 }
73
74 if maximize {
76 for i in 0..(nr * nc) {
77 temp[i] = -temp[i];
78 }
79 }
80
81 &temp
82 } else {
83 cost
84 };
85
86 for i in 0..(nr * nc) {
88 if surrogated_cost[i].is_nan() || surrogated_cost[i].is_infinite() {
89 return Err(LSAPError::Invalid);
90 }
91 }
92
93 let MINUS_1: usize = nr * nc; let mut u = vec![0.0; nr];
97 let mut v = vec![0.0; nc];
98 let mut shortest_path_costs: Vec<f64> = vec![f64::INFINITY; nc];
99 let mut path: Vec<usize> = vec![MINUS_1; nc];
100 let mut col4row: Vec<usize> = vec![MINUS_1; nr];
101 let mut row4col: Vec<usize> = vec![MINUS_1; nc];
102 let mut SR: Vec<bool> = vec![false; nr];
103 let mut SC: Vec<bool> = vec![false; nc];
104 let mut remaining: Vec<usize> = vec![MINUS_1; nc];
105
106 for cur_row in 0..nr {
108 let (sink, min_val) = augmenting_path(
109 nc,
110 &surrogated_cost,
111 &mut u,
112 &mut v,
113 &mut path,
114 &row4col,
115 &mut shortest_path_costs,
116 cur_row,
117 &mut SR,
118 &mut SC,
119 &mut remaining,
120 MINUS_1,
121 );
122
123 if sink == MINUS_1 {
124 return Err(LSAPError::Infeasible);
125 }
126
127 u[cur_row] += min_val;
129 for i in 0..nr {
130 if SR[i] && i != cur_row {
131 u[i] += min_val - shortest_path_costs[col4row[i]];
132 }
133 }
134
135 for j in 0..nc {
136 if SC[j] {
137 v[j] -= min_val - shortest_path_costs[j];
138 }
139 }
140
141 let mut j = sink;
143 loop {
144 let i = path[j];
145 row4col[j] = i;
146 std::mem::swap(&mut col4row[i], &mut j);
147 if i == cur_row {
148 break;
149 }
150 }
151 }
152
153 let mut a = Vec::with_capacity(nr);
154 let mut b = Vec::with_capacity(nr);
155
156 if transpose {
157 for v in argsort_iter(&col4row) {
158 a.push(col4row[v]);
159 b.push(v);
160 }
161 } else {
162 for i in 0..nr {
163 a.push(i);
164 b.push(col4row[i]);
165 }
166 }
167
168 return Ok((a, b));
169}
170
171fn augmenting_path(
172 nc: usize,
173 cost: &Vec<f64>,
174 u: &mut Vec<f64>,
175 v: &mut Vec<f64>,
176 path: &mut Vec<usize>,
177 row4col: &Vec<usize>,
178 shortest_path_costs: &mut Vec<f64>,
179 mut i: usize,
180 SR: &mut Vec<bool>,
181 SC: &mut Vec<bool>,
182 remaining: &mut Vec<usize>,
183 MINUS_1: usize,
184) -> (usize, f64) {
185 let mut min_val = 0.0;
186
187 let mut num_remaining = nc;
190 for it in 0..nc {
191 remaining[it] = nc - it - 1;
194 }
195
196 SR.fill(false);
197 SC.fill(false);
198 shortest_path_costs.fill(f64::INFINITY);
199
200 let mut sink = MINUS_1;
202 while sink == MINUS_1 {
203 let mut index = MINUS_1;
204 let mut lowest = f64::INFINITY;
205 SR[i] = true;
206
207 for it in 0..num_remaining {
208 let j = remaining[it];
209
210 let r: f64 = min_val + cost[i * nc + j] - u[i] - v[j];
211 if r < shortest_path_costs[j] {
212 path[j] = i;
213 shortest_path_costs[j] = r;
214 }
215
216 if shortest_path_costs[j] < lowest
220 || (shortest_path_costs[j] == lowest && row4col[j] == MINUS_1)
221 {
222 lowest = shortest_path_costs[j];
223 index = it;
224 }
225 }
226
227 min_val = lowest;
228 if min_val.is_infinite() {
229 return (MINUS_1, min_val); }
232
233 let j = remaining[index];
234 if row4col[j] == MINUS_1 {
235 sink = j;
236 } else {
237 i = row4col[j];
238 }
239
240 SC[j] = true;
241 num_remaining -= 1;
242 remaining[index] = remaining[num_remaining];
243 }
244
245 return (sink, min_val); }
247
248fn argsort_iter<T: Ord>(v: &Vec<T>) -> Vec<usize> {
249 let mut index = (0..v.len()).collect::<Vec<_>>();
250 index.sort_by_key(|&i| &v[i]);
251 index
252}