anno/backends/inference/span.rs
1//! Span representation types and the handshaking matrix for joint extraction.
2//!
3//! - `Span`, `SpanCandidate`, `SpanWindow`: character-offset span types
4//! - `HandshakingMatrix`, `HandshakingCell`: sparse grid for W2NER / TPLinker
5
6use super::registry::{LabelCategory, LabelDefinition, SemanticRegistry};
7use anno_core::{RaggedBatch, SpanCandidate};
8
9// Span Representation
10// =============================================================================
11
12/// Configuration for span representation.
13///
14/// # Research Context (Deep Span Representations, arXiv:2210.04182)
15///
16/// From "Deep Span Representations for NER":
17/// > "Existing span-based NER systems **shallowly aggregate** the token
18/// > representations to span representations. However, this typically results
19/// > in significant ineffectiveness for **long-span entities**."
20///
21/// Common span representation strategies:
22///
23/// | Method | Formula | Pros | Cons |
24/// |--------|---------|------|------|
25/// | Concat | [h_i; h_j] | Simple, fast | Ignores middle tokens |
26/// | Pooling | mean(h_i:h_j) | Uses all tokens | Loses boundary info |
27/// | Attention | attn(h_i:h_j) | Learnable | Expensive |
28/// | GLiNER | FFN([h_i; h_j; w]) | Balanced | Requires width emb |
29///
30/// # Recommendation (GLiNER Default)
31///
32/// For most use cases, concatenating first + last token embeddings with
33/// a width embedding provides the best tradeoff:
34/// - O(N) complexity (vs O(N²) for all-pairs attention)
35/// - Captures boundary positions (critical for NER)
36/// - Width embedding disambiguates "I" vs "New York City"
37#[derive(Debug, Clone)]
38pub struct SpanRepConfig {
39 /// Hidden dimension of the encoder
40 pub hidden_dim: usize,
41 /// Maximum span width (in tokens)
42 ///
43 /// GLiNER uses K=12: "to keep linear complexity without harming recall."
44 /// Wider spans rarely contain coherent entities.
45 pub max_width: usize,
46 /// Whether to include width embeddings
47 ///
48 /// Critical for distinguishing spans of different lengths
49 /// with similar boundary tokens.
50 pub use_width_embeddings: bool,
51 /// Width embedding dimension (typically hidden_dim / 4)
52 pub width_emb_dim: usize,
53}
54
55impl Default for SpanRepConfig {
56 fn default() -> Self {
57 Self {
58 hidden_dim: 768,
59 max_width: 12,
60 use_width_embeddings: true,
61 width_emb_dim: 192, // 768 / 4
62 }
63 }
64}
65
66/// Computes span representations from token embeddings.
67///
68/// # Research Alignment (GLiNER, NAACL 2024)
69///
70/// From the GLiNER paper (arXiv:2311.08526):
71/// > "The representation of a span starting at position i and ending at
72/// > position j in the input text, S_ij ∈ R^D, is computed as:
73/// > **S_ij = FFN(h_i ⊗ h_j)**
74/// > where FFN denotes a two-layer feedforward network, and ⊗ represents
75/// > the concatenation operation."
76///
77/// The paper also notes:
78/// > "We set an upper bound to the length (K=12) of the span in order to
79/// > keep linear complexity in the size of the input text, without harming recall."
80///
81/// # Span Representation Formula
82///
83/// ```text
84/// span_emb = FFN(Concat(token[i], token[j], width_emb[j-i]))
85/// = W_2 · ReLU(W_1 · [h_i; h_j; w_{j-i}] + b_1) + b_2
86/// ```
87///
88/// where:
89/// - h_i = start token embedding
90/// - h_j = end token embedding
91/// - w_{j-i} = learned width embedding (captures span length)
92///
93/// This is the "gnarly bit" from GLiNER that enables zero-shot matching.
94///
95/// # Alternative: Global Pointer (arXiv:2208.03054)
96///
97/// Instead of enumerating spans, Global Pointer uses RoPE (rotary position
98/// embeddings) to predict (start, end) pairs simultaneously:
99///
100/// ```text
101/// score(i, j) = q_i^T * k_j (where q, k have RoPE applied)
102/// ```
103///
104/// Advantages:
105/// - No explicit span enumeration needed
106/// - Naturally handles nested entities
107/// - More parameter-efficient
108///
109/// GLiNER-style enumeration is still preferred for zero-shot because
110/// it allows pre-computing label embeddings.
111pub struct SpanRepresentationLayer {
112 /// Configuration
113 pub config: SpanRepConfig,
114 /// Projection weights: [input_dim, hidden_dim]
115 pub projection_weights: Vec<f32>,
116 /// Projection bias: \[hidden_dim\]
117 pub projection_bias: Vec<f32>,
118 /// Width embeddings: [max_width, width_emb_dim]
119 pub width_embeddings: Vec<f32>,
120}
121
122impl SpanRepresentationLayer {
123 /// Create a new span representation layer with random initialization.
124 pub fn new(config: SpanRepConfig) -> Self {
125 let input_dim = config.hidden_dim * 2 + config.width_emb_dim;
126
127 Self {
128 projection_weights: vec![0.0f32; input_dim * config.hidden_dim],
129 projection_bias: vec![0.0f32; config.hidden_dim],
130 width_embeddings: vec![0.0f32; config.max_width * config.width_emb_dim],
131 config,
132 }
133 }
134
135 /// Compute span representations from token embeddings.
136 ///
137 /// # Arguments
138 /// * `token_embeddings` - Flattened [num_tokens, hidden_dim]
139 /// * `candidates` - Span candidates with start/end indices
140 ///
141 /// # Returns
142 /// Span embeddings: [num_candidates, hidden_dim]
143 pub fn forward(
144 &self,
145 token_embeddings: &[f32],
146 candidates: &[SpanCandidate],
147 batch: &RaggedBatch,
148 ) -> Vec<f32> {
149 let hidden_dim = self.config.hidden_dim;
150 let width_emb_dim = self.config.width_emb_dim;
151 let max_width = self.config.max_width;
152
153 // Check for overflow in allocation
154 let total_elements = match candidates.len().checked_mul(hidden_dim) {
155 Some(v) => v,
156 None => {
157 log::warn!(
158 "Span embedding allocation overflow: {} candidates * {} hidden_dim, returning empty",
159 candidates.len(), hidden_dim
160 );
161 return vec![];
162 }
163 };
164 let mut span_embeddings = vec![0.0f32; total_elements];
165
166 for (span_idx, candidate) in candidates.iter().enumerate() {
167 // Get document token range
168 let doc_range = match batch.doc_range(candidate.doc_idx as usize) {
169 Some(r) => r,
170 None => continue,
171 };
172
173 // Validate span before computing global indices
174 if candidate.end <= candidate.start {
175 log::warn!(
176 "Invalid span candidate: end ({}) <= start ({})",
177 candidate.end,
178 candidate.start
179 );
180 continue;
181 }
182
183 // Global token indices
184 let start_global = doc_range.start + candidate.start as usize;
185 let end_global = doc_range.start + (candidate.end as usize) - 1; // Safe now that we validated
186
187 // Bounds check - must ensure both start and end slices fit
188 // Use checked arithmetic to prevent overflow
189 let start_byte = match start_global.checked_mul(hidden_dim) {
190 Some(v) => v,
191 None => {
192 log::warn!(
193 "Token index overflow: start_global={} * hidden_dim={}",
194 start_global,
195 hidden_dim
196 );
197 continue;
198 }
199 };
200 let start_end_byte = match (start_global + 1).checked_mul(hidden_dim) {
201 Some(v) => v,
202 None => {
203 log::warn!(
204 "Token index overflow: (start_global+1)={} * hidden_dim={}",
205 start_global + 1,
206 hidden_dim
207 );
208 continue;
209 }
210 };
211 let end_byte = match end_global.checked_mul(hidden_dim) {
212 Some(v) => v,
213 None => {
214 log::warn!(
215 "Token index overflow: end_global={} * hidden_dim={}",
216 end_global,
217 hidden_dim
218 );
219 continue;
220 }
221 };
222 let end_end_byte = match (end_global + 1).checked_mul(hidden_dim) {
223 Some(v) => v,
224 None => {
225 log::warn!(
226 "Token index overflow: (end_global+1)={} * hidden_dim={}",
227 end_global + 1,
228 hidden_dim
229 );
230 continue;
231 }
232 };
233
234 if start_byte >= token_embeddings.len()
235 || start_end_byte > token_embeddings.len()
236 || end_byte >= token_embeddings.len()
237 || end_end_byte > token_embeddings.len()
238 {
239 continue;
240 }
241
242 // Get start and end token embeddings
243 let start_emb = &token_embeddings[start_byte..start_end_byte];
244 let end_emb = &token_embeddings[end_byte..end_end_byte];
245
246 // Optional width embedding (index = span_len - 1).
247 let width_emb = if self.config.use_width_embeddings && width_emb_dim > 0 {
248 let max_width_idx = max_width.saturating_sub(1);
249 let span_len = candidate.width() as usize;
250 let width_idx = span_len.saturating_sub(1).min(max_width_idx);
251
252 let width_start = width_idx.saturating_mul(width_emb_dim);
253 let width_end = width_start.saturating_add(width_emb_dim);
254 if width_end > self.width_embeddings.len() {
255 None
256 } else {
257 Some(&self.width_embeddings[width_start..width_end])
258 }
259 } else {
260 None
261 };
262
263 // Baseline span representation: average of boundary embeddings (+ optional width signal).
264 // This is deterministic and works without learned projection weights.
265 let output_start = span_idx * hidden_dim;
266 for h in 0..hidden_dim {
267 span_embeddings[output_start + h] = (start_emb[h] + end_emb[h]) * 0.5;
268 if let Some(width_emb) = width_emb {
269 if h < width_emb_dim {
270 span_embeddings[output_start + h] += width_emb[h] * 0.1;
271 }
272 }
273 }
274 }
275
276 span_embeddings
277 }
278}
279
280// =============================================================================
281// Handshaking Matrix (TPLinker-style Joint Extraction)
282// =============================================================================
283
284/// Result cell in a handshaking matrix.
285#[derive(Debug, Clone, Copy)]
286pub struct HandshakingCell {
287 /// Row index (token i)
288 pub i: u32,
289 /// Column index (token j)
290 pub j: u32,
291 /// Predicted label index
292 pub label_idx: u16,
293 /// Confidence score
294 pub score: f32,
295}
296
297/// Handshaking matrix for joint entity-relation extraction.
298///
299/// # Research Alignment (W2NER, AAAI 2022)
300///
301/// From the W2NER paper (arXiv:2112.10070):
302/// > "We present a novel alternative by modeling the unified NER as word-word
303/// > relation classification, namely W2NER. The architecture resolves the kernel
304/// > bottleneck of unified NER by effectively modeling the neighboring relations
305/// > between entity words with **Next-Neighboring-Word (NNW)** and
306/// > **Tail-Head-Word-* (THW-*)** relations."
307///
308/// In TPLinker/W2NER, we don't just tag tokens - we tag token PAIRS.
309/// The matrix M\[i,j\] contains the label for the span (i, j).
310///
311/// # Key Relations
312///
313/// | Relation | Description | Purpose |
314/// |----------|-------------|---------|
315/// | NNW | Next-Neighboring-Word | Links adjacent tokens within entity |
316/// | THW-* | Tail-Head-Word | Links end of one entity to start of next |
317///
318/// # Benefits
319///
320/// - Overlapping entities (same token in multiple spans)
321/// - Joint entity-relation extraction in one pass
322/// - Explicit boundary modeling
323/// - Handles flat, nested, AND discontinuous NER in one model
324pub struct HandshakingMatrix {
325 /// Non-zero cells (sparse representation)
326 pub cells: Vec<HandshakingCell>,
327 /// Sequence length
328 pub seq_len: usize,
329 /// Number of labels
330 pub num_labels: usize,
331}
332
333impl HandshakingMatrix {
334 /// Create from dense scores with thresholding.
335 ///
336 /// # Arguments
337 /// * `scores` - Dense [seq_len, seq_len, num_labels] scores
338 /// * `threshold` - Minimum score to keep
339 pub fn from_dense(scores: &[f32], seq_len: usize, num_labels: usize, threshold: f32) -> Self {
340 // Performance: Pre-allocate cells vec with estimated capacity
341 // Most matrices have sparse cells (only high-scoring ones), so we estimate conservatively
342 let estimated_capacity = (seq_len * seq_len / 10).min(1000); // ~10% of cells typically pass threshold
343 let mut cells = Vec::with_capacity(estimated_capacity);
344
345 for i in 0..seq_len {
346 for j in i..seq_len {
347 // Upper triangular (i <= j)
348 for l in 0..num_labels {
349 let idx = i * seq_len * num_labels + j * num_labels + l;
350 if idx < scores.len() {
351 let score = scores[idx];
352 if score >= threshold {
353 cells.push(HandshakingCell {
354 i: i as u32,
355 j: j as u32,
356 label_idx: l as u16,
357 score,
358 });
359 }
360 }
361 }
362 }
363 }
364
365 Self {
366 cells,
367 seq_len,
368 num_labels,
369 }
370 }
371
372 /// Decode entities from handshaking matrix.
373 ///
374 /// In W2NER convention, cell (i, j) represents a span where:
375 /// - j is the start token index
376 /// - i is the end token index (inclusive, so we add 1 for exclusive end)
377 pub fn decode_entities<'a>(
378 &self,
379 registry: &'a SemanticRegistry,
380 ) -> Vec<(SpanCandidate, &'a LabelDefinition, f32)> {
381 let mut entities = Vec::new();
382
383 for cell in &self.cells {
384 if let Some(label) = registry.labels.get(cell.label_idx as usize) {
385 if label.category == LabelCategory::Entity {
386 // W2NER: j=start, i=end (inclusive), so span is [j, i+1)
387 entities.push((SpanCandidate::new(0, cell.j, cell.i + 1), label, cell.score));
388 }
389 }
390 }
391
392 // Performance: Use unstable sort (we don't need stable sort here)
393 // Sort by position, then by score (descending)
394 entities.sort_unstable_by(|a, b| {
395 a.0.start
396 .cmp(&b.0.start)
397 .then_with(|| a.0.end.cmp(&b.0.end))
398 .then_with(|| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal))
399 });
400
401 // Performance: Pre-allocate kept vec with estimated capacity
402 // Non-maximum suppression
403 let mut kept = Vec::with_capacity(entities.len().min(32));
404 for (span, label, score) in entities {
405 let overlaps = kept.iter().any(|(s, _, _): &(SpanCandidate, _, _)| {
406 !(span.end <= s.start || s.end <= span.start)
407 });
408 if !overlaps {
409 kept.push((span, label, score));
410 }
411 }
412
413 kept
414 }
415}
416
417// =============================================================================