use crate::core::{Corpus, Identity, IdentityId, IdentitySource, TrackId, TrackRef};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Resolver {
similarity_threshold: f32,
require_type_match: bool,
}
impl Resolver {
pub fn new() -> Self {
Self {
similarity_threshold: 0.7,
require_type_match: true,
}
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.similarity_threshold = threshold;
self
}
pub fn require_type_match(mut self, require: bool) -> Self {
self.require_type_match = require;
self
}
pub fn resolve_inter_doc_coref(
&self,
corpus: &mut Corpus,
similarity_threshold: Option<f32>,
require_type_match: Option<bool>,
) -> Vec<IdentityId> {
let threshold = similarity_threshold.unwrap_or(self.similarity_threshold);
let type_match = require_type_match.unwrap_or(self.require_type_match);
#[derive(Debug, Clone)]
struct TrackData {
track_ref: TrackRef,
canonical_surface: String,
entity_type: Option<crate::TypeLabel>,
cluster_confidence: crate::Confidence,
embedding: Option<Vec<f32>>,
}
let mut track_data: Vec<TrackData> = Vec::new();
let doc_ids: Vec<String> = corpus.documents().map(|d| d.id().to_owned()).collect();
for doc_id in doc_ids {
if let Some(doc) = corpus.get_document(&doc_id) {
for track in doc.tracks() {
if let Some(track_ref) = doc.track_ref(track.id) {
track_data.push(TrackData {
track_ref,
canonical_surface: track.canonical_surface.clone(),
entity_type: track.entity_type.clone(),
cluster_confidence: track.cluster_confidence,
embedding: track.embedding.clone(),
});
}
}
}
}
if track_data.is_empty() {
return vec![];
}
let mut union_find: Vec<usize> = (0..track_data.len()).collect();
let mut rank: Vec<usize> = vec![0; track_data.len()];
fn find(parent: &mut [usize], i: usize) -> usize {
if parent[i] != i {
parent[i] = find(parent, parent[i]);
}
parent[i]
}
fn union(parent: &mut [usize], rank: &mut [usize], i: usize, j: usize) {
let pi = find(parent, i);
let pj = find(parent, j);
if pi != pj {
match rank[pi].cmp(&rank[pj]) {
std::cmp::Ordering::Less => parent[pi] = pj,
std::cmp::Ordering::Greater => parent[pj] = pi,
std::cmp::Ordering::Equal => {
parent[pi] = pj;
rank[pj] += 1;
}
}
}
}
for i in 0..track_data.len() {
for j in (i + 1)..track_data.len() {
let track_a = &track_data[i];
let track_b = &track_data[j];
if type_match && track_a.entity_type != track_b.entity_type {
continue;
}
let similarity =
if let (Some(emb_a), Some(emb_b)) = (&track_a.embedding, &track_b.embedding) {
embedding_similarity(emb_a, emb_b)
} else {
string_similarity(&track_a.canonical_surface, &track_b.canonical_surface)
};
if similarity >= threshold {
union(&mut union_find, &mut rank, i, j);
}
}
}
let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..track_data.len() {
let root = find(&mut union_find, i);
cluster_map.entry(root).or_default().push(i);
}
let mut created_ids = Vec::new();
for (_, member_indices) in cluster_map.iter() {
if member_indices.is_empty() {
continue;
}
let first_idx = member_indices[0];
let first_track = &track_data[first_idx];
let track_refs_in_cluster: Vec<TrackRef> = member_indices
.iter()
.map(|&idx| track_data[idx].track_ref.clone())
.collect();
let identity = Identity {
id: corpus.next_identity_id(), canonical_name: first_track.canonical_surface.clone(),
entity_type: first_track.entity_type.clone(),
kb_id: None,
kb_name: None,
description: None,
embedding: first_track.embedding.clone(),
aliases: Vec::new(),
confidence: first_track.cluster_confidence,
source: Some(IdentitySource::CrossDocCoref {
track_refs: track_refs_in_cluster,
}),
};
let identity_id = corpus.add_identity(identity);
created_ids.push(identity_id);
let links: Vec<(String, TrackId)> = member_indices
.iter()
.map(|&idx| {
let track_ref = &track_data[idx].track_ref;
(track_ref.doc_id.clone(), track_ref.track_id)
})
.collect();
for (doc_id, track_id) in links {
if let Some(doc) = corpus.get_document_mut(&doc_id) {
doc.link_track_to_identity(track_id, identity_id);
} else {
log::warn!(
"Document '{}' not found when linking track {} to identity {}",
doc_id,
track_id,
identity_id
);
}
}
}
created_ids
}
}
impl Default for Resolver {
fn default() -> Self {
Self::new()
}
}
pub fn string_similarity(a: &str, b: &str) -> f32 {
fn normalize_word(w: &str) -> String {
let lower = w.to_lowercase();
lower
.trim_end_matches("'s")
.trim_end_matches("'s")
.trim_end_matches('\'')
.to_string()
}
let words_a: std::collections::HashSet<String> =
a.split_whitespace().map(normalize_word).collect();
let words_b: std::collections::HashSet<String> =
b.split_whitespace().map(normalize_word).collect();
if words_a.is_empty() && words_b.is_empty() {
return 1.0;
}
if words_a.is_empty() || words_b.is_empty() {
return 0.0;
}
let intersection = words_a.intersection(&words_b).count();
let union = words_a.union(&words_b).count();
if union == 0 {
0.0
} else {
intersection as f32 / union as f32
}
}
pub fn embedding_similarity(emb_a: &[f32], emb_b: &[f32]) -> f32 {
if emb_a.len() != emb_b.len() || emb_a.is_empty() {
return 0.0;
}
let dot_product: f32 = emb_a.iter().zip(emb_b.iter()).map(|(a, b)| a * b).sum();
let norm_a: f32 = emb_a.iter().map(|a| a * a).sum::<f32>().sqrt();
let norm_b: f32 = emb_b.iter().map(|b| b * b).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot_product / (norm_a * norm_b) + 1.0) / 2.0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_string_similarity_identical() {
assert_eq!(string_similarity("hello world", "hello world"), 1.0);
}
#[test]
fn test_string_similarity_partial() {
let sim = string_similarity("hello world", "hello");
assert!(sim > 0.0 && sim < 1.0);
assert!((sim - 0.5).abs() < 0.001);
}
#[test]
fn test_string_similarity_empty() {
assert_eq!(string_similarity("", ""), 1.0);
assert_eq!(string_similarity("hello", ""), 0.0);
assert_eq!(string_similarity("", "hello"), 0.0);
}
#[test]
fn test_string_similarity_symmetric() {
let sim_ab = string_similarity("hello world", "world peace");
let sim_ba = string_similarity("world peace", "hello world");
assert_eq!(sim_ab, sim_ba);
}
#[test]
fn test_embedding_similarity_identical() {
let emb = vec![1.0, 0.0, 0.0];
assert_eq!(embedding_similarity(&emb, &emb), 1.0);
}
#[test]
fn test_embedding_similarity_orthogonal() {
let emb1 = vec![1.0, 0.0];
let emb2 = vec![0.0, 1.0];
let sim = embedding_similarity(&emb1, &emb2);
assert!((sim - 0.5).abs() < 0.001);
}
#[test]
fn test_embedding_similarity_opposite() {
let emb1 = vec![1.0, 0.0];
let emb2 = vec![-1.0, 0.0];
let sim = embedding_similarity(&emb1, &emb2);
assert!((sim - 0.0).abs() < 0.001);
}
#[test]
fn test_embedding_similarity_mismatched_length() {
let emb1 = vec![1.0, 0.0];
let emb2 = vec![1.0, 0.0, 0.0];
assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
}
#[test]
fn test_embedding_similarity_empty() {
let emb1: Vec<f32> = vec![];
let emb2: Vec<f32> = vec![];
assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
}
#[test]
fn test_embedding_similarity_zero_norm() {
let emb1 = vec![0.0, 0.0];
let emb2 = vec![1.0, 0.0];
assert_eq!(embedding_similarity(&emb1, &emb2), 0.0);
}
#[test]
fn test_resolver_builder() {
let resolver = Resolver::new()
.with_threshold(0.8)
.require_type_match(false);
assert_eq!(resolver.similarity_threshold, 0.8);
assert!(!resolver.require_type_match);
}
#[test]
fn test_resolver_default() {
let resolver = Resolver::default();
assert_eq!(resolver.similarity_threshold, 0.7);
assert!(resolver.require_type_match); }
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn string_sim_bounded(a in ".*", b in ".*") {
let sim = string_similarity(&a, &b);
prop_assert!((0.0..=1.0).contains(&sim));
}
#[test]
fn string_sim_symmetric(a in "[a-z ]{0,30}", b in "[a-z ]{0,30}") {
let sim_ab = string_similarity(&a, &b);
let sim_ba = string_similarity(&b, &a);
prop_assert!((sim_ab - sim_ba).abs() < 0.0001);
}
#[test]
fn string_sim_reflexive(s in "[a-z]{1,20}") {
let sim = string_similarity(&s, &s);
prop_assert!((sim - 1.0).abs() < 0.0001);
}
#[test]
fn embedding_sim_bounded(dim in 1usize..50, seed in any::<u64>()) {
let mut rng = seed;
let emb1: Vec<f32> = (0..dim).map(|_| {
rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
(rng % 2000) as f32 / 1000.0 - 1.0
}).collect();
let emb2: Vec<f32> = (0..dim).map(|_| {
rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
(rng % 2000) as f32 / 1000.0 - 1.0
}).collect();
let sim = embedding_similarity(&emb1, &emb2);
prop_assert!((0.0..=1.0).contains(&sim),
"Embedding similarity out of bounds: {}", sim);
}
#[test]
fn embedding_sim_symmetric(dim in 1usize..20, seed in any::<u64>()) {
let mut rng = seed;
let emb1: Vec<f32> = (0..dim).map(|_| {
rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
(rng % 100) as f32 / 100.0
}).collect();
let emb2: Vec<f32> = (0..dim).map(|_| {
rng = rng.wrapping_mul(1103515245).wrapping_add(12345);
(rng % 100) as f32 / 100.0
}).collect();
let sim_ab = embedding_similarity(&emb1, &emb2);
let sim_ba = embedding_similarity(&emb2, &emb1);
prop_assert!((sim_ab - sim_ba).abs() < 0.0001,
"Embedding similarity not symmetric: {} vs {}", sim_ab, sim_ba);
}
}
}