anno/backends/inference/
coref.rs1use crate::Entity;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
11pub struct CoreferenceCluster {
12 pub id: u64,
14 pub members: Vec<usize>,
16 pub representative: usize,
18 pub canonical_name: String,
20}
21
22#[derive(Debug, Clone)]
24pub struct CoreferenceConfig {
25 pub similarity_threshold: f32,
27 pub max_distance: Option<usize>,
29 pub use_string_match: bool,
31}
32
33impl Default for CoreferenceConfig {
34 fn default() -> Self {
35 Self {
36 similarity_threshold: 0.85,
37 max_distance: Some(500),
38 use_string_match: true,
39 }
40 }
41}
42
43pub fn resolve_coreferences(
57 entities: &[Entity],
58 embeddings: &[f32], hidden_dim: usize,
60 config: &CoreferenceConfig,
61) -> Vec<CoreferenceCluster> {
62 let n = entities.len();
63 if n == 0 {
64 return vec![];
65 }
66
67 let mut parent: Vec<usize> = (0..n).collect();
69
70 fn find(parent: &mut [usize], i: usize) -> usize {
71 if parent[i] != i {
72 parent[i] = find(parent, parent[i]);
73 }
74 parent[i]
75 }
76
77 fn union(parent: &mut [usize], i: usize, j: usize) {
78 let pi = find(parent, i);
79 let pj = find(parent, j);
80 if pi != pj {
81 parent[pi] = pj;
82 }
83 }
84
85 for i in 0..n {
87 for j in (i + 1)..n {
88 if config.use_string_match {
90 let text_i = entities[i].text.to_lowercase();
91 let text_j = entities[j].text.to_lowercase();
92 if text_i == text_j || text_i.contains(&text_j) || text_j.contains(&text_i) {
93 if entities[i].entity_type == entities[j].entity_type {
95 union(&mut parent, i, j);
96 continue;
97 }
98 }
99 }
100
101 if let Some(max_dist) = config.max_distance {
103 let dist = if entities[i].end <= entities[j].start {
104 entities[j].start - entities[i].end
105 } else {
106 entities[i].start.saturating_sub(entities[j].end)
107 };
108 if dist > max_dist {
109 continue;
110 }
111 }
112
113 if embeddings.len() >= (j + 1) * hidden_dim {
115 let emb_i = &embeddings[i * hidden_dim..(i + 1) * hidden_dim];
116 let emb_j = &embeddings[j * hidden_dim..(j + 1) * hidden_dim];
117
118 let similarity = cosine_similarity(emb_i, emb_j);
119
120 if similarity >= config.similarity_threshold {
121 if entities[i].entity_type == entities[j].entity_type {
123 union(&mut parent, i, j);
124 }
125 }
126 }
127 }
128 }
129
130 let mut cluster_members: HashMap<usize, Vec<usize>> = HashMap::new();
132 for i in 0..n {
133 let root = find(&mut parent, i);
134 cluster_members.entry(root).or_default().push(i);
135 }
136
137 let mut clusters = Vec::new();
139 let mut cluster_id = 0u64;
140
141 for (_root, members) in cluster_members {
142 if members.len() > 1 {
143 let representative = *members
145 .iter()
146 .max_by_key(|&&i| entities[i].text.len())
147 .unwrap_or(&members[0]);
148
149 clusters.push(CoreferenceCluster {
150 id: cluster_id,
151 members,
152 representative,
153 canonical_name: entities[representative].text.clone(),
154 });
155 cluster_id += 1;
156 }
157 }
158
159 clusters
160}
161
162pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
169 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
170 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
171 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
172
173 if norm_a > 0.0 && norm_b > 0.0 {
174 dot / (norm_a * norm_b)
175 } else {
176 0.0
177 }
178}