Skip to main content

anno/backends/w2ner/
decode.rs

1//! W2NER decoding algorithms.
2//!
3//! Standalone pure functions for decoding word-word relation grids into entity spans.
4//! These are separated from the W2NER model struct so they can be used independently
5//! with pre-computed grids (e.g., from external inference or testing).
6//!
7//! # Algorithm reference
8//! - arXiv:2112.10070 §3.3 (Li et al., "Unified Named Entity Recognition as Word-Word
9//!   Relation Classification", AAAI 2022)
10
11use crate::backends::inference::{HandshakingCell, HandshakingMatrix};
12use crate::EntityType;
13
14/// Decoded row from the discontinuous entity algorithm:
15/// `(entity_type_label, word_spans, score)` where each span is
16/// `(word_start, word_end_exclusive)`.
17pub type DiscontinuousDecodeRow = (String, Vec<(usize, usize)>, f64);
18
19/// W2NER word-word relation types.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub enum W2NERRelation {
22    /// Next-Neighboring-Word: tokens are adjacent in same entity.
23    NNW,
24    /// Tail-Head-Word: marks entity boundary (tail → head).
25    THW,
26    /// No relation between tokens.
27    None,
28}
29
30impl W2NERRelation {
31    /// Convert from label index (0=None, 1=NNW, 2=THW).
32    #[must_use]
33    pub fn from_index(idx: usize) -> Self {
34        match idx {
35            0 => Self::None,
36            1 => Self::NNW,
37            2 => Self::THW,
38            _ => Self::None,
39        }
40    }
41
42    /// Convert to label index.
43    #[must_use]
44    pub fn to_index(self) -> usize {
45        match self {
46            Self::None => 0,
47            Self::NNW => 1,
48            Self::THW => 2,
49        }
50    }
51}
52
53// =============================================================================
54// Decode algorithms
55// =============================================================================
56
57/// Decode contiguous entity spans from a handshaking matrix.
58///
59/// Finds all THW(tail, head) cells above `threshold`, sorts by start position,
60/// and optionally removes nested spans (outermost wins when `allow_nested` is false).
61///
62/// Returns `Vec<(word_start, word_end_exclusive, score)>`.
63#[must_use]
64pub fn decode_from_matrix(
65    matrix: &HandshakingMatrix,
66    tokens: &[&str],
67    entity_type_idx: usize,
68    threshold: f32,
69    allow_nested: bool,
70) -> Vec<(usize, usize, f64)> {
71    let mut entities = Vec::with_capacity(16);
72
73    for cell in &matrix.cells {
74        let relation = W2NERRelation::from_index(cell.label_idx as usize);
75        if relation == W2NERRelation::THW && cell.score >= threshold {
76            let tail = cell.i as usize;
77            let head = cell.j as usize;
78            if head <= tail && head < tokens.len() && tail < tokens.len() {
79                entities.push((head, tail + 1, cell.score as f64));
80            }
81        }
82    }
83
84    entities.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| (b.1 - b.0).cmp(&(a.1 - a.0))));
85
86    if !allow_nested {
87        entities = remove_nested(&entities);
88    }
89
90    let _ = entity_type_idx;
91    entities
92}
93
94/// Decode discontinuous entity spans using the full NNW+THW algorithm (§3.3).
95///
96/// THW cells identify entity boundaries; NNW cells identify adjacent-word connections
97/// within the same entity. Gaps in the NNW chain produce disjoint sub-spans.
98///
99/// `first_label` is used as the entity-type string when the model uses a single grid;
100/// pass an empty string to fall back to `"ENTITY"`.
101#[must_use]
102pub fn decode_discontinuous_from_matrix(
103    matrix: &HandshakingMatrix,
104    tokens: &[&str],
105    threshold: f32,
106    first_label: &str,
107) -> Vec<DiscontinuousDecodeRow> {
108    let n = tokens.len();
109
110    let mut entity_boundaries: Vec<(usize, usize, f64)> = Vec::new();
111    for cell in &matrix.cells {
112        if W2NERRelation::from_index(cell.label_idx as usize) == W2NERRelation::THW
113            && cell.score >= threshold
114        {
115            let tail = cell.i as usize;
116            let head = cell.j as usize;
117            if head <= tail && tail < n {
118                entity_boundaries.push((head, tail, cell.score as f64));
119            }
120        }
121    }
122
123    let mut nnw: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
124    for cell in &matrix.cells {
125        if W2NERRelation::from_index(cell.label_idx as usize) == W2NERRelation::NNW
126            && cell.score >= threshold
127        {
128            let a = cell.i as usize;
129            let b = cell.j as usize;
130            nnw.insert((a, b));
131            nnw.insert((b, a));
132        }
133    }
134
135    let mut results: Vec<DiscontinuousDecodeRow> = Vec::new();
136    let type_label = if first_label.is_empty() {
137        "ENTITY".to_string()
138    } else {
139        first_label.to_string()
140    };
141
142    for (head, tail, score) in entity_boundaries {
143        let mut segments: Vec<(usize, usize)> = Vec::new();
144        let mut seg_start = head;
145        for i in head..tail {
146            let j = i + 1;
147            if !nnw.contains(&(i, j)) {
148                segments.push((seg_start, i + 1));
149                seg_start = j;
150            }
151        }
152        segments.push((seg_start, tail + 1));
153        results.push((type_label.clone(), segments, score));
154    }
155
156    results.sort_unstable_by(|a, b| {
157        let a_start = a.1.first().map(|s| s.0).unwrap_or(usize::MAX);
158        let b_start = b.1.first().map(|s| s.0).unwrap_or(usize::MAX);
159        let a_len: usize = a.1.iter().map(|(s, e)| e - s).sum();
160        let b_len: usize = b.1.iter().map(|(s, e)| e - s).sum();
161        a_start.cmp(&b_start).then_with(|| b_len.cmp(&a_len))
162    });
163
164    results
165}
166
167/// Convert a dense `[seq_len × seq_len × num_relations]` grid to a sparse matrix.
168///
169/// Cells with `rel == 0` (None relation) or score below `threshold` are dropped.
170#[must_use]
171pub fn grid_to_matrix(
172    grid: &[f32],
173    seq_len: usize,
174    num_relations: usize,
175    threshold: f32,
176) -> HandshakingMatrix {
177    let mut cells = Vec::new();
178    for i in 0..seq_len {
179        for j in 0..seq_len {
180            for rel in 0..num_relations {
181                let idx = i * seq_len * num_relations + j * num_relations + rel;
182                if let Some(&score) = grid.get(idx) {
183                    if score >= threshold && rel > 0 {
184                        cells.push(HandshakingCell {
185                            i: i as u32,
186                            j: j as u32,
187                            label_idx: rel as u16,
188                            score,
189                        });
190                    }
191                }
192            }
193        }
194    }
195    HandshakingMatrix {
196        cells,
197        seq_len,
198        num_labels: num_relations,
199    }
200}
201
202/// Remove nested entities, keeping the outermost span at each position.
203pub(crate) fn remove_nested(entities: &[(usize, usize, f64)]) -> Vec<(usize, usize, f64)> {
204    let mut result = Vec::new();
205    let mut last_end = 0;
206    for &(start, end, score) in entities {
207        if start >= last_end {
208            result.push((start, end, score));
209            last_end = end;
210        }
211    }
212    result
213}
214
215/// Map a label string to the canonical `EntityType`.
216#[must_use]
217pub fn map_label_to_entity_type(label: &str) -> EntityType {
218    match label.to_uppercase().as_str() {
219        "PER" | "PERSON" => EntityType::Person,
220        "ORG" | "ORGANIZATION" => EntityType::Organization,
221        "LOC" | "LOCATION" | "GPE" => EntityType::Location,
222        "DATE" => EntityType::Date,
223        "TIME" => EntityType::Time,
224        "MONEY" => EntityType::Money,
225        "PERCENT" => EntityType::Percent,
226        "MISC" => EntityType::Other("MISC".to_string()),
227        _ => EntityType::Other(label.to_string()),
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::backends::inference::{HandshakingCell, HandshakingMatrix};
235
236    fn cell(i: u32, j: u32, rel: W2NERRelation, score: f32) -> HandshakingCell {
237        HandshakingCell {
238            i,
239            j,
240            label_idx: rel.to_index() as u16,
241            score,
242        }
243    }
244
245    fn mat(cells: Vec<HandshakingCell>, seq_len: usize) -> HandshakingMatrix {
246        HandshakingMatrix {
247            cells,
248            seq_len,
249            num_labels: 3,
250        }
251    }
252
253    #[test]
254    fn decode_single_contiguous_entity() {
255        // THW(tail=2, head=0) → entity spans words 0..=2
256        let tokens = ["New", "York", "City"];
257        let m = mat(vec![cell(2, 0, W2NERRelation::THW, 0.9)], 3);
258        let result = decode_from_matrix(&m, &tokens, 0, 0.5, true);
259        assert_eq!(result.len(), 1);
260        assert_eq!(result[0].0, 0); // word start
261        assert_eq!(result[0].1, 3); // word end (exclusive)
262    }
263
264    #[test]
265    fn decode_removes_nested_when_disabled() {
266        let tokens = ["The", "University", "of", "California"];
267        // outer: THW(3,0), inner: THW(3,1)
268        let m = mat(
269            vec![
270                cell(3, 0, W2NERRelation::THW, 0.8),
271                cell(3, 1, W2NERRelation::THW, 0.9),
272            ],
273            4,
274        );
275        let nested = decode_from_matrix(&m, &tokens, 0, 0.5, true);
276        assert_eq!(nested.len(), 2, "should keep both when nested=true");
277
278        let flat = decode_from_matrix(&m, &tokens, 0, 0.5, false);
279        assert_eq!(flat.len(), 1, "should keep only outer when nested=false");
280    }
281
282    #[test]
283    fn decode_discontinuous_splits_on_nnw_gap() {
284        // Entity: head=0, tail=3, but no NNW between words 1-2 → two segments
285        let tokens = ["severe", "pain", "in", "abdomen"];
286        let m = mat(
287            vec![
288                cell(3, 0, W2NERRelation::THW, 0.8),
289                cell(0, 1, W2NERRelation::NNW, 0.8),
290                // no NNW between 1-2
291                cell(2, 3, W2NERRelation::NNW, 0.8),
292            ],
293            4,
294        );
295        let result = decode_discontinuous_from_matrix(&m, &tokens, 0.5, "SYMPTOM");
296        assert_eq!(result.len(), 1);
297        let (label, spans, _score) = &result[0];
298        assert_eq!(label, "SYMPTOM");
299        assert_eq!(
300            spans.len(),
301            2,
302            "expected 2 disjoint segments; got {}",
303            spans.len()
304        );
305        assert_eq!(spans[0], (0, 2)); // words 0-1
306        assert_eq!(spans[1], (2, 4)); // words 2-3
307    }
308
309    #[test]
310    fn grid_to_matrix_filters_none_and_below_threshold() {
311        // 2×2×3 grid: only rel=2 (THW) at (0,1) with score 0.9 should survive
312        let mut grid = vec![0.0f32; 2 * 2 * 3];
313        grid[5] = 0.9; // (i=0,j=1,rel=2)
314        grid[4] = 0.2; // below threshold
315        let m = grid_to_matrix(&grid, 2, 3, 0.5);
316        assert_eq!(m.cells.len(), 1);
317        assert_eq!(m.cells[0].label_idx, 2);
318    }
319
320    #[test]
321    fn map_label_person_org_loc() {
322        use crate::EntityType;
323        assert_eq!(map_label_to_entity_type("PER"), EntityType::Person);
324        assert_eq!(map_label_to_entity_type("ORG"), EntityType::Organization);
325        assert_eq!(map_label_to_entity_type("GPE"), EntityType::Location);
326        assert!(matches!(
327            map_label_to_entity_type("CUSTOM"),
328            EntityType::Other(_)
329        ));
330    }
331}