Skip to main content

anno/backends/
hmm.rs

1//! Hidden Markov Model (HMM) NER backend.
2//!
3//! Implements classical statistical NER using HMMs, the dominant approach
4//! from the 1990s before CRFs became popular. This was the **first statistical
5//! approach** to NER, replacing rule-based systems.
6//!
7//! # Historical Context
8//!
9//! NER first appeared at MUC-6 (1996), where Grishman & Sundheim defined
10//! the task of identifying people, organizations, and locations. Early
11//! systems were rule-based (lexicons, hand-crafted patterns). HMMs brought
12//! statistical learning to NER:
13//!
14//! ```text
15//! 1987  MUC-1: First IE conference (no formal NER task)
16//! 1996  MUC-6: NER formally defined (PER, ORG, LOC)
17//! 1997  Nymble (Bikel et al.): early HMM NER system
18//! 1998  BBN IdentiFinder: HMM-based, MUC-7 benchmark
19//! 2001  CRFs introduced (Lafferty et al.) — HMMs become a common comparison baseline
20//! ```
21//!
22//! # Architecture
23//!
24//! ```text
25//! Input: "John works at Google"
26//!    ↓
27//! ┌─────────────────────────────────────────────────────────┐
28//! │ Hidden States (NER Tags)                                │
29//! │                                                         │
30//! │  B-PER ──> O ──> O ──> B-ORG                           │
31//! │    │       │     │       │                              │
32//! │    ↓       ↓     ↓       ↓                              │
33//! │  John   works   at   Google                             │
34//! │                                                         │
35//! │ Observed Emissions                                      │
36//! └─────────────────────────────────────────────────────────┘
37//!
38//! P(tags | words) ∝ P(tags) × P(words | tags)
39//!                 = ∏ P(tag_i | tag_{i-1}) × P(word_i | tag_i)
40//! ```
41//!
42//! # HMM Components
43//!
44//! 1. **States**: BIO tags (B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, O)
45//! 2. **Observations**: Words in the text
46//! 3. **Transition Probabilities**: P(tag_i | tag_{i-1})
47//! 4. **Emission Probabilities**: P(word | tag)
48//! 5. **Initial Probabilities**: P(tag | start)
49//!
50//! # Mathematical Formulation
51//!
52//! HMMs are **generative models** that model the joint probability:
53//!
54//! ```text
55//! P(x, y) = P(y) × P(x | y)
56//!         = P(y_1) × ∏_{t=2}^{T} P(y_t | y_{t-1})    // transitions
57//!                 × ∏_{t=1}^{T} P(x_t | y_t)          // emissions
58//! ```
59//!
60//! Decoding uses the **Viterbi algorithm** (dynamic programming) to find
61//! the most likely state sequence in O(T × |S|²) time.
62//!
63//! # History
64//!
65//! - Rabiner (1989): "A Tutorial on Hidden Markov Models" (foundational)
66//! - Bikel et al. 1997: "Nymble: A High-Performance Learning Name-finder"
67//! - BBN IdentiFinder: One of the first HMM-based NER systems
68//! - Often replaced by CRFs for NER in the 2000s, but still useful as a baseline/teaching model
69//!
70//! # Why HMMs Often Underperform CRFs (for NER)
71//!
72//! | Aspect | HMM | CRF |
73//! |--------|-----|-----|
74//! | Model Type | Generative | Discriminative |
75//! | Features | Word identity only | Arbitrary features |
76//! | Context | First-order Markov | Arbitrary windows |
77//! | Label Bias | Inherent | Solved |
78//! | Performance | task-dependent | task-dependent |
79//!
80//! HMMs are typically used with relatively limited emission features. CRFs can use arbitrary
81//! feature functions (capitalization, prefixes/suffixes, gazetteers, etc.) while remaining a
82//! globally normalized conditional model.
83//!
84//! # References
85//!
86//! - Rabiner (1989): "A Tutorial on Hidden Markov Models and Selected
87//!   Applications in Speech Recognition" (Proceedings of IEEE)
88//! - Bikel, Miller, Schwartz, Weischedel (1997): "Nymble: A High-Performance
89//!   Learning Name-finder" (ANLP)
90//! - Bikel, Schwartz, Weischedel (1999): "An Algorithm that Learns What's
91//!   in a Name" (Machine Learning)
92//!
93//! # See Also
94//!
95//! - CRF-style sequence models (`backends/crf.rs`)
96
97use crate::{Entity, EntityType, Model, Result};
98use std::collections::HashMap;
99
100#[cfg(feature = "bundled-hmm-params")]
101use std::sync::OnceLock;
102
103#[cfg(feature = "bundled-hmm-params")]
104use serde_json as _;
105
106#[derive(Debug, Clone)]
107struct HmmParams {
108    states: Vec<String>,
109    initial: Vec<f64>,
110    transitions: Vec<Vec<f64>>,
111    backoff: serde_json::Value,
112}
113
114#[derive(Debug, Clone)]
115struct HmmBackoff {
116    /// len_bucket -> probs per state index (aligned with `states`)
117    len: HashMap<String, Vec<f64>>,
118    /// boolean feature -> P(feature_present | state) per state index
119    bool_present: HashMap<String, Vec<f64>>,
120    /// Stable list of boolean features to include absent probabilities.
121    bool_keys: Vec<String>,
122}
123
124/// HMM configuration.
125#[derive(Debug, Clone)]
126pub struct HmmConfig {
127    /// Smoothing parameter for unseen words.
128    pub smoothing: f64,
129    /// Use log probabilities for numerical stability.
130    pub use_log_probs: bool,
131    /// Optional penalty applied to non-O emissions when using bundled backoff.
132    ///
133    /// Values < 1.0 reduce spurious entities; values > 1.0 increase recall but may over-tag.
134    pub non_o_emission_scale: f64,
135    /// If true, prefer bundled priors/transitions (when available) instead of heuristic dynamics.
136    pub use_bundled_dynamics: bool,
137}
138
139impl Default for HmmConfig {
140    fn default() -> Self {
141        Self {
142            smoothing: 1e-10,
143            use_log_probs: true,
144            // Tuned to reduce spurious tagging when bundled params are enabled.
145            non_o_emission_scale: 0.5,
146            // When bundled params are available, use their dynamics by default.
147            // This keeps the "trained" path genuinely end-to-end.
148            use_bundled_dynamics: true,
149        }
150    }
151}
152
153/// Hidden Markov Model for NER.
154///
155/// This implements a first-order HMM (bigram) for sequence labeling.
156/// Uses the Viterbi algorithm for decoding.
157#[derive(Debug)]
158pub struct HmmNER {
159    /// Configuration.
160    config: HmmConfig,
161    /// State labels (BIO tags).
162    states: Vec<String>,
163    /// State to index mapping.
164    state_to_idx: HashMap<String, usize>,
165    /// Transition probabilities: P(state_j | state_i)
166    /// transitions\[i\]\[j\] = P(j | i)
167    transitions: Vec<Vec<f64>>,
168    /// Initial state probabilities: P(state | start)
169    initial: Vec<f64>,
170    /// Emission probabilities: P(word | state)
171    /// Key: (state_idx, word), Value: probability
172    emissions: HashMap<(usize, String), f64>,
173    /// Vocabulary for unknown word handling.
174    #[allow(dead_code)] // Reserved for OOV handling
175    vocab: HashMap<String, usize>,
176    /// Optional bundled emission backoff (small, trained).
177    backoff: Option<HmmBackoff>,
178}
179
180impl HmmNER {
181    /// Create a new HMM NER model with default parameters.
182    #[must_use]
183    pub fn new() -> Self {
184        Self::with_config(HmmConfig::default())
185    }
186
187    /// Create a new HMM NER model using only heuristic parameters (no bundled params).
188    ///
189    /// This is useful for E2E evaluation comparisons (heuristic vs bundled-trained).
190    #[must_use]
191    pub fn new_heuristic() -> Self {
192        Self::with_config_no_bundled(HmmConfig::default())
193    }
194
195    /// Create with custom configuration.
196    #[must_use]
197    pub fn with_config(config: HmmConfig) -> Self {
198        Self::with_config_internal(config, true)
199    }
200
201    /// Create with custom configuration, skipping bundled params even if the feature is enabled.
202    #[must_use]
203    pub fn with_config_no_bundled(config: HmmConfig) -> Self {
204        Self::with_config_internal(config, false)
205    }
206
207    fn with_config_internal(config: HmmConfig, allow_bundled: bool) -> Self {
208        let states = vec![
209            "O".to_string(),
210            "B-PER".to_string(),
211            "I-PER".to_string(),
212            "B-ORG".to_string(),
213            "I-ORG".to_string(),
214            "B-LOC".to_string(),
215            "I-LOC".to_string(),
216            "B-MISC".to_string(),
217            "I-MISC".to_string(),
218        ];
219
220        let state_to_idx: HashMap<String, usize> = states
221            .iter()
222            .enumerate()
223            .map(|(i, s)| (s.clone(), i))
224            .collect();
225
226        let n = states.len();
227
228        // Initialize transition probabilities with BIO constraints
229        let mut transitions = vec![vec![0.0; n]; n];
230        Self::init_transitions(&mut transitions, &states, &config);
231
232        // Initialize with uniform priors, biased toward O
233        // Initial state distribution - more balanced to allow entities at start
234        let mut initial = vec![0.0; n];
235        for (i, state) in states.iter().enumerate() {
236            if state == "O" {
237                initial[i] = 0.4; // O is common but not dominant
238            } else if state.starts_with("B-") {
239                initial[i] = 0.15; // Entities can start sentences
240            } else if state.starts_with("I-") {
241                initial[i] = config.smoothing; // I- can't start
242            }
243        }
244        Self::normalize(&mut initial);
245
246        // Initialize emission probabilities with heuristics
247        let emissions = Self::init_emissions(&states, &state_to_idx);
248
249        let mut m = Self {
250            config,
251            states,
252            state_to_idx,
253            transitions,
254            initial,
255            emissions,
256            vocab: HashMap::new(),
257            backoff: None,
258        };
259
260        // Optional bundled params (priors + transitions only). These are small enough to ship,
261        // and they don't embed word identity emissions.
262        if allow_bundled {
263            if let Some(p) = Self::bundled_params() {
264                if p.states == m.states
265                    && p.initial.len() == m.states.len()
266                    && p.transitions.len() == m.states.len()
267                    && p.transitions.iter().all(|r| r.len() == m.states.len())
268                {
269                    let backoff = HmmBackoff::from_params(&p);
270                    m.backoff = Some(backoff);
271                    // Prefer bundled dynamics when configured (the default config does),
272                    // since the bundled params are intended to be a real end-to-end baseline.
273                    //
274                    // You can force-enable via env var, or force-disable via config.
275                    let use_dynamics_env = std::env::var("ANNO_HMM_USE_BUNDLED_DYNAMICS")
276                        .ok()
277                        .is_some_and(|v| {
278                            let s = v.trim();
279                            s == "1"
280                                || s.eq_ignore_ascii_case("true")
281                                || s.eq_ignore_ascii_case("yes")
282                        });
283                    let use_dynamics = m.config.use_bundled_dynamics || use_dynamics_env;
284                    if use_dynamics {
285                        m.initial = p.initial.clone();
286                        m.transitions = p.transitions.clone();
287                    }
288                }
289            }
290        }
291
292        m
293    }
294
295    fn bundled_params() -> Option<HmmParams> {
296        #[cfg(feature = "bundled-hmm-params")]
297        {
298            static ONCE: OnceLock<Option<HmmParams>> = OnceLock::new();
299            return ONCE
300                .get_or_init(|| {
301                    let s = include_str!("hmm_params.json");
302                    let v: serde_json::Value = serde_json::from_str(s).ok()?;
303                    let states = v
304                        .get("states")?
305                        .as_array()?
306                        .iter()
307                        .map(|x| x.as_str().map(|s| s.to_string()))
308                        .collect::<Option<Vec<_>>>()?;
309                    let initial = v
310                        .get("initial")?
311                        .as_array()?
312                        .iter()
313                        .map(|x| x.as_f64())
314                        .collect::<Option<Vec<_>>>()?;
315                    let transitions = v
316                        .get("transitions")?
317                        .as_array()?
318                        .iter()
319                        .map(|row| {
320                            row.as_array()?
321                                .iter()
322                                .map(|x| x.as_f64())
323                                .collect::<Option<Vec<_>>>()
324                        })
325                        .collect::<Option<Vec<_>>>()?;
326                    let backoff = v.get("backoff")?.clone();
327                    Some(HmmParams {
328                        states,
329                        initial,
330                        transitions,
331                        backoff,
332                    })
333                })
334                .clone();
335        }
336        #[cfg(not(feature = "bundled-hmm-params"))]
337        {
338            None
339        }
340    }
341
342    /// Initialize transition matrix with BIO constraints.
343    fn init_transitions(trans: &mut [Vec<f64>], states: &[String], config: &HmmConfig) {
344        let n = states.len();
345
346        for i in 0..n {
347            for j in 0..n {
348                let from = &states[i];
349                let to = &states[j];
350
351                // BIO constraints
352                if let Some(entity_type) = to.strip_prefix("I-") {
353                    let valid_b = format!("B-{}", entity_type);
354                    let valid_i = format!("I-{}", entity_type);
355
356                    if from == &valid_b || from == &valid_i {
357                        trans[i][j] = 0.3; // Valid continuation
358                    } else {
359                        trans[i][j] = config.smoothing; // Invalid (very low)
360                    }
361                } else if to.starts_with("B-") {
362                    trans[i][j] = 0.1; // Entities are relatively rare
363                } else {
364                    // O tag
365                    trans[i][j] = 0.5; // Most transitions go to O
366                }
367            }
368
369            // Normalize row
370            Self::normalize(&mut trans[i]);
371        }
372    }
373
374    /// Initialize emission probabilities with comprehensive gazetteers.
375    ///
376    /// These are empirically-tuned emission probabilities based on word lists
377    /// commonly found in NER training data (CoNLL-2003, OntoNotes, etc.).
378    fn init_emissions(
379        _states: &[String],
380        state_to_idx: &HashMap<String, usize>,
381    ) -> HashMap<(usize, String), f64> {
382        let mut emissions = HashMap::new();
383
384        // Comprehensive person indicators (names, titles, honorifics)
385        let person_indicators = [
386            // Common first names
387            "john",
388            "mary",
389            "james",
390            "david",
391            "michael",
392            "robert",
393            "william",
394            "richard",
395            "sarah",
396            "jennifer",
397            "elizabeth",
398            "lisa",
399            "marie",
400            "jane",
401            "emily",
402            "anna",
403            "barack",
404            "donald",
405            "joe",
406            "george",
407            "bill",
408            "hillary",
409            "elon",
410            "jeff",
411            "angela",
412            "vladimir",
413            "emmanuel",
414            "xi",
415            "narendra",
416            "justin",
417            "rishi",
418            "steve",
419            "tim",
420            "mark",
421            "satya",
422            "sundar",
423            "sheryl",
424            "sam",
425            "dario",
426            // Common surnames (political, tech, historical)
427            "obama",
428            "biden",
429            "trump",
430            "bush",
431            "clinton",
432            "reagan",
433            "kennedy",
434            "lincoln",
435            "merkel",
436            "macron",
437            "putin",
438            "jinping",
439            "modi",
440            "trudeau",
441            "sunak",
442            "musk",
443            "bezos",
444            "zuckerberg",
445            "gates",
446            "jobs",
447            "wozniak",
448            "cook",
449            "pichai",
450            "nadella",
451            "altman",
452            "amodei",
453            "hassabis",
454            "hinton",
455            "lecun",
456            "bengio",
457            "smith",
458            "johnson",
459            "williams",
460            "brown",
461            "jones",
462            "garcia",
463            "miller",
464            "davis",
465            // Honorifics and titles
466            "mr",
467            "mrs",
468            "ms",
469            "dr",
470            "prof",
471            "sir",
472            "lord",
473            "lady",
474            "president",
475            "ceo",
476            "chairman",
477            "director",
478            "minister",
479            "senator",
480            "mayor",
481            "governor",
482            "chancellor",
483            "prime",
484            "secretary",
485            "ambassador",
486            "general",
487            "admiral",
488        ];
489
490        // Comprehensive organization indicators
491        let org_indicators = [
492            // Company names
493            "google",
494            "apple",
495            "microsoft",
496            "amazon",
497            "facebook",
498            "meta",
499            "tesla",
500            "ibm",
501            "intel",
502            "nvidia",
503            "oracle",
504            "cisco",
505            "adobe",
506            "netflix",
507            "uber",
508            "toyota",
509            "honda",
510            "ford",
511            "chevrolet",
512            "bmw",
513            "mercedes",
514            "audi",
515            // Suffixes
516            "inc",
517            "corp",
518            "ltd",
519            "llc",
520            "co",
521            "plc",
522            "gmbh",
523            "ag",
524            "sa",
525            "company",
526            "corporation",
527            "incorporated",
528            "limited",
529            "group",
530            "holdings",
531            // Institutional
532            "university",
533            "institute",
534            "college",
535            "academy",
536            "school",
537            "hospital",
538            "foundation",
539            "association",
540            "organization",
541            "committee",
542            "council",
543            "department",
544            "ministry",
545            "agency",
546            "bureau",
547            "commission",
548            // Government/International
549            "fbi",
550            "cia",
551            "nsa",
552            "nasa",
553            "un",
554            "nato",
555            "who",
556            "imf",
557            "eu",
558            "usa",
559            "parliament",
560            "congress",
561            "senate",
562            "house",
563            "court",
564            "bank",
565        ];
566
567        // Comprehensive location indicators
568        let loc_indicators = [
569            // US cities/states
570            "new",
571            "york",
572            "california",
573            "texas",
574            "florida",
575            "washington",
576            "chicago",
577            "boston",
578            "seattle",
579            "san",
580            "francisco",
581            "los",
582            "angeles",
583            "las",
584            "vegas",
585            "miami",
586            "denver",
587            "atlanta",
588            "phoenix",
589            "dallas",
590            "houston",
591            "portland",
592            // World cities
593            "london",
594            "paris",
595            "berlin",
596            "tokyo",
597            "beijing",
598            "moscow",
599            "sydney",
600            "toronto",
601            "vancouver",
602            "rome",
603            "madrid",
604            "amsterdam",
605            "brussels",
606            "vienna",
607            "seoul",
608            "singapore",
609            "hong",
610            "kong",
611            "dubai",
612            "mumbai",
613            "delhi",
614            // Countries/regions
615            "united",
616            "states",
617            "america",
618            "china",
619            "russia",
620            "germany",
621            "france",
622            "japan",
623            "india",
624            "brazil",
625            "canada",
626            "australia",
627            "uk",
628            "britain",
629            "italy",
630            "spain",
631            "mexico",
632            "korea",
633            "taiwan",
634            "vietnam",
635            "thailand",
636            // Geographic terms
637            "city",
638            "county",
639            "state",
640            "country",
641            "province",
642            "region",
643            "district",
644            "river",
645            "mountain",
646            "lake",
647            "ocean",
648            "sea",
649            "island",
650            "peninsula",
651            "north",
652            "south",
653            "east",
654            "west",
655            "central",
656            "northern",
657            "southern",
658        ];
659
660        // Set higher emission probabilities for known indicators
661        for word in person_indicators {
662            let b_idx = state_to_idx["B-PER"];
663            let i_idx = state_to_idx["I-PER"];
664            emissions.insert((b_idx, word.to_string()), 0.4);
665            emissions.insert((i_idx, word.to_string()), 0.25);
666        }
667
668        for word in org_indicators {
669            let b_idx = state_to_idx["B-ORG"];
670            let i_idx = state_to_idx["I-ORG"];
671            emissions.insert((b_idx, word.to_string()), 0.4);
672            emissions.insert((i_idx, word.to_string()), 0.25);
673        }
674
675        for word in loc_indicators {
676            let b_idx = state_to_idx["B-LOC"];
677            let i_idx = state_to_idx["I-LOC"];
678            emissions.insert((b_idx, word.to_string()), 0.4);
679            emissions.insert((i_idx, word.to_string()), 0.25);
680        }
681
682        emissions
683    }
684
685    /// Normalize a probability vector.
686    fn normalize(vec: &mut [f64]) {
687        let sum: f64 = vec.iter().sum();
688        if sum > 0.0 {
689            for v in vec.iter_mut() {
690                *v /= sum;
691            }
692        }
693    }
694
695    /// Get emission probability for (state, word).
696    fn emission_prob(&self, state_idx: usize, word: &str) -> f64 {
697        let lower = word.to_lowercase();
698
699        // Check explicit emissions (known entity names)
700        if let Some(&prob) = self.emissions.get(&(state_idx, lower.clone())) {
701            return prob;
702        }
703
704        // If we have bundled backoff emissions, prefer them over heuristics.
705        // These are compact, trained probabilities over generic word features (no word identity).
706        if let Some(b) = self.backoff.as_ref() {
707            // Emission score uses a naive Bayes factorization:
708            //   P(features | state) = P(len_bucket | state) * Π_f P(f|state)^(present) * (1-P(f|state))^(absent)
709            // We only use the small set of features in the bundled table.
710            let lb = Self::len_bucket(word);
711            let mut sum_log = 0.0f64;
712            if let Some(p) = b.len.get(lb).and_then(|v| v.get(state_idx).copied()) {
713                sum_log += p.max(1e-12).ln();
714            } else {
715                sum_log += (1e-12f64).ln();
716            }
717            let feats = Self::bool_features(word);
718            for k in &b.bool_keys {
719                let present = feats.get(k.as_str()).copied().unwrap_or(false);
720                let p_present = b
721                    .bool_present
722                    .get(k)
723                    .and_then(|v| v.get(state_idx).copied())
724                    .unwrap_or(1e-12)
725                    .clamp(1e-12, 1.0 - 1e-12);
726                let p = if present { p_present } else { 1.0 - p_present };
727                sum_log += p.max(1e-12).ln();
728            }
729            let mut score = sum_log.exp().max(self.config.smoothing);
730            // State 0 is "O" in our state list.
731            if state_idx != 0 {
732                score *= self.config.non_o_emission_scale.max(1e-6);
733            }
734            return score.max(self.config.smoothing);
735        }
736
737        // Heuristic emissions based on word features
738        let state = &self.states[state_idx];
739        let is_capitalized = word.chars().next().is_some_and(|c| c.is_uppercase());
740        let is_all_caps =
741            word.chars().all(|c| c.is_uppercase() || !c.is_alphabetic()) && word.len() > 1;
742        let has_digit = word.chars().any(|c| c.is_ascii_digit());
743        let is_title_case = is_capitalized && word.len() > 1;
744
745        // Check for organization suffixes
746        let org_suffixes = [
747            "Inc", "Corp", "Ltd", "LLC", "Co", "Company", "Inc.", "Corp.", "Ltd.",
748        ];
749        let is_org_suffix = org_suffixes.contains(&word);
750
751        if state == "O" {
752            // Non-capitalized words and digits are likely O
753            if !is_capitalized {
754                return 0.7;
755            }
756            // Capitalized at sentence start - unclear
757            if has_digit {
758                return 0.5;
759            }
760            // Title case words are less likely to be O
761            if is_title_case {
762                return 0.15;
763            }
764            return 0.4;
765        }
766
767        if state.starts_with("B-") || state.starts_with("I-") {
768            let entity_type = &state[2..];
769
770            // Organization suffixes strongly indicate ORG
771            if entity_type == "ORG" && is_org_suffix {
772                return 0.8;
773            }
774
775            // All caps = likely ORG (acronyms like IBM, NASA)
776            if is_all_caps && entity_type == "ORG" {
777                return 0.6;
778            }
779
780            // Title case words are likely entities, but prefer PER for typical names
781            // Most proper nouns starting with capital letters are person names
782            // unless they have organization-specific markers
783            if is_title_case && !has_digit {
784                if entity_type == "PER" {
785                    return 0.55; // Slightly prefer PER over others for title case
786                } else if entity_type == "LOC" {
787                    return 0.45; // Locations are second most common title case
788                } else if entity_type == "ORG" {
789                    return 0.35; // ORGs need more evidence (suffix, acronym)
790                }
791                return 0.4;
792            }
793
794            // Capitalized words at least somewhat likely
795            if is_capitalized && !has_digit {
796                return 0.3;
797            }
798
799            return self.config.smoothing;
800        }
801
802        self.config.smoothing
803    }
804
805    /// Viterbi decoding to find most likely state sequence.
806    fn viterbi(&self, words: &[&str]) -> Vec<usize> {
807        if words.is_empty() {
808            return vec![];
809        }
810
811        let n = words.len();
812        let m = self.states.len();
813
814        // Use log probabilities for numerical stability
815        let log = |p: f64| if p > 0.0 { p.ln() } else { f64::NEG_INFINITY };
816
817        // DP tables
818        let mut dp = vec![vec![f64::NEG_INFINITY; m]; n];
819        let mut backptr = vec![vec![0usize; m]; n];
820
821        // Initialize first position
822        for (j, cell) in dp[0].iter_mut().enumerate().take(m) {
823            *cell = log(self.initial[j]) + log(self.emission_prob(j, words[0]));
824        }
825
826        // Forward pass
827        for t in 1..n {
828            for j in 0..m {
829                let emit = log(self.emission_prob(j, words[t]));
830
831                for i in 0..m {
832                    let trans = log(self.transitions[i][j]);
833                    let score = dp[t - 1][i] + trans + emit;
834
835                    if score > dp[t][j] {
836                        dp[t][j] = score;
837                        backptr[t][j] = i;
838                    }
839                }
840            }
841        }
842
843        // Find best final state
844        let mut best_state = 0;
845        let mut best_score = f64::NEG_INFINITY;
846        for (j, &score) in dp[n - 1].iter().enumerate() {
847            if score > best_score {
848                best_score = score;
849                best_state = j;
850            }
851        }
852
853        // Backtrack
854        let mut path = vec![0usize; n];
855        path[n - 1] = best_state;
856        for t in (0..n - 1).rev() {
857            path[t] = backptr[t + 1][path[t + 1]];
858        }
859
860        path
861    }
862
863    /// Convert BIO labels to entities.
864    ///
865    /// Uses token position tracking to correctly handle duplicate entity texts.
866    /// The previous implementation used `text.find()` which always returned the
867    /// first occurrence, causing incorrect offsets for duplicate entities.
868    fn decode_entities(&self, text: &str, words: &[&str], labels: &[usize]) -> Vec<Entity> {
869        use crate::offset::SpanConverter;
870
871        let converter = SpanConverter::new(text);
872        let mut entities = Vec::new();
873
874        // Track token positions (byte offsets) as we iterate
875        let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, words);
876
877        let mut current: Option<(usize, usize, EntityType, Vec<&str>)> = None;
878
879        for (i, (&label_idx, &word)) in labels.iter().zip(words.iter()).enumerate() {
880            let label = &self.states[label_idx];
881
882            if label.starts_with("B-") {
883                // Save previous entity
884                if let Some((start_idx, end_idx, entity_type, entity_words)) = current.take() {
885                    Self::push_entity_from_positions(
886                        &converter,
887                        &token_positions,
888                        start_idx,
889                        end_idx,
890                        &entity_words,
891                        entity_type,
892                        &mut entities,
893                    );
894                }
895
896                // Start new entity
897                let entity_type_str = label
898                    .strip_prefix("B-")
899                    .or_else(|| label.strip_prefix("I-"))
900                    .expect("label should start with B- or I-");
901                let entity_type = match entity_type_str {
902                    "PER" => EntityType::Person,
903                    "ORG" => EntityType::Organization,
904                    "LOC" => EntityType::Location,
905                    other => EntityType::Other(other.to_string()),
906                };
907                current = Some((i, i, entity_type, vec![word]));
908            } else if label.starts_with("I-") && current.is_some() {
909                if let Some((_, ref mut end_idx, _, ref mut entity_words)) = current {
910                    entity_words.push(word);
911                    *end_idx = i;
912                }
913            } else {
914                // O tag
915                if let Some((start_idx, end_idx, entity_type, entity_words)) = current.take() {
916                    Self::push_entity_from_positions(
917                        &converter,
918                        &token_positions,
919                        start_idx,
920                        end_idx,
921                        &entity_words,
922                        entity_type,
923                        &mut entities,
924                    );
925                }
926            }
927        }
928
929        // Final entity
930        if let Some((start_idx, end_idx, entity_type, entity_words)) = current {
931            Self::push_entity_from_positions(
932                &converter,
933                &token_positions,
934                start_idx,
935                end_idx,
936                &entity_words,
937                entity_type,
938                &mut entities,
939            );
940        }
941
942        entities
943    }
944
945    /// Calculate byte positions for each token in the text.
946    fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
947        let mut positions = Vec::with_capacity(tokens.len());
948        let mut byte_pos = 0;
949
950        for token in tokens {
951            // Find token starting from current position
952            if let Some(rel_pos) = text[byte_pos..].find(token) {
953                let start = byte_pos + rel_pos;
954                let end = start + token.len();
955                positions.push((start, end));
956                byte_pos = end; // Move past this token
957            } else {
958                // Fallback: use current position (shouldn't happen with whitespace tokenization)
959                positions.push((byte_pos, byte_pos));
960            }
961        }
962
963        positions
964    }
965
966    /// Helper to create entity with correct character offsets using token positions.
967    fn push_entity_from_positions(
968        converter: &crate::offset::SpanConverter,
969        positions: &[(usize, usize)],
970        start_token_idx: usize,
971        end_token_idx: usize,
972        words: &[&str],
973        entity_type: EntityType,
974        entities: &mut Vec<Entity>,
975    ) {
976        if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
977            return;
978        }
979
980        let byte_start = positions[start_token_idx].0;
981        let byte_end = positions[end_token_idx].1;
982        let char_start = converter.byte_to_char(byte_start);
983        let char_end = converter.byte_to_char(byte_end);
984        let entity_text = words.join(" ");
985
986        entities.push(Entity::new(
987            entity_text,
988            entity_type,
989            char_start,
990            char_end,
991            0.65, // HMM confidence
992        ));
993    }
994
995    /// Train the HMM from labeled data.
996    ///
997    /// # Arguments
998    /// * `sentences` - List of (words, tags) pairs
999    pub fn train(&mut self, sentences: &[(&[&str], &[&str])]) {
1000        // Count transitions
1001        let n = self.states.len();
1002        let mut trans_counts = vec![vec![0usize; n]; n];
1003        let mut initial_counts = vec![0usize; n];
1004        let mut emission_counts: HashMap<(usize, String), usize> = HashMap::new();
1005        let mut state_counts = vec![0usize; n];
1006
1007        for (words, tags) in sentences {
1008            if tags.is_empty() {
1009                continue;
1010            }
1011
1012            // Initial state
1013            if let Some(&idx) = self.state_to_idx.get(tags[0]) {
1014                initial_counts[idx] += 1;
1015            }
1016
1017            // Transitions and emissions
1018            for (i, (word, tag)) in words.iter().zip(tags.iter()).enumerate() {
1019                if let Some(&tag_idx) = self.state_to_idx.get(*tag) {
1020                    // Emission count
1021                    *emission_counts
1022                        .entry((tag_idx, word.to_lowercase()))
1023                        .or_insert(0) += 1;
1024                    state_counts[tag_idx] += 1;
1025
1026                    // Transition count
1027                    if i > 0 {
1028                        if let Some(&prev_idx) = self.state_to_idx.get(tags[i - 1]) {
1029                            trans_counts[prev_idx][tag_idx] += 1;
1030                        }
1031                    }
1032                }
1033            }
1034        }
1035
1036        // Convert counts to probabilities (with smoothing)
1037        let total_initial: f64 =
1038            initial_counts.iter().sum::<usize>() as f64 + self.config.smoothing * n as f64;
1039        for (i, &count) in initial_counts.iter().enumerate() {
1040            self.initial[i] = (count as f64 + self.config.smoothing) / total_initial;
1041        }
1042
1043        for (i, row) in trans_counts.iter().enumerate().take(n) {
1044            let total: f64 = row.iter().sum::<usize>() as f64 + self.config.smoothing * n as f64;
1045            for (j, &count) in row.iter().enumerate().take(n) {
1046                self.transitions[i][j] = (count as f64 + self.config.smoothing) / total;
1047            }
1048        }
1049
1050        for ((state_idx, word), count) in emission_counts {
1051            let total = state_counts[state_idx] as f64;
1052            if total > 0.0 {
1053                self.emissions
1054                    .insert((state_idx, word), count as f64 / total);
1055            }
1056        }
1057    }
1058
1059    fn len_bucket(word: &str) -> &'static str {
1060        let n = word.chars().count();
1061        if n <= 1 {
1062            "len:1"
1063        } else if n == 2 {
1064            "len:2"
1065        } else if n == 3 {
1066            "len:3"
1067        } else if (4..=5).contains(&n) {
1068            "len:4_5"
1069        } else if (6..=8).contains(&n) {
1070            "len:6_8"
1071        } else {
1072            "len:9p"
1073        }
1074    }
1075
1076    fn bool_features(word: &str) -> HashMap<&'static str, bool> {
1077        let is_capitalized = word.chars().next().is_some_and(|c| c.is_uppercase());
1078        let is_all_caps = word.chars().all(|c| c.is_uppercase() || !c.is_alphabetic())
1079            && word.chars().count() > 1;
1080        let is_digit = !word.is_empty() && word.chars().all(|c| c.is_ascii_digit());
1081        let has_digit = word.chars().any(|c| c.is_ascii_digit());
1082        let has_hyphen = word.contains('-');
1083        let has_dot = word.contains('.');
1084        let mut m = HashMap::new();
1085        m.insert("is_capitalized", is_capitalized);
1086        m.insert("is_all_caps", is_all_caps);
1087        m.insert("is_digit", is_digit);
1088        m.insert("has_digit", has_digit);
1089        m.insert("has_hyphen", has_hyphen);
1090        m.insert("has_dot", has_dot);
1091        m
1092    }
1093}
1094
1095impl HmmBackoff {
1096    fn from_params(p: &HmmParams) -> Self {
1097        // backoff schema:
1098        // {
1099        //   "len": { bucket: { state: prob } },
1100        //   "bool": { feat: { state: p_present } }
1101        // }
1102        let mut len: HashMap<String, Vec<f64>> = HashMap::new();
1103        let mut bool_present: HashMap<String, Vec<f64>> = HashMap::new();
1104
1105        if let Some(obj) = p.backoff.as_object() {
1106            if let Some(len_obj) = obj.get("len").and_then(|v| v.as_object()) {
1107                for (bucket, distv) in len_obj {
1108                    let mut v = vec![1e-12; p.states.len()];
1109                    if let Some(dist) = distv.as_object() {
1110                        for (i, state) in p.states.iter().enumerate() {
1111                            if let Some(x) = dist.get(state).and_then(|x| x.as_f64()) {
1112                                v[i] = x;
1113                            }
1114                        }
1115                    }
1116                    len.insert(bucket.clone(), v);
1117                }
1118            }
1119            if let Some(bool_obj) = obj.get("bool").and_then(|v| v.as_object()) {
1120                for (feat, distv) in bool_obj {
1121                    let mut v = vec![1e-12; p.states.len()];
1122                    if let Some(dist) = distv.as_object() {
1123                        for (i, state) in p.states.iter().enumerate() {
1124                            if let Some(x) = dist.get(state).and_then(|x| x.as_f64()) {
1125                                v[i] = x;
1126                            }
1127                        }
1128                    }
1129                    bool_present.insert(feat.clone(), v);
1130                }
1131            }
1132        }
1133
1134        let mut bool_keys: Vec<String> = bool_present.keys().cloned().collect();
1135        bool_keys.sort();
1136        Self {
1137            len,
1138            bool_present,
1139            bool_keys,
1140        }
1141    }
1142}
1143
1144impl Default for HmmNER {
1145    fn default() -> Self {
1146        Self::new()
1147    }
1148}
1149
1150impl Model for HmmNER {
1151    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
1152        if text.trim().is_empty() {
1153            return Ok(vec![]);
1154        }
1155
1156        let words: Vec<&str> = text.split_whitespace().collect();
1157        if words.is_empty() {
1158            return Ok(vec![]);
1159        }
1160
1161        let label_indices = self.viterbi(&words);
1162        let entities = self.decode_entities(text, &words, &label_indices);
1163
1164        Ok(entities)
1165    }
1166
1167    fn supported_types(&self) -> Vec<EntityType> {
1168        vec![
1169            EntityType::Person,
1170            EntityType::Organization,
1171            EntityType::Location,
1172            EntityType::Other("MISC".to_string()),
1173        ]
1174    }
1175
1176    fn is_available(&self) -> bool {
1177        true // Always available
1178    }
1179}
1180
1181impl crate::sealed::Sealed for HmmNER {}
1182impl crate::NamedEntityCapable for HmmNER {}
1183
1184#[cfg(test)]
1185mod tests {
1186    use super::*;
1187
1188    #[test]
1189    fn test_basic_extraction() {
1190        let ner = HmmNER::new();
1191        let entities = ner
1192            .extract_entities("John works at Google in California.", None)
1193            .unwrap();
1194
1195        // HMM with heuristics should find some entities
1196        for entity in &entities {
1197            assert!(entity.confidence > 0.0 && entity.confidence <= 1.0);
1198        }
1199    }
1200
1201    #[test]
1202    fn test_empty_input() {
1203        let ner = HmmNER::new();
1204        let entities = ner.extract_entities("", None).unwrap();
1205        assert!(entities.is_empty());
1206    }
1207
1208    #[test]
1209    fn test_viterbi_path_length() {
1210        let ner = HmmNER::new();
1211        let words = vec!["John", "works", "at", "Google"];
1212        let path = ner.viterbi(&words);
1213
1214        assert_eq!(path.len(), words.len());
1215    }
1216
1217    #[test]
1218    fn test_bio_constraints() {
1219        let ner = HmmNER::new();
1220
1221        // I-PER should not follow O with high probability
1222        let i_per = ner.state_to_idx["I-PER"];
1223        let o = ner.state_to_idx["O"];
1224        let b_per = ner.state_to_idx["B-PER"];
1225
1226        // Transition O -> I-PER should be very low
1227        assert!(ner.transitions[o][i_per] < 0.01);
1228
1229        // Transition B-PER -> I-PER should be reasonable
1230        assert!(ner.transitions[b_per][i_per] > 0.1);
1231    }
1232
1233    #[test]
1234    fn test_emission_heuristics() {
1235        let ner = HmmNER::new();
1236
1237        let _o_idx = ner.state_to_idx["O"];
1238        let b_per_idx = ner.state_to_idx["B-PER"];
1239
1240        // Capitalized word should have higher entity probability
1241        let cap_prob = ner.emission_prob(b_per_idx, "John");
1242        let lower_prob = ner.emission_prob(b_per_idx, "john");
1243
1244        assert!(cap_prob >= lower_prob);
1245    }
1246
1247    #[test]
1248    fn test_training() {
1249        let mut ner = HmmNER::new();
1250
1251        let sentences: Vec<(&[&str], &[&str])> = vec![
1252            (
1253                &["John", "works", "at", "Google"][..],
1254                &["B-PER", "O", "O", "B-ORG"][..],
1255            ),
1256            (
1257                &["Mary", "lives", "in", "Paris"][..],
1258                &["B-PER", "O", "O", "B-LOC"][..],
1259            ),
1260        ];
1261
1262        ner.train(&sentences);
1263
1264        // After training, transitions should be updated
1265        let b_per = ner.state_to_idx["B-PER"];
1266        let o = ner.state_to_idx["O"];
1267
1268        // B-PER -> O should be high (entities followed by non-entities)
1269        assert!(ner.transitions[b_per][o] > 0.3);
1270    }
1271
1272    #[test]
1273    fn test_unicode_offsets() {
1274        let ner = HmmNER::new();
1275        let text = "北京 Google Inc.";
1276        let char_count = text.chars().count();
1277
1278        let entities = ner.extract_entities(text, None).unwrap();
1279
1280        for entity in &entities {
1281            assert!(entity.start <= entity.end);
1282            assert!(entity.end <= char_count);
1283        }
1284    }
1285
1286    #[test]
1287    fn test_config() {
1288        let config = HmmConfig {
1289            smoothing: 1e-5,
1290            ..Default::default()
1291        };
1292
1293        let ner = HmmNER::with_config(config);
1294        assert_eq!(ner.config.smoothing, 1e-5);
1295    }
1296
1297    #[test]
1298    fn test_supported_types() {
1299        let ner = HmmNER::new();
1300        let types = ner.supported_types();
1301
1302        assert!(types.contains(&EntityType::Person));
1303        assert!(types.contains(&EntityType::Organization));
1304        assert!(types.contains(&EntityType::Location));
1305    }
1306
1307    /// Test that duplicate entity texts get correct offsets.
1308    #[test]
1309    fn test_duplicate_entity_offsets() {
1310        // Test token position calculation directly
1311        let text = "Google bought Google for $1 billion.";
1312        let tokens: Vec<&str> = text.split_whitespace().collect();
1313        let positions = HmmNER::calculate_token_positions(text, &tokens);
1314
1315        // First "Google" at byte 0-6
1316        assert_eq!(
1317            positions[0],
1318            (0, 6),
1319            "First 'Google' should be at bytes 0-6"
1320        );
1321        // Second "Google" at byte 14-20
1322        assert_eq!(
1323            positions[2],
1324            (14, 20),
1325            "Second 'Google' should be at bytes 14-20"
1326        );
1327    }
1328
1329    /// Test token position calculation with Unicode.
1330    #[test]
1331    fn test_token_positions_unicode() {
1332        let text = "東京 Tokyo 東京";
1333        let tokens: Vec<&str> = text.split_whitespace().collect();
1334        let positions = HmmNER::calculate_token_positions(text, &tokens);
1335
1336        // Each 東京 is 6 bytes (2 chars × 3 bytes each)
1337        assert_eq!(positions[0], (0, 6), "First '東京' at bytes 0-6");
1338        assert_eq!(positions[1], (7, 12), "Tokyo at bytes 7-12");
1339        assert_eq!(positions[2], (13, 19), "Second '東京' at bytes 13-19");
1340    }
1341}