converge_optimization/graph/
matching.rs1use std::collections::VecDeque;
4
5use crate::{Error, Result};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct Matching {
10 pub pairs: Vec<(usize, usize)>,
12 pub size: usize,
14}
15
16pub fn bipartite_matching(
26 left_count: usize,
27 right_count: usize,
28 edges: &[(usize, usize)],
29) -> Result<Matching> {
30 for &(l, r) in edges {
31 if l >= left_count {
32 return Err(Error::invalid_input(format!(
33 "left node {l} out of range (left_count={left_count})"
34 )));
35 }
36 if r >= right_count {
37 return Err(Error::invalid_input(format!(
38 "right node {r} out of range (right_count={right_count})"
39 )));
40 }
41 }
42
43 let mut adj: Vec<Vec<usize>> = vec![vec![]; left_count];
44 for &(l, r) in edges {
45 adj[l].push(r);
46 }
47
48 const NONE: usize = usize::MAX;
49 let mut match_left = vec![NONE; left_count];
50 let mut match_right = vec![NONE; right_count];
51 let mut total = 0usize;
52
53 loop {
54 let mut dist = vec![NONE; left_count];
56 let mut queue = VecDeque::new();
57
58 for l in 0..left_count {
59 if match_left[l] == NONE {
60 dist[l] = 0;
61 queue.push_back(l);
62 }
63 }
64
65 let mut found_free_right = false;
66 while let Some(l) = queue.pop_front() {
67 for &r in &adj[l] {
68 let nl = match_right[r];
69 if nl == NONE {
70 found_free_right = true;
71 } else if dist[nl] == NONE {
72 dist[nl] = dist[l] + 1;
73 queue.push_back(nl);
74 }
75 }
76 }
77
78 if !found_free_right {
79 break;
80 }
81
82 for l in 0..left_count {
84 if match_left[l] == NONE && dfs(l, &adj, &mut match_left, &mut match_right, &mut dist) {
85 total += 1;
86 }
87 }
88 }
89
90 let pairs: Vec<(usize, usize)> = (0..left_count)
91 .filter(|&l| match_left[l] != NONE)
92 .map(|l| (l, match_left[l]))
93 .collect();
94
95 Ok(Matching { pairs, size: total })
96}
97
98fn dfs(
99 l: usize,
100 adj: &[Vec<usize>],
101 match_left: &mut [usize],
102 match_right: &mut [usize],
103 dist: &mut [usize],
104) -> bool {
105 const NONE: usize = usize::MAX;
106 for &r in &adj[l] {
107 let nl = match_right[r];
108 if nl == NONE || (dist[nl] == dist[l] + 1 && dfs(nl, adj, match_left, match_right, dist)) {
109 match_left[l] = r;
110 match_right[r] = l;
111 return true;
112 }
113 }
114 dist[l] = NONE; false
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn empty_graph_produces_empty_matching() {
124 let m = bipartite_matching(3, 3, &[]).unwrap();
125 assert_eq!(m.size, 0);
126 assert!(m.pairs.is_empty());
127 }
128
129 #[test]
130 fn perfect_matching_on_complete_bipartite() {
131 let edges: Vec<(usize, usize)> = (0..3).flat_map(|l| (0..3).map(move |r| (l, r))).collect();
133 let m = bipartite_matching(3, 3, &edges).unwrap();
134 assert_eq!(m.size, 3);
135 assert_eq!(m.pairs.len(), 3);
136 let mut lefts: Vec<usize> = m.pairs.iter().map(|&(l, _)| l).collect();
138 let mut rights: Vec<usize> = m.pairs.iter().map(|&(_, r)| r).collect();
139 lefts.sort_unstable();
140 rights.sort_unstable();
141 lefts.dedup();
142 rights.dedup();
143 assert_eq!(lefts.len(), 3);
144 assert_eq!(rights.len(), 3);
145 }
146
147 #[test]
148 fn partial_matching_when_right_side_is_smaller() {
149 let edges = [(0, 0), (1, 0), (2, 1), (3, 1)];
151 let m = bipartite_matching(4, 2, &edges).unwrap();
152 assert_eq!(m.size, 2);
153 }
154
155 #[test]
156 fn single_edge_matches() {
157 let m = bipartite_matching(1, 1, &[(0, 0)]).unwrap();
158 assert_eq!(m.size, 1);
159 assert_eq!(m.pairs, vec![(0, 0)]);
160 }
161
162 #[test]
163 fn disjoint_components_all_matched() {
164 let m = bipartite_matching(2, 2, &[(0, 0), (1, 1)]).unwrap();
166 assert_eq!(m.size, 2);
167 }
168
169 #[test]
170 fn augmenting_path_required() {
171 let m = bipartite_matching(2, 2, &[(0, 0), (0, 1), (1, 0)]).unwrap();
173 assert_eq!(m.size, 2);
174 }
175
176 #[test]
177 fn out_of_range_left_node_returns_error() {
178 let err = bipartite_matching(2, 2, &[(5, 0)]).unwrap_err();
179 assert!(matches!(err, Error::InvalidInput(_)));
180 }
181
182 #[test]
183 fn out_of_range_right_node_returns_error() {
184 let err = bipartite_matching(2, 2, &[(0, 5)]).unwrap_err();
185 assert!(matches!(err, Error::InvalidInput(_)));
186 }
187
188 #[test]
189 fn no_edges_means_no_matches_regardless_of_counts() {
190 let m = bipartite_matching(10, 10, &[]).unwrap();
191 assert_eq!(m.size, 0);
192 }
193
194 #[test]
195 fn matching_is_valid_no_shared_endpoints() {
196 let edges: Vec<(usize, usize)> = (0..5).flat_map(|l| (0..5).map(move |r| (l, r))).collect();
197 let m = bipartite_matching(5, 5, &edges).unwrap();
198 assert_eq!(m.size, 5);
199 let mut seen_left = std::collections::HashSet::new();
200 let mut seen_right = std::collections::HashSet::new();
201 for (l, r) in &m.pairs {
202 assert!(seen_left.insert(l), "left node {l} appears twice");
203 assert!(seen_right.insert(r), "right node {r} appears twice");
204 }
205 }
206}