1use crate::core::{Corpus, Identity, IdentityId, IdentitySource, TrackId, TrackRef};
60use std::collections::HashMap;
61
62#[derive(Debug, Clone)]
64pub struct Resolver {
65 similarity_threshold: f32,
66 require_type_match: bool,
67}
68
69impl Resolver {
70 pub fn new() -> Self {
72 Self {
73 similarity_threshold: 0.7,
74 require_type_match: true,
75 }
76 }
77
78 pub fn with_threshold(mut self, threshold: f32) -> Self {
80 self.similarity_threshold = threshold;
81 self
82 }
83
84 pub fn require_type_match(mut self, require: bool) -> Self {
86 self.require_type_match = require;
87 self
88 }
89
90 pub fn resolve_inter_doc_coref(
113 &self,
114 corpus: &mut Corpus,
115 similarity_threshold: Option<f32>,
116 require_type_match: Option<bool>,
117 ) -> Vec<IdentityId> {
118 let threshold = similarity_threshold.unwrap_or(self.similarity_threshold);
119 let type_match = require_type_match.unwrap_or(self.require_type_match);
120
121 #[derive(Debug, Clone)]
123 struct TrackData {
124 track_ref: TrackRef,
125 canonical_surface: String,
126 entity_type: Option<crate::TypeLabel>,
127 cluster_confidence: f32,
128 embedding: Option<Vec<f32>>,
129 }
130
131 let mut track_data: Vec<TrackData> = Vec::new();
132 let doc_ids: Vec<String> = corpus.documents().map(|d| d.id.clone()).collect();
134 for doc_id in doc_ids {
135 if let Some(doc) = corpus.get_document(&doc_id) {
136 for track in doc.tracks() {
137 if let Some(track_ref) = doc.track_ref(track.id) {
138 track_data.push(TrackData {
139 track_ref,
140 canonical_surface: track.canonical_surface.clone(),
141 entity_type: track.entity_type.clone(),
142 cluster_confidence: track.cluster_confidence,
143 embedding: track.embedding.clone(),
144 });
145 }
146 }
147 }
148 }
149
150 if track_data.is_empty() {
151 return vec![];
152 }
153
154 let mut union_find: Vec<usize> = (0..track_data.len()).collect();
157
158 fn find(parent: &mut [usize], i: usize) -> usize {
159 if parent[i] != i {
160 parent[i] = find(parent, parent[i]);
161 }
162 parent[i]
163 }
164
165 fn union(parent: &mut [usize], i: usize, j: usize) {
166 let pi = find(parent, i);
167 let pj = find(parent, j);
168 if pi != pj {
169 parent[pi] = pj;
170 }
171 }
172
173 for i in 0..track_data.len() {
175 for j in (i + 1)..track_data.len() {
176 let track_a = &track_data[i];
177 let track_b = &track_data[j];
178
179 if type_match && track_a.entity_type != track_b.entity_type {
181 continue;
182 }
183
184 let similarity =
188 if let (Some(emb_a), Some(emb_b)) = (&track_a.embedding, &track_b.embedding) {
189 embedding_similarity(emb_a, emb_b)
191 } else {
192 string_similarity(&track_a.canonical_surface, &track_b.canonical_surface)
195 };
196
197 if similarity >= threshold {
198 union(&mut union_find, i, j);
199 }
200 }
201 }
202
203 let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
205 for i in 0..track_data.len() {
206 let root = find(&mut union_find, i);
207 cluster_map.entry(root).or_default().push(i);
208 }
209
210 let mut created_ids = Vec::new();
214 for (_, member_indices) in cluster_map.iter() {
215 if member_indices.is_empty() {
216 continue;
217 }
218
219 let first_idx = member_indices[0];
221 let first_track = &track_data[first_idx];
222
223 let track_refs_in_cluster: Vec<TrackRef> = member_indices
225 .iter()
226 .map(|&idx| track_data[idx].track_ref.clone())
227 .collect();
228
229 let identity = Identity {
231 id: corpus.next_identity_id(), canonical_name: first_track.canonical_surface.clone(),
233 entity_type: first_track.entity_type.clone(),
234 kb_id: None,
235 kb_name: None,
236 description: None,
237 embedding: first_track.embedding.clone(),
238 aliases: Vec::new(),
239 confidence: first_track.cluster_confidence,
240 source: Some(IdentitySource::CrossDocCoref {
241 track_refs: track_refs_in_cluster,
242 }),
243 };
244
245 let identity_id = corpus.add_identity(identity);
246 created_ids.push(identity_id);
247
248 let links: Vec<(String, TrackId)> = member_indices
251 .iter()
252 .map(|&idx| {
253 let track_ref = &track_data[idx].track_ref;
254 (track_ref.doc_id.clone(), track_ref.track_id)
255 })
256 .collect();
257
258 for (doc_id, track_id) in links {
259 if let Some(doc) = corpus.get_document_mut(&doc_id) {
260 doc.link_track_to_identity(track_id, identity_id);
261 } else {
262 log::warn!(
265 "Document '{}' not found when linking track {} to identity {}",
266 doc_id,
267 track_id,
268 identity_id
269 );
270 }
271 }
272 }
273
274 created_ids
275 }
276}
277
278impl Default for Resolver {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284pub fn string_similarity(a: &str, b: &str) -> f32 {
300 fn normalize_word(w: &str) -> String {
302 let lower = w.to_lowercase();
303 lower
304 .trim_end_matches("'s")
305 .trim_end_matches("'s")
306 .trim_end_matches('\'')
307 .to_string()
308 }
309
310 let words_a: std::collections::HashSet<String> =
312 a.split_whitespace().map(normalize_word).collect();
313 let words_b: std::collections::HashSet<String> =
314 b.split_whitespace().map(normalize_word).collect();
315
316 if words_a.is_empty() && words_b.is_empty() {
317 return 1.0;
318 }
319 if words_a.is_empty() || words_b.is_empty() {
320 return 0.0;
321 }
322
323 let intersection = words_a.intersection(&words_b).count();
324 let union = words_a.union(&words_b).count();
325
326 if union == 0 {
327 0.0
328 } else {
329 intersection as f32 / union as f32
330 }
331}
332
333pub fn embedding_similarity(emb_a: &[f32], emb_b: &[f32]) -> f32 {
351 if emb_a.len() != emb_b.len() || emb_a.is_empty() {
352 return 0.0;
353 }
354
355 let dot_product: f32 = emb_a.iter().zip(emb_b.iter()).map(|(a, b)| a * b).sum();
357 let norm_a: f32 = emb_a.iter().map(|a| a * a).sum::<f32>().sqrt();
358 let norm_b: f32 = emb_b.iter().map(|b| b * b).sum::<f32>().sqrt();
359
360 if norm_a == 0.0 || norm_b == 0.0 {
361 return 0.0;
362 }
363
364 (dot_product / (norm_a * norm_b) + 1.0) / 2.0
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_string_similarity_identical() {
374 assert_eq!(string_similarity("hello world", "hello world"), 1.0);
375 }
376
377 #[test]
378 fn test_string_similarity_partial() {
379 let sim = string_similarity("hello world", "hello");
380 assert!(sim > 0.0 && sim < 1.0);
381 assert!((sim - 0.5).abs() < 0.001);
384 }
385
386 #[test]
387 fn test_string_similarity_empty() {
388 assert_eq!(string_similarity("", ""), 1.0);
389 assert_eq!(string_similarity("hello", ""), 0.0);
390 assert_eq!(string_similarity("", "hello"), 0.0);
391 }
392
393 #[test]
394 fn test_string_similarity_symmetric() {
395 let sim_ab = string_similarity("hello world", "world peace");
396 let sim_ba = string_similarity("world peace", "hello world");
397 assert_eq!(sim_ab, sim_ba);
398 }
399
400 #[test]
401 fn test_embedding_similarity_identical() {
402 let emb = vec![1.0, 0.0, 0.0];
403 assert_eq!(embedding_similarity(&emb, &emb), 1.0);
404 }
405
406 #[test]
407 fn test_embedding_similarity_orthogonal() {
408 let emb1 = vec![1.0, 0.0];
409 let emb2 = vec![0.0, 1.0];
410 let sim = embedding_similarity(&emb1, &emb2);
412 assert!((sim - 0.5).abs() < 0.001);
413 }
414
415 #[test]
416 fn test_embedding_similarity_opposite() {
417 let emb1 = vec![1.0, 0.0];
418 let emb2 = vec![-1.0, 0.0];
419 let sim = embedding_similarity(&emb1, &emb2);
421 assert!((sim - 0.0).abs() < 0.001);
422 }
423
424 #[test]
425 fn test_embedding_similarity_mismatched_length() {
426 let emb1 = vec![1.0, 0.0];
427 let emb2 = vec![1.0, 0.0, 0.0];
428 assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
429 }
430
431 #[test]
432 fn test_embedding_similarity_empty() {
433 let emb1: Vec<f32> = vec![];
434 let emb2: Vec<f32> = vec![];
435 assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
436 }
437
438 #[test]
439 fn test_embedding_similarity_zero_norm() {
440 let emb1 = vec![0.0, 0.0];
441 let emb2 = vec![1.0, 0.0];
442 assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
443 }
444
445 #[test]
446 fn test_resolver_builder() {
447 let resolver = Resolver::new()
448 .with_threshold(0.8)
449 .require_type_match(false); assert_eq!(resolver.similarity_threshold, 0.8);
452 assert!(!resolver.require_type_match);
453 }
454
455 #[test]
456 fn test_resolver_default() {
457 let resolver = Resolver::default();
458 assert_eq!(resolver.similarity_threshold, 0.7);
459 assert!(resolver.require_type_match); }
461}
462
463#[cfg(test)]
464mod proptests {
465 use super::*;
466 use proptest::prelude::*;
467
468 proptest! {
469 #![proptest_config(ProptestConfig::with_cases(100))]
470
471 #[test]
473 fn string_sim_bounded(a in ".*", b in ".*") {
474 let sim = string_similarity(&a, &b);
475 prop_assert!((0.0..=1.0).contains(&sim));
476 }
477
478 #[test]
480 fn string_sim_symmetric(a in "[a-z ]{0,30}", b in "[a-z ]{0,30}") {
481 let sim_ab = string_similarity(&a, &b);
482 let sim_ba = string_similarity(&b, &a);
483 prop_assert!((sim_ab - sim_ba).abs() < 0.0001);
484 }
485
486 #[test]
488 fn string_sim_reflexive(s in "[a-z]{1,20}") {
489 let sim = string_similarity(&s, &s);
490 prop_assert!((sim - 1.0).abs() < 0.0001);
491 }
492
493 #[test]
495 fn embedding_sim_bounded(dim in 1usize..50, seed in any::<u64>()) {
496 let mut rng = seed;
497 let emb1: Vec<f32> = (0..dim).map(|_| {
498 rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
499 (rng % 2000) as f32 / 1000.0 - 1.0
500 }).collect();
501 let emb2: Vec<f32> = (0..dim).map(|_| {
502 rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
503 (rng % 2000) as f32 / 1000.0 - 1.0
504 }).collect();
505
506 let sim = embedding_similarity(&emb1, &emb2);
507 prop_assert!((0.0..=1.0).contains(&sim),
508 "Embedding similarity out of bounds: {}", sim);
509 }
510
511 #[test]
513 fn embedding_sim_symmetric(dim in 1usize..20, seed in any::<u64>()) {
514 let mut rng = seed;
515 let emb1: Vec<f32> = (0..dim).map(|_| {
516 rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
517 (rng % 100) as f32 / 100.0
518 }).collect();
519 let emb2: Vec<f32> = (0..dim).map(|_| {
520 rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
521 (rng % 100) as f32 / 100.0
522 }).collect();
523
524 let sim_ab = embedding_similarity(&emb1, &emb2);
525 let sim_ba = embedding_similarity(&emb2, &emb1);
526 prop_assert!((sim_ab - sim_ba).abs() < 0.0001,
527 "Embedding similarity not symmetric: {} vs {}", sim_ab, sim_ba);
528 }
529 }
530}