1use crate::{Entity, EntityType, Model, Result};
98use std::collections::HashMap;
99
100#[cfg(feature = "bundled-hmm-params")]
101use std::sync::OnceLock;
102
103#[cfg(feature = "bundled-hmm-params")]
104use serde_json as _;
105
106#[derive(Debug, Clone)]
107struct HmmParams {
108 states: Vec<String>,
109 initial: Vec<f64>,
110 transitions: Vec<Vec<f64>>,
111 backoff: serde_json::Value,
112}
113
114#[derive(Debug, Clone)]
115struct HmmBackoff {
116 len: HashMap<String, Vec<f64>>,
118 bool_present: HashMap<String, Vec<f64>>,
120 bool_keys: Vec<String>,
122}
123
124#[derive(Debug, Clone)]
126pub struct HmmConfig {
127 pub smoothing: f64,
129 pub use_log_probs: bool,
131 pub non_o_emission_scale: f64,
135 pub use_bundled_dynamics: bool,
137}
138
139impl Default for HmmConfig {
140 fn default() -> Self {
141 Self {
142 smoothing: 1e-10,
143 use_log_probs: true,
144 non_o_emission_scale: 0.5,
146 use_bundled_dynamics: true,
149 }
150 }
151}
152
153#[derive(Debug)]
158pub struct HmmNER {
159 config: HmmConfig,
161 states: Vec<String>,
163 state_to_idx: HashMap<String, usize>,
165 transitions: Vec<Vec<f64>>,
168 initial: Vec<f64>,
170 emissions: HashMap<(usize, String), f64>,
173 #[allow(dead_code)] vocab: HashMap<String, usize>,
176 backoff: Option<HmmBackoff>,
178}
179
180impl HmmNER {
181 #[must_use]
183 pub fn new() -> Self {
184 Self::with_config(HmmConfig::default())
185 }
186
187 #[must_use]
191 pub fn new_heuristic() -> Self {
192 Self::with_config_no_bundled(HmmConfig::default())
193 }
194
195 #[must_use]
197 pub fn with_config(config: HmmConfig) -> Self {
198 Self::with_config_internal(config, true)
199 }
200
201 #[must_use]
203 pub fn with_config_no_bundled(config: HmmConfig) -> Self {
204 Self::with_config_internal(config, false)
205 }
206
207 fn with_config_internal(config: HmmConfig, allow_bundled: bool) -> Self {
208 let states = vec![
209 "O".to_string(),
210 "B-PER".to_string(),
211 "I-PER".to_string(),
212 "B-ORG".to_string(),
213 "I-ORG".to_string(),
214 "B-LOC".to_string(),
215 "I-LOC".to_string(),
216 "B-MISC".to_string(),
217 "I-MISC".to_string(),
218 ];
219
220 let state_to_idx: HashMap<String, usize> = states
221 .iter()
222 .enumerate()
223 .map(|(i, s)| (s.clone(), i))
224 .collect();
225
226 let n = states.len();
227
228 let mut transitions = vec![vec![0.0; n]; n];
230 Self::init_transitions(&mut transitions, &states, &config);
231
232 let mut initial = vec![0.0; n];
235 for (i, state) in states.iter().enumerate() {
236 if state == "O" {
237 initial[i] = 0.4; } else if state.starts_with("B-") {
239 initial[i] = 0.15; } else if state.starts_with("I-") {
241 initial[i] = config.smoothing; }
243 }
244 Self::normalize(&mut initial);
245
246 let emissions = Self::init_emissions(&states, &state_to_idx);
248
249 let mut m = Self {
250 config,
251 states,
252 state_to_idx,
253 transitions,
254 initial,
255 emissions,
256 vocab: HashMap::new(),
257 backoff: None,
258 };
259
260 if allow_bundled {
263 if let Some(p) = Self::bundled_params() {
264 if p.states == m.states
265 && p.initial.len() == m.states.len()
266 && p.transitions.len() == m.states.len()
267 && p.transitions.iter().all(|r| r.len() == m.states.len())
268 {
269 let backoff = HmmBackoff::from_params(&p);
270 m.backoff = Some(backoff);
271 let use_dynamics_env = std::env::var("ANNO_HMM_USE_BUNDLED_DYNAMICS")
276 .ok()
277 .is_some_and(|v| {
278 let s = v.trim();
279 s == "1"
280 || s.eq_ignore_ascii_case("true")
281 || s.eq_ignore_ascii_case("yes")
282 });
283 let use_dynamics = m.config.use_bundled_dynamics || use_dynamics_env;
284 if use_dynamics {
285 m.initial = p.initial.clone();
286 m.transitions = p.transitions.clone();
287 }
288 }
289 }
290 }
291
292 m
293 }
294
295 fn bundled_params() -> Option<HmmParams> {
296 #[cfg(feature = "bundled-hmm-params")]
297 {
298 static ONCE: OnceLock<Option<HmmParams>> = OnceLock::new();
299 return ONCE
300 .get_or_init(|| {
301 let s = include_str!("hmm_params.json");
302 let v: serde_json::Value = serde_json::from_str(s).ok()?;
303 let states = v
304 .get("states")?
305 .as_array()?
306 .iter()
307 .map(|x| x.as_str().map(|s| s.to_string()))
308 .collect::<Option<Vec<_>>>()?;
309 let initial = v
310 .get("initial")?
311 .as_array()?
312 .iter()
313 .map(|x| x.as_f64())
314 .collect::<Option<Vec<_>>>()?;
315 let transitions = v
316 .get("transitions")?
317 .as_array()?
318 .iter()
319 .map(|row| {
320 row.as_array()?
321 .iter()
322 .map(|x| x.as_f64())
323 .collect::<Option<Vec<_>>>()
324 })
325 .collect::<Option<Vec<_>>>()?;
326 let backoff = v.get("backoff")?.clone();
327 Some(HmmParams {
328 states,
329 initial,
330 transitions,
331 backoff,
332 })
333 })
334 .clone();
335 }
336 #[cfg(not(feature = "bundled-hmm-params"))]
337 {
338 None
339 }
340 }
341
342 fn init_transitions(trans: &mut [Vec<f64>], states: &[String], config: &HmmConfig) {
344 let n = states.len();
345
346 for i in 0..n {
347 for j in 0..n {
348 let from = &states[i];
349 let to = &states[j];
350
351 if let Some(entity_type) = to.strip_prefix("I-") {
353 let valid_b = format!("B-{}", entity_type);
354 let valid_i = format!("I-{}", entity_type);
355
356 if from == &valid_b || from == &valid_i {
357 trans[i][j] = 0.3; } else {
359 trans[i][j] = config.smoothing; }
361 } else if to.starts_with("B-") {
362 trans[i][j] = 0.1; } else {
364 trans[i][j] = 0.5; }
367 }
368
369 Self::normalize(&mut trans[i]);
371 }
372 }
373
374 fn init_emissions(
379 _states: &[String],
380 state_to_idx: &HashMap<String, usize>,
381 ) -> HashMap<(usize, String), f64> {
382 let mut emissions = HashMap::new();
383
384 let person_indicators = [
386 "john",
388 "mary",
389 "james",
390 "david",
391 "michael",
392 "robert",
393 "william",
394 "richard",
395 "sarah",
396 "jennifer",
397 "elizabeth",
398 "lisa",
399 "marie",
400 "jane",
401 "emily",
402 "anna",
403 "barack",
404 "donald",
405 "joe",
406 "george",
407 "bill",
408 "hillary",
409 "elon",
410 "jeff",
411 "angela",
412 "vladimir",
413 "emmanuel",
414 "xi",
415 "narendra",
416 "justin",
417 "rishi",
418 "steve",
419 "tim",
420 "mark",
421 "satya",
422 "sundar",
423 "sheryl",
424 "sam",
425 "dario",
426 "obama",
428 "biden",
429 "trump",
430 "bush",
431 "clinton",
432 "reagan",
433 "kennedy",
434 "lincoln",
435 "merkel",
436 "macron",
437 "putin",
438 "jinping",
439 "modi",
440 "trudeau",
441 "sunak",
442 "musk",
443 "bezos",
444 "zuckerberg",
445 "gates",
446 "jobs",
447 "wozniak",
448 "cook",
449 "pichai",
450 "nadella",
451 "altman",
452 "amodei",
453 "hassabis",
454 "hinton",
455 "lecun",
456 "bengio",
457 "smith",
458 "johnson",
459 "williams",
460 "brown",
461 "jones",
462 "garcia",
463 "miller",
464 "davis",
465 "mr",
467 "mrs",
468 "ms",
469 "dr",
470 "prof",
471 "sir",
472 "lord",
473 "lady",
474 "president",
475 "ceo",
476 "chairman",
477 "director",
478 "minister",
479 "senator",
480 "mayor",
481 "governor",
482 "chancellor",
483 "prime",
484 "secretary",
485 "ambassador",
486 "general",
487 "admiral",
488 ];
489
490 let org_indicators = [
492 "google",
494 "apple",
495 "microsoft",
496 "amazon",
497 "facebook",
498 "meta",
499 "tesla",
500 "ibm",
501 "intel",
502 "nvidia",
503 "oracle",
504 "cisco",
505 "adobe",
506 "netflix",
507 "uber",
508 "toyota",
509 "honda",
510 "ford",
511 "chevrolet",
512 "bmw",
513 "mercedes",
514 "audi",
515 "inc",
517 "corp",
518 "ltd",
519 "llc",
520 "co",
521 "plc",
522 "gmbh",
523 "ag",
524 "sa",
525 "company",
526 "corporation",
527 "incorporated",
528 "limited",
529 "group",
530 "holdings",
531 "university",
533 "institute",
534 "college",
535 "academy",
536 "school",
537 "hospital",
538 "foundation",
539 "association",
540 "organization",
541 "committee",
542 "council",
543 "department",
544 "ministry",
545 "agency",
546 "bureau",
547 "commission",
548 "fbi",
550 "cia",
551 "nsa",
552 "nasa",
553 "un",
554 "nato",
555 "who",
556 "imf",
557 "eu",
558 "usa",
559 "parliament",
560 "congress",
561 "senate",
562 "house",
563 "court",
564 "bank",
565 ];
566
567 let loc_indicators = [
569 "new",
571 "york",
572 "california",
573 "texas",
574 "florida",
575 "washington",
576 "chicago",
577 "boston",
578 "seattle",
579 "san",
580 "francisco",
581 "los",
582 "angeles",
583 "las",
584 "vegas",
585 "miami",
586 "denver",
587 "atlanta",
588 "phoenix",
589 "dallas",
590 "houston",
591 "portland",
592 "london",
594 "paris",
595 "berlin",
596 "tokyo",
597 "beijing",
598 "moscow",
599 "sydney",
600 "toronto",
601 "vancouver",
602 "rome",
603 "madrid",
604 "amsterdam",
605 "brussels",
606 "vienna",
607 "seoul",
608 "singapore",
609 "hong",
610 "kong",
611 "dubai",
612 "mumbai",
613 "delhi",
614 "united",
616 "states",
617 "america",
618 "china",
619 "russia",
620 "germany",
621 "france",
622 "japan",
623 "india",
624 "brazil",
625 "canada",
626 "australia",
627 "uk",
628 "britain",
629 "italy",
630 "spain",
631 "mexico",
632 "korea",
633 "taiwan",
634 "vietnam",
635 "thailand",
636 "city",
638 "county",
639 "state",
640 "country",
641 "province",
642 "region",
643 "district",
644 "river",
645 "mountain",
646 "lake",
647 "ocean",
648 "sea",
649 "island",
650 "peninsula",
651 "north",
652 "south",
653 "east",
654 "west",
655 "central",
656 "northern",
657 "southern",
658 ];
659
660 for word in person_indicators {
662 let b_idx = state_to_idx["B-PER"];
663 let i_idx = state_to_idx["I-PER"];
664 emissions.insert((b_idx, word.to_string()), 0.4);
665 emissions.insert((i_idx, word.to_string()), 0.25);
666 }
667
668 for word in org_indicators {
669 let b_idx = state_to_idx["B-ORG"];
670 let i_idx = state_to_idx["I-ORG"];
671 emissions.insert((b_idx, word.to_string()), 0.4);
672 emissions.insert((i_idx, word.to_string()), 0.25);
673 }
674
675 for word in loc_indicators {
676 let b_idx = state_to_idx["B-LOC"];
677 let i_idx = state_to_idx["I-LOC"];
678 emissions.insert((b_idx, word.to_string()), 0.4);
679 emissions.insert((i_idx, word.to_string()), 0.25);
680 }
681
682 emissions
683 }
684
685 fn normalize(vec: &mut [f64]) {
687 let sum: f64 = vec.iter().sum();
688 if sum > 0.0 {
689 for v in vec.iter_mut() {
690 *v /= sum;
691 }
692 }
693 }
694
695 fn emission_prob(&self, state_idx: usize, word: &str) -> f64 {
697 let lower = word.to_lowercase();
698
699 if let Some(&prob) = self.emissions.get(&(state_idx, lower.clone())) {
701 return prob;
702 }
703
704 if let Some(b) = self.backoff.as_ref() {
707 let lb = Self::len_bucket(word);
711 let mut sum_log = 0.0f64;
712 if let Some(p) = b.len.get(lb).and_then(|v| v.get(state_idx).copied()) {
713 sum_log += p.max(1e-12).ln();
714 } else {
715 sum_log += (1e-12f64).ln();
716 }
717 let feats = Self::bool_features(word);
718 for k in &b.bool_keys {
719 let present = feats.get(k.as_str()).copied().unwrap_or(false);
720 let p_present = b
721 .bool_present
722 .get(k)
723 .and_then(|v| v.get(state_idx).copied())
724 .unwrap_or(1e-12)
725 .clamp(1e-12, 1.0 - 1e-12);
726 let p = if present { p_present } else { 1.0 - p_present };
727 sum_log += p.max(1e-12).ln();
728 }
729 let mut score = sum_log.exp().max(self.config.smoothing);
730 if state_idx != 0 {
732 score *= self.config.non_o_emission_scale.max(1e-6);
733 }
734 return score.max(self.config.smoothing);
735 }
736
737 let state = &self.states[state_idx];
739 let is_capitalized = word.chars().next().is_some_and(|c| c.is_uppercase());
740 let is_all_caps =
741 word.chars().all(|c| c.is_uppercase() || !c.is_alphabetic()) && word.len() > 1;
742 let has_digit = word.chars().any(|c| c.is_ascii_digit());
743 let is_title_case = is_capitalized && word.len() > 1;
744
745 let org_suffixes = [
747 "Inc", "Corp", "Ltd", "LLC", "Co", "Company", "Inc.", "Corp.", "Ltd.",
748 ];
749 let is_org_suffix = org_suffixes.contains(&word);
750
751 if state == "O" {
752 if !is_capitalized {
754 return 0.7;
755 }
756 if has_digit {
758 return 0.5;
759 }
760 if is_title_case {
762 return 0.15;
763 }
764 return 0.4;
765 }
766
767 if state.starts_with("B-") || state.starts_with("I-") {
768 let entity_type = &state[2..];
769
770 if entity_type == "ORG" && is_org_suffix {
772 return 0.8;
773 }
774
775 if is_all_caps && entity_type == "ORG" {
777 return 0.6;
778 }
779
780 if is_title_case && !has_digit {
784 if entity_type == "PER" {
785 return 0.55; } else if entity_type == "LOC" {
787 return 0.45; } else if entity_type == "ORG" {
789 return 0.35; }
791 return 0.4;
792 }
793
794 if is_capitalized && !has_digit {
796 return 0.3;
797 }
798
799 return self.config.smoothing;
800 }
801
802 self.config.smoothing
803 }
804
805 fn viterbi(&self, words: &[&str]) -> Vec<usize> {
807 if words.is_empty() {
808 return vec![];
809 }
810
811 let n = words.len();
812 let m = self.states.len();
813
814 let log = |p: f64| if p > 0.0 { p.ln() } else { f64::NEG_INFINITY };
816
817 let mut dp = vec![vec![f64::NEG_INFINITY; m]; n];
819 let mut backptr = vec![vec![0usize; m]; n];
820
821 for (j, cell) in dp[0].iter_mut().enumerate().take(m) {
823 *cell = log(self.initial[j]) + log(self.emission_prob(j, words[0]));
824 }
825
826 for t in 1..n {
828 for j in 0..m {
829 let emit = log(self.emission_prob(j, words[t]));
830
831 for i in 0..m {
832 let trans = log(self.transitions[i][j]);
833 let score = dp[t - 1][i] + trans + emit;
834
835 if score > dp[t][j] {
836 dp[t][j] = score;
837 backptr[t][j] = i;
838 }
839 }
840 }
841 }
842
843 let mut best_state = 0;
845 let mut best_score = f64::NEG_INFINITY;
846 for (j, &score) in dp[n - 1].iter().enumerate() {
847 if score > best_score {
848 best_score = score;
849 best_state = j;
850 }
851 }
852
853 let mut path = vec![0usize; n];
855 path[n - 1] = best_state;
856 for t in (0..n - 1).rev() {
857 path[t] = backptr[t + 1][path[t + 1]];
858 }
859
860 path
861 }
862
863 fn decode_entities(&self, text: &str, words: &[&str], labels: &[usize]) -> Vec<Entity> {
869 use crate::offset::SpanConverter;
870
871 let converter = SpanConverter::new(text);
872 let mut entities = Vec::new();
873
874 let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, words);
876
877 let mut current: Option<(usize, usize, EntityType, Vec<&str>)> = None;
878
879 for (i, (&label_idx, &word)) in labels.iter().zip(words.iter()).enumerate() {
880 let label = &self.states[label_idx];
881
882 if label.starts_with("B-") {
883 if let Some((start_idx, end_idx, entity_type, entity_words)) = current.take() {
885 Self::push_entity_from_positions(
886 &converter,
887 &token_positions,
888 start_idx,
889 end_idx,
890 &entity_words,
891 entity_type,
892 &mut entities,
893 );
894 }
895
896 let entity_type_str = label
898 .strip_prefix("B-")
899 .or_else(|| label.strip_prefix("I-"))
900 .expect("label should start with B- or I-");
901 let entity_type = match entity_type_str {
902 "PER" => EntityType::Person,
903 "ORG" => EntityType::Organization,
904 "LOC" => EntityType::Location,
905 other => EntityType::Other(other.to_string()),
906 };
907 current = Some((i, i, entity_type, vec![word]));
908 } else if label.starts_with("I-") && current.is_some() {
909 if let Some((_, ref mut end_idx, _, ref mut entity_words)) = current {
910 entity_words.push(word);
911 *end_idx = i;
912 }
913 } else {
914 if let Some((start_idx, end_idx, entity_type, entity_words)) = current.take() {
916 Self::push_entity_from_positions(
917 &converter,
918 &token_positions,
919 start_idx,
920 end_idx,
921 &entity_words,
922 entity_type,
923 &mut entities,
924 );
925 }
926 }
927 }
928
929 if let Some((start_idx, end_idx, entity_type, entity_words)) = current {
931 Self::push_entity_from_positions(
932 &converter,
933 &token_positions,
934 start_idx,
935 end_idx,
936 &entity_words,
937 entity_type,
938 &mut entities,
939 );
940 }
941
942 entities
943 }
944
945 fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
947 let mut positions = Vec::with_capacity(tokens.len());
948 let mut byte_pos = 0;
949
950 for token in tokens {
951 if let Some(rel_pos) = text[byte_pos..].find(token) {
953 let start = byte_pos + rel_pos;
954 let end = start + token.len();
955 positions.push((start, end));
956 byte_pos = end; } else {
958 positions.push((byte_pos, byte_pos));
960 }
961 }
962
963 positions
964 }
965
966 fn push_entity_from_positions(
968 converter: &crate::offset::SpanConverter,
969 positions: &[(usize, usize)],
970 start_token_idx: usize,
971 end_token_idx: usize,
972 words: &[&str],
973 entity_type: EntityType,
974 entities: &mut Vec<Entity>,
975 ) {
976 if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
977 return;
978 }
979
980 let byte_start = positions[start_token_idx].0;
981 let byte_end = positions[end_token_idx].1;
982 let char_start = converter.byte_to_char(byte_start);
983 let char_end = converter.byte_to_char(byte_end);
984 let entity_text = words.join(" ");
985
986 entities.push(Entity::new(
987 entity_text,
988 entity_type,
989 char_start,
990 char_end,
991 0.65, ));
993 }
994
995 pub fn train(&mut self, sentences: &[(&[&str], &[&str])]) {
1000 let n = self.states.len();
1002 let mut trans_counts = vec![vec![0usize; n]; n];
1003 let mut initial_counts = vec![0usize; n];
1004 let mut emission_counts: HashMap<(usize, String), usize> = HashMap::new();
1005 let mut state_counts = vec![0usize; n];
1006
1007 for (words, tags) in sentences {
1008 if tags.is_empty() {
1009 continue;
1010 }
1011
1012 if let Some(&idx) = self.state_to_idx.get(tags[0]) {
1014 initial_counts[idx] += 1;
1015 }
1016
1017 for (i, (word, tag)) in words.iter().zip(tags.iter()).enumerate() {
1019 if let Some(&tag_idx) = self.state_to_idx.get(*tag) {
1020 *emission_counts
1022 .entry((tag_idx, word.to_lowercase()))
1023 .or_insert(0) += 1;
1024 state_counts[tag_idx] += 1;
1025
1026 if i > 0 {
1028 if let Some(&prev_idx) = self.state_to_idx.get(tags[i - 1]) {
1029 trans_counts[prev_idx][tag_idx] += 1;
1030 }
1031 }
1032 }
1033 }
1034 }
1035
1036 let total_initial: f64 =
1038 initial_counts.iter().sum::<usize>() as f64 + self.config.smoothing * n as f64;
1039 for (i, &count) in initial_counts.iter().enumerate() {
1040 self.initial[i] = (count as f64 + self.config.smoothing) / total_initial;
1041 }
1042
1043 for (i, row) in trans_counts.iter().enumerate().take(n) {
1044 let total: f64 = row.iter().sum::<usize>() as f64 + self.config.smoothing * n as f64;
1045 for (j, &count) in row.iter().enumerate().take(n) {
1046 self.transitions[i][j] = (count as f64 + self.config.smoothing) / total;
1047 }
1048 }
1049
1050 for ((state_idx, word), count) in emission_counts {
1051 let total = state_counts[state_idx] as f64;
1052 if total > 0.0 {
1053 self.emissions
1054 .insert((state_idx, word), count as f64 / total);
1055 }
1056 }
1057 }
1058
1059 fn len_bucket(word: &str) -> &'static str {
1060 let n = word.chars().count();
1061 if n <= 1 {
1062 "len:1"
1063 } else if n == 2 {
1064 "len:2"
1065 } else if n == 3 {
1066 "len:3"
1067 } else if (4..=5).contains(&n) {
1068 "len:4_5"
1069 } else if (6..=8).contains(&n) {
1070 "len:6_8"
1071 } else {
1072 "len:9p"
1073 }
1074 }
1075
1076 fn bool_features(word: &str) -> HashMap<&'static str, bool> {
1077 let is_capitalized = word.chars().next().is_some_and(|c| c.is_uppercase());
1078 let is_all_caps = word.chars().all(|c| c.is_uppercase() || !c.is_alphabetic())
1079 && word.chars().count() > 1;
1080 let is_digit = !word.is_empty() && word.chars().all(|c| c.is_ascii_digit());
1081 let has_digit = word.chars().any(|c| c.is_ascii_digit());
1082 let has_hyphen = word.contains('-');
1083 let has_dot = word.contains('.');
1084 let mut m = HashMap::new();
1085 m.insert("is_capitalized", is_capitalized);
1086 m.insert("is_all_caps", is_all_caps);
1087 m.insert("is_digit", is_digit);
1088 m.insert("has_digit", has_digit);
1089 m.insert("has_hyphen", has_hyphen);
1090 m.insert("has_dot", has_dot);
1091 m
1092 }
1093}
1094
1095impl HmmBackoff {
1096 fn from_params(p: &HmmParams) -> Self {
1097 let mut len: HashMap<String, Vec<f64>> = HashMap::new();
1103 let mut bool_present: HashMap<String, Vec<f64>> = HashMap::new();
1104
1105 if let Some(obj) = p.backoff.as_object() {
1106 if let Some(len_obj) = obj.get("len").and_then(|v| v.as_object()) {
1107 for (bucket, distv) in len_obj {
1108 let mut v = vec![1e-12; p.states.len()];
1109 if let Some(dist) = distv.as_object() {
1110 for (i, state) in p.states.iter().enumerate() {
1111 if let Some(x) = dist.get(state).and_then(|x| x.as_f64()) {
1112 v[i] = x;
1113 }
1114 }
1115 }
1116 len.insert(bucket.clone(), v);
1117 }
1118 }
1119 if let Some(bool_obj) = obj.get("bool").and_then(|v| v.as_object()) {
1120 for (feat, distv) in bool_obj {
1121 let mut v = vec![1e-12; p.states.len()];
1122 if let Some(dist) = distv.as_object() {
1123 for (i, state) in p.states.iter().enumerate() {
1124 if let Some(x) = dist.get(state).and_then(|x| x.as_f64()) {
1125 v[i] = x;
1126 }
1127 }
1128 }
1129 bool_present.insert(feat.clone(), v);
1130 }
1131 }
1132 }
1133
1134 let mut bool_keys: Vec<String> = bool_present.keys().cloned().collect();
1135 bool_keys.sort();
1136 Self {
1137 len,
1138 bool_present,
1139 bool_keys,
1140 }
1141 }
1142}
1143
1144impl Default for HmmNER {
1145 fn default() -> Self {
1146 Self::new()
1147 }
1148}
1149
1150impl Model for HmmNER {
1151 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
1152 if text.trim().is_empty() {
1153 return Ok(vec![]);
1154 }
1155
1156 let words: Vec<&str> = text.split_whitespace().collect();
1157 if words.is_empty() {
1158 return Ok(vec![]);
1159 }
1160
1161 let label_indices = self.viterbi(&words);
1162 let entities = self.decode_entities(text, &words, &label_indices);
1163
1164 Ok(entities)
1165 }
1166
1167 fn supported_types(&self) -> Vec<EntityType> {
1168 vec![
1169 EntityType::Person,
1170 EntityType::Organization,
1171 EntityType::Location,
1172 EntityType::Other("MISC".to_string()),
1173 ]
1174 }
1175
1176 fn is_available(&self) -> bool {
1177 true }
1179}
1180
1181impl crate::sealed::Sealed for HmmNER {}
1182impl crate::NamedEntityCapable for HmmNER {}
1183
1184#[cfg(test)]
1185mod tests {
1186 use super::*;
1187
1188 #[test]
1189 fn test_basic_extraction() {
1190 let ner = HmmNER::new();
1191 let entities = ner
1192 .extract_entities("John works at Google in California.", None)
1193 .unwrap();
1194
1195 for entity in &entities {
1197 assert!(entity.confidence > 0.0 && entity.confidence <= 1.0);
1198 }
1199 }
1200
1201 #[test]
1202 fn test_empty_input() {
1203 let ner = HmmNER::new();
1204 let entities = ner.extract_entities("", None).unwrap();
1205 assert!(entities.is_empty());
1206 }
1207
1208 #[test]
1209 fn test_viterbi_path_length() {
1210 let ner = HmmNER::new();
1211 let words = vec!["John", "works", "at", "Google"];
1212 let path = ner.viterbi(&words);
1213
1214 assert_eq!(path.len(), words.len());
1215 }
1216
1217 #[test]
1218 fn test_bio_constraints() {
1219 let ner = HmmNER::new();
1220
1221 let i_per = ner.state_to_idx["I-PER"];
1223 let o = ner.state_to_idx["O"];
1224 let b_per = ner.state_to_idx["B-PER"];
1225
1226 assert!(ner.transitions[o][i_per] < 0.01);
1228
1229 assert!(ner.transitions[b_per][i_per] > 0.1);
1231 }
1232
1233 #[test]
1234 fn test_emission_heuristics() {
1235 let ner = HmmNER::new();
1236
1237 let _o_idx = ner.state_to_idx["O"];
1238 let b_per_idx = ner.state_to_idx["B-PER"];
1239
1240 let cap_prob = ner.emission_prob(b_per_idx, "John");
1242 let lower_prob = ner.emission_prob(b_per_idx, "john");
1243
1244 assert!(cap_prob >= lower_prob);
1245 }
1246
1247 #[test]
1248 fn test_training() {
1249 let mut ner = HmmNER::new();
1250
1251 let sentences: Vec<(&[&str], &[&str])> = vec![
1252 (
1253 &["John", "works", "at", "Google"][..],
1254 &["B-PER", "O", "O", "B-ORG"][..],
1255 ),
1256 (
1257 &["Mary", "lives", "in", "Paris"][..],
1258 &["B-PER", "O", "O", "B-LOC"][..],
1259 ),
1260 ];
1261
1262 ner.train(&sentences);
1263
1264 let b_per = ner.state_to_idx["B-PER"];
1266 let o = ner.state_to_idx["O"];
1267
1268 assert!(ner.transitions[b_per][o] > 0.3);
1270 }
1271
1272 #[test]
1273 fn test_unicode_offsets() {
1274 let ner = HmmNER::new();
1275 let text = "北京 Google Inc.";
1276 let char_count = text.chars().count();
1277
1278 let entities = ner.extract_entities(text, None).unwrap();
1279
1280 for entity in &entities {
1281 assert!(entity.start <= entity.end);
1282 assert!(entity.end <= char_count);
1283 }
1284 }
1285
1286 #[test]
1287 fn test_config() {
1288 let config = HmmConfig {
1289 smoothing: 1e-5,
1290 ..Default::default()
1291 };
1292
1293 let ner = HmmNER::with_config(config);
1294 assert_eq!(ner.config.smoothing, 1e-5);
1295 }
1296
1297 #[test]
1298 fn test_supported_types() {
1299 let ner = HmmNER::new();
1300 let types = ner.supported_types();
1301
1302 assert!(types.contains(&EntityType::Person));
1303 assert!(types.contains(&EntityType::Organization));
1304 assert!(types.contains(&EntityType::Location));
1305 }
1306
1307 #[test]
1309 fn test_duplicate_entity_offsets() {
1310 let text = "Google bought Google for $1 billion.";
1312 let tokens: Vec<&str> = text.split_whitespace().collect();
1313 let positions = HmmNER::calculate_token_positions(text, &tokens);
1314
1315 assert_eq!(
1317 positions[0],
1318 (0, 6),
1319 "First 'Google' should be at bytes 0-6"
1320 );
1321 assert_eq!(
1323 positions[2],
1324 (14, 20),
1325 "Second 'Google' should be at bytes 14-20"
1326 );
1327 }
1328
1329 #[test]
1331 fn test_token_positions_unicode() {
1332 let text = "東京 Tokyo 東京";
1333 let tokens: Vec<&str> = text.split_whitespace().collect();
1334 let positions = HmmNER::calculate_token_positions(text, &tokens);
1335
1336 assert_eq!(positions[0], (0, 6), "First '東京' at bytes 0-6");
1338 assert_eq!(positions[1], (7, 12), "Tokyo at bytes 7-12");
1339 assert_eq!(positions[2], (13, 19), "Second '東京' at bytes 13-19");
1340 }
1341}