Skip to main content

anno/backends/
bilstm_crf.rs

1//! BiLSTM-CRF NER backend.
2//!
3//! Implements the dominant neural NER architecture from 2015-2018, before transformers.
4//! This architecture represents the pivotal transition from feature engineering to
5//! representation learning, while retaining the CRF layer's sequence modeling.
6//!
7//! # Historical Context
8//!
9//! The NER field evolved through three eras:
10//!
11//! ```text
12//! Era 1: Rule-based (1987-1997)     - Lexicons, hand-crafted patterns
13//! Era 2: Statistical (1997-2015)    - HMM → MEMM → CRF (feature engineering)
14//! Era 3: Neural (2011-present)      - CNN → BiLSTM-CRF → Transformers
15//! ```
16//!
17//! BiLSTM-CRF bridged statistical and neural approaches:
18//! - **BiLSTM**: Learns features automatically from data (no feature engineering)
19//! - **CRF layer**: Retains structured prediction from statistical era
20//!
21//! Collobert et al. 2011 ("NLP from Scratch") first showed CNNs for NER, but
22//! BiLSTM-CRF (2015) became the dominant architecture until BERT (2018).
23//!
24//! # Why Keep the CRF Layer?
25//!
26//! The BiLSTM produces emission scores for each position, but doesn't model
27//! label dependencies. The CRF layer ensures:
28//! - Valid BIO sequences (no `I-PER` after `O`)
29//! - Learned transition patterns (e.g., `B-ORG` often followed by `I-ORG`)
30//!
31//! ```text
32//! Without CRF:  BiLSTM predicts [B-PER, I-ORG, O, B-LOC]  // invalid!
33//! With CRF:     Viterbi finds   [B-PER, O,     O, B-LOC]  // valid sequence
34//! ```
35//!
36//! # Architecture
37//!
38//! ```text
39//! Input: "John works at Google"
40//!    ↓
41//! ┌─────────────────────────────────────────┐
42//! │ Word Embeddings (GloVe/Word2Vec)        │
43//! │ + Character Embeddings (CNN/LSTM)       │
44//! └─────────────────────────────────────────┘
45//!    ↓
46//! ┌─────────────────────────────────────────┐
47//! │ Bidirectional LSTM                      │
48//! │  Forward:  h₁ → h₂ → h₃ → h₄           │
49//! │  Backward: h₁ ← h₂ ← h₃ ← h₄           │
50//! │  Concat:   [h→;h←] for each position    │
51//! └─────────────────────────────────────────┘
52//!    ↓
53//! ┌─────────────────────────────────────────┐
54//! │ CRF Layer                               │
55//! │  - Emission scores from BiLSTM          │
56//! │  - Transition matrix learned            │
57//! │  - Viterbi decoding for best sequence   │
58//! └─────────────────────────────────────────┘
59//!    ↓
60//! Output: B-PER O O B-ORG
61//! ```
62//!
63//! # Key Papers
64//!
65//! - Collobert et al. 2011: "Natural Language Processing (Almost) from Scratch"
66//! - Huang et al. 2015: "Bidirectional LSTM-CRF Models for Sequence Tagging"
67//! - Lample et al. 2016: "Neural Architectures for Named Entity Recognition"
68//! - Ma & Hovy 2016: "End-to-end Sequence Labeling via Bi-directional LSTM-CNNs-CRF"
69//! - Peters et al. 2018: "Deep Contextualized Word Representations" (ELMo)
70//!
71//! # References
72//!
73//! - Collobert, Weston, Bottou, et al. (2011): "Natural Language Processing
74//!   (Almost) from Scratch" (JMLR) — first neural NER
75//! - Huang, Xu, Yu (2015): "Bidirectional LSTM-CRF Models for Sequence
76//!   Tagging" (arXiv:1508.01991) — introduced BiLSTM-CRF
77//! - Lample, Ballesteros, Subramanian, et al. (2016): "Neural Architectures
78//!   for Named Entity Recognition" (NAACL) — char embeddings
79//! - Ma & Hovy (2016): "End-to-end Sequence Labeling via Bi-directional
80//!   LSTM-CNNs-CRF" (ACL) — CNN char encoder
81//!
82//! # See Also
83//!
84//! - Historical NER baselines (HMM/CRF-era sequence models)
85//!
86//! # Usage
87//!
88//! ```rust
89//! use anno::backends::bilstm_crf::BiLstmCrfNER;
90//! use anno::Model;
91//!
92//! // Create with heuristic weights (no neural inference)
93//! let ner = BiLstmCrfNER::new();
94//! let entities = ner.extract_entities("John works at Google", None).unwrap();
95//! ```
96//!
97//! With ONNX feature enabled, load pre-trained weights:
98//!
99//! ```rust,ignore
100//! // Requires: features = ["onnx"]
101//! let ner = BiLstmCrfNER::from_onnx("path/to/model.onnx")?;
102//! ```
103
104use crate::{Entity, EntityType, Model, Result};
105use std::collections::HashMap;
106
107/// BiLSTM-CRF configuration.
108#[derive(Debug, Clone)]
109pub struct BiLstmCrfConfig {
110    /// Hidden size for LSTM layers.
111    pub hidden_size: usize,
112    /// Number of LSTM layers.
113    pub num_layers: usize,
114    /// Dropout probability.
115    pub dropout: f32,
116    /// Whether to use character-level embeddings.
117    pub use_char_embeddings: bool,
118    /// Maximum sequence length.
119    pub max_seq_len: usize,
120}
121
122impl Default for BiLstmCrfConfig {
123    fn default() -> Self {
124        Self {
125            hidden_size: 256,
126            num_layers: 2,
127            dropout: 0.5,
128            use_char_embeddings: true,
129            max_seq_len: 512,
130        }
131    }
132}
133
134/// BiLSTM-CRF NER model.
135///
136/// This implements the classic neural NER architecture that dominated
137/// from 2015-2018 before transformer models.
138///
139/// # Components
140///
141/// 1. **Word Embeddings**: Pre-trained (GloVe/Word2Vec) or learned
142/// 2. **Character Embeddings**: CNN or LSTM over characters (optional)
143/// 3. **BiLSTM Encoder**: Bidirectional LSTM for context
144/// 4. **CRF Decoder**: Structured prediction with transition constraints
145#[derive(Debug)]
146pub struct BiLstmCrfNER {
147    /// Model configuration.
148    #[allow(dead_code)] // Reserved for model serialization
149    config: BiLstmCrfConfig,
150    /// BIO labels for decoding.
151    labels: Vec<String>,
152    /// Label to index mapping.
153    label_to_idx: HashMap<String, usize>,
154    /// Transition scores (from CRF layer).
155    transitions: Vec<Vec<f64>>,
156    /// Word vocabulary (word -> embedding index).
157    #[allow(dead_code)] // Reserved for embedding lookup
158    vocab: HashMap<String, usize>,
159    /// ONNX session for inference (when onnx feature enabled).
160    #[cfg(feature = "onnx")]
161    session: Option<ort::session::Session>,
162}
163
164impl BiLstmCrfNER {
165    /// Create a new BiLSTM-CRF model with default configuration.
166    ///
167    /// This creates a model that uses heuristic-based inference
168    /// (no neural weights). For actual neural inference, use
169    /// `from_onnx()` to load pre-trained weights.
170    #[must_use]
171    pub fn new() -> Self {
172        Self::with_config(BiLstmCrfConfig::default())
173    }
174
175    /// Create with custom configuration.
176    #[must_use]
177    pub fn with_config(config: BiLstmCrfConfig) -> Self {
178        let labels = vec![
179            "O".to_string(),
180            "B-PER".to_string(),
181            "I-PER".to_string(),
182            "B-ORG".to_string(),
183            "I-ORG".to_string(),
184            "B-LOC".to_string(),
185            "I-LOC".to_string(),
186            "B-MISC".to_string(),
187            "I-MISC".to_string(),
188        ];
189
190        let label_to_idx: HashMap<String, usize> = labels
191            .iter()
192            .enumerate()
193            .map(|(i, l)| (l.clone(), i))
194            .collect();
195
196        // Initialize transition matrix with sensible defaults
197        // Higher scores for valid BIO transitions
198        let n = labels.len();
199        let mut transitions = vec![vec![0.0; n]; n];
200
201        // BIO constraints: I-X can only follow B-X or I-X
202        for i in 0..n {
203            for j in 0..n {
204                let from_label = &labels[i];
205                let to_label = &labels[j];
206
207                if let Some(entity_type) = to_label.strip_prefix("I-") {
208                    let valid_prev = format!("B-{}", entity_type);
209                    let valid_cont = format!("I-{}", entity_type);
210
211                    if from_label == &valid_prev || from_label == &valid_cont {
212                        transitions[i][j] = 1.0; // Valid transition
213                    } else {
214                        transitions[i][j] = -10.0; // Invalid transition
215                    }
216                } else {
217                    // B-X or O can follow anything
218                    transitions[i][j] = 0.0;
219                }
220            }
221        }
222
223        Self {
224            config,
225            labels,
226            label_to_idx,
227            transitions,
228            vocab: HashMap::new(),
229            #[cfg(feature = "onnx")]
230            session: None,
231        }
232    }
233
234    /// Load from ONNX model file.
235    #[cfg(feature = "onnx")]
236    pub fn from_onnx(model_path: &str) -> Result<Self> {
237        use crate::Error;
238        use ort::session::{builder::GraphOptimizationLevel, Session};
239
240        let session = Session::builder()
241            .map_err(|e| Error::model_init(format!("Failed to create session builder: {}", e)))?
242            .with_optimization_level(GraphOptimizationLevel::Level3)
243            .map_err(|e| Error::model_init(format!("Failed to set optimization level: {}", e)))?
244            .commit_from_file(model_path)
245            .map_err(|e| Error::model_init(format!("Failed to load ONNX model: {}", e)))?;
246
247        let mut model = Self::new();
248        model.session = Some(session);
249        Ok(model)
250    }
251
252    /// Tokenize text into words.
253    fn tokenize(text: &str) -> Vec<&str> {
254        text.split_whitespace().collect()
255    }
256
257    /// Get emission scores for each token.
258    ///
259    /// In a full implementation, this would run the BiLSTM.
260    /// Here we use realistic heuristic features as a fallback,
261    /// combining gazetteers, word shape, and contextual patterns.
262    fn get_emissions(&self, tokens: &[&str]) -> Vec<Vec<f64>> {
263        let n_labels = self.labels.len();
264        let mut emissions = vec![vec![0.0; n_labels]; tokens.len()];
265
266        // Gazetteers for better heuristic accuracy
267        const PERSON_NAMES: &[&str] = &[
268            "john",
269            "mary",
270            "james",
271            "david",
272            "michael",
273            "robert",
274            "william",
275            "richard",
276            "sarah",
277            "jennifer",
278            "elizabeth",
279            "lisa",
280            "marie",
281            "jane",
282            "emily",
283            "anna",
284            "barack",
285            "donald",
286            "joe",
287            "george",
288            "bill",
289            "hillary",
290            "elon",
291            "jeff",
292            "mr",
293            "mrs",
294            "ms",
295            "dr",
296            "prof",
297            "sir",
298            "lord",
299            "president",
300            "ceo",
301        ];
302        const ORG_NAMES: &[&str] = &[
303            "google",
304            "apple",
305            "microsoft",
306            "amazon",
307            "facebook",
308            "meta",
309            "tesla",
310            "ibm",
311            "intel",
312            "nvidia",
313            "oracle",
314            "cisco",
315            "adobe",
316            "netflix",
317            "uber",
318            "university",
319            "institute",
320            "corporation",
321            "company",
322            "inc",
323            "corp",
324            "ltd",
325            "llc",
326            "foundation",
327            "association",
328            "organization",
329            "department",
330            "agency",
331            "fbi",
332            "cia",
333            "nsa",
334            "nasa",
335            "un",
336            "nato",
337            "who",
338            "imf",
339            "eu",
340            "usa",
341        ];
342        const LOC_NAMES: &[&str] = &[
343            "new",
344            "york",
345            "california",
346            "texas",
347            "florida",
348            "london",
349            "paris",
350            "berlin",
351            "tokyo",
352            "beijing",
353            "moscow",
354            "washington",
355            "chicago",
356            "boston",
357            "seattle",
358            "san",
359            "francisco",
360            "los",
361            "angeles",
362            "las",
363            "vegas",
364            "united",
365            "states",
366            "america",
367            "china",
368            "russia",
369            "germany",
370            "france",
371            "japan",
372            "india",
373            "brazil",
374            "city",
375            "county",
376            "state",
377            "country",
378            "river",
379            "mountain",
380            "lake",
381            "ocean",
382        ];
383
384        for (i, token) in tokens.iter().enumerate() {
385            let lower = token.to_lowercase();
386            let is_capitalized = token.chars().next().is_some_and(|c| c.is_uppercase());
387            let is_all_caps = token
388                .chars()
389                .all(|c| c.is_uppercase() || !c.is_alphabetic())
390                && token.len() > 1;
391            let has_digit = token.chars().any(|c| c.is_ascii_digit());
392            let is_first = i == 0;
393
394            // Default: bias toward O (entities are rare)
395            emissions[i][0] = 1.5;
396
397            // Gazetteer matches (strongest signal)
398            if PERSON_NAMES.contains(&lower.as_str()) {
399                emissions[i][self.label_to_idx["B-PER"]] += 2.0;
400                emissions[i][self.label_to_idx["I-PER"]] += 1.0;
401            }
402            if ORG_NAMES.contains(&lower.as_str()) {
403                emissions[i][self.label_to_idx["B-ORG"]] += 2.0;
404                emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
405            }
406            if LOC_NAMES.contains(&lower.as_str()) {
407                emissions[i][self.label_to_idx["B-LOC"]] += 2.0;
408                emissions[i][self.label_to_idx["I-LOC"]] += 1.0;
409            }
410
411            // Capitalization (weaker signal, context-dependent)
412            if is_capitalized && !has_digit && !is_first {
413                emissions[i][self.label_to_idx["B-PER"]] += 0.8;
414                emissions[i][self.label_to_idx["B-ORG"]] += 0.6;
415                emissions[i][self.label_to_idx["B-LOC"]] += 0.5;
416            }
417
418            // Organization suffixes
419            if lower.ends_with("inc.")
420                || lower.ends_with("corp.")
421                || lower.ends_with("ltd.")
422                || lower.ends_with("llc")
423                || lower.ends_with("co.")
424            {
425                emissions[i][self.label_to_idx["B-ORG"]] += 1.5;
426                emissions[i][self.label_to_idx["I-ORG"]] += 1.0;
427            }
428
429            // Acronyms (2-5 uppercase letters)
430            if is_all_caps && token.len() >= 2 && token.len() <= 5 && !has_digit {
431                emissions[i][self.label_to_idx["B-ORG"]] += 1.2;
432            }
433
434            // Honorifics signal person
435            if ["mr.", "mrs.", "ms.", "dr.", "prof."].contains(&lower.as_str()) {
436                emissions[i][self.label_to_idx["B-PER"]] += 1.5;
437            }
438
439            // "The" before proper noun often signals ORG or LOC
440            if i > 0 && tokens[i - 1].to_lowercase() == "the" && is_capitalized {
441                emissions[i][self.label_to_idx["B-ORG"]] += 0.5;
442                emissions[i][self.label_to_idx["B-LOC"]] += 0.3;
443            }
444
445            // Multi-word entity continuation
446            if i > 0 {
447                let prev_cap = tokens[i - 1]
448                    .chars()
449                    .next()
450                    .is_some_and(|c| c.is_uppercase());
451                if prev_cap && is_capitalized && !is_first {
452                    // Likely continuation of entity
453                    emissions[i][self.label_to_idx["I-PER"]] += 0.6;
454                    emissions[i][self.label_to_idx["I-ORG"]] += 0.6;
455                    emissions[i][self.label_to_idx["I-LOC"]] += 0.4;
456                }
457            }
458        }
459
460        emissions
461    }
462
463    /// Viterbi decoding with CRF transitions.
464    fn viterbi_decode(&self, emissions: &[Vec<f64>]) -> Vec<usize> {
465        if emissions.is_empty() {
466            return vec![];
467        }
468
469        let n = emissions.len();
470        let m = self.labels.len();
471
472        // DP tables
473        let mut scores = vec![vec![f64::NEG_INFINITY; m]; n];
474        let mut backpointers = vec![vec![0usize; m]; n];
475
476        // Initialize first position
477        for j in 0..m {
478            scores[0][j] = emissions[0][j];
479        }
480
481        // Forward pass
482        for i in 1..n {
483            for j in 0..m {
484                let mut best_score = f64::NEG_INFINITY;
485                let mut best_prev = 0;
486
487                #[allow(clippy::needless_range_loop)]
488                for k in 0..m {
489                    let score = scores[i - 1][k] + self.transitions[k][j] + emissions[i][j];
490                    if score > best_score {
491                        best_score = score;
492                        best_prev = k;
493                    }
494                }
495
496                scores[i][j] = best_score;
497                backpointers[i][j] = best_prev;
498            }
499        }
500
501        // Backward pass
502        let mut path = vec![0usize; n];
503        let mut best_final = 0;
504        let mut best_score = f64::NEG_INFINITY;
505
506        for (j, &score) in scores[n - 1].iter().enumerate() {
507            if score > best_score {
508                best_score = score;
509                best_final = j;
510            }
511        }
512
513        path[n - 1] = best_final;
514        for i in (0..n - 1).rev() {
515            path[i] = backpointers[i + 1][path[i + 1]];
516        }
517
518        path
519    }
520
521    /// Convert BIO labels to entities.
522    ///
523    /// Uses token position tracking to correctly handle duplicate entity texts.
524    /// The previous implementation used `text.find()` which always returned the
525    /// first occurrence, causing incorrect offsets for duplicate entities.
526    fn labels_to_entities(
527        &self,
528        text: &str,
529        tokens: &[&str],
530        label_indices: &[usize],
531    ) -> Vec<Entity> {
532        use crate::offset::SpanConverter;
533
534        let converter = SpanConverter::new(text);
535        let mut entities = Vec::new();
536
537        // Track token positions (byte offsets) as we iterate
538        let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, tokens);
539
540        let mut current_entity: Option<(usize, usize, EntityType, Vec<&str>)> = None;
541
542        for (i, (&label_idx, &token)) in label_indices.iter().zip(tokens.iter()).enumerate() {
543            let label = &self.labels[label_idx];
544
545            if let Some(entity_suffix) = label.strip_prefix("B-") {
546                // Save previous entity if any
547                if let Some((start_token_idx, end_token_idx, entity_type, words)) =
548                    current_entity.take()
549                {
550                    Self::push_entity_from_positions(
551                        &converter,
552                        &token_positions,
553                        start_token_idx,
554                        end_token_idx,
555                        &words,
556                        entity_type,
557                        &mut entities,
558                    );
559                }
560
561                // Start new entity
562                let entity_type = match entity_suffix {
563                    "PER" => EntityType::Person,
564                    "ORG" => EntityType::Organization,
565                    "LOC" => EntityType::Location,
566                    other => EntityType::Other(other.to_string()),
567                };
568                current_entity = Some((i, i, entity_type, vec![token]));
569            } else if label.starts_with("I-") && current_entity.is_some() {
570                // Continue current entity
571                if let Some((_, ref mut end_idx, _, ref mut words)) = current_entity {
572                    words.push(token);
573                    *end_idx = i;
574                }
575            } else {
576                // O label - save and reset
577                if let Some((start_token_idx, end_token_idx, entity_type, words)) =
578                    current_entity.take()
579                {
580                    Self::push_entity_from_positions(
581                        &converter,
582                        &token_positions,
583                        start_token_idx,
584                        end_token_idx,
585                        &words,
586                        entity_type,
587                        &mut entities,
588                    );
589                }
590            }
591        }
592
593        // Don't forget last entity
594        if let Some((start_token_idx, end_token_idx, entity_type, words)) = current_entity.take() {
595            Self::push_entity_from_positions(
596                &converter,
597                &token_positions,
598                start_token_idx,
599                end_token_idx,
600                &words,
601                entity_type,
602                &mut entities,
603            );
604        }
605
606        entities
607    }
608
609    /// Calculate byte positions for each token in the text.
610    fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
611        let mut positions = Vec::with_capacity(tokens.len());
612        let mut byte_pos = 0;
613
614        for token in tokens {
615            // Find token starting from current position
616            if let Some(rel_pos) = text[byte_pos..].find(token) {
617                let start = byte_pos + rel_pos;
618                let end = start + token.len();
619                positions.push((start, end));
620                byte_pos = end; // Move past this token
621            } else {
622                // Fallback: use current position (shouldn't happen with whitespace tokenization)
623                positions.push((byte_pos, byte_pos));
624            }
625        }
626
627        positions
628    }
629
630    /// Helper to push entity using tracked token positions.
631    fn push_entity_from_positions(
632        converter: &crate::offset::SpanConverter,
633        positions: &[(usize, usize)],
634        start_token_idx: usize,
635        end_token_idx: usize,
636        words: &[&str],
637        entity_type: EntityType,
638        entities: &mut Vec<Entity>,
639    ) {
640        if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
641            return;
642        }
643
644        let byte_start = positions[start_token_idx].0;
645        let byte_end = positions[end_token_idx].1;
646        let char_start = converter.byte_to_char(byte_start);
647        let char_end = converter.byte_to_char(byte_end);
648        let entity_text = words.join(" ");
649
650        entities.push(Entity::new(
651            entity_text,
652            entity_type,
653            char_start,
654            char_end,
655            0.75, // BiLSTM-CRF confidence
656        ));
657    }
658}
659
660impl Default for BiLstmCrfNER {
661    fn default() -> Self {
662        Self::new()
663    }
664}
665
666impl Model for BiLstmCrfNER {
667    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
668        if text.trim().is_empty() {
669            return Ok(vec![]);
670        }
671
672        let tokens = Self::tokenize(text);
673        if tokens.is_empty() {
674            return Ok(vec![]);
675        }
676
677        // Get emission scores (from BiLSTM or heuristics)
678        let emissions = self.get_emissions(&tokens);
679
680        // Viterbi decode with CRF transitions
681        let label_indices = self.viterbi_decode(&emissions);
682
683        // Convert to entities
684        let entities = self.labels_to_entities(text, &tokens, &label_indices);
685
686        Ok(entities)
687    }
688
689    fn supported_types(&self) -> Vec<EntityType> {
690        vec![
691            EntityType::Person,
692            EntityType::Organization,
693            EntityType::Location,
694            EntityType::Other("MISC".to_string()),
695        ]
696    }
697
698    fn is_available(&self) -> bool {
699        true // Always available with heuristic fallback
700    }
701}
702
703impl crate::sealed::Sealed for BiLstmCrfNER {}
704impl crate::NamedEntityCapable for BiLstmCrfNER {}
705impl crate::BatchCapable for BiLstmCrfNER {
706    fn optimal_batch_size(&self) -> Option<usize> {
707        Some(32) // BiLSTM benefits from batching
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714
715    #[test]
716    fn test_basic_extraction() {
717        let ner = BiLstmCrfNER::new();
718        let entities = ner
719            .extract_entities("John Smith works at Google Inc.", None)
720            .unwrap();
721
722        // Should find some entities with the heuristic fallback
723        // (Exact results depend on heuristic tuning)
724        assert!(entities
725            .iter()
726            .all(|e| e.confidence > 0.0 && e.confidence <= 1.0));
727    }
728
729    #[test]
730    fn test_empty_input() {
731        let ner = BiLstmCrfNER::new();
732        let entities = ner.extract_entities("", None).unwrap();
733        assert!(entities.is_empty());
734    }
735
736    #[test]
737    fn test_whitespace_only() {
738        let ner = BiLstmCrfNER::new();
739        let entities = ner.extract_entities("   \n\t  ", None).unwrap();
740        assert!(entities.is_empty());
741    }
742
743    #[test]
744    fn test_viterbi_respects_bio_constraints() {
745        let ner = BiLstmCrfNER::new();
746
747        // Create emissions that would prefer I-PER after O
748        // But CRF transitions should prevent this
749        let emissions = vec![
750            vec![0.5, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], // O preferred
751            vec![0.1, 0.1, 0.8, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], // I-PER has high score
752        ];
753
754        let path = ner.viterbi_decode(&emissions);
755
756        // Should NOT have I-PER (idx 2) after O (idx 0) due to transition constraints
757        // Instead should have B-PER (idx 1) or O
758        if path[0] == 0 {
759            // If first is O, second should not be I-*
760            assert!(
761                path[1] == 0 || ner.labels[path[1]].starts_with("B-"),
762                "Invalid BIO sequence: O followed by {}",
763                ner.labels[path[1]]
764            );
765        }
766    }
767
768    #[test]
769    fn test_unicode_offsets() {
770        let ner = BiLstmCrfNER::new();
771        let text = "北京 Google Inc.";
772        let char_count = text.chars().count();
773
774        let entities = ner.extract_entities(text, None).unwrap();
775
776        for entity in &entities {
777            assert!(entity.start <= entity.end);
778            assert!(entity.end <= char_count);
779        }
780    }
781
782    #[test]
783    fn test_config() {
784        let config = BiLstmCrfConfig {
785            hidden_size: 512,
786            num_layers: 3,
787            dropout: 0.3,
788            use_char_embeddings: false,
789            max_seq_len: 256,
790        };
791
792        let ner = BiLstmCrfNER::with_config(config.clone());
793        assert_eq!(ner.config.hidden_size, 512);
794        assert_eq!(ner.config.num_layers, 3);
795    }
796
797    #[test]
798    fn test_transition_matrix_shape() {
799        let ner = BiLstmCrfNER::new();
800        let n = ner.labels.len();
801
802        assert_eq!(ner.transitions.len(), n);
803        for row in &ner.transitions {
804            assert_eq!(row.len(), n);
805        }
806    }
807
808    #[test]
809    fn test_supported_types() {
810        let ner = BiLstmCrfNER::new();
811        let types = ner.supported_types();
812
813        assert!(types.contains(&EntityType::Person));
814        assert!(types.contains(&EntityType::Organization));
815        assert!(types.contains(&EntityType::Location));
816    }
817
818    /// Test that duplicate entity texts get correct offsets.
819    ///
820    /// This test verifies the fix for a bug where `text.find()` was used to locate
821    /// entity positions, which always returned the first occurrence. When the same
822    /// entity text appeared multiple times, subsequent occurrences would have
823    /// incorrect offsets pointing to the first occurrence.
824    #[test]
825    fn test_duplicate_entity_offsets() {
826        let ner = BiLstmCrfNER::new();
827
828        // Text with "Google" appearing twice
829        let text = "Google bought Google for $1 billion.";
830
831        // Test token position calculation directly
832        let tokens: Vec<&str> = text.split_whitespace().collect();
833        let positions = BiLstmCrfNER::calculate_token_positions(text, &tokens);
834
835        // "Google" appears at indices 0 and 2 in tokens
836        // First "Google" at byte 0-6
837        assert_eq!(
838            positions[0],
839            (0, 6),
840            "First 'Google' should be at bytes 0-6"
841        );
842        // Second "Google" at byte 14-20 (after "Google bought ")
843        assert_eq!(
844            positions[2],
845            (14, 20),
846            "Second 'Google' should be at bytes 14-20"
847        );
848
849        // Also test with the full extraction
850        let entities = ner.extract_entities(text, None).unwrap();
851
852        // If any Google entities are found, verify they have distinct offsets
853        let google_entities: Vec<_> = entities
854            .iter()
855            .filter(|e| e.text.contains("Google"))
856            .collect();
857
858        if google_entities.len() >= 2 {
859            assert_ne!(
860                google_entities[0].start, google_entities[1].start,
861                "Duplicate entities should have different start positions"
862            );
863        }
864    }
865
866    /// Test token position calculation with Unicode.
867    #[test]
868    fn test_token_positions_unicode() {
869        let text = "東京 Tokyo 東京 Osaka";
870        let tokens: Vec<&str> = text.split_whitespace().collect();
871        let positions = BiLstmCrfNER::calculate_token_positions(text, &tokens);
872
873        // Each 東京 is 6 bytes (2 chars × 3 bytes each)
874        assert_eq!(positions[0], (0, 6), "First '東京' at bytes 0-6");
875        assert_eq!(positions[1], (7, 12), "Tokyo at bytes 7-12");
876        assert_eq!(positions[2], (13, 19), "Second '東京' at bytes 13-19");
877        assert_eq!(positions[3], (20, 25), "Osaka at bytes 20-25");
878    }
879}