Skip to main content

anno/backends/inference/
span.rs

1//! Span representation types and the handshaking matrix for joint extraction.
2//!
3//! - `Span`, `SpanCandidate`, `SpanWindow`: character-offset span types
4//! - `HandshakingMatrix`, `HandshakingCell`: sparse grid for W2NER / TPLinker
5
6use super::registry::{LabelCategory, LabelDefinition, SemanticRegistry};
7use anno_core::{RaggedBatch, SpanCandidate};
8
9// Span Representation
10// =============================================================================
11
12/// Configuration for span representation.
13///
14/// # Research Context (Deep Span Representations, arXiv:2210.04182)
15///
16/// From "Deep Span Representations for NER":
17/// > "Existing span-based NER systems **shallowly aggregate** the token
18/// > representations to span representations. However, this typically results
19/// > in significant ineffectiveness for **long-span entities**."
20///
21/// Common span representation strategies:
22///
23/// | Method | Formula | Pros | Cons |
24/// |--------|---------|------|------|
25/// | Concat | [h_i; h_j] | Simple, fast | Ignores middle tokens |
26/// | Pooling | mean(h_i:h_j) | Uses all tokens | Loses boundary info |
27/// | Attention | attn(h_i:h_j) | Learnable | Expensive |
28/// | GLiNER | FFN([h_i; h_j; w]) | Balanced | Requires width emb |
29///
30/// # Recommendation (GLiNER Default)
31///
32/// For most use cases, concatenating first + last token embeddings with
33/// a width embedding provides the best tradeoff:
34/// - O(N) complexity (vs O(N²) for all-pairs attention)
35/// - Captures boundary positions (critical for NER)
36/// - Width embedding disambiguates "I" vs "New York City"
37#[derive(Debug, Clone)]
38pub struct SpanRepConfig {
39    /// Hidden dimension of the encoder
40    pub hidden_dim: usize,
41    /// Maximum span width (in tokens)
42    ///
43    /// GLiNER uses K=12: "to keep linear complexity without harming recall."
44    /// Wider spans rarely contain coherent entities.
45    pub max_width: usize,
46    /// Whether to include width embeddings
47    ///
48    /// Critical for distinguishing spans of different lengths
49    /// with similar boundary tokens.
50    pub use_width_embeddings: bool,
51    /// Width embedding dimension (typically hidden_dim / 4)
52    pub width_emb_dim: usize,
53}
54
55impl Default for SpanRepConfig {
56    fn default() -> Self {
57        Self {
58            hidden_dim: 768,
59            max_width: 12,
60            use_width_embeddings: true,
61            width_emb_dim: 192, // 768 / 4
62        }
63    }
64}
65
66/// Computes span representations from token embeddings.
67///
68/// # Research Alignment (GLiNER, NAACL 2024)
69///
70/// From the GLiNER paper (arXiv:2311.08526):
71/// > "The representation of a span starting at position i and ending at
72/// > position j in the input text, S_ij ∈ R^D, is computed as:
73/// > **S_ij = FFN(h_i ⊗ h_j)**
74/// > where FFN denotes a two-layer feedforward network, and ⊗ represents
75/// > the concatenation operation."
76///
77/// The paper also notes:
78/// > "We set an upper bound to the length (K=12) of the span in order to
79/// > keep linear complexity in the size of the input text, without harming recall."
80///
81/// # Span Representation Formula
82///
83/// ```text
84/// span_emb = FFN(Concat(token[i], token[j], width_emb[j-i]))
85///          = W_2 · ReLU(W_1 · [h_i; h_j; w_{j-i}] + b_1) + b_2
86/// ```
87///
88/// where:
89/// - h_i = start token embedding
90/// - h_j = end token embedding
91/// - w_{j-i} = learned width embedding (captures span length)
92///
93/// This is the "gnarly bit" from GLiNER that enables zero-shot matching.
94///
95/// # Alternative: Global Pointer (arXiv:2208.03054)
96///
97/// Instead of enumerating spans, Global Pointer uses RoPE (rotary position
98/// embeddings) to predict (start, end) pairs simultaneously:
99///
100/// ```text
101/// score(i, j) = q_i^T * k_j    (where q, k have RoPE applied)
102/// ```
103///
104/// Advantages:
105/// - No explicit span enumeration needed
106/// - Naturally handles nested entities
107/// - More parameter-efficient
108///
109/// GLiNER-style enumeration is still preferred for zero-shot because
110/// it allows pre-computing label embeddings.
111pub struct SpanRepresentationLayer {
112    /// Configuration
113    pub config: SpanRepConfig,
114    /// Projection weights: [input_dim, hidden_dim]
115    pub projection_weights: Vec<f32>,
116    /// Projection bias: \[hidden_dim\]
117    pub projection_bias: Vec<f32>,
118    /// Width embeddings: [max_width, width_emb_dim]
119    pub width_embeddings: Vec<f32>,
120}
121
122impl SpanRepresentationLayer {
123    /// Create a new span representation layer with random initialization.
124    pub fn new(config: SpanRepConfig) -> Self {
125        let input_dim = config.hidden_dim * 2 + config.width_emb_dim;
126
127        Self {
128            projection_weights: vec![0.0f32; input_dim * config.hidden_dim],
129            projection_bias: vec![0.0f32; config.hidden_dim],
130            width_embeddings: vec![0.0f32; config.max_width * config.width_emb_dim],
131            config,
132        }
133    }
134
135    /// Compute span representations from token embeddings.
136    ///
137    /// # Arguments
138    /// * `token_embeddings` - Flattened [num_tokens, hidden_dim]
139    /// * `candidates` - Span candidates with start/end indices
140    ///
141    /// # Returns
142    /// Span embeddings: [num_candidates, hidden_dim]
143    pub fn forward(
144        &self,
145        token_embeddings: &[f32],
146        candidates: &[SpanCandidate],
147        batch: &RaggedBatch,
148    ) -> Vec<f32> {
149        let hidden_dim = self.config.hidden_dim;
150        let width_emb_dim = self.config.width_emb_dim;
151        let max_width = self.config.max_width;
152
153        // Check for overflow in allocation
154        let total_elements = match candidates.len().checked_mul(hidden_dim) {
155            Some(v) => v,
156            None => {
157                log::warn!(
158                    "Span embedding allocation overflow: {} candidates * {} hidden_dim, returning empty",
159                    candidates.len(), hidden_dim
160                );
161                return vec![];
162            }
163        };
164        let mut span_embeddings = vec![0.0f32; total_elements];
165
166        for (span_idx, candidate) in candidates.iter().enumerate() {
167            // Get document token range
168            let doc_range = match batch.doc_range(candidate.doc_idx as usize) {
169                Some(r) => r,
170                None => continue,
171            };
172
173            // Validate span before computing global indices
174            if candidate.end <= candidate.start {
175                log::warn!(
176                    "Invalid span candidate: end ({}) <= start ({})",
177                    candidate.end,
178                    candidate.start
179                );
180                continue;
181            }
182
183            // Global token indices
184            let start_global = doc_range.start + candidate.start as usize;
185            let end_global = doc_range.start + (candidate.end as usize) - 1; // Safe now that we validated
186
187            // Bounds check - must ensure both start and end slices fit
188            // Use checked arithmetic to prevent overflow
189            let start_byte = match start_global.checked_mul(hidden_dim) {
190                Some(v) => v,
191                None => {
192                    log::warn!(
193                        "Token index overflow: start_global={} * hidden_dim={}",
194                        start_global,
195                        hidden_dim
196                    );
197                    continue;
198                }
199            };
200            let start_end_byte = match (start_global + 1).checked_mul(hidden_dim) {
201                Some(v) => v,
202                None => {
203                    log::warn!(
204                        "Token index overflow: (start_global+1)={} * hidden_dim={}",
205                        start_global + 1,
206                        hidden_dim
207                    );
208                    continue;
209                }
210            };
211            let end_byte = match end_global.checked_mul(hidden_dim) {
212                Some(v) => v,
213                None => {
214                    log::warn!(
215                        "Token index overflow: end_global={} * hidden_dim={}",
216                        end_global,
217                        hidden_dim
218                    );
219                    continue;
220                }
221            };
222            let end_end_byte = match (end_global + 1).checked_mul(hidden_dim) {
223                Some(v) => v,
224                None => {
225                    log::warn!(
226                        "Token index overflow: (end_global+1)={} * hidden_dim={}",
227                        end_global + 1,
228                        hidden_dim
229                    );
230                    continue;
231                }
232            };
233
234            if start_byte >= token_embeddings.len()
235                || start_end_byte > token_embeddings.len()
236                || end_byte >= token_embeddings.len()
237                || end_end_byte > token_embeddings.len()
238            {
239                continue;
240            }
241
242            // Get start and end token embeddings
243            let start_emb = &token_embeddings[start_byte..start_end_byte];
244            let end_emb = &token_embeddings[end_byte..end_end_byte];
245
246            // Optional width embedding (index = span_len - 1).
247            let width_emb = if self.config.use_width_embeddings && width_emb_dim > 0 {
248                let max_width_idx = max_width.saturating_sub(1);
249                let span_len = candidate.width() as usize;
250                let width_idx = span_len.saturating_sub(1).min(max_width_idx);
251
252                let width_start = width_idx.saturating_mul(width_emb_dim);
253                let width_end = width_start.saturating_add(width_emb_dim);
254                if width_end > self.width_embeddings.len() {
255                    None
256                } else {
257                    Some(&self.width_embeddings[width_start..width_end])
258                }
259            } else {
260                None
261            };
262
263            // Baseline span representation: average of boundary embeddings (+ optional width signal).
264            // This is deterministic and works without learned projection weights.
265            let output_start = span_idx * hidden_dim;
266            for h in 0..hidden_dim {
267                span_embeddings[output_start + h] = (start_emb[h] + end_emb[h]) * 0.5;
268                if let Some(width_emb) = width_emb {
269                    if h < width_emb_dim {
270                        span_embeddings[output_start + h] += width_emb[h] * 0.1;
271                    }
272                }
273            }
274        }
275
276        span_embeddings
277    }
278}
279
280// =============================================================================
281// Handshaking Matrix (TPLinker-style Joint Extraction)
282// =============================================================================
283
284/// Result cell in a handshaking matrix.
285#[derive(Debug, Clone, Copy)]
286pub struct HandshakingCell {
287    /// Row index (token i)
288    pub i: u32,
289    /// Column index (token j)
290    pub j: u32,
291    /// Predicted label index
292    pub label_idx: u16,
293    /// Confidence score
294    pub score: f32,
295}
296
297/// Handshaking matrix for joint entity-relation extraction.
298///
299/// # Research Alignment (W2NER, AAAI 2022)
300///
301/// From the W2NER paper (arXiv:2112.10070):
302/// > "We present a novel alternative by modeling the unified NER as word-word
303/// > relation classification, namely W2NER. The architecture resolves the kernel
304/// > bottleneck of unified NER by effectively modeling the neighboring relations
305/// > between entity words with **Next-Neighboring-Word (NNW)** and
306/// > **Tail-Head-Word-* (THW-*)** relations."
307///
308/// In TPLinker/W2NER, we don't just tag tokens - we tag token PAIRS.
309/// The matrix M\[i,j\] contains the label for the span (i, j).
310///
311/// # Key Relations
312///
313/// | Relation | Description | Purpose |
314/// |----------|-------------|---------|
315/// | NNW | Next-Neighboring-Word | Links adjacent tokens within entity |
316/// | THW-* | Tail-Head-Word | Links end of one entity to start of next |
317///
318/// # Benefits
319///
320/// - Overlapping entities (same token in multiple spans)
321/// - Joint entity-relation extraction in one pass
322/// - Explicit boundary modeling
323/// - Handles flat, nested, AND discontinuous NER in one model
324pub struct HandshakingMatrix {
325    /// Non-zero cells (sparse representation)
326    pub cells: Vec<HandshakingCell>,
327    /// Sequence length
328    pub seq_len: usize,
329    /// Number of labels
330    pub num_labels: usize,
331}
332
333impl HandshakingMatrix {
334    /// Create from dense scores with thresholding.
335    ///
336    /// # Arguments
337    /// * `scores` - Dense [seq_len, seq_len, num_labels] scores
338    /// * `threshold` - Minimum score to keep
339    pub fn from_dense(scores: &[f32], seq_len: usize, num_labels: usize, threshold: f32) -> Self {
340        // Performance: Pre-allocate cells vec with estimated capacity
341        // Most matrices have sparse cells (only high-scoring ones), so we estimate conservatively
342        let estimated_capacity = (seq_len * seq_len / 10).min(1000); // ~10% of cells typically pass threshold
343        let mut cells = Vec::with_capacity(estimated_capacity);
344
345        for i in 0..seq_len {
346            for j in i..seq_len {
347                // Upper triangular (i <= j)
348                for l in 0..num_labels {
349                    let idx = i * seq_len * num_labels + j * num_labels + l;
350                    if idx < scores.len() {
351                        let score = scores[idx];
352                        if score >= threshold {
353                            cells.push(HandshakingCell {
354                                i: i as u32,
355                                j: j as u32,
356                                label_idx: l as u16,
357                                score,
358                            });
359                        }
360                    }
361                }
362            }
363        }
364
365        Self {
366            cells,
367            seq_len,
368            num_labels,
369        }
370    }
371
372    /// Decode entities from handshaking matrix.
373    ///
374    /// In W2NER convention, cell (i, j) represents a span where:
375    /// - j is the start token index
376    /// - i is the end token index (inclusive, so we add 1 for exclusive end)
377    pub fn decode_entities<'a>(
378        &self,
379        registry: &'a SemanticRegistry,
380    ) -> Vec<(SpanCandidate, &'a LabelDefinition, f32)> {
381        let mut entities = Vec::new();
382
383        for cell in &self.cells {
384            if let Some(label) = registry.labels.get(cell.label_idx as usize) {
385                if label.category == LabelCategory::Entity {
386                    // W2NER: j=start, i=end (inclusive), so span is [j, i+1)
387                    entities.push((SpanCandidate::new(0, cell.j, cell.i + 1), label, cell.score));
388                }
389            }
390        }
391
392        // Performance: Use unstable sort (we don't need stable sort here)
393        // Sort by position, then by score (descending)
394        entities.sort_unstable_by(|a, b| {
395            a.0.start
396                .cmp(&b.0.start)
397                .then_with(|| a.0.end.cmp(&b.0.end))
398                .then_with(|| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal))
399        });
400
401        // Performance: Pre-allocate kept vec with estimated capacity
402        // Non-maximum suppression
403        let mut kept = Vec::with_capacity(entities.len().min(32));
404        for (span, label, score) in entities {
405            let overlaps = kept.iter().any(|(s, _, _): &(SpanCandidate, _, _)| {
406                !(span.end <= s.start || s.end <= span.start)
407            });
408            if !overlaps {
409                kept.push((span, label, score));
410            }
411        }
412
413        kept
414    }
415}
416
417// =============================================================================