Skip to main content

anno/backends/
span_utils.rs

1//! Shared span tensor utilities for GLiNER-family models.
2//!
3//! GLiNER and NuNER both use span-based architectures where the model scores
4//! every possible span (up to `MAX_SPAN_WIDTH` words) against each entity type.
5//! This module provides common utilities for generating span tensors and
6//! decoding span-based outputs.
7//!
8//! # Architecture Overview
9//!
10//! ```text
11//! Input Text: "Steve Jobs founded Apple"
12//!             [0]    [1]     [2]    [3]
13//!
14//! Span Grid (MAX_SPAN_WIDTH=3):
15//!
16//!   Start=0: (0,0) (0,1) (0,2)  → "Steve", "Steve Jobs", "Steve Jobs founded"
17//!   Start=1: (1,1) (1,2) (1,3)  → "Jobs", "Jobs founded", "Jobs founded Apple"
18//!   Start=2: (2,2) (2,3)        → "founded", "founded Apple"
19//!   Start=3: (3,3)              → "Apple"
20//!
21//! Total spans: sum(min(MAX_WIDTH, remaining_words)) for each start position
22//! ```
23//!
24//! # Span Tensor Format
25//!
26//! - `span_idx`: `[num_spans, 2]` - (start_word, end_word) indices for each span
27//! - `span_mask`: `[num_spans]` - boolean mask indicating valid spans
28//!
29//! # Output Decoding
30//!
31//! Model output shape: `[batch, num_words, max_width, num_entity_types]`
32//!
33//! Each cell `output[0][start][width][type_idx]` contains the score for:
34//! - Span starting at word `start`
35//! - With width `width + 1` words (i.e., ends at word `start + width`)
36//! - Being an entity of type `type_idx`
37//!
38//! # References
39//!
40//! - GLiNER paper: "GLiNER: Generalist and Lightweight Model for Named Entity Recognition"
41//! - NuNER: Token-based variant using same span representation
42//! - Community GLiNER implementations (for span layout conventions)
43
44use crate::{Entity, EntityType, Error, Result};
45
46/// Default maximum span width (in words) for GLiNER-family models.
47///
48/// This matches the training configuration of most GLiNER/NuNER models.
49/// Spans longer than this are not considered by the model.
50pub const DEFAULT_MAX_SPAN_WIDTH: usize = 12;
51
52/// Configuration for span-based NER decoding.
53#[derive(Debug, Clone)]
54pub struct SpanConfig {
55    /// Maximum span width in words.
56    pub max_span_width: usize,
57    /// Confidence threshold for entity extraction.
58    pub threshold: f32,
59}
60
61impl Default for SpanConfig {
62    fn default() -> Self {
63        Self {
64            max_span_width: DEFAULT_MAX_SPAN_WIDTH,
65            threshold: 0.5,
66        }
67    }
68}
69
70/// Generate span tensors for ONNX model input.
71///
72/// Creates the `span_idx` and `span_mask` tensors required by GLiNER-family models.
73///
74/// # Arguments
75///
76/// * `num_words` - Number of words in the input text
77/// * `max_width` - Maximum span width to consider
78///
79/// # Returns
80///
81/// A tuple of:
82/// - `span_idx`: Flattened `[num_spans * 2]` array of (start, end) pairs
83/// - `span_mask`: `[num_spans]` boolean mask of valid spans
84///
85/// # Example
86///
87/// ```rust
88/// use anno::backends::span_utils::make_span_tensors;
89///
90/// let (span_idx, span_mask) = make_span_tensors(4, 3);
91///
92/// // First span: (0, 0) -> "word 0"
93/// assert_eq!(span_idx[0], 0); // start
94/// assert_eq!(span_idx[1], 0); // end (exclusive would be 1, but GLiNER uses inclusive)
95/// assert!(span_mask[0]);
96/// ```
97pub fn make_span_tensors(num_words: usize, max_width: usize) -> (Vec<i64>, Vec<bool>) {
98    // Calculate total number of spans with overflow protection
99    let num_spans = match num_words.checked_mul(max_width) {
100        Some(v) => v,
101        None => {
102            log::warn!(
103                "[span_utils] Span count overflow: {} words * {} max_width, returning empty",
104                num_words,
105                max_width
106            );
107            return (Vec::new(), Vec::new());
108        }
109    };
110
111    let span_idx_len = match num_spans.checked_mul(2) {
112        Some(v) => v,
113        None => {
114            log::warn!(
115                "[span_utils] Span idx length overflow: {} * 2, returning empty",
116                num_spans
117            );
118            return (Vec::new(), Vec::new());
119        }
120    };
121
122    let mut span_idx: Vec<i64> = vec![0; span_idx_len];
123    let mut span_mask: Vec<bool> = vec![false; num_spans];
124
125    for start in 0..num_words {
126        let remaining_width = num_words - start;
127        let actual_max_width = max_width.min(remaining_width);
128
129        for width in 0..actual_max_width {
130            // Calculate linear index with overflow protection
131            let dim = match start.checked_mul(max_width) {
132                Some(v) => match v.checked_add(width) {
133                    Some(d) => d,
134                    None => continue,
135                },
136                None => continue,
137            };
138
139            // Bounds check before array access
140            if let Some(dim2) = dim.checked_mul(2) {
141                if dim2 + 1 < span_idx_len && dim < num_spans {
142                    span_idx[dim2] = start as i64;
143                    // End offset: start + width gives the last word index (inclusive)
144                    span_idx[dim2 + 1] = (start + width) as i64;
145                    span_mask[dim] = true;
146                }
147            }
148        }
149    }
150
151    (span_idx, span_mask)
152}
153
154/// Calculate word positions (byte offsets) in the original text.
155///
156/// Maps word indices to their (start, end) byte positions in the source text.
157///
158/// # Arguments
159///
160/// * `text` - The original text
161/// * `words` - Whitespace-split words
162///
163/// # Returns
164///
165/// A vector of (start_byte, end_byte) positions for each word.
166///
167/// # Errors
168///
169/// Returns an error if any word cannot be found at the expected position.
170pub fn calculate_word_positions(text: &str, words: &[&str]) -> Result<Vec<(usize, usize)>> {
171    let mut positions = Vec::with_capacity(words.len());
172    let mut pos = 0;
173
174    for (idx, word) in words.iter().enumerate() {
175        // Find word starting from current position
176        if let Some(rel_start) = text[pos..].find(word) {
177            let abs_start = pos + rel_start;
178            let abs_end = abs_start + word.len();
179
180            // Validate: words should be in order
181            if !positions.is_empty() {
182                let (_prev_start, prev_end) = positions[positions.len() - 1];
183                if abs_start < prev_end {
184                    log::warn!(
185                        "[span_utils] Word '{}' at {} overlaps with previous word ending at {}",
186                        word,
187                        abs_start,
188                        prev_end
189                    );
190                }
191            }
192
193            positions.push((abs_start, abs_end));
194            pos = abs_end;
195        } else {
196            return Err(Error::Parse(format!(
197                "Word '{}' (index {}) not found in text starting at position {}",
198                word, idx, pos
199            )));
200        }
201    }
202
203    Ok(positions)
204}
205
206/// Extract entity span from text given word positions.
207///
208/// # Arguments
209///
210/// * `text` - The original text
211/// * `word_positions` - Byte positions of each word
212/// * `start_word` - Starting word index (inclusive)
213/// * `end_word` - Ending word index (inclusive)
214///
215/// # Returns
216///
217/// The text span and its byte range, or None if indices are invalid.
218pub fn extract_span<'a>(
219    text: &'a str,
220    word_positions: &[(usize, usize)],
221    start_word: usize,
222    end_word: usize,
223) -> Option<(&'a str, usize, usize)> {
224    let start_pos = word_positions.get(start_word)?.0;
225    let end_pos = word_positions.get(end_word)?.1;
226
227    if start_pos > end_pos || end_pos > text.len() {
228        return None;
229    }
230
231    Some((&text[start_pos..end_pos], start_pos, end_pos))
232}
233
234/// Decode span-based model output into entities.
235///
236/// This is the core decoding function for GLiNER-family models.
237///
238/// # Arguments
239///
240/// * `output_data` - Flattened model output tensor
241/// * `shape` - Output shape `[batch, num_words, max_width, num_classes]`
242/// * `text` - Original input text
243/// * `text_words` - Whitespace-split words
244/// * `entity_types` - Entity type labels
245/// * `config` - Decoding configuration
246///
247/// # Returns
248///
249/// A vector of extracted entities.
250///
251/// # Output Format
252///
253/// The model output has shape `[batch, num_words, max_width, num_classes]`:
254/// - `batch`: Always 1 for single-text inference
255/// - `num_words`: Number of words in the input
256/// - `max_width`: Maximum span width (DEFAULT_MAX_SPAN_WIDTH)
257/// - `num_classes`: Number of entity types
258///
259/// Each cell contains the score for a (start, width, type) triple.
260pub fn decode_span_output(
261    output_data: &[f32],
262    shape: &[i64],
263    text: &str,
264    text_words: &[&str],
265    entity_types: &[&str],
266    config: &SpanConfig,
267) -> Result<Vec<Entity>> {
268    // Validate shape
269    if shape.len() < 3 {
270        return Err(Error::Parse(format!(
271            "Expected at least 3D output, got shape {:?}",
272            shape
273        )));
274    }
275
276    // Parse shape dimensions
277    let (out_num_words, out_max_width, num_classes) = if shape.len() == 4 {
278        // Standard GLiNER format: [batch, num_words, max_width, num_classes]
279        (shape[1] as usize, shape[2] as usize, shape[3] as usize)
280    } else if shape.len() == 3 {
281        // Squeezed batch dimension: [num_words, max_width, num_classes]
282        (shape[0] as usize, shape[1] as usize, shape[2] as usize)
283    } else {
284        return Err(Error::Parse(format!(
285            "Unexpected output shape: {:?}",
286            shape
287        )));
288    };
289
290    log::debug!(
291        "[span_utils] Decoding: words={}, max_width={}, classes={}, data_len={}",
292        out_num_words,
293        out_max_width,
294        num_classes,
295        output_data.len()
296    );
297
298    // Calculate word positions
299    let word_positions = calculate_word_positions(text, text_words)?;
300    // `calculate_word_positions` (and most tokenizer/regex style tooling) yields byte offsets.
301    // `Entity` offsets are defined as character offsets, so convert at construction time.
302    let span_converter = crate::offset::SpanConverter::new(text);
303
304    // Validate dimensions match
305    let num_text_words = text_words.len();
306    if out_num_words < num_text_words {
307        log::warn!(
308            "[span_utils] Output has fewer words ({}) than input ({})",
309            out_num_words,
310            num_text_words
311        );
312    }
313
314    let mut entities = Vec::with_capacity(32);
315
316    // Iterate over all valid spans
317    for start in 0..num_text_words.min(out_num_words) {
318        for width in 0..config.max_span_width.min(out_max_width) {
319            let end = start + width;
320            if end >= num_text_words {
321                break;
322            }
323
324            // Find best entity type for this span
325            let base_idx = (start * out_max_width * num_classes) + (width * num_classes);
326
327            let mut best_score = config.threshold;
328            let mut best_type_idx = None;
329
330            for type_idx in 0..num_classes.min(entity_types.len()) {
331                let score = output_data.get(base_idx + type_idx).copied().unwrap_or(0.0);
332
333                if score > best_score {
334                    best_score = score;
335                    best_type_idx = Some(type_idx);
336                }
337            }
338
339            // Create entity if score exceeds threshold
340            if let Some(type_idx) = best_type_idx {
341                if let Some((span_text, start_byte, end_byte)) =
342                    extract_span(text, &word_positions, start, end)
343                {
344                    let entity_type = map_label_to_entity_type(entity_types[type_idx]);
345                    let mut entity = Entity::new(
346                        span_text,
347                        entity_type,
348                        span_converter.byte_to_char(start_byte),
349                        span_converter.byte_to_char(end_byte),
350                        best_score as f64,
351                    );
352                    entity.provenance =
353                        Some(crate::Provenance::ml("span-decoder", best_score as f64));
354                    entities.push(entity);
355                }
356            }
357        }
358    }
359
360    // Sort by position and remove overlaps (keep highest confidence)
361    entities.sort_by(|a, b| {
362        a.start
363            .cmp(&b.start)
364            // Use total ordering to avoid `partial_cmp` returning None on NaN.
365            .then_with(|| b.confidence.total_cmp(&a.confidence))
366    });
367
368    // Remove overlapping entities (keep first = highest confidence due to sort)
369    let mut filtered = Vec::with_capacity(entities.len());
370    for entity in entities {
371        let overlaps = filtered
372            .iter()
373            .any(|e: &Entity| ranges_overlap(e.start, e.end, entity.start, entity.end));
374        if !overlaps {
375            filtered.push(entity);
376        }
377    }
378
379    Ok(filtered)
380}
381
382/// Check if two ranges overlap.
383#[inline]
384fn ranges_overlap(start1: usize, end1: usize, start2: usize, end2: usize) -> bool {
385    start1 < end2 && start2 < end1
386}
387
388/// Map entity type label string to EntityType enum.
389///
390/// Handles common label variations (case-insensitive).
391pub fn map_label_to_entity_type(label: &str) -> EntityType {
392    match label.to_lowercase().as_str() {
393        "person" | "per" => EntityType::Person,
394        "organization" | "org" | "company" | "corp" => EntityType::Organization,
395        "location" | "loc" | "place" | "gpe" => EntityType::Location,
396        "date" => EntityType::Date,
397        "datetime" => EntityType::Date,
398        "time" => EntityType::Time,
399        "money" | "currency" => EntityType::Money,
400        "monetary" => EntityType::Money,
401        "percent" | "percentage" => EntityType::Percent,
402        "email" => EntityType::Email,
403        "phone" => EntityType::Phone,
404        "url" => EntityType::Url,
405        "quantity" => EntityType::Quantity,
406        "measure" => EntityType::Quantity,
407        "cardinal" => EntityType::Cardinal,
408        "number" | "num" => EntityType::Cardinal,
409        "ordinal" => EntityType::Ordinal,
410        "event" => EntityType::Other("EVENT".to_string()),
411        "product" | "prod" => EntityType::Other("PRODUCT".to_string()),
412        "work_of_art" | "work" => EntityType::Other("WORK_OF_ART".to_string()),
413        "law" | "legal" => EntityType::Other("LAW".to_string()),
414        "language" | "lang" => EntityType::Other("LANGUAGE".to_string()),
415        "norp" => EntityType::Other("NORP".to_string()), // Nationalities, religions, political groups
416        "fac" | "facility" => EntityType::Other("FACILITY".to_string()),
417        // Fine-grained / CNER-inspired labels
418        "animal" => EntityType::Other("ANIMAL".to_string()),
419        "biology" => EntityType::Other("BIOLOGY".to_string()),
420        "celestial" => EntityType::Other("CELESTIAL".to_string()),
421        "culture" => EntityType::Other("CULTURE".to_string()),
422        "discipline" => EntityType::Other("DISCIPLINE".to_string()),
423        "disease" => EntityType::Other("DISEASE".to_string()),
424        "feeling" => EntityType::Other("FEELING".to_string()),
425        "food" => EntityType::Other("FOOD".to_string()),
426        "group" => EntityType::Other("GROUP".to_string()),
427        "instrument" => EntityType::Other("INSTRUMENT".to_string()),
428        "media" => EntityType::Other("MEDIA".to_string()),
429        "asset" => EntityType::Other("ASSET".to_string()),
430        "artifact" => EntityType::Other("ARTIFACT".to_string()),
431        "part" => EntityType::Other("PART".to_string()),
432        "physical_phenomenon" | "physical" => EntityType::Other("PHYSICAL_PHENOMENON".to_string()),
433        "plant" => EntityType::Other("PLANT".to_string()),
434        "property" => EntityType::Other("PROPERTY".to_string()),
435        "psych" => EntityType::Other("PSYCH".to_string()),
436        "relation" => EntityType::Other("RELATION".to_string()),
437        "struct" => EntityType::Other("STRUCT".to_string()),
438        "substance" => EntityType::Other("SUBSTANCE".to_string()),
439        "super" | "supernatural" => EntityType::Other("SUPER".to_string()),
440        "vehicle" | "vehi" => EntityType::Other("VEHICLE".to_string()),
441        _ => EntityType::Other(label.to_uppercase()),
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_make_span_tensors_basic() {
451        let (span_idx, span_mask) = make_span_tensors(3, 2);
452
453        // 3 words * 2 max_width = 6 spans
454        assert_eq!(span_mask.len(), 6);
455        assert_eq!(span_idx.len(), 12);
456
457        // First span: word 0, width 0 → (0, 0)
458        assert!(span_mask[0]);
459        assert_eq!(span_idx[0], 0);
460        assert_eq!(span_idx[1], 0);
461
462        // Second span: word 0, width 1 → (0, 1)
463        assert!(span_mask[1]);
464        assert_eq!(span_idx[2], 0);
465        assert_eq!(span_idx[3], 1);
466    }
467
468    #[test]
469    fn test_make_span_tensors_overflow_protection() {
470        // Very large input shouldn't panic
471        let (span_idx, span_mask) = make_span_tensors(usize::MAX / 2, DEFAULT_MAX_SPAN_WIDTH);
472        // Should return empty due to overflow
473        assert!(span_idx.is_empty());
474        assert!(span_mask.is_empty());
475    }
476
477    #[test]
478    fn test_calculate_word_positions() {
479        let text = "Steve Jobs founded Apple";
480        let words: Vec<&str> = text.split_whitespace().collect();
481
482        let positions = calculate_word_positions(text, &words).unwrap();
483
484        assert_eq!(positions.len(), 4);
485        assert_eq!(positions[0], (0, 5)); // "Steve"
486        assert_eq!(positions[1], (6, 10)); // "Jobs"
487        assert_eq!(positions[2], (11, 18)); // "founded"
488        assert_eq!(positions[3], (19, 24)); // "Apple"
489    }
490
491    #[test]
492    fn test_extract_span() {
493        let text = "Steve Jobs founded Apple";
494        let positions = vec![(0, 5), (6, 10), (11, 18), (19, 24)];
495
496        // Single word span
497        let (span, start, end) = extract_span(text, &positions, 0, 0).unwrap();
498        assert_eq!(span, "Steve");
499        assert_eq!((start, end), (0, 5));
500
501        // Two-word span
502        let (span, start, end) = extract_span(text, &positions, 0, 1).unwrap();
503        assert_eq!(span, "Steve Jobs");
504        assert_eq!((start, end), (0, 10));
505
506        // Three-word span
507        let (span, start, end) = extract_span(text, &positions, 1, 3).unwrap();
508        assert_eq!(span, "Jobs founded Apple");
509        assert_eq!((start, end), (6, 24));
510    }
511
512    #[test]
513    fn test_map_label_to_entity_type() {
514        assert_eq!(map_label_to_entity_type("person"), EntityType::Person);
515        assert_eq!(map_label_to_entity_type("PER"), EntityType::Person);
516        assert_eq!(
517            map_label_to_entity_type("organization"),
518            EntityType::Organization
519        );
520        assert_eq!(map_label_to_entity_type("ORG"), EntityType::Organization);
521        assert_eq!(map_label_to_entity_type("location"), EntityType::Location);
522        assert_eq!(map_label_to_entity_type("GPE"), EntityType::Location);
523        assert_eq!(
524            map_label_to_entity_type("custom_type"),
525            EntityType::Other("CUSTOM_TYPE".to_string())
526        );
527    }
528
529    #[test]
530    fn test_ranges_overlap() {
531        assert!(ranges_overlap(0, 10, 5, 15)); // Partial overlap
532        assert!(ranges_overlap(0, 10, 0, 5)); // Contained
533        assert!(ranges_overlap(5, 15, 0, 10)); // Partial overlap (reversed)
534        assert!(!ranges_overlap(0, 5, 10, 15)); // No overlap
535        assert!(!ranges_overlap(0, 5, 5, 10)); // Adjacent (not overlapping)
536    }
537}