Skip to main content

anno/backends/inference/
coref.rs

1//! Coreference resolution utilities and data structures.
2
3use crate::Entity;
4use std::collections::HashMap;
5
6// Coreference Resolution
7// =============================================================================
8
9/// A coreference cluster (mentions referring to same entity).
10#[derive(Debug, Clone)]
11pub struct CoreferenceCluster {
12    /// Cluster ID
13    pub id: u64,
14    /// Member entities (indices into entity list)
15    pub members: Vec<usize>,
16    /// Representative entity index (most informative mention)
17    pub representative: usize,
18    /// Canonical name (from representative)
19    pub canonical_name: String,
20}
21
22/// Configuration for coreference resolution.
23#[derive(Debug, Clone)]
24pub struct CoreferenceConfig {
25    /// Minimum cosine similarity to link mentions
26    pub similarity_threshold: f32,
27    /// Maximum token distance between coreferent mentions
28    pub max_distance: Option<usize>,
29    /// Whether to use exact string matching as a signal
30    pub use_string_match: bool,
31}
32
33impl Default for CoreferenceConfig {
34    fn default() -> Self {
35        Self {
36            similarity_threshold: 0.85,
37            max_distance: Some(500),
38            use_string_match: true,
39        }
40    }
41}
42
43/// Resolve coreferences between entities using embedding similarity.
44///
45/// # Algorithm
46///
47/// 1. Compute pairwise cosine similarity between entity embeddings
48/// 2. Link entities above threshold (with optional distance constraint)
49/// 3. Build clusters via transitive closure
50/// 4. Select representative (longest/most informative mention)
51///
52/// # Example
53///
54/// Input entities: ["Lynn Conway", "She", "The engineer", "Conway"]
55/// Output clusters: [{0, 1, 2, 3}] with canonical_name = "Lynn Conway"
56pub fn resolve_coreferences(
57    entities: &[Entity],
58    embeddings: &[f32], // [num_entities, hidden_dim]
59    hidden_dim: usize,
60    config: &CoreferenceConfig,
61) -> Vec<CoreferenceCluster> {
62    let n = entities.len();
63    if n == 0 {
64        return vec![];
65    }
66
67    // Union-find for clustering
68    let mut parent: Vec<usize> = (0..n).collect();
69
70    fn find(parent: &mut [usize], i: usize) -> usize {
71        if parent[i] != i {
72            parent[i] = find(parent, parent[i]);
73        }
74        parent[i]
75    }
76
77    fn union(parent: &mut [usize], i: usize, j: usize) {
78        let pi = find(parent, i);
79        let pj = find(parent, j);
80        if pi != pj {
81            parent[pi] = pj;
82        }
83    }
84
85    // Check all pairs
86    for i in 0..n {
87        for j in (i + 1)..n {
88            // String match check (fast path)
89            if config.use_string_match {
90                let text_i = entities[i].text.to_lowercase();
91                let text_j = entities[j].text.to_lowercase();
92                if text_i == text_j || text_i.contains(&text_j) || text_j.contains(&text_i) {
93                    // Same entity type required
94                    if entities[i].entity_type == entities[j].entity_type {
95                        union(&mut parent, i, j);
96                        continue;
97                    }
98                }
99            }
100
101            // Distance check
102            if let Some(max_dist) = config.max_distance {
103                let dist = if entities[i].end <= entities[j].start {
104                    entities[j].start - entities[i].end
105                } else {
106                    entities[i].start.saturating_sub(entities[j].end)
107                };
108                if dist > max_dist {
109                    continue;
110                }
111            }
112
113            // Embedding similarity
114            if embeddings.len() >= (j + 1) * hidden_dim {
115                let emb_i = &embeddings[i * hidden_dim..(i + 1) * hidden_dim];
116                let emb_j = &embeddings[j * hidden_dim..(j + 1) * hidden_dim];
117
118                let similarity = cosine_similarity(emb_i, emb_j);
119
120                if similarity >= config.similarity_threshold {
121                    // Same entity type required
122                    if entities[i].entity_type == entities[j].entity_type {
123                        union(&mut parent, i, j);
124                    }
125                }
126            }
127        }
128    }
129
130    // Build clusters
131    let mut cluster_members: HashMap<usize, Vec<usize>> = HashMap::new();
132    for i in 0..n {
133        let root = find(&mut parent, i);
134        cluster_members.entry(root).or_default().push(i);
135    }
136
137    // Convert to CoreferenceCluster
138    let mut clusters = Vec::new();
139    let mut cluster_id = 0u64;
140
141    for (_root, members) in cluster_members {
142        if members.len() > 1 {
143            // Find representative (longest mention)
144            let representative = *members
145                .iter()
146                .max_by_key(|&&i| entities[i].text.len())
147                .unwrap_or(&members[0]);
148
149            clusters.push(CoreferenceCluster {
150                id: cluster_id,
151                members,
152                representative,
153                canonical_name: entities[representative].text.clone(),
154            });
155            cluster_id += 1;
156        }
157    }
158
159    clusters
160}
161
162/// Compute cosine similarity between two vectors.
163///
164/// Returns a value in [-1.0, 1.0] where:
165/// - 1.0 = identical direction
166/// - 0.0 = orthogonal
167/// - -1.0 = opposite direction
168pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
169    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
170    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
171    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
172
173    if norm_a > 0.0 && norm_b > 0.0 {
174        dot / (norm_a * norm_b)
175    } else {
176        0.0
177    }
178}