lsap/
lib.rs

1use derive_more::{Display, Error};
2
3#[derive(Debug, Display, Error)]
4pub enum LSAPError {
5    Invalid,
6    Infeasible,
7}
8
9// impl std::fmt::Display for LSAPError {
10//     // This trait requires `fmt` with this exact signature.
11//     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
12//         // Write strictly the first element into the supplied output
13//         // stream: `f`. Returns `fmt::Result` which indicates whether the
14//         // operation succeeded or failed. Note that `write!` uses syntax which
15//         // is very similar to `println!`.
16//         write!(f, "{:?}", self)
17//     }
18// }
19
20pub fn get_assigned_cost(
21    nr: usize,
22    nc: usize,
23    cost: &Vec<f64>,
24    maximize: bool,
25) -> Result<f64, LSAPError> {
26    let (rows, cols) = solve(nr, nc, cost, maximize)?;
27    let mut score = 0.0;
28    for i in 0..rows.len() {
29        score += cost[rows[i] * nc + cols[i]];
30    }
31    return Ok(score);
32}
33
34/// Solve the linear sum assignment problem and return a tuple of vectors containing the assigned
35///
36/// The implementation is translated from the C++ code from [Scipy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html).
37///
38/// # Arguments
39///
40/// * `nr` - number of rows in the cost matrix
41/// * `nc` - number of columns in the cost matrix
42/// * `cost` - cost matrix flattened into a vector such that item at row i, column j can be accessed via cost[i * nc + j]
43/// * `maximize` - if true, solve the maximization problem instead of the minimization problem
44pub fn solve(
45    mut nr: usize,
46    mut nc: usize,
47    cost: &Vec<f64>,
48    maximize: bool,
49) -> Result<(Vec<usize>, Vec<usize>), LSAPError> {
50    // handle trivial inputs
51    if nr == 0 || nc == 0 {
52        return Ok((vec![], vec![]));
53    }
54
55    // tall rectangular cost matrix must be transposed
56    let transpose = nc < nr;
57
58    // make a copy of the cost matrix if we need to modify it
59    let mut temp: Vec<f64>;
60    let surrogated_cost = if transpose || maximize {
61        if transpose {
62            temp = vec![0.0; nc * nr];
63            for i in 0..nr {
64                for j in 0..nc {
65                    temp[j * nr + i] = cost[i * nc + j];
66                }
67            }
68
69            std::mem::swap(&mut nr, &mut nc);
70        } else {
71            temp = cost.clone();
72        }
73
74        // negate cost matrix for maximization
75        if maximize {
76            for i in 0..(nr * nc) {
77                temp[i] = -temp[i];
78            }
79        }
80
81        &temp
82    } else {
83        cost
84    };
85
86    // test for NaN and -inf entries
87    for i in 0..(nr * nc) {
88        if surrogated_cost[i].is_nan() || surrogated_cost[i].is_infinite() {
89            return Err(LSAPError::Invalid);
90        }
91    }
92
93    // initialize variables
94    let MINUS_1: usize = nr * nc; // use this to represent -1 in the C++ code, it has the same effect
95
96    let mut u = vec![0.0; nr];
97    let mut v = vec![0.0; nc];
98    let mut shortest_path_costs: Vec<f64> = vec![f64::INFINITY; nc];
99    let mut path: Vec<usize> = vec![MINUS_1; nc];
100    let mut col4row: Vec<usize> = vec![MINUS_1; nr];
101    let mut row4col: Vec<usize> = vec![MINUS_1; nc];
102    let mut SR: Vec<bool> = vec![false; nr];
103    let mut SC: Vec<bool> = vec![false; nc];
104    let mut remaining: Vec<usize> = vec![MINUS_1; nc];
105
106    // iteratively build the solution
107    for cur_row in 0..nr {
108        let (sink, min_val) = augmenting_path(
109            nc,
110            &surrogated_cost,
111            &mut u,
112            &mut v,
113            &mut path,
114            &row4col,
115            &mut shortest_path_costs,
116            cur_row,
117            &mut SR,
118            &mut SC,
119            &mut remaining,
120            MINUS_1,
121        );
122
123        if sink == MINUS_1 {
124            return Err(LSAPError::Infeasible);
125        }
126
127        // update dual variables
128        u[cur_row] += min_val;
129        for i in 0..nr {
130            if SR[i] && i != cur_row {
131                u[i] += min_val - shortest_path_costs[col4row[i]];
132            }
133        }
134
135        for j in 0..nc {
136            if SC[j] {
137                v[j] -= min_val - shortest_path_costs[j];
138            }
139        }
140
141        // augment previous solution
142        let mut j = sink;
143        loop {
144            let i = path[j];
145            row4col[j] = i;
146            std::mem::swap(&mut col4row[i], &mut j);
147            if i == cur_row {
148                break;
149            }
150        }
151    }
152
153    let mut a = Vec::with_capacity(nr);
154    let mut b = Vec::with_capacity(nr);
155
156    if transpose {
157        for v in argsort_iter(&col4row) {
158            a.push(col4row[v]);
159            b.push(v);
160        }
161    } else {
162        for i in 0..nr {
163            a.push(i);
164            b.push(col4row[i]);
165        }
166    }
167
168    return Ok((a, b));
169}
170
171fn augmenting_path(
172    nc: usize,
173    cost: &Vec<f64>,
174    u: &mut Vec<f64>,
175    v: &mut Vec<f64>,
176    path: &mut Vec<usize>,
177    row4col: &Vec<usize>,
178    shortest_path_costs: &mut Vec<f64>,
179    mut i: usize,
180    SR: &mut Vec<bool>,
181    SC: &mut Vec<bool>,
182    remaining: &mut Vec<usize>,
183    MINUS_1: usize,
184) -> (usize, f64) {
185    let mut min_val = 0.0;
186
187    // Crouse's pseudocode uses set complements to keep track of remaining
188    // nodes.  Here we use a vector, as it is more efficient in C++ (Rust?).
189    let mut num_remaining = nc;
190    for it in 0..nc {
191        // Filling this up in reverse order ensures that the solution of a
192        // constant cost matrix is the identity matrix (c.f. #11602).
193        remaining[it] = nc - it - 1;
194    }
195
196    SR.fill(false);
197    SC.fill(false);
198    shortest_path_costs.fill(f64::INFINITY);
199
200    // find shortest augmenting path
201    let mut sink = MINUS_1;
202    while sink == MINUS_1 {
203        let mut index = MINUS_1;
204        let mut lowest = f64::INFINITY;
205        SR[i] = true;
206
207        for it in 0..num_remaining {
208            let j = remaining[it];
209
210            let r: f64 = min_val + cost[i * nc + j] - u[i] - v[j];
211            if r < shortest_path_costs[j] {
212                path[j] = i;
213                shortest_path_costs[j] = r;
214            }
215
216            // When multiple nodes have the minimum cost, we select one which
217            // gives us a new sink node. This is particularly important for
218            // integer cost matrices with small co-efficients.
219            if shortest_path_costs[j] < lowest
220                || (shortest_path_costs[j] == lowest && row4col[j] == MINUS_1)
221            {
222                lowest = shortest_path_costs[j];
223                index = it;
224            }
225        }
226
227        min_val = lowest;
228        if min_val.is_infinite() {
229            // infeasible cost matrix
230            return (MINUS_1, min_val); // returns min_val but it won't be used
231        }
232
233        let j = remaining[index];
234        if row4col[j] == MINUS_1 {
235            sink = j;
236        } else {
237            i = row4col[j];
238        }
239
240        SC[j] = true;
241        num_remaining -= 1;
242        remaining[index] = remaining[num_remaining];
243    }
244
245    return (sink, min_val); // they assign p_minVal, we return instead
246}
247
248fn argsort_iter<T: Ord>(v: &Vec<T>) -> Vec<usize> {
249    let mut index = (0..v.len()).collect::<Vec<_>>();
250    index.sort_by_key(|&i| &v[i]);
251    index
252}