anno/backends/bilstm_crf/mod.rs
1//! BiLSTM-CRF NER backend.
2//!
3//! Implements the dominant neural NER architecture from 2015-2018, before transformers.
4//! This architecture represents the pivotal transition from feature engineering to
5//! representation learning, while retaining the CRF layer's sequence modeling.
6//!
7//! # Historical Context
8//!
9//! The NER field evolved through three eras:
10//!
11//! ```text
12//! Era 1: Rule-based (1987-1997) - Lexicons, hand-crafted patterns
13//! Era 2: Statistical (1997-2015) - HMM → MEMM → CRF (feature engineering)
14//! Era 3: Neural (2011-present) - CNN → BiLSTM-CRF → Transformers
15//! ```
16//!
17//! BiLSTM-CRF bridged statistical and neural approaches:
18//! - **BiLSTM**: Learns features automatically from data (no feature engineering)
19//! - **CRF layer**: Retains structured prediction from statistical era
20//!
21//! Collobert et al. 2011 ("NLP from Scratch") first showed CNNs for NER, but
22//! BiLSTM-CRF (2015) became the dominant architecture until BERT (2018).
23//!
24//! # Why Keep the CRF Layer?
25//!
26//! The BiLSTM produces emission scores for each position, but doesn't model
27//! label dependencies. The CRF layer ensures:
28//! - Valid BIO sequences (no `I-PER` after `O`)
29//! - Learned transition patterns (e.g., `B-ORG` often followed by `I-ORG`)
30//!
31//! ```text
32//! Without CRF: BiLSTM predicts [B-PER, I-ORG, O, B-LOC] // invalid!
33//! With CRF: Viterbi finds [B-PER, O, O, B-LOC] // valid sequence
34//! ```
35//!
36//! # Architecture
37//!
38//! ```text
39//! Input: "John works at Google"
40//! ↓
41//! ┌─────────────────────────────────────────┐
42//! │ Word Embeddings (GloVe/Word2Vec) │
43//! │ + Character Embeddings (CNN/LSTM) │
44//! └─────────────────────────────────────────┘
45//! ↓
46//! ┌─────────────────────────────────────────┐
47//! │ Bidirectional LSTM │
48//! │ Forward: h₁ → h₂ → h₃ → h₄ │
49//! │ Backward: h₁ ← h₂ ← h₃ ← h₄ │
50//! │ Concat: [h→;h←] for each position │
51//! └─────────────────────────────────────────┘
52//! ↓
53//! ┌─────────────────────────────────────────┐
54//! │ CRF Layer │
55//! │ - Emission scores from BiLSTM │
56//! │ - Transition matrix learned │
57//! │ - Viterbi decoding for best sequence │
58//! └─────────────────────────────────────────┘
59//! ↓
60//! Output: B-PER O O B-ORG
61//! ```
62//!
63//! # Key Papers
64//!
65//! - Collobert et al. 2011: "Natural Language Processing (Almost) from Scratch"
66//! - Huang et al. 2015: "Bidirectional LSTM-CRF Models for Sequence Tagging"
67//! - Lample et al. 2016: "Neural Architectures for Named Entity Recognition"
68//! - Ma & Hovy 2016: "End-to-end Sequence Labeling via Bi-directional LSTM-CNNs-CRF"
69//! - Peters et al. 2018: "Deep Contextualized Word Representations" (ELMo)
70//!
71//! # References
72//!
73//! - Collobert, Weston, Bottou, et al. (2011): "Natural Language Processing
74//! (Almost) from Scratch" (JMLR) — first neural NER
75//! - Huang, Xu, Yu (2015): "Bidirectional LSTM-CRF Models for Sequence
76//! Tagging" (arXiv:1508.01991) — introduced BiLSTM-CRF
77//! - Lample, Ballesteros, Subramanian, et al. (2016): "Neural Architectures
78//! for Named Entity Recognition" (NAACL) — char embeddings
79//! - Ma & Hovy (2016): "End-to-end Sequence Labeling via Bi-directional
80//! LSTM-CNNs-CRF" (ACL) — CNN char encoder
81//!
82//! # See Also
83//!
84//! - Historical NER baselines (HMM/CRF-era sequence models)
85//!
86//! # Usage
87//!
88//! ```rust
89//! use anno::backends::bilstm_crf::BiLstmCrfNER;
90//! use anno::Model;
91//!
92//! // Create with heuristic weights (no neural inference)
93//! let ner = BiLstmCrfNER::new();
94//! let entities = ner.extract_entities("John works at Google", None).unwrap();
95//! ```
96//!
97//! With ONNX feature enabled, load pre-trained weights:
98//!
99//! ```rust,ignore
100//! // Requires: features = ["onnx"]
101//! let ner = BiLstmCrfNER::from_onnx("path/to/model.onnx")?;
102//! ```
103
104use crate::{Entity, EntityType, Model, Result};
105use std::collections::HashMap;
106
107/// BiLSTM-CRF configuration.
108#[derive(Debug, Clone)]
109pub struct BiLstmCrfConfig {
110 /// Hidden size for LSTM layers.
111 pub hidden_size: usize,
112 /// Number of LSTM layers.
113 pub num_layers: usize,
114 /// Dropout probability.
115 pub dropout: f32,
116 /// Whether to use character-level embeddings.
117 pub use_char_embeddings: bool,
118 /// Maximum sequence length.
119 pub max_seq_len: usize,
120}
121
122impl Default for BiLstmCrfConfig {
123 fn default() -> Self {
124 Self {
125 hidden_size: 256,
126 num_layers: 2,
127 dropout: 0.5,
128 use_char_embeddings: true,
129 max_seq_len: 512,
130 }
131 }
132}
133
134/// BiLSTM-CRF NER model.
135///
136/// This implements the classic neural NER architecture that dominated
137/// from 2015-2018 before transformer models.
138///
139/// # Components
140///
141/// 1. **Word Embeddings**: Pre-trained (GloVe/Word2Vec) or learned
142/// 2. **Character Embeddings**: CNN or LSTM over characters (optional)
143/// 3. **BiLSTM Encoder**: Bidirectional LSTM for context
144/// 4. **CRF Decoder**: Structured prediction with transition constraints
145#[derive(Debug)]
146pub struct BiLstmCrfNER {
147 /// Model configuration.
148 #[allow(dead_code)] // Reserved for model serialization
149 config: BiLstmCrfConfig,
150 /// BIO labels for decoding.
151 labels: Vec<String>,
152 /// Label to index mapping.
153 label_to_idx: HashMap<String, usize>,
154 /// Transition scores (from CRF layer).
155 transitions: Vec<Vec<f64>>,
156 /// Word vocabulary (word -> embedding index).
157 #[allow(dead_code)] // Reserved for embedding lookup
158 vocab: HashMap<String, usize>,
159 /// ONNX session for inference (when onnx feature enabled).
160 #[cfg(feature = "onnx")]
161 session: Option<ort::session::Session>,
162}
163
164impl BiLstmCrfNER {
165 /// Create a new BiLSTM-CRF model with default configuration.
166 ///
167 /// This creates a model that uses heuristic-based inference
168 /// (no neural weights). For actual neural inference, use
169 /// `from_onnx()` to load pre-trained weights.
170 #[must_use]
171 pub fn new() -> Self {
172 Self::with_config(BiLstmCrfConfig::default())
173 }
174
175 /// Create with custom configuration.
176 #[must_use]
177 pub fn with_config(config: BiLstmCrfConfig) -> Self {
178 let labels = vec![
179 "O".to_string(),
180 "B-PER".to_string(),
181 "I-PER".to_string(),
182 "B-ORG".to_string(),
183 "I-ORG".to_string(),
184 "B-LOC".to_string(),
185 "I-LOC".to_string(),
186 "B-MISC".to_string(),
187 "I-MISC".to_string(),
188 ];
189
190 let label_to_idx: HashMap<String, usize> = labels
191 .iter()
192 .enumerate()
193 .map(|(i, l)| (l.clone(), i))
194 .collect();
195
196 // Initialize transition matrix with sensible defaults
197 // Higher scores for valid BIO transitions
198 let n = labels.len();
199 let mut transitions = vec![vec![0.0; n]; n];
200
201 // BIO constraints: I-X can only follow B-X or I-X
202 for i in 0..n {
203 for j in 0..n {
204 let from_label = &labels[i];
205 let to_label = &labels[j];
206
207 if let Some(entity_type) = to_label.strip_prefix("I-") {
208 let valid_prev = format!("B-{}", entity_type);
209 let valid_cont = format!("I-{}", entity_type);
210
211 if from_label == &valid_prev || from_label == &valid_cont {
212 transitions[i][j] = 1.0; // Valid transition
213 } else {
214 transitions[i][j] = -10.0; // Invalid transition
215 }
216 } else {
217 // B-X or O can follow anything
218 transitions[i][j] = 0.0;
219 }
220 }
221 }
222
223 Self {
224 config,
225 labels,
226 label_to_idx,
227 transitions,
228 vocab: HashMap::new(),
229 #[cfg(feature = "onnx")]
230 session: None,
231 }
232 }
233
234 /// Load from ONNX model file.
235 #[cfg(feature = "onnx")]
236 pub fn from_onnx(model_path: &str) -> Result<Self> {
237 use crate::Error;
238 use ort::session::{builder::GraphOptimizationLevel, Session};
239
240 let session = Session::builder()
241 .map_err(|e| Error::model_init(format!("Failed to create session builder: {}", e)))?
242 .with_optimization_level(GraphOptimizationLevel::Level3)
243 .map_err(|e| Error::model_init(format!("Failed to set optimization level: {}", e)))?
244 .commit_from_file(model_path)
245 .map_err(|e| Error::model_init(format!("Failed to load ONNX model: {}", e)))?;
246
247 let mut model = Self::new();
248 model.session = Some(session);
249 Ok(model)
250 }
251
252 /// Tokenize text into words.
253 fn tokenize(text: &str) -> Vec<&str> {
254 text.split_whitespace().collect()
255 }
256
257 /// Get emission scores for each token.
258 ///
259 /// In a full implementation, this would run the BiLSTM.
260 /// Here we use realistic heuristic features as a fallback,
261 /// combining gazetteers, word shape, and contextual patterns.
262 fn get_emissions(&self, tokens: &[&str]) -> Vec<Vec<f64>> {
263 let n_labels = self.labels.len();
264 let mut emissions = vec![vec![0.0; n_labels]; tokens.len()];
265
266 // Gazetteers for better heuristic accuracy
267 const PERSON_NAMES: &[&str] = &[
268 "john",
269 "mary",
270 "james",
271 "david",
272 "michael",
273 "robert",
274 "william",
275 "richard",
276 "sarah",
277 "jennifer",
278 "elizabeth",
279 "lisa",
280 "marie",
281 "jane",
282 "emily",
283 "anna",
284 "barack",
285 "donald",
286 "joe",
287 "george",
288 "bill",
289 "hillary",
290 "elon",
291 "jeff",
292 "mr",
293 "mrs",
294 "ms",
295 "dr",
296 "prof",
297 "sir",
298 "lord",
299 "president",
300 "ceo",
301 ];
302 const ORG_NAMES: &[&str] = &[
303 "google",
304 "apple",
305 "microsoft",
306 "amazon",
307 "facebook",
308 "meta",
309 "tesla",
310 "ibm",
311 "intel",
312 "nvidia",
313 "oracle",
314 "cisco",
315 "adobe",
316 "netflix",
317 "uber",
318 "university",
319 "institute",
320 "corporation",
321 "company",
322 "inc",
323 "corp",
324 "ltd",
325 "llc",
326 "foundation",
327 "association",
328 "organization",
329 "department",
330 "agency",
331 "fbi",
332 "cia",
333 "nsa",
334 "nasa",
335 "un",
336 "nato",
337 "who",
338 "imf",
339 "eu",
340 "usa",
341 ];
342 const LOC_NAMES: &[&str] = &[
343 "new",
344 "york",
345 "california",
346 "texas",
347 "florida",
348 "london",
349 "paris",
350 "berlin",
351 "tokyo",
352 "beijing",
353 "moscow",
354 "washington",
355 "chicago",
356 "boston",
357 "seattle",
358 "san",
359 "francisco",
360 "los",
361 "angeles",
362 "las",
363 "vegas",
364 "united",
365 "states",
366 "america",
367 "china",
368 "russia",
369 "germany",
370 "france",
371 "japan",
372 "india",
373 "brazil",
374 "city",
375 "county",
376 "state",
377 "country",
378 "river",
379 "mountain",
380 "lake",
381 "ocean",
382 ];
383
384 for (i, token) in tokens.iter().enumerate() {
385 let lower = token.to_lowercase();
386 let is_capitalized = token.chars().next().is_some_and(|c| c.is_uppercase());
387 let is_all_caps = token
388 .chars()
389 .all(|c| c.is_uppercase() || !c.is_alphabetic())
390 && token.len() > 1;
391 let has_digit = token.chars().any(|c| c.is_ascii_digit());
392 let is_first = i == 0;
393
394 // Default: bias toward O (entities are rare)
395 emissions[i][0] = 1.5;
396
397 // Gazetteer matches (strongest signal)
398 if PERSON_NAMES.contains(&lower.as_str()) {
399 emissions[i][self.label_to_idx["B-PER"]] += 2.0;
400 emissions[i][self.label_to_idx["I-PER"]] += 1.0;
401 }
402 if ORG_NAMES.contains(&lower.as_str()) {
403 emissions[i][self.label_to_idx["B-ORG"]] += 2.0;
404 emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
405 }
406 if LOC_NAMES.contains(&lower.as_str()) {
407 emissions[i][self.label_to_idx["B-LOC"]] += 2.0;
408 emissions[i][self.label_to_idx["I-LOC"]] += 1.0;
409 }
410
411 // Capitalization (weaker signal, context-dependent)
412 if is_capitalized && !has_digit && !is_first {
413 emissions[i][self.label_to_idx["B-PER"]] += 0.8;
414 emissions[i][self.label_to_idx["B-ORG"]] += 0.6;
415 emissions[i][self.label_to_idx["B-LOC"]] += 0.5;
416 }
417
418 // Organization suffixes
419 if lower.ends_with("inc.")
420 || lower.ends_with("corp.")
421 || lower.ends_with("ltd.")
422 || lower.ends_with("llc")
423 || lower.ends_with("co.")
424 {
425 emissions[i][self.label_to_idx["B-ORG"]] += 1.5;
426 emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
427 }
428
429 // Acronyms (2-5 uppercase letters)
430 if is_all_caps && token.len() >= 2 && token.len() <= 5 && !has_digit {
431 emissions[i][self.label_to_idx["B-ORG"]] += 1.2;
432 }
433
434 // Honorifics signal person
435 if ["mr.", "mrs.", "ms.", "dr.", "prof."].contains(&lower.as_str()) {
436 emissions[i][self.label_to_idx["B-PER"]] += 1.5;
437 }
438
439 // "The" before proper noun often signals ORG or LOC
440 if i > 0 && tokens[i - 1].to_lowercase() == "the" && is_capitalized {
441 emissions[i][self.label_to_idx["B-ORG"]] += 0.5;
442 emissions[i][self.label_to_idx["B-LOC"]] += 0.3;
443 }
444
445 // Multi-word entity continuation
446 if i > 0 {
447 let prev_cap = tokens[i - 1]
448 .chars()
449 .next()
450 .is_some_and(|c| c.is_uppercase());
451 if prev_cap && is_capitalized && !is_first {
452 // Likely continuation of entity
453 emissions[i][self.label_to_idx["I-PER"]] += 0.6;
454 emissions[i][self.label_to_idx["I-ORG"]] += 0.6;
455 emissions[i][self.label_to_idx["I-LOC"]] += 0.4;
456 }
457 }
458 }
459
460 emissions
461 }
462
463 /// Viterbi decoding with CRF transitions.
464 fn viterbi_decode(&self, emissions: &[Vec<f64>]) -> Vec<usize> {
465 if emissions.is_empty() {
466 return vec![];
467 }
468
469 let n = emissions.len();
470 let m = self.labels.len();
471
472 // DP tables
473 let mut scores = vec![vec![f64::NEG_INFINITY; m]; n];
474 let mut backpointers = vec![vec![0usize; m]; n];
475
476 // Initialize first position
477 for j in 0..m {
478 scores[0][j] = emissions[0][j];
479 }
480
481 // Forward pass
482 for i in 1..n {
483 for j in 0..m {
484 let mut best_score = f64::NEG_INFINITY;
485 let mut best_prev = 0;
486
487 #[allow(clippy::needless_range_loop)]
488 for k in 0..m {
489 let score = scores[i - 1][k] + self.transitions[k][j] + emissions[i][j];
490 if score > best_score {
491 best_score = score;
492 best_prev = k;
493 }
494 }
495
496 scores[i][j] = best_score;
497 backpointers[i][j] = best_prev;
498 }
499 }
500
501 // Backward pass
502 let mut path = vec![0usize; n];
503 let mut best_final = 0;
504 let mut best_score = f64::NEG_INFINITY;
505
506 for (j, &score) in scores[n - 1].iter().enumerate() {
507 if score > best_score {
508 best_score = score;
509 best_final = j;
510 }
511 }
512
513 path[n - 1] = best_final;
514 for i in (0..n - 1).rev() {
515 path[i] = backpointers[i + 1][path[i + 1]];
516 }
517
518 path
519 }
520
521 /// Convert BIO labels to entities.
522 ///
523 /// Uses token position tracking to correctly handle duplicate entity texts.
524 /// The previous implementation used `text.find()` which always returned the
525 /// first occurrence, causing incorrect offsets for duplicate entities.
526 fn labels_to_entities(
527 &self,
528 text: &str,
529 tokens: &[&str],
530 label_indices: &[usize],
531 ) -> Vec<Entity> {
532 use crate::offset::SpanConverter;
533
534 let converter = SpanConverter::new(text);
535 let mut entities = Vec::new();
536
537 // Track token positions (byte offsets) as we iterate
538 let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, tokens);
539
540 let mut current_entity: Option<(usize, usize, EntityType, Vec<&str>)> = None;
541
542 for (i, (&label_idx, &token)) in label_indices.iter().zip(tokens.iter()).enumerate() {
543 let label = &self.labels[label_idx];
544
545 if let Some(entity_suffix) = label.strip_prefix("B-") {
546 // Save previous entity if any
547 if let Some((start_token_idx, end_token_idx, entity_type, words)) =
548 current_entity.take()
549 {
550 Self::push_entity_from_positions(
551 &converter,
552 &token_positions,
553 start_token_idx,
554 end_token_idx,
555 &words,
556 entity_type,
557 &mut entities,
558 );
559 }
560
561 // Start new entity
562 let entity_type = match entity_suffix {
563 "PER" => EntityType::Person,
564 "ORG" => EntityType::Organization,
565 "LOC" => EntityType::Location,
566 other => EntityType::Other(other.to_string()),
567 };
568 current_entity = Some((i, i, entity_type, vec![token]));
569 } else if label.starts_with("I-") && current_entity.is_some() {
570 // Continue current entity
571 if let Some((_, ref mut end_idx, _, ref mut words)) = current_entity {
572 words.push(token);
573 *end_idx = i;
574 }
575 } else {
576 // O label - save and reset
577 if let Some((start_token_idx, end_token_idx, entity_type, words)) =
578 current_entity.take()
579 {
580 Self::push_entity_from_positions(
581 &converter,
582 &token_positions,
583 start_token_idx,
584 end_token_idx,
585 &words,
586 entity_type,
587 &mut entities,
588 );
589 }
590 }
591 }
592
593 // Don't forget last entity
594 if let Some((start_token_idx, end_token_idx, entity_type, words)) = current_entity.take() {
595 Self::push_entity_from_positions(
596 &converter,
597 &token_positions,
598 start_token_idx,
599 end_token_idx,
600 &words,
601 entity_type,
602 &mut entities,
603 );
604 }
605
606 entities
607 }
608
609 /// Calculate byte positions for each token in the text.
610 fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
611 let mut positions = Vec::with_capacity(tokens.len());
612 let mut byte_pos = 0;
613
614 for token in tokens {
615 // Find token starting from current position
616 if let Some(rel_pos) = text[byte_pos..].find(token) {
617 let start = byte_pos + rel_pos;
618 let end = start + token.len();
619 positions.push((start, end));
620 byte_pos = end; // Move past this token
621 } else {
622 // Fallback: use current position (shouldn't happen with whitespace tokenization)
623 positions.push((byte_pos, byte_pos));
624 }
625 }
626
627 positions
628 }
629
630 /// Helper to push entity using tracked token positions.
631 fn push_entity_from_positions(
632 converter: &crate::offset::SpanConverter,
633 positions: &[(usize, usize)],
634 start_token_idx: usize,
635 end_token_idx: usize,
636 words: &[&str],
637 entity_type: EntityType,
638 entities: &mut Vec<Entity>,
639 ) {
640 if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
641 return;
642 }
643
644 let byte_start = positions[start_token_idx].0;
645 let byte_end = positions[end_token_idx].1;
646 let char_start = converter.byte_to_char(byte_start);
647 let char_end = converter.byte_to_char(byte_end);
648 let entity_text = words.join(" ");
649
650 entities.push(Entity::new(
651 entity_text,
652 entity_type,
653 char_start,
654 char_end,
655 0.75, // BiLSTM-CRF confidence
656 ));
657 }
658}
659
660impl Default for BiLstmCrfNER {
661 fn default() -> Self {
662 Self::new()
663 }
664}
665
666impl Model for BiLstmCrfNER {
667 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
668 if text.trim().is_empty() {
669 return Ok(vec![]);
670 }
671
672 let tokens = Self::tokenize(text);
673 if tokens.is_empty() {
674 return Ok(vec![]);
675 }
676
677 // Get emission scores (from BiLSTM or heuristics)
678 let emissions = self.get_emissions(&tokens);
679
680 // Viterbi decode with CRF transitions
681 let label_indices = self.viterbi_decode(&emissions);
682
683 // Convert to entities
684 let entities = self.labels_to_entities(text, &tokens, &label_indices);
685
686 Ok(entities)
687 }
688
689 fn supported_types(&self) -> Vec<EntityType> {
690 vec![
691 EntityType::Person,
692 EntityType::Organization,
693 EntityType::Location,
694 EntityType::Other("MISC".to_string()),
695 ]
696 }
697
698 fn is_available(&self) -> bool {
699 true // Always available with heuristic fallback
700 }
701
702 fn capabilities(&self) -> crate::ModelCapabilities {
703 crate::ModelCapabilities {
704 batch_capable: true,
705 optimal_batch_size: Some(32),
706 ..Default::default()
707 }
708 }
709}
710
711impl crate::sealed::Sealed for BiLstmCrfNER {}
712impl crate::NamedEntityCapable for BiLstmCrfNER {}
713impl crate::BatchCapable for BiLstmCrfNER {
714 fn optimal_batch_size(&self) -> Option<usize> {
715 Some(32) // BiLSTM benefits from batching
716 }
717}
718
719#[cfg(test)]
720mod tests;