Skip to main content

anno/backends/bilstm_crf/
mod.rs

1//! BiLSTM-CRF NER backend.
2//!
3//! Implements the dominant neural NER architecture from 2015-2018, before transformers.
4//! This architecture represents the pivotal transition from feature engineering to
5//! representation learning, while retaining the CRF layer's sequence modeling.
6//!
7//! # Historical Context
8//!
9//! The NER field evolved through three eras:
10//!
11//! ```text
12//! Era 1: Rule-based (1987-1997)     - Lexicons, hand-crafted patterns
13//! Era 2: Statistical (1997-2015)    - HMM → MEMM → CRF (feature engineering)
14//! Era 3: Neural (2011-present)      - CNN → BiLSTM-CRF → Transformers
15//! ```
16//!
17//! BiLSTM-CRF bridged statistical and neural approaches:
18//! - **BiLSTM**: Learns features automatically from data (no feature engineering)
19//! - **CRF layer**: Retains structured prediction from statistical era
20//!
21//! Collobert et al. 2011 ("NLP from Scratch") first showed CNNs for NER, but
22//! BiLSTM-CRF (2015) became the dominant architecture until BERT (2018).
23//!
24//! # Why Keep the CRF Layer?
25//!
26//! The BiLSTM produces emission scores for each position, but doesn't model
27//! label dependencies. The CRF layer ensures:
28//! - Valid BIO sequences (no `I-PER` after `O`)
29//! - Learned transition patterns (e.g., `B-ORG` often followed by `I-ORG`)
30//!
31//! ```text
32//! Without CRF:  BiLSTM predicts [B-PER, I-ORG, O, B-LOC]  // invalid!
33//! With CRF:     Viterbi finds   [B-PER, O,     O, B-LOC]  // valid sequence
34//! ```
35//!
36//! # Architecture
37//!
38//! ```text
39//! Input: "John works at Google"
40//!    ↓
41//! ┌─────────────────────────────────────────┐
42//! │ Word Embeddings (GloVe/Word2Vec)        │
43//! │ + Character Embeddings (CNN/LSTM)       │
44//! └─────────────────────────────────────────┘
45//!    ↓
46//! ┌─────────────────────────────────────────┐
47//! │ Bidirectional LSTM                      │
48//! │  Forward:  h₁ → h₂ → h₃ → h₄           │
49//! │  Backward: h₁ ← h₂ ← h₃ ← h₄           │
50//! │  Concat:   [h→;h←] for each position    │
51//! └─────────────────────────────────────────┘
52//!    ↓
53//! ┌─────────────────────────────────────────┐
54//! │ CRF Layer                               │
55//! │  - Emission scores from BiLSTM          │
56//! │  - Transition matrix learned            │
57//! │  - Viterbi decoding for best sequence   │
58//! └─────────────────────────────────────────┘
59//!    ↓
60//! Output: B-PER O O B-ORG
61//! ```
62//!
63//! # Key Papers
64//!
65//! - Collobert et al. 2011: "Natural Language Processing (Almost) from Scratch"
66//! - Huang et al. 2015: "Bidirectional LSTM-CRF Models for Sequence Tagging"
67//! - Lample et al. 2016: "Neural Architectures for Named Entity Recognition"
68//! - Ma & Hovy 2016: "End-to-end Sequence Labeling via Bi-directional LSTM-CNNs-CRF"
69//! - Peters et al. 2018: "Deep Contextualized Word Representations" (ELMo)
70//!
71//! # References
72//!
73//! - Collobert, Weston, Bottou, et al. (2011): "Natural Language Processing
74//!   (Almost) from Scratch" (JMLR) — first neural NER
75//! - Huang, Xu, Yu (2015): "Bidirectional LSTM-CRF Models for Sequence
76//!   Tagging" (arXiv:1508.01991) — introduced BiLSTM-CRF
77//! - Lample, Ballesteros, Subramanian, et al. (2016): "Neural Architectures
78//!   for Named Entity Recognition" (NAACL) — char embeddings
79//! - Ma & Hovy (2016): "End-to-end Sequence Labeling via Bi-directional
80//!   LSTM-CNNs-CRF" (ACL) — CNN char encoder
81//!
82//! # See Also
83//!
84//! - Historical NER baselines (HMM/CRF-era sequence models)
85//!
86//! # Usage
87//!
88//! ```rust
89//! use anno::backends::bilstm_crf::BiLstmCrfNER;
90//! use anno::Model;
91//!
92//! // Create with heuristic weights (no neural inference)
93//! let ner = BiLstmCrfNER::new();
94//! let entities = ner.extract_entities("John works at Google", None).unwrap();
95//! ```
96//!
97//! With ONNX feature enabled, load pre-trained weights:
98//!
99//! ```rust,ignore
100//! // Requires: features = ["onnx"]
101//! let ner = BiLstmCrfNER::from_onnx("path/to/model.onnx")?;
102//! ```
103
104use crate::{Entity, EntityType, Model, Result};
105use std::collections::HashMap;
106
107/// BiLSTM-CRF configuration.
108#[derive(Debug, Clone)]
109pub struct BiLstmCrfConfig {
110    /// Hidden size for LSTM layers.
111    pub hidden_size: usize,
112    /// Number of LSTM layers.
113    pub num_layers: usize,
114    /// Dropout probability.
115    pub dropout: f32,
116    /// Whether to use character-level embeddings.
117    pub use_char_embeddings: bool,
118    /// Maximum sequence length.
119    pub max_seq_len: usize,
120}
121
122impl Default for BiLstmCrfConfig {
123    fn default() -> Self {
124        Self {
125            hidden_size: 256,
126            num_layers: 2,
127            dropout: 0.5,
128            use_char_embeddings: true,
129            max_seq_len: 512,
130        }
131    }
132}
133
134/// BiLSTM-CRF NER model.
135///
136/// This implements the classic neural NER architecture that dominated
137/// from 2015-2018 before transformer models.
138///
139/// # Components
140///
141/// 1. **Word Embeddings**: Pre-trained (GloVe/Word2Vec) or learned
142/// 2. **Character Embeddings**: CNN or LSTM over characters (optional)
143/// 3. **BiLSTM Encoder**: Bidirectional LSTM for context
144/// 4. **CRF Decoder**: Structured prediction with transition constraints
145#[derive(Debug)]
146pub struct BiLstmCrfNER {
147    /// Model configuration.
148    #[allow(dead_code)] // Reserved for model serialization
149    config: BiLstmCrfConfig,
150    /// BIO labels for decoding.
151    labels: Vec<String>,
152    /// Label to index mapping.
153    label_to_idx: HashMap<String, usize>,
154    /// Transition scores (from CRF layer).
155    transitions: Vec<Vec<f64>>,
156    /// Word vocabulary (word -> embedding index).
157    #[allow(dead_code)] // Reserved for embedding lookup
158    vocab: HashMap<String, usize>,
159    /// ONNX session for inference (when onnx feature enabled).
160    #[cfg(feature = "onnx")]
161    session: Option<ort::session::Session>,
162}
163
164impl BiLstmCrfNER {
165    /// Create a new BiLSTM-CRF model with default configuration.
166    ///
167    /// This creates a model that uses heuristic-based inference
168    /// (no neural weights). For actual neural inference, use
169    /// `from_onnx()` to load pre-trained weights.
170    #[must_use]
171    pub fn new() -> Self {
172        Self::with_config(BiLstmCrfConfig::default())
173    }
174
175    /// Create with custom configuration.
176    #[must_use]
177    pub fn with_config(config: BiLstmCrfConfig) -> Self {
178        let labels = vec![
179            "O".to_string(),
180            "B-PER".to_string(),
181            "I-PER".to_string(),
182            "B-ORG".to_string(),
183            "I-ORG".to_string(),
184            "B-LOC".to_string(),
185            "I-LOC".to_string(),
186            "B-MISC".to_string(),
187            "I-MISC".to_string(),
188        ];
189
190        let label_to_idx: HashMap<String, usize> = labels
191            .iter()
192            .enumerate()
193            .map(|(i, l)| (l.clone(), i))
194            .collect();
195
196        // Initialize transition matrix with sensible defaults
197        // Higher scores for valid BIO transitions
198        let n = labels.len();
199        let mut transitions = vec![vec![0.0; n]; n];
200
201        // BIO constraints: I-X can only follow B-X or I-X
202        for i in 0..n {
203            for j in 0..n {
204                let from_label = &labels[i];
205                let to_label = &labels[j];
206
207                if let Some(entity_type) = to_label.strip_prefix("I-") {
208                    let valid_prev = format!("B-{}", entity_type);
209                    let valid_cont = format!("I-{}", entity_type);
210
211                    if from_label == &valid_prev || from_label == &valid_cont {
212                        transitions[i][j] = 1.0; // Valid transition
213                    } else {
214                        transitions[i][j] = -10.0; // Invalid transition
215                    }
216                } else {
217                    // B-X or O can follow anything
218                    transitions[i][j] = 0.0;
219                }
220            }
221        }
222
223        Self {
224            config,
225            labels,
226            label_to_idx,
227            transitions,
228            vocab: HashMap::new(),
229            #[cfg(feature = "onnx")]
230            session: None,
231        }
232    }
233
234    /// Load from ONNX model file.
235    #[cfg(feature = "onnx")]
236    pub fn from_onnx(model_path: &str) -> Result<Self> {
237        use crate::Error;
238        use ort::session::{builder::GraphOptimizationLevel, Session};
239
240        let session = Session::builder()
241            .map_err(|e| Error::model_init(format!("Failed to create session builder: {}", e)))?
242            .with_optimization_level(GraphOptimizationLevel::Level3)
243            .map_err(|e| Error::model_init(format!("Failed to set optimization level: {}", e)))?
244            .commit_from_file(model_path)
245            .map_err(|e| Error::model_init(format!("Failed to load ONNX model: {}", e)))?;
246
247        let mut model = Self::new();
248        model.session = Some(session);
249        Ok(model)
250    }
251
252    /// Tokenize text into words.
253    fn tokenize(text: &str) -> Vec<&str> {
254        text.split_whitespace().collect()
255    }
256
257    /// Get emission scores for each token.
258    ///
259    /// In a full implementation, this would run the BiLSTM.
260    /// Here we use realistic heuristic features as a fallback,
261    /// combining gazetteers, word shape, and contextual patterns.
262    fn get_emissions(&self, tokens: &[&str]) -> Vec<Vec<f64>> {
263        let n_labels = self.labels.len();
264        let mut emissions = vec![vec![0.0; n_labels]; tokens.len()];
265
266        // Gazetteers for better heuristic accuracy
267        const PERSON_NAMES: &[&str] = &[
268            "john",
269            "mary",
270            "james",
271            "david",
272            "michael",
273            "robert",
274            "william",
275            "richard",
276            "sarah",
277            "jennifer",
278            "elizabeth",
279            "lisa",
280            "marie",
281            "jane",
282            "emily",
283            "anna",
284            "barack",
285            "donald",
286            "joe",
287            "george",
288            "bill",
289            "hillary",
290            "elon",
291            "jeff",
292            "mr",
293            "mrs",
294            "ms",
295            "dr",
296            "prof",
297            "sir",
298            "lord",
299            "president",
300            "ceo",
301        ];
302        const ORG_NAMES: &[&str] = &[
303            "google",
304            "apple",
305            "microsoft",
306            "amazon",
307            "facebook",
308            "meta",
309            "tesla",
310            "ibm",
311            "intel",
312            "nvidia",
313            "oracle",
314            "cisco",
315            "adobe",
316            "netflix",
317            "uber",
318            "university",
319            "institute",
320            "corporation",
321            "company",
322            "inc",
323            "corp",
324            "ltd",
325            "llc",
326            "foundation",
327            "association",
328            "organization",
329            "department",
330            "agency",
331            "fbi",
332            "cia",
333            "nsa",
334            "nasa",
335            "un",
336            "nato",
337            "who",
338            "imf",
339            "eu",
340            "usa",
341        ];
342        const LOC_NAMES: &[&str] = &[
343            "new",
344            "york",
345            "california",
346            "texas",
347            "florida",
348            "london",
349            "paris",
350            "berlin",
351            "tokyo",
352            "beijing",
353            "moscow",
354            "washington",
355            "chicago",
356            "boston",
357            "seattle",
358            "san",
359            "francisco",
360            "los",
361            "angeles",
362            "las",
363            "vegas",
364            "united",
365            "states",
366            "america",
367            "china",
368            "russia",
369            "germany",
370            "france",
371            "japan",
372            "india",
373            "brazil",
374            "city",
375            "county",
376            "state",
377            "country",
378            "river",
379            "mountain",
380            "lake",
381            "ocean",
382        ];
383
384        for (i, token) in tokens.iter().enumerate() {
385            let lower = token.to_lowercase();
386            let is_capitalized = token.chars().next().is_some_and(|c| c.is_uppercase());
387            let is_all_caps = token
388                .chars()
389                .all(|c| c.is_uppercase() || !c.is_alphabetic())
390                && token.len() > 1;
391            let has_digit = token.chars().any(|c| c.is_ascii_digit());
392            let is_first = i == 0;
393
394            // Default: bias toward O (entities are rare)
395            emissions[i][0] = 1.5;
396
397            // Gazetteer matches (strongest signal)
398            if PERSON_NAMES.contains(&lower.as_str()) {
399                emissions[i][self.label_to_idx["B-PER"]] += 2.0;
400                emissions[i][self.label_to_idx["I-PER"]] += 1.0;
401            }
402            if ORG_NAMES.contains(&lower.as_str()) {
403                emissions[i][self.label_to_idx["B-ORG"]] += 2.0;
404                emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
405            }
406            if LOC_NAMES.contains(&lower.as_str()) {
407                emissions[i][self.label_to_idx["B-LOC"]] += 2.0;
408                emissions[i][self.label_to_idx["I-LOC"]] += 1.0;
409            }
410
411            // Capitalization (weaker signal, context-dependent)
412            if is_capitalized && !has_digit && !is_first {
413                emissions[i][self.label_to_idx["B-PER"]] += 0.8;
414                emissions[i][self.label_to_idx["B-ORG"]] += 0.6;
415                emissions[i][self.label_to_idx["B-LOC"]] += 0.5;
416            }
417
418            // Organization suffixes
419            if lower.ends_with("inc.")
420                || lower.ends_with("corp.")
421                || lower.ends_with("ltd.")
422                || lower.ends_with("llc")
423                || lower.ends_with("co.")
424            {
425                emissions[i][self.label_to_idx["B-ORG"]] += 1.5;
426                emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
427            }
428
429            // Acronyms (2-5 uppercase letters)
430            if is_all_caps && token.len() >= 2 && token.len() <= 5 && !has_digit {
431                emissions[i][self.label_to_idx["B-ORG"]] += 1.2;
432            }
433
434            // Honorifics signal person
435            if ["mr.", "mrs.", "ms.", "dr.", "prof."].contains(&lower.as_str()) {
436                emissions[i][self.label_to_idx["B-PER"]] += 1.5;
437            }
438
439            // "The" before proper noun often signals ORG or LOC
440            if i > 0 && tokens[i - 1].to_lowercase() == "the" && is_capitalized {
441                emissions[i][self.label_to_idx["B-ORG"]] += 0.5;
442                emissions[i][self.label_to_idx["B-LOC"]] += 0.3;
443            }
444
445            // Multi-word entity continuation
446            if i > 0 {
447                let prev_cap = tokens[i - 1]
448                    .chars()
449                    .next()
450                    .is_some_and(|c| c.is_uppercase());
451                if prev_cap && is_capitalized && !is_first {
452                    // Likely continuation of entity
453                    emissions[i][self.label_to_idx["I-PER"]] += 0.6;
454                    emissions[i][self.label_to_idx["I-ORG"]] += 0.6;
455                    emissions[i][self.label_to_idx["I-LOC"]] += 0.4;
456                }
457            }
458        }
459
460        emissions
461    }
462
463    /// Viterbi decoding with CRF transitions.
464    fn viterbi_decode(&self, emissions: &[Vec<f64>]) -> Vec<usize> {
465        if emissions.is_empty() {
466            return vec![];
467        }
468
469        let n = emissions.len();
470        let m = self.labels.len();
471
472        // DP tables
473        let mut scores = vec![vec![f64::NEG_INFINITY; m]; n];
474        let mut backpointers = vec![vec![0usize; m]; n];
475
476        // Initialize first position
477        for j in 0..m {
478            scores[0][j] = emissions[0][j];
479        }
480
481        // Forward pass
482        for i in 1..n {
483            for j in 0..m {
484                let mut best_score = f64::NEG_INFINITY;
485                let mut best_prev = 0;
486
487                #[allow(clippy::needless_range_loop)]
488                for k in 0..m {
489                    let score = scores[i - 1][k] + self.transitions[k][j] + emissions[i][j];
490                    if score > best_score {
491                        best_score = score;
492                        best_prev = k;
493                    }
494                }
495
496                scores[i][j] = best_score;
497                backpointers[i][j] = best_prev;
498            }
499        }
500
501        // Backward pass
502        let mut path = vec![0usize; n];
503        let mut best_final = 0;
504        let mut best_score = f64::NEG_INFINITY;
505
506        for (j, &score) in scores[n - 1].iter().enumerate() {
507            if score > best_score {
508                best_score = score;
509                best_final = j;
510            }
511        }
512
513        path[n - 1] = best_final;
514        for i in (0..n - 1).rev() {
515            path[i] = backpointers[i + 1][path[i + 1]];
516        }
517
518        path
519    }
520
521    /// Convert BIO labels to entities.
522    ///
523    /// Uses token position tracking to correctly handle duplicate entity texts.
524    /// The previous implementation used `text.find()` which always returned the
525    /// first occurrence, causing incorrect offsets for duplicate entities.
526    fn labels_to_entities(
527        &self,
528        text: &str,
529        tokens: &[&str],
530        label_indices: &[usize],
531    ) -> Vec<Entity> {
532        use crate::offset::SpanConverter;
533
534        let converter = SpanConverter::new(text);
535        let mut entities = Vec::new();
536
537        // Track token positions (byte offsets) as we iterate
538        let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, tokens);
539
540        let mut current_entity: Option<(usize, usize, EntityType, Vec<&str>)> = None;
541
542        for (i, (&label_idx, &token)) in label_indices.iter().zip(tokens.iter()).enumerate() {
543            let label = &self.labels[label_idx];
544
545            if let Some(entity_suffix) = label.strip_prefix("B-") {
546                // Save previous entity if any
547                if let Some((start_token_idx, end_token_idx, entity_type, words)) =
548                    current_entity.take()
549                {
550                    Self::push_entity_from_positions(
551                        &converter,
552                        &token_positions,
553                        start_token_idx,
554                        end_token_idx,
555                        &words,
556                        entity_type,
557                        &mut entities,
558                    );
559                }
560
561                // Start new entity
562                let entity_type = match entity_suffix {
563                    "PER" => EntityType::Person,
564                    "ORG" => EntityType::Organization,
565                    "LOC" => EntityType::Location,
566                    other => EntityType::Other(other.to_string()),
567                };
568                current_entity = Some((i, i, entity_type, vec![token]));
569            } else if label.starts_with("I-") && current_entity.is_some() {
570                // Continue current entity
571                if let Some((_, ref mut end_idx, _, ref mut words)) = current_entity {
572                    words.push(token);
573                    *end_idx = i;
574                }
575            } else {
576                // O label - save and reset
577                if let Some((start_token_idx, end_token_idx, entity_type, words)) =
578                    current_entity.take()
579                {
580                    Self::push_entity_from_positions(
581                        &converter,
582                        &token_positions,
583                        start_token_idx,
584                        end_token_idx,
585                        &words,
586                        entity_type,
587                        &mut entities,
588                    );
589                }
590            }
591        }
592
593        // Don't forget last entity
594        if let Some((start_token_idx, end_token_idx, entity_type, words)) = current_entity.take() {
595            Self::push_entity_from_positions(
596                &converter,
597                &token_positions,
598                start_token_idx,
599                end_token_idx,
600                &words,
601                entity_type,
602                &mut entities,
603            );
604        }
605
606        entities
607    }
608
609    /// Calculate byte positions for each token in the text.
610    fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
611        let mut positions = Vec::with_capacity(tokens.len());
612        let mut byte_pos = 0;
613
614        for token in tokens {
615            // Find token starting from current position
616            if let Some(rel_pos) = text[byte_pos..].find(token) {
617                let start = byte_pos + rel_pos;
618                let end = start + token.len();
619                positions.push((start, end));
620                byte_pos = end; // Move past this token
621            } else {
622                // Fallback: use current position (shouldn't happen with whitespace tokenization)
623                positions.push((byte_pos, byte_pos));
624            }
625        }
626
627        positions
628    }
629
630    /// Helper to push entity using tracked token positions.
631    fn push_entity_from_positions(
632        converter: &crate::offset::SpanConverter,
633        positions: &[(usize, usize)],
634        start_token_idx: usize,
635        end_token_idx: usize,
636        words: &[&str],
637        entity_type: EntityType,
638        entities: &mut Vec<Entity>,
639    ) {
640        if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
641            return;
642        }
643
644        let byte_start = positions[start_token_idx].0;
645        let byte_end = positions[end_token_idx].1;
646        let char_start = converter.byte_to_char(byte_start);
647        let char_end = converter.byte_to_char(byte_end);
648        let entity_text = words.join(" ");
649
650        entities.push(Entity::new(
651            entity_text,
652            entity_type,
653            char_start,
654            char_end,
655            0.75, // BiLSTM-CRF confidence
656        ));
657    }
658}
659
660impl Default for BiLstmCrfNER {
661    fn default() -> Self {
662        Self::new()
663    }
664}
665
666impl Model for BiLstmCrfNER {
667    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
668        if text.trim().is_empty() {
669            return Ok(vec![]);
670        }
671
672        let tokens = Self::tokenize(text);
673        if tokens.is_empty() {
674            return Ok(vec![]);
675        }
676
677        // Get emission scores (from BiLSTM or heuristics)
678        let emissions = self.get_emissions(&tokens);
679
680        // Viterbi decode with CRF transitions
681        let label_indices = self.viterbi_decode(&emissions);
682
683        // Convert to entities
684        let entities = self.labels_to_entities(text, &tokens, &label_indices);
685
686        Ok(entities)
687    }
688
689    fn supported_types(&self) -> Vec<EntityType> {
690        vec![
691            EntityType::Person,
692            EntityType::Organization,
693            EntityType::Location,
694            EntityType::Other("MISC".to_string()),
695        ]
696    }
697
698    fn is_available(&self) -> bool {
699        true // Always available with heuristic fallback
700    }
701
702    fn capabilities(&self) -> crate::ModelCapabilities {
703        crate::ModelCapabilities {
704            batch_capable: true,
705            optimal_batch_size: Some(32),
706            ..Default::default()
707        }
708    }
709}
710
711impl crate::sealed::Sealed for BiLstmCrfNER {}
712impl crate::NamedEntityCapable for BiLstmCrfNER {}
713impl crate::BatchCapable for BiLstmCrfNER {
714    fn optimal_batch_size(&self) -> Option<usize> {
715        Some(32) // BiLSTM benefits from batching
716    }
717}
718
719#[cfg(test)]
720mod tests;