Skip to main content

anno/backends/
crf.rs

1//! Conditional Random Field (CRF) NER backend.
2//!
3//! Implements classical statistical NER using CRF sequence labeling.
4//! This provides a lightweight, interpretable baseline that requires no
5//! external dependencies or GPU acceleration.
6//!
7//! # History
8//!
9//! CRF-based NER was a common baseline throughout the 2000s (pre-neural sequence labeling):
10//! - Lafferty et al. 2001: Introduced CRFs for sequence labeling (ICML)
11//! - McCallum & Li 2003: Applied CRFs to NER
12//! - Stanford NER (2003-2014): CRF-based, still widely used
13//! - State-of-art until neural methods (BiLSTM-CRF, 2015+)
14//!
15//! # Why CRF Beat Previous Methods
16//!
17//! CRFs solved the **label bias problem** that plagued MEMMs (Maximum Entropy
18//! Markov Models, McCallum et al. 2000):
19//!
20//! ```text
21//! Label Bias: In MEMMs, states with few successors effectively ignore
22//! observations. Transition scores are conditional on current state,
23//! so low-entropy states "absorb" probability mass regardless of input.
24//!
25//! HMM:   Generative model     P(x,y) = P(y) × P(x|y)
26//! MEMM:  Local discriminative P(y_t|y_{t-1}, x)  ← label bias here
27//! CRF:   Global discriminative P(y|x) = (1/Z) exp(∑ features × weights)
28//!                                       ↑ normalizes over entire sequence
29//! ```
30//!
31//! CRF models the conditional probability of the entire label sequence given
32//! the observation sequence, using global normalization:
33//!
34//! ```text
35//! P(y|x) = (1/Z(x)) × exp( ∑_t ∑_k λ_k × f_k(y_t, y_{t-1}, x, t) )
36//!
37//! where:
38//!   - Z(x) is the partition function (normalizer)
39//!   - f_k are feature functions
40//!   - λ_k are learned weights
41//! ```
42//!
43//! # References
44//!
45//! - Lafferty, McCallum, Pereira (2001): "Conditional Random Fields:
46//!   Probabilistic Models for Segmenting and Labeling Sequence Data" (ICML)
47//! - McCallum & Li (2003): "Early Results for Named Entity Recognition with
48//!   Conditional Random Fields, Feature Induction and Web-Enhanced Lexicons"
49//! - Finkel, Grenager, Manning (2005): "Incorporating Non-local Information
50//!   into Information Extraction Systems by Gibbs Sampling" (ACL)
51//!
52//! # See Also
53//!
54//! - Historical NER baselines (HMM/CRF-era sequence models)
55//!
56//! # Feature Templates
57//!
58//! The CRF uses the following feature templates (matching `train_crf_weights.py`):
59//!
60//! ```text
61//! - bias           : Always-on feature for label-specific bias
62//! - word.lower     : Lowercased current word
63//! - word.shape     : Word shape pattern (Xx, X, x, 0, etc.)
64//! - word.isdigit   : Whether word is all digits
65//! - word.istitle   : Whether word is titlecased
66//! - word.isupper   : Whether word is all uppercase
67//! - prefix{2,3}    : First 2-3 characters
68//! - suffix{2,3}    : Last 2-3 characters
69//! - -1:word.*      : Previous word features
70//! - +1:word.*      : Next word features
71//! - BOS/EOS        : Sentence boundary markers
72//! ```
73//!
74//! # Performance
75//!
76//! Performance depends on weights, tokenization, and dataset; use the eval harness
77//! for quantitative results.
78//! | Heuristic | Lower | Hand-tuned, always available |
79//! | Trained | Higher | From `train_crf_weights.py` |
80//! | Neural | Highest | For comparison (GLiNER, BERT) |
81//!
82//! # Usage
83//!
84//! ```rust
85//! use anno::CrfNER;
86//! use anno::Model;
87//!
88//! // Use with default heuristic weights
89//! let ner = CrfNER::new();
90//! let entities = ner.extract_entities("John Smith works at Google", None)?;
91//!
92//! // Or load trained weights for better accuracy
93//! // let ner = CrfNER::with_weights("crf_weights.json")?;
94//! # Ok::<(), anno::Error>(())
95//! ```
96//!
97//! # Training Weights
98//!
99//! To train weights on CoNLL-2003:
100//!
101//! ```bash
102//! uv run scripts/train_crf_weights.py
103//! ```
104//!
105//! This produces `crf_weights.json` which can be loaded with `CrfNER::with_weights()`.
106//!
107//! Nuance: CoNLL-2003’s English text is derived from Reuters/RCV1 and is commonly treated as
108//! redistribution-restricted. The CoNLL site notes that, “because of copyright reasons we only
109//! make available the annotations” and that you need separate access to the Reuters corpus to
110//! build the full dataset: `http://www.clips.uantwerpen.be/conll2003/ner/`.
111//!
112//! Practical consequence: `anno` includes a training script, but it does not ship a CoNLL-trained
113//! `crf_weights.json` out of the box.
114//!
115//! # Advantages Over Neural Methods
116//!
117//! - **Interpretable**: Features and weights are human-readable
118//! - **Fast training**: CPU-only training; typically faster to iterate than neural training loops
119//! - **No dependencies**: Pure Rust, no ONNX/Candle required
120//! - **Deterministic**: Same input always produces same output
121//! - **Small footprint**: Small weights file compared to ML model artifacts
122
123use crate::{Entity, EntityType, Model, Result};
124use std::collections::HashMap;
125#[cfg(feature = "bundled-crf-weights")]
126use std::sync::OnceLock;
127
128/// CRF-based NER model.
129///
130/// Uses hand-crafted features and sequence labeling for named entity recognition.
131/// This is a pure-Rust implementation that doesn't require external libraries.
132pub struct CrfNER {
133    /// Feature weights learned during training (or loaded from file)
134    weights: HashMap<String, f64>,
135    /// Entity type gazetteer lists
136    gazetteers: HashMap<EntityType, Vec<String>>,
137    /// Label set (BIO tagging)
138    labels: Vec<String>,
139    /// Feature templates
140    templates: Vec<FeatureTemplate>,
141}
142
143/// Feature template for CRF
144#[derive(Debug, Clone)]
145pub enum FeatureTemplate {
146    /// Current word
147    Word,
148    /// Word at offset
149    WordAt(i32),
150    /// Word shape (Xx, XX, x, 0)
151    Shape,
152    /// Shape at offset
153    ShapeAt(i32),
154    /// Prefix of length n
155    Prefix(usize),
156    /// Suffix of length n
157    Suffix(usize),
158    /// Is in gazetteer for entity type
159    InGazetteer(EntityType),
160    /// Previous label
161    PrevLabel,
162    /// Bigram: current + previous label
163    LabelBigram,
164    /// Word + Label combination
165    WordLabel,
166}
167
168impl Default for CrfNER {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174impl CrfNER {
175    /// Create a new CRF NER model with default features.
176    #[must_use]
177    pub fn new() -> Self {
178        // Default gazetteers (common names, locations, organizations)
179        let mut gazetteers = HashMap::new();
180
181        gazetteers.insert(
182            EntityType::Person,
183            vec![
184                // Common first names
185                "John",
186                "Mary",
187                "James",
188                "Robert",
189                "Michael",
190                "David",
191                "William",
192                "Richard",
193                "Joseph",
194                "Thomas",
195                "Elizabeth",
196                "Jennifer",
197                "Linda",
198                "Barbara",
199                "Susan",
200                "Jessica",
201                "Sarah",
202                "Karen",
203                "Nancy",
204                "Margaret",
205                // Titles that precede names
206                "Dr",
207                "Mr",
208                "Mrs",
209                "Ms",
210                "Prof",
211                "President",
212                "CEO",
213                "Senator",
214            ]
215            .into_iter()
216            .map(String::from)
217            .collect(),
218        );
219
220        gazetteers.insert(
221            EntityType::Location,
222            vec![
223                // Countries
224                "USA",
225                "UK",
226                "France",
227                "Germany",
228                "China",
229                "Japan",
230                "India",
231                "Brazil",
232                "Canada",
233                "Australia",
234                "Russia",
235                "Italy",
236                "Spain",
237                "Mexico",
238                // US States
239                "California",
240                "Texas",
241                "Florida",
242                "New York",
243                "Illinois",
244                "Pennsylvania",
245                // Major cities
246                "London",
247                "Paris",
248                "Tokyo",
249                "Beijing",
250                "Moscow",
251                "Berlin",
252                "Rome",
253                "Madrid",
254                "Sydney",
255                "Toronto",
256                "Mumbai",
257                "Shanghai",
258                "Seoul",
259            ]
260            .into_iter()
261            .map(String::from)
262            .collect(),
263        );
264
265        gazetteers.insert(
266            EntityType::Organization,
267            vec![
268                // Companies
269                "Google",
270                "Apple",
271                "Microsoft",
272                "Amazon",
273                "Facebook",
274                "Tesla",
275                "IBM",
276                "Intel",
277                "Oracle",
278                "Cisco",
279                "Samsung",
280                "Sony",
281                "Toyota",
282                "Honda",
283                // Suffixes
284                "Inc",
285                "Corp",
286                "LLC",
287                "Ltd",
288                "Company",
289                "Corporation",
290                "Group",
291                // Organizations
292                "UN",
293                "NATO",
294                "WHO",
295                "FBI",
296                "CIA",
297                "NASA",
298                "EU",
299                "OPEC",
300            ]
301            .into_iter()
302            .map(String::from)
303            .collect(),
304        );
305
306        // Standard BIO labels
307        let labels = vec![
308            "O".to_string(),
309            "B-PER".to_string(),
310            "I-PER".to_string(),
311            "B-ORG".to_string(),
312            "I-ORG".to_string(),
313            "B-LOC".to_string(),
314            "I-LOC".to_string(),
315            "B-MISC".to_string(),
316            "I-MISC".to_string(),
317        ];
318
319        // Default feature templates
320        let templates = vec![
321            FeatureTemplate::Word,
322            FeatureTemplate::WordAt(-1),
323            FeatureTemplate::WordAt(1),
324            FeatureTemplate::Shape,
325            FeatureTemplate::ShapeAt(-1),
326            FeatureTemplate::ShapeAt(1),
327            FeatureTemplate::Prefix(2),
328            FeatureTemplate::Prefix(3),
329            FeatureTemplate::Suffix(2),
330            FeatureTemplate::Suffix(3),
331            FeatureTemplate::InGazetteer(EntityType::Person),
332            FeatureTemplate::InGazetteer(EntityType::Location),
333            FeatureTemplate::InGazetteer(EntityType::Organization),
334            FeatureTemplate::PrevLabel,
335        ];
336
337        // Initialize with shipped trained weights when available; fall back to heuristics.
338        let weights = Self::shipped_weights().unwrap_or_else(Self::default_weights);
339
340        Self {
341            weights,
342            gazetteers,
343            labels,
344            templates,
345        }
346    }
347
348    /// Create a CRF NER model using only the built-in heuristic weight table.
349    ///
350    /// This is useful for E2E evaluation comparisons (heuristic vs bundled-trained) and for
351    /// builds that want deterministic behavior without any bundled assets.
352    #[must_use]
353    pub fn new_heuristic() -> Self {
354        let mut m = Self::new();
355        m.weights = Self::default_weights();
356        m
357    }
358
359    fn shipped_weights() -> Option<HashMap<String, f64>> {
360        #[cfg(feature = "bundled-crf-weights")]
361        {
362            static ONCE: OnceLock<Option<HashMap<String, f64>>> = OnceLock::new();
363            return ONCE
364                .get_or_init(|| {
365                    // Keep this lightweight and robust:
366                    // - parsing failure should not break the backend (fall back to heuristics)
367                    let s = include_str!("crf_weights.json");
368                    serde_json::from_str::<HashMap<String, f64>>(s).ok()
369                })
370                .clone();
371        }
372        #[cfg(not(feature = "bundled-crf-weights"))]
373        {
374            None
375        }
376    }
377
378    /// Load weights from a JSON file.
379    ///
380    /// # Example JSON format:
381    /// ```json
382    /// {
383    ///     "gaz:PER:B-PER": 2.5,
384    ///     "shape=Xx:B-PER": 1.5,
385    ///     "trans:B-PER->I-PER": 1.0
386    /// }
387    /// ```
388    pub fn load_weights(path: &str) -> Result<HashMap<String, f64>> {
389        let content = std::fs::read_to_string(path).map_err(|e| {
390            crate::Error::invalid_input(format!("Failed to read weights file: {}", e))
391        })?;
392        let weights: HashMap<String, f64> = serde_json::from_str(&content).map_err(|e| {
393            crate::Error::invalid_input(format!("Failed to parse weights JSON: {}", e))
394        })?;
395        Ok(weights)
396    }
397
398    /// Create CRF model with weights from a file.
399    pub fn with_weights(path: &str) -> Result<Self> {
400        let weights = Self::load_weights(path)?;
401        let mut model = Self::new();
402        model.weights = weights;
403        Ok(model)
404    }
405
406    /// Create default heuristic weights for common features.
407    ///
408    /// These weights are hand-tuned heuristics, not learned from data.
409    /// For better accuracy, train weights using scripts/train_crf_weights.py
410    /// and load them with `CrfNER::with_weights("crf_weights.json")`.
411    fn default_weights() -> HashMap<String, f64> {
412        let mut w = HashMap::new();
413
414        // Strong bias toward O (outside) by default - entities are rare
415        w.insert("bias:O".to_string(), 3.0);
416
417        // Extra strong bias for lowercase words
418        w.insert("word.shape=x:O".to_string(), 2.5);
419        w.insert("word.shape=x:B-PER".to_string(), -3.0);
420        w.insert("word.shape=x:I-PER".to_string(), -2.0);
421
422        // Gazetteer features are very strong signals for B- tags
423        w.insert("gaz:PER:B-PER".to_string(), 4.0);
424        w.insert("gaz:LOC:B-LOC".to_string(), 4.0);
425        w.insert("gaz:ORG:B-ORG".to_string(), 4.0);
426
427        // Capitalization patterns - only for B- tags
428        w.insert("word.shape=Xx:B-PER".to_string(), 2.0);
429        w.insert("word.shape=Xx:B-LOC".to_string(), 1.5);
430        w.insert("word.shape=Xx:B-ORG".to_string(), 1.5);
431        w.insert("word.shape=Xx:I-PER".to_string(), 1.0); // Continue if already in entity
432        w.insert("word.shape=Xx:I-ORG".to_string(), 1.0);
433        w.insert("word.shape=XX:B-ORG".to_string(), 2.5); // Acronyms like IBM, NASA
434
435        // Lowercase words are unlikely to be entity starts
436        w.insert("word.shape=x:B-PER".to_string(), -2.0);
437        w.insert("word.shape=x:B-ORG".to_string(), -2.0);
438        w.insert("word.shape=x:B-LOC".to_string(), -2.0);
439
440        // Common words that are NOT entities - strongly bias toward O
441        for word in [
442            "the",
443            "a",
444            "an",
445            "of",
446            "in",
447            "at",
448            "to",
449            "and",
450            "or",
451            "is",
452            "was",
453            "were",
454            "be",
455            "been",
456            "being",
457            "have",
458            "has",
459            "had",
460            "do",
461            "does",
462            "did",
463            "will",
464            "would",
465            "could",
466            "should",
467            "may",
468            "might",
469            "must",
470            "can",
471            "won",
472            "works",
473            "worked",
474            "working",
475            "serves",
476            "served",
477            "announced",
478            "said",
479            "made",
480            "that",
481            "this",
482            "which",
483            "for",
484            "with",
485            "as",
486            "by",
487            "on",
488            "from",
489            "into",
490            "through",
491            "during",
492            "before",
493            "after",
494            "above",
495            "below",
496            "between",
497            "under",
498            "again",
499            "further",
500            "then",
501            "once",
502            "here",
503            "there",
504            "when",
505            "where",
506            "why",
507            "how",
508            "all",
509            "each",
510            "few",
511            "more",
512            "most",
513            "other",
514            "some",
515            "such",
516            "no",
517            "not",
518            "only",
519            "own",
520            "same",
521            "so",
522            "than",
523            "too",
524            "very",
525        ] {
526            w.insert(format!("word.lower={}:O", word), 5.0);
527            w.insert(format!("word.lower={}:B-PER", word), -5.0);
528            w.insert(format!("word.lower={}:B-ORG", word), -5.0);
529            w.insert(format!("word.lower={}:B-LOC", word), -5.0);
530            w.insert(format!("word.lower={}:I-PER", word), -4.0);
531            w.insert(format!("word.lower={}:I-ORG", word), -4.0);
532            w.insert(format!("word.lower={}:I-LOC", word), -4.0);
533        }
534
535        // Common suffixes for organizations
536        w.insert("suffix3=inc:B-ORG".to_string(), 3.0);
537        // Note: we intentionally do not add a suffix4 feature here; keep default weights aligned
538        // with `scripts/train_crf_weights.py` (prefix/suffix lengths 2-3).
539        w.insert("suffix3=ltd:B-ORG".to_string(), 3.0);
540        w.insert("suffix3=llc:B-ORG".to_string(), 3.0);
541
542        // Context words that suggest entities
543        w.insert("-1:word.lower=dr:B-PER".to_string(), 2.5);
544        w.insert("-1:word.lower=mr:B-PER".to_string(), 2.5);
545        w.insert("-1:word.lower=mrs:B-PER".to_string(), 2.5);
546        w.insert("-1:word.lower=ms:B-PER".to_string(), 2.5);
547        w.insert("-1:word.lower=prof:B-PER".to_string(), 2.5);
548        w.insert("-1:word.lower=president:B-PER".to_string(), 2.0);
549        w.insert("-1:word.lower=ceo:B-PER".to_string(), 2.0);
550
551        // Location context
552        w.insert("-1:word.lower=in:B-LOC".to_string(), 1.5);
553        w.insert("-1:word.lower=at:B-LOC".to_string(), 1.5);
554        w.insert("-1:word.lower=from:B-LOC".to_string(), 1.5);
555        w.insert("-1:word.lower=of:B-LOC".to_string(), 1.0);
556        w.insert("-1:word.lower=of:B-ORG".to_string(), 1.0);
557
558        // Transition features (BIO constraints) - very important
559        // Valid transitions
560        w.insert("trans:B-PER->I-PER".to_string(), 3.0);
561        w.insert("trans:B-ORG->I-ORG".to_string(), 3.0);
562        w.insert("trans:B-LOC->I-LOC".to_string(), 3.0);
563        w.insert("trans:I-PER->I-PER".to_string(), 2.0);
564        w.insert("trans:I-ORG->I-ORG".to_string(), 2.0);
565        w.insert("trans:I-LOC->I-LOC".to_string(), 2.0);
566
567        // End entity transitions
568        w.insert("trans:B-PER->O".to_string(), 0.0);
569        w.insert("trans:B-ORG->O".to_string(), 0.0);
570        w.insert("trans:B-LOC->O".to_string(), 0.0);
571        w.insert("trans:I-PER->O".to_string(), 0.0);
572        w.insert("trans:I-ORG->O".to_string(), 0.0);
573        w.insert("trans:I-LOC->O".to_string(), 0.0);
574
575        // Invalid transitions (strongly penalize)
576        w.insert("trans:O->I-PER".to_string(), -10.0);
577        w.insert("trans:O->I-ORG".to_string(), -10.0);
578        w.insert("trans:O->I-LOC".to_string(), -10.0);
579        w.insert("trans:O->I-MISC".to_string(), -10.0);
580
581        // Cross-type I- transitions are invalid
582        w.insert("trans:B-PER->I-ORG".to_string(), -10.0);
583        w.insert("trans:B-PER->I-LOC".to_string(), -10.0);
584        w.insert("trans:B-ORG->I-PER".to_string(), -10.0);
585        w.insert("trans:B-ORG->I-LOC".to_string(), -10.0);
586        w.insert("trans:B-LOC->I-PER".to_string(), -10.0);
587        w.insert("trans:B-LOC->I-ORG".to_string(), -10.0);
588
589        w
590    }
591
592    /// Match Python `str.isdigit()` / `c.isdigit()` behavior used in
593    /// `scripts/train_crf_weights.py`.
594    ///
595    /// Note: This is *Unicode-aware* (unlike `is_ascii_digit`).
596    #[allow(clippy::is_digit_ascii_radix)]
597    fn is_digit_py(c: char) -> bool {
598        c.is_digit(10)
599    }
600
601    /// Compute word shape (e.g., "John" -> "Xxxx", "USA" -> "XXX")
602    fn word_shape(word: &str) -> String {
603        word.chars()
604            .map(|c| {
605                if c.is_uppercase() {
606                    'X'
607                } else if c.is_lowercase() {
608                    'x'
609                } else if Self::is_digit_py(c) {
610                    '0'
611                } else {
612                    c
613                }
614            })
615            .collect::<String>()
616            // Compress repeated chars
617            .chars()
618            .fold(String::new(), |mut acc, c| {
619                if !acc.ends_with(&c.to_string()) {
620                    acc.push(c);
621                }
622                acc
623            })
624    }
625
626    /// Extract features for a token at given position.
627    ///
628    /// Feature format matches `scripts/train_crf_weights.py` for compatibility
629    /// with trained weights. The key insight is that features must match exactly
630    /// between training and inference.
631    ///
632    /// # Feature Types
633    ///
634    /// - `bias` - Always present, allows label-specific bias
635    /// - `word.lower={word}` - Lowercased word identity
636    /// - `word.shape={shape}` - Word shape (Xx, X, x, 0)
637    /// - `word.isdigit={bool}` - Whether all digits
638    /// - `word.istitle={bool}` - Whether titlecase
639    /// - `word.isupper={bool}` - Whether all uppercase
640    /// - `prefix{n}={chars}` - First n characters
641    /// - `suffix{n}={chars}` - Last n characters
642    /// - `-1:word.lower={word}` - Previous word features
643    /// - `+1:word.lower={word}` - Next word features
644    /// - `BOS` / `EOS` - Beginning/end of sentence markers
645    fn extract_features(&self, tokens: &[&str], pos: usize, _prev_label: &str) -> Vec<String> {
646        let mut features = Vec::with_capacity(20);
647        let word = tokens[pos];
648
649        fn bool_py(v: bool) -> &'static str {
650            if v {
651                "True"
652            } else {
653                "False"
654            }
655        }
656
657        // Bias feature (always present)
658        features.push("bias".to_string());
659
660        // Word identity features
661        features.push(format!("word.lower={}", word.to_lowercase()));
662        features.push(format!("word.shape={}", Self::word_shape(word)));
663        features.push(format!(
664            "word.isdigit={}",
665            // Match Python `str.isdigit()` behavior used in `scripts/train_crf_weights.py`:
666            // - empty string -> False
667            // - Unicode-aware digits -> True
668            bool_py(!word.is_empty() && word.chars().all(Self::is_digit_py))
669        ));
670        features.push(format!(
671            "word.istitle={}",
672            bool_py(
673                word.chars().next().is_some_and(|c| c.is_uppercase())
674                    && word.chars().skip(1).all(|c| c.is_lowercase())
675            )
676        ));
677        features.push(format!(
678            "word.isupper={}",
679            bool_py(word.chars().all(|c| c.is_uppercase()))
680        ));
681
682        // Prefix/suffix features
683        let chars: Vec<char> = word.chars().collect();
684        if chars.len() >= 2 {
685            let prefix2: String = chars[..2].iter().collect();
686            let suffix2: String = chars[chars.len() - 2..].iter().collect();
687            features.push(format!("prefix2={}", prefix2.to_lowercase()));
688            features.push(format!("suffix2={}", suffix2.to_lowercase()));
689        }
690        if chars.len() >= 3 {
691            let prefix3: String = chars[..3].iter().collect();
692            let suffix3: String = chars[chars.len() - 3..].iter().collect();
693            features.push(format!("prefix3={}", prefix3.to_lowercase()));
694            features.push(format!("suffix3={}", suffix3.to_lowercase()));
695        }
696
697        // Context features (previous word)
698        if pos > 0 {
699            let prev_word = tokens[pos - 1];
700            features.push(format!("-1:word.lower={}", prev_word.to_lowercase()));
701            features.push(format!(
702                "-1:word.istitle={}",
703                bool_py(
704                    prev_word.chars().next().is_some_and(|c| c.is_uppercase())
705                        && prev_word.chars().skip(1).all(|c| c.is_lowercase())
706                )
707            ));
708            features.push(format!(
709                "-1:word.isupper={}",
710                bool_py(prev_word.chars().all(|c| c.is_uppercase()))
711            ));
712            features.push(format!("-1:word.shape={}", Self::word_shape(prev_word)));
713        } else {
714            features.push("BOS".to_string());
715        }
716
717        // Context features (next word)
718        if pos + 1 < tokens.len() {
719            let next_word = tokens[pos + 1];
720            features.push(format!("+1:word.lower={}", next_word.to_lowercase()));
721            features.push(format!(
722                "+1:word.istitle={}",
723                bool_py(
724                    next_word.chars().next().is_some_and(|c| c.is_uppercase())
725                        && next_word.chars().skip(1).all(|c| c.is_lowercase())
726                )
727            ));
728            features.push(format!(
729                "+1:word.isupper={}",
730                bool_py(next_word.chars().all(|c| c.is_uppercase()))
731            ));
732            features.push(format!("+1:word.shape={}", Self::word_shape(next_word)));
733        } else {
734            features.push("EOS".to_string());
735        }
736
737        // Gazetteer features (kept for backwards compatibility)
738        for template in &self.templates {
739            if let FeatureTemplate::InGazetteer(entity_type) = template {
740                if let Some(gaz) = self.gazetteers.get(entity_type) {
741                    if gaz.iter().any(|g| g.eq_ignore_ascii_case(word)) {
742                        features.push(format!("gaz:{}", entity_type.as_label()));
743                    }
744                }
745            }
746        }
747
748        features
749    }
750
751    /// Score a label for given features using learned weights.
752    fn score_label(&self, features: &[String], label: &str) -> f64 {
753        let mut score = 0.0;
754        let debug = std::env::var("CRF_DEBUG").is_ok();
755
756        if debug && label == "I-PER" {
757            eprintln!("  Features for I-PER: {:?}", features);
758        }
759
760        for feat in features {
761            let key = format!("{}:{}", feat, label);
762            if let Some(&w) = self.weights.get(&key) {
763                if debug && w.abs() > 0.1 {
764                    eprintln!("  CRF: {} -> {:.2}", key, w);
765                }
766                score += w;
767            }
768            // Also check feature alone (type-independent)
769            if let Some(&w) = self.weights.get(feat) {
770                score += w * 0.5;
771            }
772        }
773        // Bias towards O for unknown tokens (no features matched)
774        if label == "O" {
775            score += 0.5; // Small default bias toward O
776        }
777        score
778    }
779
780    /// Viterbi decoding to find best label sequence.
781    fn viterbi_decode(&self, tokens: &[&str]) -> Vec<String> {
782        if tokens.is_empty() {
783            return vec![];
784        }
785
786        let n = tokens.len();
787        let m = self.labels.len();
788
789        // Dynamic programming tables
790        let mut scores = vec![vec![f64::NEG_INFINITY; m]; n];
791        let mut backpointers = vec![vec![0usize; m]; n];
792
793        // Initialize first position
794        let features = self.extract_features(tokens, 0, "O");
795        for (j, label) in self.labels.iter().enumerate() {
796            scores[0][j] = self.score_label(&features, label);
797        }
798
799        // Forward pass
800        for i in 1..n {
801            for (j, label) in self.labels.iter().enumerate() {
802                let mut best_score = f64::NEG_INFINITY;
803                let mut best_prev = 0;
804
805                for (k, prev_label) in self.labels.iter().enumerate() {
806                    let features = self.extract_features(tokens, i, prev_label);
807                    let trans_key = format!("trans:{}->{}", prev_label, label);
808                    let trans_score = self.weights.get(&trans_key).copied().unwrap_or(0.0);
809                    let score = scores[i - 1][k] + self.score_label(&features, label) + trans_score;
810
811                    if score > best_score {
812                        best_score = score;
813                        best_prev = k;
814                    }
815                }
816
817                scores[i][j] = best_score;
818                backpointers[i][j] = best_prev;
819            }
820        }
821
822        // Backward pass to recover best path
823        let mut path = vec![0usize; n];
824        let mut best_final = 0;
825        let mut best_score = f64::NEG_INFINITY;
826        for (j, &score) in scores[n - 1].iter().enumerate() {
827            if score > best_score {
828                best_score = score;
829                best_final = j;
830            }
831        }
832        path[n - 1] = best_final;
833
834        for i in (0..n - 1).rev() {
835            path[i] = backpointers[i + 1][path[i + 1]];
836        }
837
838        path.iter().map(|&j| self.labels[j].clone()).collect()
839    }
840
841    /// Convert BIO labels to entities.
842    ///
843    /// Note: Uses SpanConverter for correct byte-to-char offset conversion.
844    /// Entity offsets are CHARACTER offsets, not byte offsets.
845    ///
846    /// Uses token position tracking to correctly handle duplicate entity texts.
847    /// The previous implementation used `text.find()` which always returned the
848    /// first occurrence, causing incorrect offsets for duplicate entities.
849    fn labels_to_entities(&self, text: &str, tokens: &[&str], labels: &[String]) -> Vec<Entity> {
850        use crate::offset::SpanConverter;
851
852        let mut entities = Vec::new();
853
854        // Build converter once for all byte-to-char conversions
855        let converter = SpanConverter::new(text);
856
857        // Track token positions (byte offsets) as we iterate
858        let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, tokens);
859
860        let mut current_entity: Option<(usize, usize, EntityType, Vec<&str>)> = None;
861
862        for (i, (token, label)) in tokens.iter().zip(labels.iter()).enumerate() {
863            if label.starts_with("B-") {
864                // Save previous entity if any
865                if let Some((start_idx, end_idx, entity_type, words)) = current_entity.take() {
866                    Self::push_entity_from_positions(
867                        &converter,
868                        &token_positions,
869                        start_idx,
870                        end_idx,
871                        &words,
872                        entity_type,
873                        &mut entities,
874                    );
875                }
876
877                // Start new entity
878                let entity_type = match label.as_str() {
879                    "B-PER" => EntityType::Person,
880                    "B-ORG" => EntityType::Organization,
881                    "B-LOC" => EntityType::Location,
882                    _ => EntityType::Other("MISC".to_string()),
883                };
884                current_entity = Some((i, i, entity_type, vec![token]));
885            } else if label.starts_with("I-") {
886                // Continue current entity
887                if let Some((_, ref mut end_idx, _, ref mut words)) = current_entity {
888                    words.push(token);
889                    *end_idx = i;
890                }
891            } else {
892                // O label - save and reset
893                if let Some((start_idx, end_idx, entity_type, words)) = current_entity.take() {
894                    Self::push_entity_from_positions(
895                        &converter,
896                        &token_positions,
897                        start_idx,
898                        end_idx,
899                        &words,
900                        entity_type,
901                        &mut entities,
902                    );
903                }
904            }
905        }
906
907        // Don't forget last entity
908        if let Some((start_idx, end_idx, entity_type, words)) = current_entity.take() {
909            Self::push_entity_from_positions(
910                &converter,
911                &token_positions,
912                start_idx,
913                end_idx,
914                &words,
915                entity_type,
916                &mut entities,
917            );
918        }
919
920        entities
921    }
922
923    /// Calculate byte positions for each token in the text.
924    fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
925        let mut positions = Vec::with_capacity(tokens.len());
926        let mut byte_pos = 0;
927
928        for token in tokens {
929            // Find token starting from current position
930            if let Some(rel_pos) = text[byte_pos..].find(token) {
931                let start = byte_pos + rel_pos;
932                let end = start + token.len();
933                positions.push((start, end));
934                byte_pos = end; // Move past this token
935            } else {
936                // Fallback: use current position (shouldn't happen with whitespace tokenization)
937                positions.push((byte_pos, byte_pos));
938            }
939        }
940
941        positions
942    }
943
944    /// Helper to create entity with correct character offsets using token positions.
945    fn push_entity_from_positions(
946        converter: &crate::offset::SpanConverter,
947        positions: &[(usize, usize)],
948        start_token_idx: usize,
949        end_token_idx: usize,
950        words: &[&str],
951        entity_type: EntityType,
952        entities: &mut Vec<Entity>,
953    ) {
954        if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
955            return;
956        }
957
958        let byte_start = positions[start_token_idx].0;
959        let byte_end = positions[end_token_idx].1;
960        let char_start = converter.byte_to_char(byte_start);
961        let char_end = converter.byte_to_char(byte_end);
962        let entity_text = words.join(" ");
963
964        entities.push(Entity::new(
965            &entity_text,
966            entity_type,
967            char_start,
968            char_end,
969            0.7, // CRF confidence is hard to calibrate
970        ));
971    }
972
973    /// Simple whitespace tokenizer.
974    fn tokenize(text: &str) -> Vec<&str> {
975        text.split_whitespace().collect()
976    }
977}
978
979impl Model for CrfNER {
980    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
981        if text.trim().is_empty() {
982            return Ok(vec![]);
983        }
984
985        let tokens = Self::tokenize(text);
986        if tokens.is_empty() {
987            return Ok(vec![]);
988        }
989
990        let labels = self.viterbi_decode(&tokens);
991        let entities = self.labels_to_entities(text, &tokens, &labels);
992
993        Ok(entities)
994    }
995
996    fn supported_types(&self) -> Vec<EntityType> {
997        vec![
998            EntityType::Person,
999            EntityType::Organization,
1000            EntityType::Location,
1001            EntityType::Other("MISC".to_string()),
1002        ]
1003    }
1004
1005    fn is_available(&self) -> bool {
1006        true // Always available (no external dependencies)
1007    }
1008
1009    fn name(&self) -> &'static str {
1010        "crf"
1011    }
1012
1013    fn description(&self) -> &'static str {
1014        "CRF-based NER (classical statistical method)"
1015    }
1016}
1017
1018impl crate::NamedEntityCapable for CrfNER {}
1019
1020impl crate::BatchCapable for CrfNER {
1021    fn optimal_batch_size(&self) -> Option<usize> {
1022        Some(32) // CRF is fast, can handle batches
1023    }
1024}
1025
1026impl crate::StreamingCapable for CrfNER {
1027    fn recommended_chunk_size(&self) -> usize {
1028        4096 // Smaller chunks since CRF is token-based
1029    }
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034    use super::*;
1035
1036    #[test]
1037    fn test_crf_basic() {
1038        let ner = CrfNER::new();
1039        let entities = ner
1040            .extract_entities("John Smith works at Google in California", None)
1041            .unwrap();
1042
1043        // With our default heuristic weights + gazetteers, we should usually get some entities.
1044        // (Trained weights will do better, but defaults should not be totally dead.)
1045        assert!(!entities.is_empty(), "Expected some entities, got none");
1046    }
1047
1048    #[test]
1049    fn test_word_shape() {
1050        assert_eq!(CrfNER::word_shape("John"), "Xx");
1051        assert_eq!(CrfNER::word_shape("USA"), "X");
1052        assert_eq!(CrfNER::word_shape("hello"), "x");
1053        assert_eq!(CrfNER::word_shape("123"), "0");
1054        assert_eq!(CrfNER::word_shape("Hello123"), "Xx0");
1055    }
1056
1057    #[test]
1058    fn test_tokenize() {
1059        let tokens = CrfNER::tokenize("Hello world");
1060        assert_eq!(tokens, vec!["Hello", "world"]);
1061    }
1062
1063    #[test]
1064    fn test_empty_input() {
1065        let ner = CrfNER::new();
1066        let entities = ner.extract_entities("", None).unwrap();
1067        assert!(entities.is_empty());
1068    }
1069
1070    #[test]
1071    fn test_gazetteer_lookup() {
1072        let ner = CrfNER::new();
1073
1074        // Gazetteer should contain common entities
1075        assert!(ner.gazetteers[&EntityType::Person].contains(&"John".to_string()));
1076        assert!(ner.gazetteers[&EntityType::Location].contains(&"California".to_string()));
1077        assert!(ner.gazetteers[&EntityType::Organization].contains(&"Google".to_string()));
1078    }
1079
1080    #[test]
1081    fn test_viterbi_returns_valid_labels() {
1082        let ner = CrfNER::new();
1083        let tokens = vec!["John", "works", "at", "Google"];
1084        let labels = ner.viterbi_decode(&tokens);
1085
1086        assert_eq!(labels.len(), tokens.len());
1087        for label in &labels {
1088            assert!(ner.labels.contains(label));
1089        }
1090    }
1091
1092    #[test]
1093    fn test_common_verbs_not_in_entities() {
1094        let ner = CrfNER::new();
1095
1096        // Test that common verbs don't get tagged as part of entities
1097        let entities = ner
1098            .extract_entities("John Smith works at Apple", None)
1099            .unwrap();
1100
1101        // Should find John Smith and Apple, but NOT "works"
1102        let entity_texts: Vec<&str> = entities.iter().map(|e| e.text.as_str()).collect();
1103        for entity_text in &entity_texts {
1104            assert!(
1105                !entity_text.contains("works"),
1106                "Entity '{}' should not contain 'works'",
1107                entity_text
1108            );
1109        }
1110    }
1111
1112    #[test]
1113    fn test_weights_for_common_words() {
1114        // This test asserts properties of the heuristic weight table, which includes
1115        // many `word.lower=...` entries. The bundled trained weights are intentionally
1116        // compact and omit token-identity features to keep the shipped file size small.
1117        #[cfg(feature = "bundled-crf-weights")]
1118        {
1119            return;
1120        }
1121
1122        let ner = CrfNER::new();
1123
1124        // Check that weights exist for common stop words
1125        assert!(
1126            ner.weights.contains_key("word.lower=works:O"),
1127            "Missing weight for word.lower=works:O"
1128        );
1129        assert!(
1130            ner.weights.contains_key("word.lower=works:I-PER"),
1131            "Missing weight for word.lower=works:I-PER"
1132        );
1133
1134        // Check that O weight is positive and I-* weight is negative
1135        let o_weight = *ner.weights.get("word.lower=works:O").unwrap();
1136        let i_per_weight = *ner.weights.get("word.lower=works:I-PER").unwrap();
1137        assert!(
1138            o_weight > 0.0,
1139            "O weight should be positive, got {}",
1140            o_weight
1141        );
1142        assert!(
1143            i_per_weight < 0.0,
1144            "I-PER weight should be negative, got {}",
1145            i_per_weight
1146        );
1147    }
1148
1149    #[test]
1150    fn test_unicode_char_offsets() {
1151        // Test that entity offsets are character-based, not byte-based
1152        let ner = CrfNER::new();
1153
1154        // "北京" is 2 chars, 6 bytes. "Beijing" is 7 chars, 7 bytes.
1155        // Text "北京 Beijing" is 10 chars, 14 bytes.
1156        let text = "北京 Beijing";
1157        assert_eq!(text.len(), 14, "Expected 14 bytes");
1158        assert_eq!(text.chars().count(), 10, "Expected 10 characters");
1159
1160        let entities = ner.extract_entities(text, None).unwrap();
1161
1162        // Regardless of what entities are found, check all offsets are valid char offsets
1163        let char_count = text.chars().count();
1164        for entity in &entities {
1165            assert!(
1166                entity.start <= entity.end,
1167                "Invalid span: start {} > end {}",
1168                entity.start,
1169                entity.end
1170            );
1171            assert!(
1172                entity.end <= char_count,
1173                "Entity end {} exceeds char count {} for text {:?}",
1174                entity.end,
1175                char_count,
1176                text
1177            );
1178
1179            // Also verify we can extract the text at those offsets
1180            let extracted: String = text
1181                .chars()
1182                .skip(entity.start)
1183                .take(entity.end - entity.start)
1184                .collect();
1185            assert!(
1186                !extracted.is_empty() || entity.start == entity.end,
1187                "Empty extraction for entity at {}..{} in {:?}",
1188                entity.start,
1189                entity.end,
1190                text
1191            );
1192        }
1193    }
1194
1195    #[test]
1196    fn test_multilingual_inputs_no_panic_and_valid_spans() {
1197        let ner = CrfNER::new();
1198        let texts = [
1199            // Latin
1200            "Marie Curie discovered radium in Paris.",
1201            // CJK
1202            "習近平在北京會見了普京。",
1203            // Arabic (RTL)
1204            "التقى محمد بن سلمان بالرئيس في الرياض",
1205            // Cyrillic
1206            "Путин встретился с Си Цзиньпином в Москве.",
1207            // Devanagari
1208            "प्रधान मंत्री शर्मा दिल्ली में मिले।",
1209        ];
1210
1211        for text in texts {
1212            let entities = ner.extract_entities(text, None).unwrap();
1213            let char_count = text.chars().count();
1214            for e in entities {
1215                assert!(e.start <= e.end);
1216                assert!(e.end <= char_count);
1217                let _span: String = text.chars().skip(e.start).take(e.end - e.start).collect();
1218            }
1219        }
1220    }
1221
1222    /// Test that duplicate entity texts get correct offsets.
1223    #[test]
1224    fn test_duplicate_entity_offsets() {
1225        // Test token position calculation directly
1226        let text = "Google bought Google for $1 billion.";
1227        let tokens: Vec<&str> = text.split_whitespace().collect();
1228        let positions = CrfNER::calculate_token_positions(text, &tokens);
1229
1230        // First "Google" at byte 0-6
1231        assert_eq!(
1232            positions[0],
1233            (0, 6),
1234            "First 'Google' should be at bytes 0-6"
1235        );
1236        // Second "Google" at byte 14-20
1237        assert_eq!(
1238            positions[2],
1239            (14, 20),
1240            "Second 'Google' should be at bytes 14-20"
1241        );
1242    }
1243
1244    /// Test token position calculation with Unicode.
1245    #[test]
1246    fn test_token_positions_unicode() {
1247        let text = "東京 Tokyo 東京";
1248        let tokens: Vec<&str> = text.split_whitespace().collect();
1249        let positions = CrfNER::calculate_token_positions(text, &tokens);
1250
1251        // Each 東京 is 6 bytes (2 chars × 3 bytes each)
1252        assert_eq!(positions[0], (0, 6), "First '東京' at bytes 0-6");
1253        assert_eq!(positions[1], (7, 12), "Tokyo at bytes 7-12");
1254        assert_eq!(positions[2], (13, 19), "Second '東京' at bytes 13-19");
1255    }
1256}