Skip to main content

anno_core/coalesce/
resolver.rs

1//! # Batch Entity Resolution with Union-Find
2//!
3//! This module provides the core batch entity resolution algorithm using the
4//! **disjoint-set (union-find)** data structure with path compression.
5//!
6//! ## Algorithm Overview
7//!
8//! 1. **Collect** all tracks from all documents in the corpus
9//! 2. **Compare** pairs: compute similarity (embedding cosine or string Jaccard)
10//! 3. **Cluster** using union-find: if \( \text{sim}(t_i, t_j) \geq \theta \), merge
11//! 4. **Create** identities: one per cluster, linked to constituent tracks
12//!
13//! ## Complexity Analysis
14//!
15//! Let \( n \) = number of tracks, \( m \) = number of merges performed.
16//!
17//! - **Pairwise comparison:** \( O(n^2) \) — the bottleneck for large corpora
18//! - **Union-find operations:** \( O(m \cdot \alpha(n)) \) where \( \alpha \) is
19//!   the inverse Ackermann function
20//!
21//! For all practical \( n \), \( \alpha(n) \leq 4 \), so union-find is effectively
22//! \( O(m) \). The overall complexity is dominated by pairwise comparison.
23//!
24//! ## The Inverse Ackermann Function
25//!
26//! The Ackermann function \( A(m, n) \) grows faster than any primitive recursive
27//! function. Its inverse \( \alpha(n) \) grows so slowly that:
28//!
29//! - \( \alpha(10^{80}) \leq 4 \) (more atoms than in the observable universe)
30//!
31//! Tarjan (1975) proved this bound is tight for union-find with path compression
32//! and union-by-rank.
33//!
34//! ## Similarity Metrics
35//!
36//! Two metrics are provided:
37//!
38//! - **Embedding similarity** (cosine): \( \cos(\mathbf{a}, \mathbf{b}) = \frac{\mathbf{a} \cdot \mathbf{b}}{||\mathbf{a}|| \cdot ||\mathbf{b}||} \)
39//! - **String similarity** (Jaccard on words): \( J(A, B) = \frac{|A \cap B|}{|A \cup B|} \)
40//!
41//! If both tracks have embeddings, cosine is used; otherwise, Jaccard on word sets.
42//!
43//! ## Example
44//!
45//! ```
46//! use anno_core::coalesce::Resolver;
47//! use anno_core::Corpus;
48//!
49//! let resolver = Resolver::new()
50//!     .with_threshold(0.7)
51//!     .require_type_match(true);
52//!
53//! let mut corpus = Corpus::new();
54//! // ... add documents with tracks ...
55//!
56//! let identities = resolver.resolve_inter_doc_coref(&mut corpus, None, None);
57//! ```
58
59use crate::core::{Corpus, Identity, IdentityId, IdentitySource, TrackId, TrackRef};
60use std::collections::HashMap;
61
62/// Coalescer for inter-document entity resolution.
63#[derive(Debug, Clone)]
64pub struct Resolver {
65    similarity_threshold: f32,
66    require_type_match: bool,
67}
68
69impl Resolver {
70    /// Create a new resolver with default settings.
71    pub fn new() -> Self {
72        Self {
73            similarity_threshold: 0.7,
74            require_type_match: true,
75        }
76    }
77
78    /// Create a new resolver with custom settings.
79    pub fn with_threshold(mut self, threshold: f32) -> Self {
80        self.similarity_threshold = threshold;
81        self
82    }
83
84    /// Set whether to require entity type match for clustering.
85    pub fn require_type_match(mut self, require: bool) -> Self {
86        self.require_type_match = require;
87        self
88    }
89
90    /// Coalesce inter-document entities across all documents in a corpus.
91    ///
92    /// This method clusters tracks from different documents that refer to the same
93    /// real-world entity, creating `Identity` instances without KB links.
94    ///
95    /// # Algorithm
96    ///
97    /// 1. Extract all tracks from all documents
98    /// 2. Compute track embeddings (if available) or use string similarity
99    /// 3. Cluster tracks using similarity threshold
100    /// 4. Create Identity for each cluster
101    /// 5. Link tracks to identities
102    ///
103    /// # Parameters
104    ///
105    /// * `corpus` - The corpus containing documents to resolve
106    /// * `similarity_threshold` - Minimum similarity (0.0-1.0) to cluster tracks
107    /// * `require_type_match` - Only cluster tracks with same entity type
108    ///
109    /// # Returns
110    ///
111    /// Vector of created identities, each linked to tracks from multiple documents.
112    pub fn resolve_inter_doc_coref(
113        &self,
114        corpus: &mut Corpus,
115        similarity_threshold: Option<f32>,
116        require_type_match: Option<bool>,
117    ) -> Vec<IdentityId> {
118        let threshold = similarity_threshold.unwrap_or(self.similarity_threshold);
119        let type_match = require_type_match.unwrap_or(self.require_type_match);
120
121        // 1. Collect all track data (clone what we need to avoid borrow conflicts)
122        #[derive(Debug, Clone)]
123        struct TrackData {
124            track_ref: TrackRef,
125            canonical_surface: String,
126            entity_type: Option<crate::TypeLabel>,
127            cluster_confidence: f32,
128            embedding: Option<Vec<f32>>,
129        }
130
131        let mut track_data: Vec<TrackData> = Vec::new();
132        // Collect document IDs first to avoid borrow checker issues
133        let doc_ids: Vec<String> = corpus.documents().map(|d| d.id.clone()).collect();
134        for doc_id in doc_ids {
135            if let Some(doc) = corpus.get_document(&doc_id) {
136                for track in doc.tracks() {
137                    if let Some(track_ref) = doc.track_ref(track.id) {
138                        track_data.push(TrackData {
139                            track_ref,
140                            canonical_surface: track.canonical_surface.clone(),
141                            entity_type: track.entity_type.clone(),
142                            cluster_confidence: track.cluster_confidence,
143                            embedding: track.embedding.clone(),
144                        });
145                    }
146                }
147            }
148        }
149
150        if track_data.is_empty() {
151            return vec![];
152        }
153
154        // 2. Cluster tracks using string similarity or embeddings
155        // Uses embeddings if available (from track.embedding), otherwise falls back to string similarity
156        let mut union_find: Vec<usize> = (0..track_data.len()).collect();
157
158        fn find(parent: &mut [usize], i: usize) -> usize {
159            if parent[i] != i {
160                parent[i] = find(parent, parent[i]);
161            }
162            parent[i]
163        }
164
165        fn union(parent: &mut [usize], i: usize, j: usize) {
166            let pi = find(parent, i);
167            let pj = find(parent, j);
168            if pi != pj {
169                parent[pi] = pj;
170            }
171        }
172
173        // Compare all pairs
174        for i in 0..track_data.len() {
175            for j in (i + 1)..track_data.len() {
176                let track_a = &track_data[i];
177                let track_b = &track_data[j];
178
179                // Type check
180                if type_match && track_a.entity_type != track_b.entity_type {
181                    continue;
182                }
183
184                // Compute similarity: prefer embeddings if BOTH available, fallback to string similarity
185                // Edge case: If only one track has an embedding, we can't compare embeddings directly,
186                // so we fall back to string similarity for consistency.
187                let similarity =
188                    if let (Some(emb_a), Some(emb_b)) = (&track_a.embedding, &track_b.embedding) {
189                        // Both have embeddings: use cosine similarity
190                        embedding_similarity(emb_a, emb_b)
191                    } else {
192                        // One or both missing embeddings: fallback to string similarity
193                        // This handles: (Some, None), (None, Some), (None, None)
194                        string_similarity(&track_a.canonical_surface, &track_b.canonical_surface)
195                    };
196
197                if similarity >= threshold {
198                    union(&mut union_find, i, j);
199                }
200            }
201        }
202
203        // 3. Build clusters
204        let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
205        for i in 0..track_data.len() {
206            let root = find(&mut union_find, i);
207            cluster_map.entry(root).or_default().push(i);
208        }
209
210        // 4. Create identities for each cluster
211        // Note: Singleton clusters (clusters with only one track) still create identities.
212        // This allows tracking entities that appear only once across documents.
213        let mut created_ids = Vec::new();
214        for (_, member_indices) in cluster_map.iter() {
215            if member_indices.is_empty() {
216                continue;
217            }
218
219            // Safe: we just checked is_empty() above, so member_indices[0] is valid
220            let first_idx = member_indices[0];
221            let first_track = &track_data[first_idx];
222
223            // Collect all track refs in this cluster
224            let track_refs_in_cluster: Vec<TrackRef> = member_indices
225                .iter()
226                .map(|&idx| track_data[idx].track_ref.clone())
227                .collect();
228
229            // Create identity
230            let identity = Identity {
231                id: corpus.next_identity_id(), // Will be set by add_identity
232                canonical_name: first_track.canonical_surface.clone(),
233                entity_type: first_track.entity_type.clone(),
234                kb_id: None,
235                kb_name: None,
236                description: None,
237                embedding: first_track.embedding.clone(),
238                aliases: Vec::new(),
239                confidence: first_track.cluster_confidence,
240                source: Some(IdentitySource::CrossDocCoref {
241                    track_refs: track_refs_in_cluster,
242                }),
243            };
244
245            let identity_id = corpus.add_identity(identity);
246            created_ids.push(identity_id);
247
248            // 5. Link tracks to identity
249            // Collect doc_id and track_id pairs first to avoid borrow conflicts
250            let links: Vec<(String, TrackId)> = member_indices
251                .iter()
252                .map(|&idx| {
253                    let track_ref = &track_data[idx].track_ref;
254                    (track_ref.doc_id.clone(), track_ref.track_id)
255                })
256                .collect();
257
258            for (doc_id, track_id) in links {
259                if let Some(doc) = corpus.get_document_mut(&doc_id) {
260                    doc.link_track_to_identity(track_id, identity_id);
261                } else {
262                    // Document was removed or doesn't exist - this is a data consistency issue
263                    // Log warning but continue with other tracks
264                    log::warn!(
265                        "Document '{}' not found when linking track {} to identity {}",
266                        doc_id,
267                        track_id,
268                        identity_id
269                    );
270                }
271            }
272        }
273
274        created_ids
275    }
276}
277
278impl Default for Resolver {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284/// Compute string similarity using Jaccard similarity on word sets.
285///
286/// Returns a value in [0.0, 1.0] where 1.0 is identical.
287///
288/// # Example
289///
290/// ```rust
291/// use anno_core::coalesce::resolver::string_similarity;
292///
293/// let sim = string_similarity("Lynn Conway", "Lynn Conway");
294/// assert_eq!(sim, 1.0);
295///
296/// let sim = string_similarity("Lynn Conway", "Conway");
297/// assert!(sim > 0.0); // "Conway" shares one word with "Lynn Conway"
298/// ```
299pub fn string_similarity(a: &str, b: &str) -> f32 {
300    // Normalize words: lowercase, strip possessives
301    fn normalize_word(w: &str) -> String {
302        let lower = w.to_lowercase();
303        lower
304            .trim_end_matches("'s")
305            .trim_end_matches("'s")
306            .trim_end_matches('\'')
307            .to_string()
308    }
309
310    // Simple Jaccard similarity on normalized word sets
311    let words_a: std::collections::HashSet<String> =
312        a.split_whitespace().map(normalize_word).collect();
313    let words_b: std::collections::HashSet<String> =
314        b.split_whitespace().map(normalize_word).collect();
315
316    if words_a.is_empty() && words_b.is_empty() {
317        return 1.0;
318    }
319    if words_a.is_empty() || words_b.is_empty() {
320        return 0.0;
321    }
322
323    let intersection = words_a.intersection(&words_b).count();
324    let union = words_a.union(&words_b).count();
325
326    if union == 0 {
327        0.0
328    } else {
329        intersection as f32 / union as f32
330    }
331}
332
333/// Compute embedding similarity using cosine similarity.
334///
335/// Returns a value in [0.0, 1.0] where 1.0 is identical.
336///
337/// Formula: `cosine(a, b) = (a · b) / (||a|| × ||b||)`, normalized to [0, 1].
338/// Measures angle between vectors, not magnitude, making it suitable for embeddings.
339///
340/// # Example
341///
342/// ```rust
343/// use anno_core::coalesce::resolver::embedding_similarity;
344///
345/// let emb1 = vec![1.0, 0.0, 0.0];
346/// let emb2 = vec![1.0, 0.0, 0.0];
347/// let sim = embedding_similarity(&emb1, &emb2);
348/// assert_eq!(sim, 1.0);
349/// ```
350pub fn embedding_similarity(emb_a: &[f32], emb_b: &[f32]) -> f32 {
351    if emb_a.len() != emb_b.len() || emb_a.is_empty() {
352        return 0.0;
353    }
354
355    // Cosine similarity
356    let dot_product: f32 = emb_a.iter().zip(emb_b.iter()).map(|(a, b)| a * b).sum();
357    let norm_a: f32 = emb_a.iter().map(|a| a * a).sum::<f32>().sqrt();
358    let norm_b: f32 = emb_b.iter().map(|b| b * b).sum::<f32>().sqrt();
359
360    if norm_a == 0.0 || norm_b == 0.0 {
361        return 0.0;
362    }
363
364    // Normalize to [0, 1] range (cosine similarity is [-1, 1])
365    (dot_product / (norm_a * norm_b) + 1.0) / 2.0
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_string_similarity_identical() {
374        assert_eq!(string_similarity("hello world", "hello world"), 1.0);
375    }
376
377    #[test]
378    fn test_string_similarity_partial() {
379        let sim = string_similarity("hello world", "hello");
380        assert!(sim > 0.0 && sim < 1.0);
381        // "hello" is 1 word, "hello world" has 2 words
382        // intersection = 1, union = 2, sim = 0.5
383        assert!((sim - 0.5).abs() < 0.001);
384    }
385
386    #[test]
387    fn test_string_similarity_empty() {
388        assert_eq!(string_similarity("", ""), 1.0);
389        assert_eq!(string_similarity("hello", ""), 0.0);
390        assert_eq!(string_similarity("", "hello"), 0.0);
391    }
392
393    #[test]
394    fn test_string_similarity_symmetric() {
395        let sim_ab = string_similarity("hello world", "world peace");
396        let sim_ba = string_similarity("world peace", "hello world");
397        assert_eq!(sim_ab, sim_ba);
398    }
399
400    #[test]
401    fn test_embedding_similarity_identical() {
402        let emb = vec![1.0, 0.0, 0.0];
403        assert_eq!(embedding_similarity(&emb, &emb), 1.0);
404    }
405
406    #[test]
407    fn test_embedding_similarity_orthogonal() {
408        let emb1 = vec![1.0, 0.0];
409        let emb2 = vec![0.0, 1.0];
410        // Orthogonal vectors have cosine = 0, normalized to 0.5
411        let sim = embedding_similarity(&emb1, &emb2);
412        assert!((sim - 0.5).abs() < 0.001);
413    }
414
415    #[test]
416    fn test_embedding_similarity_opposite() {
417        let emb1 = vec![1.0, 0.0];
418        let emb2 = vec![-1.0, 0.0];
419        // Opposite vectors have cosine = -1, normalized to 0.0
420        let sim = embedding_similarity(&emb1, &emb2);
421        assert!((sim - 0.0).abs() < 0.001);
422    }
423
424    #[test]
425    fn test_embedding_similarity_mismatched_length() {
426        let emb1 = vec![1.0, 0.0];
427        let emb2 = vec![1.0, 0.0, 0.0];
428        assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
429    }
430
431    #[test]
432    fn test_embedding_similarity_empty() {
433        let emb1: Vec<f32> = vec![];
434        let emb2: Vec<f32> = vec![];
435        assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
436    }
437
438    #[test]
439    fn test_embedding_similarity_zero_norm() {
440        let emb1 = vec![0.0, 0.0];
441        let emb2 = vec![1.0, 0.0];
442        assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
443    }
444
445    #[test]
446    fn test_resolver_builder() {
447        let resolver = Resolver::new()
448            .with_threshold(0.8)
449            .require_type_match(false); // Default is true, test setting to false
450
451        assert_eq!(resolver.similarity_threshold, 0.8);
452        assert!(!resolver.require_type_match);
453    }
454
455    #[test]
456    fn test_resolver_default() {
457        let resolver = Resolver::default();
458        assert_eq!(resolver.similarity_threshold, 0.7);
459        assert!(resolver.require_type_match); // Default is true
460    }
461}
462
463#[cfg(test)]
464mod proptests {
465    use super::*;
466    use proptest::prelude::*;
467
468    proptest! {
469        #![proptest_config(ProptestConfig::with_cases(100))]
470
471        /// String similarity is always in [0, 1]
472        #[test]
473        fn string_sim_bounded(a in ".*", b in ".*") {
474            let sim = string_similarity(&a, &b);
475            prop_assert!((0.0..=1.0).contains(&sim));
476        }
477
478        /// String similarity is symmetric
479        #[test]
480        fn string_sim_symmetric(a in "[a-z ]{0,30}", b in "[a-z ]{0,30}") {
481            let sim_ab = string_similarity(&a, &b);
482            let sim_ba = string_similarity(&b, &a);
483            prop_assert!((sim_ab - sim_ba).abs() < 0.0001);
484        }
485
486        /// String similarity is reflexive (identical strings = 1.0)
487        #[test]
488        fn string_sim_reflexive(s in "[a-z]{1,20}") {
489            let sim = string_similarity(&s, &s);
490            prop_assert!((sim - 1.0).abs() < 0.0001);
491        }
492
493        /// Embedding similarity is bounded [0, 1] for our normalization
494        #[test]
495        fn embedding_sim_bounded(dim in 1usize..50, seed in any::<u64>()) {
496            let mut rng = seed;
497            let emb1: Vec<f32> = (0..dim).map(|_| {
498                rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
499                (rng % 2000) as f32 / 1000.0 - 1.0
500            }).collect();
501            let emb2: Vec<f32> = (0..dim).map(|_| {
502                rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
503                (rng % 2000) as f32 / 1000.0 - 1.0
504            }).collect();
505
506            let sim = embedding_similarity(&emb1, &emb2);
507            prop_assert!((0.0..=1.0).contains(&sim),
508                "Embedding similarity out of bounds: {}", sim);
509        }
510
511        /// Embedding similarity is symmetric
512        #[test]
513        fn embedding_sim_symmetric(dim in 1usize..20, seed in any::<u64>()) {
514            let mut rng = seed;
515            let emb1: Vec<f32> = (0..dim).map(|_| {
516                rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
517                (rng % 100) as f32 / 100.0
518            }).collect();
519            let emb2: Vec<f32> = (0..dim).map(|_| {
520                rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
521                (rng % 100) as f32 / 100.0
522            }).collect();
523
524            let sim_ab = embedding_similarity(&emb1, &emb2);
525            let sim_ba = embedding_similarity(&emb2, &emb1);
526            prop_assert!((sim_ab - sim_ba).abs() < 0.0001,
527                "Embedding similarity not symmetric: {} vs {}", sim_ab, sim_ba);
528        }
529    }
530}