1use zer_core::{
2 entity::{Entity, EntityId, EntityMember, ResolutionMethod},
3 record::RecordId,
4 scoring::{ModelParams, ScoredPair},
5 traits::Clusterer,
6};
7
8use crate::{
9 graph::{ClusterConfig, ClusterGraph},
10 threshold::partition_by_band,
11};
12
13pub struct ConnectedComponentsClusterer {
22 pub config: ClusterConfig,
23}
24
25impl Default for ConnectedComponentsClusterer {
26 fn default() -> Self {
27 Self { config: ClusterConfig::default() }
28 }
29}
30
31impl Clusterer for ConnectedComponentsClusterer {
32 fn cluster(&self, pairs: &[ScoredPair], params: &ModelParams) -> Vec<Entity> {
33 let banded = partition_by_band(pairs.to_vec(), params);
34
35 let mut graph = ClusterGraph::new();
36 graph.add_pairs(&banded.auto_match);
37
38 let components = graph.compute_clusters(&self.config);
39
40 components
41 .into_iter()
42 .enumerate()
43 .map(|(idx, members)| {
44 let entity_members = members
45 .iter()
46 .map(|&rid| EntityMember {
47 record_id: rid,
48 score: best_score_in_cluster(rid, &banded.auto_match),
49 method: ResolutionMethod::AutoMatch,
50 source: None,
51 })
52 .collect();
53
54 Entity {
55 id: idx as EntityId + 1,
58 members: entity_members,
59 }
60 })
61 .collect()
62 }
63}
64
65fn best_score_in_cluster(record_id: RecordId, pairs: &[ScoredPair]) -> f32 {
68 pairs
69 .iter()
70 .filter(|p| p.record_a == record_id || p.record_b == record_id)
71 .map(|p| p.match_probability)
72 .fold(0.0_f32, f32::max)
73}
74
75#[cfg(test)]
78mod tests {
79 use super::*;
80 use zer_core::{comparison::ComparisonVector, scoring::MatchBand};
81
82 fn params() -> ModelParams {
83 ModelParams {
84 m: vec![],
85 u: vec![],
86 log_prior_odds: 0.0,
87 upper_threshold: 0.8,
88 lower_threshold: 0.2,
89 }
90 }
91
92 fn pair(a: u64, b: u64, prob: f32, band: MatchBand) -> ScoredPair {
93 ScoredPair {
94 record_a: a,
95 record_b: b,
96 match_weight: 0.0,
97 match_probability: prob,
98 vector: ComparisonVector { record_a: a, record_b: b, levels: vec![] },
99 band,
100 }
101 }
102
103 #[test]
104 fn empty_pairs_returns_empty() {
105 let clusterer = ConnectedComponentsClusterer::default();
106 let entities = clusterer.cluster(&[], ¶ms());
107 assert!(entities.is_empty());
108 }
109
110 #[test]
111 fn two_matched_pairs_form_one_entity() {
112 let clusterer = ConnectedComponentsClusterer::default();
113 let pairs = vec![
114 pair(1, 2, 0.95, MatchBand::AutoMatch),
115 pair(2, 3, 0.95, MatchBand::AutoMatch),
116 ];
117 let entities = clusterer.cluster(&pairs, ¶ms());
118 assert_eq!(entities.len(), 1);
119 assert_eq!(entities[0].members.len(), 3);
120 }
121
122 #[test]
123 fn auto_rejected_pairs_ignored() {
124 let clusterer = ConnectedComponentsClusterer::default();
125 let pairs = vec![
126 pair(1, 2, 0.95, MatchBand::AutoMatch),
127 pair(3, 4, 0.05, MatchBand::AutoReject),
128 ];
129 let entities = clusterer.cluster(&pairs, ¶ms());
130 assert_eq!(entities.len(), 1);
131 let rids: Vec<_> = entities[0].members.iter().map(|m| m.record_id).collect();
132 assert!(rids.contains(&1));
133 assert!(rids.contains(&2));
134 assert!(!rids.contains(&3));
135 assert!(!rids.contains(&4));
136 }
137
138 #[test]
139 fn members_get_correct_scores() {
140 let clusterer = ConnectedComponentsClusterer::default();
141 let pairs = vec![
142 pair(1, 2, 0.92, MatchBand::AutoMatch),
143 pair(1, 3, 0.88, MatchBand::AutoMatch),
144 ];
145 let entities = clusterer.cluster(&pairs, ¶ms());
146 assert_eq!(entities.len(), 1);
147
148 let member_1 = entities[0].members.iter().find(|m| m.record_id == 1).unwrap();
149 assert!((member_1.score - 0.92).abs() < 1e-5, "record 1 best score is 0.92");
150 }
151}