Skip to main content

zer_cluster/
graph.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use petgraph::{
4    graph::{NodeIndex, UnGraph},
5    visit::EdgeRef,
6};
7use zer_core::{record::RecordId, scoring::ScoredPair};
8
9/// Parameters controlling cluster shape after graph construction.
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
11pub struct ClusterConfig {
12    /// Clusters larger than this are subjected to star pruning.
13    pub max_cluster_size: usize,
14    /// Edges with weight below this threshold are removed before extracting
15    /// components (weak-edge removal / chain-breaking).
16    pub within_cluster_min: f32,
17}
18
19impl Default for ClusterConfig {
20    fn default() -> Self {
21        Self {
22            max_cluster_size:   50,
23            within_cluster_min: 0.85,
24        }
25    }
26}
27
28/// Undirected similarity graph over records.
29///
30/// Each node is a `RecordId`; each edge weight is the `match_probability` of
31/// the `AutoMatch` pair that connected those two records.
32pub struct ClusterGraph {
33    graph:    UnGraph<RecordId, f32>,
34    node_map: HashMap<RecordId, NodeIndex>,
35}
36
37impl ClusterGraph {
38    pub fn new() -> Self {
39        Self {
40            graph:    UnGraph::new_undirected(),
41            node_map: HashMap::new(),
42        }
43    }
44
45    /// Add `AutoMatch` pairs to the graph. Non-AutoMatch pairs are ignored.
46    pub fn add_pairs(&mut self, pairs: &[ScoredPair]) {
47        for pair in pairs {
48            let a = self.get_or_insert(pair.record_a);
49            let b = self.get_or_insert(pair.record_b);
50            // Avoid duplicate edges, keep the higher-weight one.
51            if let Some(edge) = self.graph.find_edge(a, b) {
52                let w = self.graph.edge_weight_mut(edge).unwrap();
53                if pair.match_probability > *w {
54                    *w = pair.match_probability;
55                }
56            } else {
57                self.graph.add_edge(a, b, pair.match_probability);
58            }
59        }
60    }
61
62    /// Compute clusters using the two-phase chain-breaking algorithm:
63    ///
64    /// 1. **Weak-edge removal**: remove all edges with weight <
65    ///    `config.within_cluster_min` then extract connected components.
66    /// 2. **Star pruning**: for any component whose size exceeds
67    ///    `config.max_cluster_size`, find the hub (highest-degree node in the
68    ///    original graph), remove all non-hub edges below the min threshold,
69    ///    and re-extract components from that sub-graph.
70    ///
71    /// Returns only non-trivial components (size ≥ 2).
72    pub fn compute_clusters(&self, config: &ClusterConfig) -> Vec<Vec<RecordId>> {
73        let pruned = weak_edge_removal(&self.graph, config.within_cluster_min);
74        let mut components = extract_components(&pruned);
75
76        // Star pruning for oversized components.
77        let mut result = Vec::new();
78        for comp in components.drain(..) {
79            if comp.len() <= config.max_cluster_size {
80                if comp.len() >= 2 {
81                    result.push(comp);
82                }
83            } else {
84                let sub = star_prune(&self.graph, &comp, config.within_cluster_min);
85                result.extend(sub.into_iter().filter(|c| c.len() >= 2));
86            }
87        }
88        result
89    }
90
91    fn get_or_insert(&mut self, id: RecordId) -> NodeIndex {
92        if let Some(&idx) = self.node_map.get(&id) {
93            return idx;
94        }
95        let idx = self.graph.add_node(id);
96        self.node_map.insert(id, idx);
97        idx
98    }
99}
100
101impl Default for ClusterGraph {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107// ── Graph algorithms ──────────────────────────────────────────────────────────
108
109/// Clone the graph, remove all edges below `min_weight`, and return the result.
110///
111/// Edge indices are removed in descending order to avoid the petgraph `Graph`
112/// index-swap issue: removing edge `i` moves the last edge into slot `i`, so
113/// removing from highest to lowest keeps all lower indices stable.
114fn weak_edge_removal(graph: &UnGraph<RecordId, f32>, min_weight: f32) -> UnGraph<RecordId, f32> {
115    let mut g = graph.clone();
116    let mut weak: Vec<_> = g
117        .edge_indices()
118        .filter(|&e| *g.edge_weight(e).unwrap() < min_weight)
119        .collect();
120    weak.sort_by_key(|e| std::cmp::Reverse(e.index()));
121    for e in weak {
122        g.remove_edge(e);
123    }
124    g
125}
126
127/// BFS-based connected-component extraction.
128///
129/// `petgraph::algo::connected_components()` returns only a count, this
130/// function also yields the actual groups.
131pub(crate) fn extract_components(graph: &UnGraph<RecordId, f32>) -> Vec<Vec<RecordId>> {
132    let mut visited   = HashSet::new();
133    let mut components = Vec::new();
134
135    for start in graph.node_indices() {
136        if !visited.insert(start) {
137            continue;
138        }
139        let mut comp  = vec![graph[start]];
140        let mut queue = VecDeque::from([start]);
141
142        while let Some(node) = queue.pop_front() {
143            for nb in graph.neighbors(node) {
144                if visited.insert(nb) {
145                    comp.push(graph[nb]);
146                    queue.push_back(nb);
147                }
148            }
149        }
150        components.push(comp);
151    }
152    components
153}
154
155/// Star pruning for a single oversized component.
156///
157/// Finds the hub (highest-degree node in the original graph restricted to
158/// `comp`), builds a sub-graph containing only hub-edges with weight ≥
159/// `min_weight`, and returns the resulting sub-components.
160fn star_prune(
161    graph:      &UnGraph<RecordId, f32>,
162    comp:       &[RecordId],
163    min_weight: f32,
164) -> Vec<Vec<RecordId>> {
165    let comp_set: HashSet<RecordId> = comp.iter().copied().collect();
166
167    // Identify node indices in the original graph for this component.
168    let node_indices: Vec<NodeIndex> = graph
169        .node_indices()
170        .filter(|&n| comp_set.contains(&graph[n]))
171        .collect();
172
173    // Find hub: node with most edges to other comp members with weight >= min.
174    let hub = node_indices.iter().max_by_key(|&&n| {
175        graph
176            .edges(n)
177            .filter(|e| {
178                let other = if e.source() == n { e.target() } else { e.source() };
179                comp_set.contains(&graph[other]) && *e.weight() >= min_weight
180            })
181            .count()
182    });
183
184    let Some(&hub_idx) = hub else {
185        return vec![];
186    };
187
188    // Build sub-graph: hub + its qualifying neighbors.
189    let mut sub: UnGraph<RecordId, f32> = UnGraph::new_undirected();
190    let mut sub_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
191
192    let hub_sub = sub.add_node(graph[hub_idx]);
193    sub_map.insert(hub_idx, hub_sub);
194
195    for edge in graph.edges(hub_idx) {
196        let other = if edge.source() == hub_idx { edge.target() } else { edge.source() };
197        if !comp_set.contains(&graph[other]) || *edge.weight() < min_weight {
198            continue;
199        }
200        let other_sub = *sub_map.entry(other).or_insert_with(|| sub.add_node(graph[other]));
201        sub.add_edge(hub_sub, other_sub, *edge.weight());
202    }
203
204    extract_components(&sub)
205}
206
207// ── Unit tests ────────────────────────────────────────────────────────────────
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
213    use zer_core::scoring::ScoredPair;
214
215    fn auto_match_pair(a: u64, b: u64, prob: f32) -> ScoredPair {
216        ScoredPair {
217            record_a:          a,
218            record_b:          b,
219            match_weight:      0.0,
220            match_probability: prob,
221            vector:            ComparisonVector { record_a: a, record_b: b, levels: vec![] },
222            band:              MatchBand::AutoMatch,
223        }
224    }
225
226    fn config() -> ClusterConfig {
227        ClusterConfig { max_cluster_size: 50, within_cluster_min: 0.85 }
228    }
229
230    #[test]
231    fn basic_connected_components() {
232        // A-B, B-C → one component of 3
233        let mut g = ClusterGraph::new();
234        g.add_pairs(&[auto_match_pair(1, 2, 0.95), auto_match_pair(2, 3, 0.95)]);
235        let clusters = g.compute_clusters(&config());
236        assert_eq!(clusters.len(), 1);
237        assert_eq!(clusters[0].len(), 3);
238    }
239
240    #[test]
241    fn single_pair_one_cluster() {
242        let mut g = ClusterGraph::new();
243        g.add_pairs(&[auto_match_pair(1, 2, 0.95)]);
244        let clusters = g.compute_clusters(&config());
245        assert_eq!(clusters.len(), 1);
246        assert_eq!(clusters[0].len(), 2);
247    }
248
249    #[test]
250    fn weak_bridge_splits_chain() {
251        // A -[0.95]- B -[0.28]- C -[0.95]- D
252        // with within_cluster_min = 0.85, the B-C edge is removed
253        // → {A,B} and {C,D}
254        let mut g = ClusterGraph::new();
255        g.add_pairs(&[
256            auto_match_pair(1, 2, 0.95), // A-B strong
257            auto_match_pair(2, 3, 0.28), // B-C weak bridge
258            auto_match_pair(3, 4, 0.95), // C-D strong
259        ]);
260        let mut clusters = g.compute_clusters(&config());
261        clusters.sort_by_key(|c| *c.iter().min().unwrap());
262        assert_eq!(clusters.len(), 2, "weak bridge must split chain into 2 clusters");
263        assert_eq!(clusters[0].len(), 2);
264        assert_eq!(clusters[1].len(), 2);
265
266        let mut c0 = clusters[0].clone(); c0.sort();
267        let mut c1 = clusters[1].clone(); c1.sort();
268        assert_eq!(c0, vec![1, 2]);
269        assert_eq!(c1, vec![3, 4]);
270    }
271
272    #[test]
273    fn star_pruning_splits_oversized_cluster() {
274        // Hub (id=0) connected to 60 satellites with prob 0.95.
275        // max_cluster_size = 50 → star pruning kicks in, yielding the hub+satellites
276        // as a valid cluster (star pruning keeps all hub-edges ≥ min_weight).
277        let cfg = ClusterConfig { max_cluster_size: 50, within_cluster_min: 0.85 };
278        let mut g = ClusterGraph::new();
279        let pairs: Vec<_> = (1u64..=60).map(|i| auto_match_pair(0, i, 0.95)).collect();
280        g.add_pairs(&pairs);
281
282        let clusters = g.compute_clusters(&cfg);
283        // After star pruning, the hub stays connected to all 60 neighbors
284        // (all edges >= 0.85), so we get one cluster of 61.
285        // The important thing is that oversized handling runs without panic.
286        assert!(!clusters.is_empty());
287        let total_members: usize = clusters.iter().map(|c| c.len()).sum();
288        assert!(total_members >= 2);
289    }
290
291    #[test]
292    fn two_disconnected_pairs_two_clusters() {
293        let mut g = ClusterGraph::new();
294        g.add_pairs(&[auto_match_pair(1, 2, 0.95), auto_match_pair(3, 4, 0.95)]);
295        let clusters = g.compute_clusters(&config());
296        assert_eq!(clusters.len(), 2);
297    }
298}