Skip to main content

converge_optimization/graph/
matching.rs

1//! Graph matching algorithms — Hopcroft-Karp maximum bipartite matching.
2
3use std::collections::VecDeque;
4
5use crate::{Error, Result};
6
7/// Result of a maximum bipartite matching.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct Matching {
10    /// Matched pairs `(left_node, right_node)`, both 0-indexed.
11    pub pairs: Vec<(usize, usize)>,
12    /// Size of the matching (number of matched pairs).
13    pub size: usize,
14}
15
16/// Find a maximum bipartite matching using Hopcroft-Karp.
17///
18/// The graph has `left_count` nodes on the left side (0..left_count) and
19/// `right_count` nodes on the right side (0..right_count). `edges` is a list
20/// of `(left, right)` pairs; each edge must satisfy `left < left_count` and
21/// `right < right_count`.
22///
23/// Returns the maximum matching: the largest set of edges with no shared
24/// endpoint. Runs in O(E √V) time.
25pub fn bipartite_matching(
26    left_count: usize,
27    right_count: usize,
28    edges: &[(usize, usize)],
29) -> Result<Matching> {
30    for &(l, r) in edges {
31        if l >= left_count {
32            return Err(Error::invalid_input(format!(
33                "left node {l} out of range (left_count={left_count})"
34            )));
35        }
36        if r >= right_count {
37            return Err(Error::invalid_input(format!(
38                "right node {r} out of range (right_count={right_count})"
39            )));
40        }
41    }
42
43    let mut adj: Vec<Vec<usize>> = vec![vec![]; left_count];
44    for &(l, r) in edges {
45        adj[l].push(r);
46    }
47
48    const NONE: usize = usize::MAX;
49    let mut match_left = vec![NONE; left_count];
50    let mut match_right = vec![NONE; right_count];
51    let mut total = 0usize;
52
53    loop {
54        // BFS: build layered graph from free left nodes.
55        let mut dist = vec![NONE; left_count];
56        let mut queue = VecDeque::new();
57
58        for l in 0..left_count {
59            if match_left[l] == NONE {
60                dist[l] = 0;
61                queue.push_back(l);
62            }
63        }
64
65        let mut found_free_right = false;
66        while let Some(l) = queue.pop_front() {
67            for &r in &adj[l] {
68                let nl = match_right[r];
69                if nl == NONE {
70                    found_free_right = true;
71                } else if dist[nl] == NONE {
72                    dist[nl] = dist[l] + 1;
73                    queue.push_back(nl);
74                }
75            }
76        }
77
78        if !found_free_right {
79            break;
80        }
81
82        // DFS: augment along shortest paths found by BFS.
83        for l in 0..left_count {
84            if match_left[l] == NONE && dfs(l, &adj, &mut match_left, &mut match_right, &mut dist) {
85                total += 1;
86            }
87        }
88    }
89
90    let pairs: Vec<(usize, usize)> = (0..left_count)
91        .filter(|&l| match_left[l] != NONE)
92        .map(|l| (l, match_left[l]))
93        .collect();
94
95    Ok(Matching { pairs, size: total })
96}
97
98fn dfs(
99    l: usize,
100    adj: &[Vec<usize>],
101    match_left: &mut [usize],
102    match_right: &mut [usize],
103    dist: &mut [usize],
104) -> bool {
105    const NONE: usize = usize::MAX;
106    for &r in &adj[l] {
107        let nl = match_right[r];
108        if nl == NONE || (dist[nl] == dist[l] + 1 && dfs(nl, adj, match_left, match_right, dist)) {
109            match_left[l] = r;
110            match_right[r] = l;
111            return true;
112        }
113    }
114    dist[l] = NONE; // exhaust this node so it won't be revisited in this phase
115    false
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn empty_graph_produces_empty_matching() {
124        let m = bipartite_matching(3, 3, &[]).unwrap();
125        assert_eq!(m.size, 0);
126        assert!(m.pairs.is_empty());
127    }
128
129    #[test]
130    fn perfect_matching_on_complete_bipartite() {
131        // K_{3,3}: every left connects to every right — maximum matching is 3.
132        let edges: Vec<(usize, usize)> = (0..3).flat_map(|l| (0..3).map(move |r| (l, r))).collect();
133        let m = bipartite_matching(3, 3, &edges).unwrap();
134        assert_eq!(m.size, 3);
135        assert_eq!(m.pairs.len(), 3);
136        // Each left and right appears at most once.
137        let mut lefts: Vec<usize> = m.pairs.iter().map(|&(l, _)| l).collect();
138        let mut rights: Vec<usize> = m.pairs.iter().map(|&(_, r)| r).collect();
139        lefts.sort_unstable();
140        rights.sort_unstable();
141        lefts.dedup();
142        rights.dedup();
143        assert_eq!(lefts.len(), 3);
144        assert_eq!(rights.len(), 3);
145    }
146
147    #[test]
148    fn partial_matching_when_right_side_is_smaller() {
149        // 4 left nodes, 2 right nodes — maximum matching is 2.
150        let edges = [(0, 0), (1, 0), (2, 1), (3, 1)];
151        let m = bipartite_matching(4, 2, &edges).unwrap();
152        assert_eq!(m.size, 2);
153    }
154
155    #[test]
156    fn single_edge_matches() {
157        let m = bipartite_matching(1, 1, &[(0, 0)]).unwrap();
158        assert_eq!(m.size, 1);
159        assert_eq!(m.pairs, vec![(0, 0)]);
160    }
161
162    #[test]
163    fn disjoint_components_all_matched() {
164        // Two disjoint pairs.
165        let m = bipartite_matching(2, 2, &[(0, 0), (1, 1)]).unwrap();
166        assert_eq!(m.size, 2);
167    }
168
169    #[test]
170    fn augmenting_path_required() {
171        // Initial greedy would match 0→0, 1→0 (fail), but Hopcroft-Karp finds 0→1, 1→0.
172        let m = bipartite_matching(2, 2, &[(0, 0), (0, 1), (1, 0)]).unwrap();
173        assert_eq!(m.size, 2);
174    }
175
176    #[test]
177    fn out_of_range_left_node_returns_error() {
178        let err = bipartite_matching(2, 2, &[(5, 0)]).unwrap_err();
179        assert!(matches!(err, Error::InvalidInput(_)));
180    }
181
182    #[test]
183    fn out_of_range_right_node_returns_error() {
184        let err = bipartite_matching(2, 2, &[(0, 5)]).unwrap_err();
185        assert!(matches!(err, Error::InvalidInput(_)));
186    }
187
188    #[test]
189    fn no_edges_means_no_matches_regardless_of_counts() {
190        let m = bipartite_matching(10, 10, &[]).unwrap();
191        assert_eq!(m.size, 0);
192    }
193
194    #[test]
195    fn matching_is_valid_no_shared_endpoints() {
196        let edges: Vec<(usize, usize)> = (0..5).flat_map(|l| (0..5).map(move |r| (l, r))).collect();
197        let m = bipartite_matching(5, 5, &edges).unwrap();
198        assert_eq!(m.size, 5);
199        let mut seen_left = std::collections::HashSet::new();
200        let mut seen_right = std::collections::HashSet::new();
201        for (l, r) in &m.pairs {
202            assert!(seen_left.insert(l), "left node {l} appears twice");
203            assert!(seen_right.insert(r), "right node {r} appears twice");
204        }
205    }
206}