1use crate::backends::inference::{HandshakingCell, HandshakingMatrix};
12use crate::EntityType;
13
14pub type DiscontinuousDecodeRow = (String, Vec<(usize, usize)>, f64);
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum W2NERRelation {
22 NNW,
24 THW,
26 None,
28}
29
30impl W2NERRelation {
31 #[must_use]
33 pub fn from_index(idx: usize) -> Self {
34 match idx {
35 0 => Self::None,
36 1 => Self::NNW,
37 2 => Self::THW,
38 _ => Self::None,
39 }
40 }
41
42 #[must_use]
44 pub fn to_index(self) -> usize {
45 match self {
46 Self::None => 0,
47 Self::NNW => 1,
48 Self::THW => 2,
49 }
50 }
51}
52
53#[must_use]
64pub fn decode_from_matrix(
65 matrix: &HandshakingMatrix,
66 tokens: &[&str],
67 entity_type_idx: usize,
68 threshold: f32,
69 allow_nested: bool,
70) -> Vec<(usize, usize, f64)> {
71 let mut entities = Vec::with_capacity(16);
72
73 for cell in &matrix.cells {
74 let relation = W2NERRelation::from_index(cell.label_idx as usize);
75 if relation == W2NERRelation::THW && cell.score >= threshold {
76 let tail = cell.i as usize;
77 let head = cell.j as usize;
78 if head <= tail && head < tokens.len() && tail < tokens.len() {
79 entities.push((head, tail + 1, cell.score as f64));
80 }
81 }
82 }
83
84 entities.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| (b.1 - b.0).cmp(&(a.1 - a.0))));
85
86 if !allow_nested {
87 entities = remove_nested(&entities);
88 }
89
90 let _ = entity_type_idx;
91 entities
92}
93
94#[must_use]
102pub fn decode_discontinuous_from_matrix(
103 matrix: &HandshakingMatrix,
104 tokens: &[&str],
105 threshold: f32,
106 first_label: &str,
107) -> Vec<DiscontinuousDecodeRow> {
108 let n = tokens.len();
109
110 let mut entity_boundaries: Vec<(usize, usize, f64)> = Vec::new();
111 for cell in &matrix.cells {
112 if W2NERRelation::from_index(cell.label_idx as usize) == W2NERRelation::THW
113 && cell.score >= threshold
114 {
115 let tail = cell.i as usize;
116 let head = cell.j as usize;
117 if head <= tail && tail < n {
118 entity_boundaries.push((head, tail, cell.score as f64));
119 }
120 }
121 }
122
123 let mut nnw: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
124 for cell in &matrix.cells {
125 if W2NERRelation::from_index(cell.label_idx as usize) == W2NERRelation::NNW
126 && cell.score >= threshold
127 {
128 let a = cell.i as usize;
129 let b = cell.j as usize;
130 nnw.insert((a, b));
131 nnw.insert((b, a));
132 }
133 }
134
135 let mut results: Vec<DiscontinuousDecodeRow> = Vec::new();
136 let type_label = if first_label.is_empty() {
137 "ENTITY".to_string()
138 } else {
139 first_label.to_string()
140 };
141
142 for (head, tail, score) in entity_boundaries {
143 let mut segments: Vec<(usize, usize)> = Vec::new();
144 let mut seg_start = head;
145 for i in head..tail {
146 let j = i + 1;
147 if !nnw.contains(&(i, j)) {
148 segments.push((seg_start, i + 1));
149 seg_start = j;
150 }
151 }
152 segments.push((seg_start, tail + 1));
153 results.push((type_label.clone(), segments, score));
154 }
155
156 results.sort_unstable_by(|a, b| {
157 let a_start = a.1.first().map(|s| s.0).unwrap_or(usize::MAX);
158 let b_start = b.1.first().map(|s| s.0).unwrap_or(usize::MAX);
159 let a_len: usize = a.1.iter().map(|(s, e)| e - s).sum();
160 let b_len: usize = b.1.iter().map(|(s, e)| e - s).sum();
161 a_start.cmp(&b_start).then_with(|| b_len.cmp(&a_len))
162 });
163
164 results
165}
166
167#[must_use]
171pub fn grid_to_matrix(
172 grid: &[f32],
173 seq_len: usize,
174 num_relations: usize,
175 threshold: f32,
176) -> HandshakingMatrix {
177 let mut cells = Vec::new();
178 for i in 0..seq_len {
179 for j in 0..seq_len {
180 for rel in 0..num_relations {
181 let idx = i * seq_len * num_relations + j * num_relations + rel;
182 if let Some(&score) = grid.get(idx) {
183 if score >= threshold && rel > 0 {
184 cells.push(HandshakingCell {
185 i: i as u32,
186 j: j as u32,
187 label_idx: rel as u16,
188 score,
189 });
190 }
191 }
192 }
193 }
194 }
195 HandshakingMatrix {
196 cells,
197 seq_len,
198 num_labels: num_relations,
199 }
200}
201
202pub(crate) fn remove_nested(entities: &[(usize, usize, f64)]) -> Vec<(usize, usize, f64)> {
204 let mut result = Vec::new();
205 let mut last_end = 0;
206 for &(start, end, score) in entities {
207 if start >= last_end {
208 result.push((start, end, score));
209 last_end = end;
210 }
211 }
212 result
213}
214
215#[must_use]
217pub fn map_label_to_entity_type(label: &str) -> EntityType {
218 match label.to_uppercase().as_str() {
219 "PER" | "PERSON" => EntityType::Person,
220 "ORG" | "ORGANIZATION" => EntityType::Organization,
221 "LOC" | "LOCATION" | "GPE" => EntityType::Location,
222 "DATE" => EntityType::Date,
223 "TIME" => EntityType::Time,
224 "MONEY" => EntityType::Money,
225 "PERCENT" => EntityType::Percent,
226 "MISC" => EntityType::Other("MISC".to_string()),
227 _ => EntityType::Other(label.to_string()),
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::backends::inference::{HandshakingCell, HandshakingMatrix};
235
236 fn cell(i: u32, j: u32, rel: W2NERRelation, score: f32) -> HandshakingCell {
237 HandshakingCell {
238 i,
239 j,
240 label_idx: rel.to_index() as u16,
241 score,
242 }
243 }
244
245 fn mat(cells: Vec<HandshakingCell>, seq_len: usize) -> HandshakingMatrix {
246 HandshakingMatrix {
247 cells,
248 seq_len,
249 num_labels: 3,
250 }
251 }
252
253 #[test]
254 fn decode_single_contiguous_entity() {
255 let tokens = ["New", "York", "City"];
257 let m = mat(vec![cell(2, 0, W2NERRelation::THW, 0.9)], 3);
258 let result = decode_from_matrix(&m, &tokens, 0, 0.5, true);
259 assert_eq!(result.len(), 1);
260 assert_eq!(result[0].0, 0); assert_eq!(result[0].1, 3); }
263
264 #[test]
265 fn decode_removes_nested_when_disabled() {
266 let tokens = ["The", "University", "of", "California"];
267 let m = mat(
269 vec![
270 cell(3, 0, W2NERRelation::THW, 0.8),
271 cell(3, 1, W2NERRelation::THW, 0.9),
272 ],
273 4,
274 );
275 let nested = decode_from_matrix(&m, &tokens, 0, 0.5, true);
276 assert_eq!(nested.len(), 2, "should keep both when nested=true");
277
278 let flat = decode_from_matrix(&m, &tokens, 0, 0.5, false);
279 assert_eq!(flat.len(), 1, "should keep only outer when nested=false");
280 }
281
282 #[test]
283 fn decode_discontinuous_splits_on_nnw_gap() {
284 let tokens = ["severe", "pain", "in", "abdomen"];
286 let m = mat(
287 vec![
288 cell(3, 0, W2NERRelation::THW, 0.8),
289 cell(0, 1, W2NERRelation::NNW, 0.8),
290 cell(2, 3, W2NERRelation::NNW, 0.8),
292 ],
293 4,
294 );
295 let result = decode_discontinuous_from_matrix(&m, &tokens, 0.5, "SYMPTOM");
296 assert_eq!(result.len(), 1);
297 let (label, spans, _score) = &result[0];
298 assert_eq!(label, "SYMPTOM");
299 assert_eq!(
300 spans.len(),
301 2,
302 "expected 2 disjoint segments; got {}",
303 spans.len()
304 );
305 assert_eq!(spans[0], (0, 2)); assert_eq!(spans[1], (2, 4)); }
308
309 #[test]
310 fn grid_to_matrix_filters_none_and_below_threshold() {
311 let mut grid = vec![0.0f32; 2 * 2 * 3];
313 grid[5] = 0.9; grid[4] = 0.2; let m = grid_to_matrix(&grid, 2, 3, 0.5);
316 assert_eq!(m.cells.len(), 1);
317 assert_eq!(m.cells[0].label_idx, 2);
318 }
319
320 #[test]
321 fn map_label_person_org_loc() {
322 use crate::EntityType;
323 assert_eq!(map_label_to_entity_type("PER"), EntityType::Person);
324 assert_eq!(map_label_to_entity_type("ORG"), EntityType::Organization);
325 assert_eq!(map_label_to_entity_type("GPE"), EntityType::Location);
326 assert!(matches!(
327 map_label_to_entity_type("CUSTOM"),
328 EntityType::Other(_)
329 ));
330 }
331}