1use crate::{Entity, EntityType, Model, Result};
124use std::collections::HashMap;
125#[cfg(feature = "bundled-crf-weights")]
126use std::sync::OnceLock;
127
128pub struct CrfNER {
133 weights: HashMap<String, f64>,
135 gazetteers: HashMap<EntityType, Vec<String>>,
137 labels: Vec<String>,
139 templates: Vec<FeatureTemplate>,
141}
142
143#[derive(Debug, Clone)]
145pub enum FeatureTemplate {
146 Word,
148 WordAt(i32),
150 Shape,
152 ShapeAt(i32),
154 Prefix(usize),
156 Suffix(usize),
158 InGazetteer(EntityType),
160 PrevLabel,
162 LabelBigram,
164 WordLabel,
166}
167
168impl Default for CrfNER {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174impl CrfNER {
175 #[must_use]
177 pub fn new() -> Self {
178 let mut gazetteers = HashMap::new();
180
181 gazetteers.insert(
182 EntityType::Person,
183 vec![
184 "John",
186 "Mary",
187 "James",
188 "Robert",
189 "Michael",
190 "David",
191 "William",
192 "Richard",
193 "Joseph",
194 "Thomas",
195 "Elizabeth",
196 "Jennifer",
197 "Linda",
198 "Barbara",
199 "Susan",
200 "Jessica",
201 "Sarah",
202 "Karen",
203 "Nancy",
204 "Margaret",
205 "Dr",
207 "Mr",
208 "Mrs",
209 "Ms",
210 "Prof",
211 "President",
212 "CEO",
213 "Senator",
214 ]
215 .into_iter()
216 .map(String::from)
217 .collect(),
218 );
219
220 gazetteers.insert(
221 EntityType::Location,
222 vec![
223 "USA",
225 "UK",
226 "France",
227 "Germany",
228 "China",
229 "Japan",
230 "India",
231 "Brazil",
232 "Canada",
233 "Australia",
234 "Russia",
235 "Italy",
236 "Spain",
237 "Mexico",
238 "California",
240 "Texas",
241 "Florida",
242 "New York",
243 "Illinois",
244 "Pennsylvania",
245 "London",
247 "Paris",
248 "Tokyo",
249 "Beijing",
250 "Moscow",
251 "Berlin",
252 "Rome",
253 "Madrid",
254 "Sydney",
255 "Toronto",
256 "Mumbai",
257 "Shanghai",
258 "Seoul",
259 ]
260 .into_iter()
261 .map(String::from)
262 .collect(),
263 );
264
265 gazetteers.insert(
266 EntityType::Organization,
267 vec![
268 "Google",
270 "Apple",
271 "Microsoft",
272 "Amazon",
273 "Facebook",
274 "Tesla",
275 "IBM",
276 "Intel",
277 "Oracle",
278 "Cisco",
279 "Samsung",
280 "Sony",
281 "Toyota",
282 "Honda",
283 "Inc",
285 "Corp",
286 "LLC",
287 "Ltd",
288 "Company",
289 "Corporation",
290 "Group",
291 "UN",
293 "NATO",
294 "WHO",
295 "FBI",
296 "CIA",
297 "NASA",
298 "EU",
299 "OPEC",
300 ]
301 .into_iter()
302 .map(String::from)
303 .collect(),
304 );
305
306 let labels = vec![
308 "O".to_string(),
309 "B-PER".to_string(),
310 "I-PER".to_string(),
311 "B-ORG".to_string(),
312 "I-ORG".to_string(),
313 "B-LOC".to_string(),
314 "I-LOC".to_string(),
315 "B-MISC".to_string(),
316 "I-MISC".to_string(),
317 ];
318
319 let templates = vec![
321 FeatureTemplate::Word,
322 FeatureTemplate::WordAt(-1),
323 FeatureTemplate::WordAt(1),
324 FeatureTemplate::Shape,
325 FeatureTemplate::ShapeAt(-1),
326 FeatureTemplate::ShapeAt(1),
327 FeatureTemplate::Prefix(2),
328 FeatureTemplate::Prefix(3),
329 FeatureTemplate::Suffix(2),
330 FeatureTemplate::Suffix(3),
331 FeatureTemplate::InGazetteer(EntityType::Person),
332 FeatureTemplate::InGazetteer(EntityType::Location),
333 FeatureTemplate::InGazetteer(EntityType::Organization),
334 FeatureTemplate::PrevLabel,
335 ];
336
337 let weights = Self::shipped_weights().unwrap_or_else(Self::default_weights);
339
340 Self {
341 weights,
342 gazetteers,
343 labels,
344 templates,
345 }
346 }
347
348 #[must_use]
353 pub fn new_heuristic() -> Self {
354 let mut m = Self::new();
355 m.weights = Self::default_weights();
356 m
357 }
358
359 fn shipped_weights() -> Option<HashMap<String, f64>> {
360 #[cfg(feature = "bundled-crf-weights")]
361 {
362 static ONCE: OnceLock<Option<HashMap<String, f64>>> = OnceLock::new();
363 return ONCE
364 .get_or_init(|| {
365 let s = include_str!("crf_weights.json");
368 serde_json::from_str::<HashMap<String, f64>>(s).ok()
369 })
370 .clone();
371 }
372 #[cfg(not(feature = "bundled-crf-weights"))]
373 {
374 None
375 }
376 }
377
378 pub fn load_weights(path: &str) -> Result<HashMap<String, f64>> {
389 let content = std::fs::read_to_string(path).map_err(|e| {
390 crate::Error::invalid_input(format!("Failed to read weights file: {}", e))
391 })?;
392 let weights: HashMap<String, f64> = serde_json::from_str(&content).map_err(|e| {
393 crate::Error::invalid_input(format!("Failed to parse weights JSON: {}", e))
394 })?;
395 Ok(weights)
396 }
397
398 pub fn with_weights(path: &str) -> Result<Self> {
400 let weights = Self::load_weights(path)?;
401 let mut model = Self::new();
402 model.weights = weights;
403 Ok(model)
404 }
405
406 fn default_weights() -> HashMap<String, f64> {
412 let mut w = HashMap::new();
413
414 w.insert("bias:O".to_string(), 3.0);
416
417 w.insert("word.shape=x:O".to_string(), 2.5);
419 w.insert("word.shape=x:B-PER".to_string(), -3.0);
420 w.insert("word.shape=x:I-PER".to_string(), -2.0);
421
422 w.insert("gaz:PER:B-PER".to_string(), 4.0);
424 w.insert("gaz:LOC:B-LOC".to_string(), 4.0);
425 w.insert("gaz:ORG:B-ORG".to_string(), 4.0);
426
427 w.insert("word.shape=Xx:B-PER".to_string(), 2.0);
429 w.insert("word.shape=Xx:B-LOC".to_string(), 1.5);
430 w.insert("word.shape=Xx:B-ORG".to_string(), 1.5);
431 w.insert("word.shape=Xx:I-PER".to_string(), 1.0); w.insert("word.shape=Xx:I-ORG".to_string(), 1.0);
433 w.insert("word.shape=XX:B-ORG".to_string(), 2.5); w.insert("word.shape=x:B-PER".to_string(), -2.0);
437 w.insert("word.shape=x:B-ORG".to_string(), -2.0);
438 w.insert("word.shape=x:B-LOC".to_string(), -2.0);
439
440 for word in [
442 "the",
443 "a",
444 "an",
445 "of",
446 "in",
447 "at",
448 "to",
449 "and",
450 "or",
451 "is",
452 "was",
453 "were",
454 "be",
455 "been",
456 "being",
457 "have",
458 "has",
459 "had",
460 "do",
461 "does",
462 "did",
463 "will",
464 "would",
465 "could",
466 "should",
467 "may",
468 "might",
469 "must",
470 "can",
471 "won",
472 "works",
473 "worked",
474 "working",
475 "serves",
476 "served",
477 "announced",
478 "said",
479 "made",
480 "that",
481 "this",
482 "which",
483 "for",
484 "with",
485 "as",
486 "by",
487 "on",
488 "from",
489 "into",
490 "through",
491 "during",
492 "before",
493 "after",
494 "above",
495 "below",
496 "between",
497 "under",
498 "again",
499 "further",
500 "then",
501 "once",
502 "here",
503 "there",
504 "when",
505 "where",
506 "why",
507 "how",
508 "all",
509 "each",
510 "few",
511 "more",
512 "most",
513 "other",
514 "some",
515 "such",
516 "no",
517 "not",
518 "only",
519 "own",
520 "same",
521 "so",
522 "than",
523 "too",
524 "very",
525 ] {
526 w.insert(format!("word.lower={}:O", word), 5.0);
527 w.insert(format!("word.lower={}:B-PER", word), -5.0);
528 w.insert(format!("word.lower={}:B-ORG", word), -5.0);
529 w.insert(format!("word.lower={}:B-LOC", word), -5.0);
530 w.insert(format!("word.lower={}:I-PER", word), -4.0);
531 w.insert(format!("word.lower={}:I-ORG", word), -4.0);
532 w.insert(format!("word.lower={}:I-LOC", word), -4.0);
533 }
534
535 w.insert("suffix3=inc:B-ORG".to_string(), 3.0);
537 w.insert("suffix3=ltd:B-ORG".to_string(), 3.0);
540 w.insert("suffix3=llc:B-ORG".to_string(), 3.0);
541
542 w.insert("-1:word.lower=dr:B-PER".to_string(), 2.5);
544 w.insert("-1:word.lower=mr:B-PER".to_string(), 2.5);
545 w.insert("-1:word.lower=mrs:B-PER".to_string(), 2.5);
546 w.insert("-1:word.lower=ms:B-PER".to_string(), 2.5);
547 w.insert("-1:word.lower=prof:B-PER".to_string(), 2.5);
548 w.insert("-1:word.lower=president:B-PER".to_string(), 2.0);
549 w.insert("-1:word.lower=ceo:B-PER".to_string(), 2.0);
550
551 w.insert("-1:word.lower=in:B-LOC".to_string(), 1.5);
553 w.insert("-1:word.lower=at:B-LOC".to_string(), 1.5);
554 w.insert("-1:word.lower=from:B-LOC".to_string(), 1.5);
555 w.insert("-1:word.lower=of:B-LOC".to_string(), 1.0);
556 w.insert("-1:word.lower=of:B-ORG".to_string(), 1.0);
557
558 w.insert("trans:B-PER->I-PER".to_string(), 3.0);
561 w.insert("trans:B-ORG->I-ORG".to_string(), 3.0);
562 w.insert("trans:B-LOC->I-LOC".to_string(), 3.0);
563 w.insert("trans:I-PER->I-PER".to_string(), 2.0);
564 w.insert("trans:I-ORG->I-ORG".to_string(), 2.0);
565 w.insert("trans:I-LOC->I-LOC".to_string(), 2.0);
566
567 w.insert("trans:B-PER->O".to_string(), 0.0);
569 w.insert("trans:B-ORG->O".to_string(), 0.0);
570 w.insert("trans:B-LOC->O".to_string(), 0.0);
571 w.insert("trans:I-PER->O".to_string(), 0.0);
572 w.insert("trans:I-ORG->O".to_string(), 0.0);
573 w.insert("trans:I-LOC->O".to_string(), 0.0);
574
575 w.insert("trans:O->I-PER".to_string(), -10.0);
577 w.insert("trans:O->I-ORG".to_string(), -10.0);
578 w.insert("trans:O->I-LOC".to_string(), -10.0);
579 w.insert("trans:O->I-MISC".to_string(), -10.0);
580
581 w.insert("trans:B-PER->I-ORG".to_string(), -10.0);
583 w.insert("trans:B-PER->I-LOC".to_string(), -10.0);
584 w.insert("trans:B-ORG->I-PER".to_string(), -10.0);
585 w.insert("trans:B-ORG->I-LOC".to_string(), -10.0);
586 w.insert("trans:B-LOC->I-PER".to_string(), -10.0);
587 w.insert("trans:B-LOC->I-ORG".to_string(), -10.0);
588
589 w
590 }
591
592 #[allow(clippy::is_digit_ascii_radix)]
597 fn is_digit_py(c: char) -> bool {
598 c.is_digit(10)
599 }
600
601 fn word_shape(word: &str) -> String {
603 word.chars()
604 .map(|c| {
605 if c.is_uppercase() {
606 'X'
607 } else if c.is_lowercase() {
608 'x'
609 } else if Self::is_digit_py(c) {
610 '0'
611 } else {
612 c
613 }
614 })
615 .collect::<String>()
616 .chars()
618 .fold(String::new(), |mut acc, c| {
619 if !acc.ends_with(&c.to_string()) {
620 acc.push(c);
621 }
622 acc
623 })
624 }
625
626 fn extract_features(&self, tokens: &[&str], pos: usize, _prev_label: &str) -> Vec<String> {
646 let mut features = Vec::with_capacity(20);
647 let word = tokens[pos];
648
649 fn bool_py(v: bool) -> &'static str {
650 if v {
651 "True"
652 } else {
653 "False"
654 }
655 }
656
657 features.push("bias".to_string());
659
660 features.push(format!("word.lower={}", word.to_lowercase()));
662 features.push(format!("word.shape={}", Self::word_shape(word)));
663 features.push(format!(
664 "word.isdigit={}",
665 bool_py(!word.is_empty() && word.chars().all(Self::is_digit_py))
669 ));
670 features.push(format!(
671 "word.istitle={}",
672 bool_py(
673 word.chars().next().is_some_and(|c| c.is_uppercase())
674 && word.chars().skip(1).all(|c| c.is_lowercase())
675 )
676 ));
677 features.push(format!(
678 "word.isupper={}",
679 bool_py(word.chars().all(|c| c.is_uppercase()))
680 ));
681
682 let chars: Vec<char> = word.chars().collect();
684 if chars.len() >= 2 {
685 let prefix2: String = chars[..2].iter().collect();
686 let suffix2: String = chars[chars.len() - 2..].iter().collect();
687 features.push(format!("prefix2={}", prefix2.to_lowercase()));
688 features.push(format!("suffix2={}", suffix2.to_lowercase()));
689 }
690 if chars.len() >= 3 {
691 let prefix3: String = chars[..3].iter().collect();
692 let suffix3: String = chars[chars.len() - 3..].iter().collect();
693 features.push(format!("prefix3={}", prefix3.to_lowercase()));
694 features.push(format!("suffix3={}", suffix3.to_lowercase()));
695 }
696
697 if pos > 0 {
699 let prev_word = tokens[pos - 1];
700 features.push(format!("-1:word.lower={}", prev_word.to_lowercase()));
701 features.push(format!(
702 "-1:word.istitle={}",
703 bool_py(
704 prev_word.chars().next().is_some_and(|c| c.is_uppercase())
705 && prev_word.chars().skip(1).all(|c| c.is_lowercase())
706 )
707 ));
708 features.push(format!(
709 "-1:word.isupper={}",
710 bool_py(prev_word.chars().all(|c| c.is_uppercase()))
711 ));
712 features.push(format!("-1:word.shape={}", Self::word_shape(prev_word)));
713 } else {
714 features.push("BOS".to_string());
715 }
716
717 if pos + 1 < tokens.len() {
719 let next_word = tokens[pos + 1];
720 features.push(format!("+1:word.lower={}", next_word.to_lowercase()));
721 features.push(format!(
722 "+1:word.istitle={}",
723 bool_py(
724 next_word.chars().next().is_some_and(|c| c.is_uppercase())
725 && next_word.chars().skip(1).all(|c| c.is_lowercase())
726 )
727 ));
728 features.push(format!(
729 "+1:word.isupper={}",
730 bool_py(next_word.chars().all(|c| c.is_uppercase()))
731 ));
732 features.push(format!("+1:word.shape={}", Self::word_shape(next_word)));
733 } else {
734 features.push("EOS".to_string());
735 }
736
737 for template in &self.templates {
739 if let FeatureTemplate::InGazetteer(entity_type) = template {
740 if let Some(gaz) = self.gazetteers.get(entity_type) {
741 if gaz.iter().any(|g| g.eq_ignore_ascii_case(word)) {
742 features.push(format!("gaz:{}", entity_type.as_label()));
743 }
744 }
745 }
746 }
747
748 features
749 }
750
751 fn score_label(&self, features: &[String], label: &str) -> f64 {
753 let mut score = 0.0;
754 let debug = std::env::var("CRF_DEBUG").is_ok();
755
756 if debug && label == "I-PER" {
757 eprintln!(" Features for I-PER: {:?}", features);
758 }
759
760 for feat in features {
761 let key = format!("{}:{}", feat, label);
762 if let Some(&w) = self.weights.get(&key) {
763 if debug && w.abs() > 0.1 {
764 eprintln!(" CRF: {} -> {:.2}", key, w);
765 }
766 score += w;
767 }
768 if let Some(&w) = self.weights.get(feat) {
770 score += w * 0.5;
771 }
772 }
773 if label == "O" {
775 score += 0.5; }
777 score
778 }
779
780 fn viterbi_decode(&self, tokens: &[&str]) -> Vec<String> {
782 if tokens.is_empty() {
783 return vec![];
784 }
785
786 let n = tokens.len();
787 let m = self.labels.len();
788
789 let mut scores = vec![vec![f64::NEG_INFINITY; m]; n];
791 let mut backpointers = vec![vec![0usize; m]; n];
792
793 let features = self.extract_features(tokens, 0, "O");
795 for (j, label) in self.labels.iter().enumerate() {
796 scores[0][j] = self.score_label(&features, label);
797 }
798
799 for i in 1..n {
801 for (j, label) in self.labels.iter().enumerate() {
802 let mut best_score = f64::NEG_INFINITY;
803 let mut best_prev = 0;
804
805 for (k, prev_label) in self.labels.iter().enumerate() {
806 let features = self.extract_features(tokens, i, prev_label);
807 let trans_key = format!("trans:{}->{}", prev_label, label);
808 let trans_score = self.weights.get(&trans_key).copied().unwrap_or(0.0);
809 let score = scores[i - 1][k] + self.score_label(&features, label) + trans_score;
810
811 if score > best_score {
812 best_score = score;
813 best_prev = k;
814 }
815 }
816
817 scores[i][j] = best_score;
818 backpointers[i][j] = best_prev;
819 }
820 }
821
822 let mut path = vec![0usize; n];
824 let mut best_final = 0;
825 let mut best_score = f64::NEG_INFINITY;
826 for (j, &score) in scores[n - 1].iter().enumerate() {
827 if score > best_score {
828 best_score = score;
829 best_final = j;
830 }
831 }
832 path[n - 1] = best_final;
833
834 for i in (0..n - 1).rev() {
835 path[i] = backpointers[i + 1][path[i + 1]];
836 }
837
838 path.iter().map(|&j| self.labels[j].clone()).collect()
839 }
840
841 fn labels_to_entities(&self, text: &str, tokens: &[&str], labels: &[String]) -> Vec<Entity> {
850 use crate::offset::SpanConverter;
851
852 let mut entities = Vec::new();
853
854 let converter = SpanConverter::new(text);
856
857 let token_positions: Vec<(usize, usize)> = Self::calculate_token_positions(text, tokens);
859
860 let mut current_entity: Option<(usize, usize, EntityType, Vec<&str>)> = None;
861
862 for (i, (token, label)) in tokens.iter().zip(labels.iter()).enumerate() {
863 if label.starts_with("B-") {
864 if let Some((start_idx, end_idx, entity_type, words)) = current_entity.take() {
866 Self::push_entity_from_positions(
867 &converter,
868 &token_positions,
869 start_idx,
870 end_idx,
871 &words,
872 entity_type,
873 &mut entities,
874 );
875 }
876
877 let entity_type = match label.as_str() {
879 "B-PER" => EntityType::Person,
880 "B-ORG" => EntityType::Organization,
881 "B-LOC" => EntityType::Location,
882 _ => EntityType::Other("MISC".to_string()),
883 };
884 current_entity = Some((i, i, entity_type, vec![token]));
885 } else if label.starts_with("I-") {
886 if let Some((_, ref mut end_idx, _, ref mut words)) = current_entity {
888 words.push(token);
889 *end_idx = i;
890 }
891 } else {
892 if let Some((start_idx, end_idx, entity_type, words)) = current_entity.take() {
894 Self::push_entity_from_positions(
895 &converter,
896 &token_positions,
897 start_idx,
898 end_idx,
899 &words,
900 entity_type,
901 &mut entities,
902 );
903 }
904 }
905 }
906
907 if let Some((start_idx, end_idx, entity_type, words)) = current_entity.take() {
909 Self::push_entity_from_positions(
910 &converter,
911 &token_positions,
912 start_idx,
913 end_idx,
914 &words,
915 entity_type,
916 &mut entities,
917 );
918 }
919
920 entities
921 }
922
923 fn calculate_token_positions(text: &str, tokens: &[&str]) -> Vec<(usize, usize)> {
925 let mut positions = Vec::with_capacity(tokens.len());
926 let mut byte_pos = 0;
927
928 for token in tokens {
929 if let Some(rel_pos) = text[byte_pos..].find(token) {
931 let start = byte_pos + rel_pos;
932 let end = start + token.len();
933 positions.push((start, end));
934 byte_pos = end; } else {
936 positions.push((byte_pos, byte_pos));
938 }
939 }
940
941 positions
942 }
943
944 fn push_entity_from_positions(
946 converter: &crate::offset::SpanConverter,
947 positions: &[(usize, usize)],
948 start_token_idx: usize,
949 end_token_idx: usize,
950 words: &[&str],
951 entity_type: EntityType,
952 entities: &mut Vec<Entity>,
953 ) {
954 if start_token_idx >= positions.len() || end_token_idx >= positions.len() {
955 return;
956 }
957
958 let byte_start = positions[start_token_idx].0;
959 let byte_end = positions[end_token_idx].1;
960 let char_start = converter.byte_to_char(byte_start);
961 let char_end = converter.byte_to_char(byte_end);
962 let entity_text = words.join(" ");
963
964 entities.push(Entity::new(
965 &entity_text,
966 entity_type,
967 char_start,
968 char_end,
969 0.7, ));
971 }
972
973 fn tokenize(text: &str) -> Vec<&str> {
975 text.split_whitespace().collect()
976 }
977}
978
979impl Model for CrfNER {
980 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
981 if text.trim().is_empty() {
982 return Ok(vec![]);
983 }
984
985 let tokens = Self::tokenize(text);
986 if tokens.is_empty() {
987 return Ok(vec![]);
988 }
989
990 let labels = self.viterbi_decode(&tokens);
991 let entities = self.labels_to_entities(text, &tokens, &labels);
992
993 Ok(entities)
994 }
995
996 fn supported_types(&self) -> Vec<EntityType> {
997 vec![
998 EntityType::Person,
999 EntityType::Organization,
1000 EntityType::Location,
1001 EntityType::Other("MISC".to_string()),
1002 ]
1003 }
1004
1005 fn is_available(&self) -> bool {
1006 true }
1008
1009 fn name(&self) -> &'static str {
1010 "crf"
1011 }
1012
1013 fn description(&self) -> &'static str {
1014 "CRF-based NER (classical statistical method)"
1015 }
1016}
1017
1018impl crate::NamedEntityCapable for CrfNER {}
1019
1020impl crate::BatchCapable for CrfNER {
1021 fn optimal_batch_size(&self) -> Option<usize> {
1022 Some(32) }
1024}
1025
1026impl crate::StreamingCapable for CrfNER {
1027 fn recommended_chunk_size(&self) -> usize {
1028 4096 }
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034 use super::*;
1035
1036 #[test]
1037 fn test_crf_basic() {
1038 let ner = CrfNER::new();
1039 let entities = ner
1040 .extract_entities("John Smith works at Google in California", None)
1041 .unwrap();
1042
1043 assert!(!entities.is_empty(), "Expected some entities, got none");
1046 }
1047
1048 #[test]
1049 fn test_word_shape() {
1050 assert_eq!(CrfNER::word_shape("John"), "Xx");
1051 assert_eq!(CrfNER::word_shape("USA"), "X");
1052 assert_eq!(CrfNER::word_shape("hello"), "x");
1053 assert_eq!(CrfNER::word_shape("123"), "0");
1054 assert_eq!(CrfNER::word_shape("Hello123"), "Xx0");
1055 }
1056
1057 #[test]
1058 fn test_tokenize() {
1059 let tokens = CrfNER::tokenize("Hello world");
1060 assert_eq!(tokens, vec!["Hello", "world"]);
1061 }
1062
1063 #[test]
1064 fn test_empty_input() {
1065 let ner = CrfNER::new();
1066 let entities = ner.extract_entities("", None).unwrap();
1067 assert!(entities.is_empty());
1068 }
1069
1070 #[test]
1071 fn test_gazetteer_lookup() {
1072 let ner = CrfNER::new();
1073
1074 assert!(ner.gazetteers[&EntityType::Person].contains(&"John".to_string()));
1076 assert!(ner.gazetteers[&EntityType::Location].contains(&"California".to_string()));
1077 assert!(ner.gazetteers[&EntityType::Organization].contains(&"Google".to_string()));
1078 }
1079
1080 #[test]
1081 fn test_viterbi_returns_valid_labels() {
1082 let ner = CrfNER::new();
1083 let tokens = vec!["John", "works", "at", "Google"];
1084 let labels = ner.viterbi_decode(&tokens);
1085
1086 assert_eq!(labels.len(), tokens.len());
1087 for label in &labels {
1088 assert!(ner.labels.contains(label));
1089 }
1090 }
1091
1092 #[test]
1093 fn test_common_verbs_not_in_entities() {
1094 let ner = CrfNER::new();
1095
1096 let entities = ner
1098 .extract_entities("John Smith works at Apple", None)
1099 .unwrap();
1100
1101 let entity_texts: Vec<&str> = entities.iter().map(|e| e.text.as_str()).collect();
1103 for entity_text in &entity_texts {
1104 assert!(
1105 !entity_text.contains("works"),
1106 "Entity '{}' should not contain 'works'",
1107 entity_text
1108 );
1109 }
1110 }
1111
1112 #[test]
1113 fn test_weights_for_common_words() {
1114 #[cfg(feature = "bundled-crf-weights")]
1118 {
1119 return;
1120 }
1121
1122 let ner = CrfNER::new();
1123
1124 assert!(
1126 ner.weights.contains_key("word.lower=works:O"),
1127 "Missing weight for word.lower=works:O"
1128 );
1129 assert!(
1130 ner.weights.contains_key("word.lower=works:I-PER"),
1131 "Missing weight for word.lower=works:I-PER"
1132 );
1133
1134 let o_weight = *ner.weights.get("word.lower=works:O").unwrap();
1136 let i_per_weight = *ner.weights.get("word.lower=works:I-PER").unwrap();
1137 assert!(
1138 o_weight > 0.0,
1139 "O weight should be positive, got {}",
1140 o_weight
1141 );
1142 assert!(
1143 i_per_weight < 0.0,
1144 "I-PER weight should be negative, got {}",
1145 i_per_weight
1146 );
1147 }
1148
1149 #[test]
1150 fn test_unicode_char_offsets() {
1151 let ner = CrfNER::new();
1153
1154 let text = "北京 Beijing";
1157 assert_eq!(text.len(), 14, "Expected 14 bytes");
1158 assert_eq!(text.chars().count(), 10, "Expected 10 characters");
1159
1160 let entities = ner.extract_entities(text, None).unwrap();
1161
1162 let char_count = text.chars().count();
1164 for entity in &entities {
1165 assert!(
1166 entity.start <= entity.end,
1167 "Invalid span: start {} > end {}",
1168 entity.start,
1169 entity.end
1170 );
1171 assert!(
1172 entity.end <= char_count,
1173 "Entity end {} exceeds char count {} for text {:?}",
1174 entity.end,
1175 char_count,
1176 text
1177 );
1178
1179 let extracted: String = text
1181 .chars()
1182 .skip(entity.start)
1183 .take(entity.end - entity.start)
1184 .collect();
1185 assert!(
1186 !extracted.is_empty() || entity.start == entity.end,
1187 "Empty extraction for entity at {}..{} in {:?}",
1188 entity.start,
1189 entity.end,
1190 text
1191 );
1192 }
1193 }
1194
1195 #[test]
1196 fn test_multilingual_inputs_no_panic_and_valid_spans() {
1197 let ner = CrfNER::new();
1198 let texts = [
1199 "Marie Curie discovered radium in Paris.",
1201 "習近平在北京會見了普京。",
1203 "التقى محمد بن سلمان بالرئيس في الرياض",
1205 "Путин встретился с Си Цзиньпином в Москве.",
1207 "प्रधान मंत्री शर्मा दिल्ली में मिले।",
1209 ];
1210
1211 for text in texts {
1212 let entities = ner.extract_entities(text, None).unwrap();
1213 let char_count = text.chars().count();
1214 for e in entities {
1215 assert!(e.start <= e.end);
1216 assert!(e.end <= char_count);
1217 let _span: String = text.chars().skip(e.start).take(e.end - e.start).collect();
1218 }
1219 }
1220 }
1221
1222 #[test]
1224 fn test_duplicate_entity_offsets() {
1225 let text = "Google bought Google for $1 billion.";
1227 let tokens: Vec<&str> = text.split_whitespace().collect();
1228 let positions = CrfNER::calculate_token_positions(text, &tokens);
1229
1230 assert_eq!(
1232 positions[0],
1233 (0, 6),
1234 "First 'Google' should be at bytes 0-6"
1235 );
1236 assert_eq!(
1238 positions[2],
1239 (14, 20),
1240 "Second 'Google' should be at bytes 14-20"
1241 );
1242 }
1243
1244 #[test]
1246 fn test_token_positions_unicode() {
1247 let text = "東京 Tokyo 東京";
1248 let tokens: Vec<&str> = text.split_whitespace().collect();
1249 let positions = CrfNER::calculate_token_positions(text, &tokens);
1250
1251 assert_eq!(positions[0], (0, 6), "First '東京' at bytes 0-6");
1253 assert_eq!(positions[1], (7, 12), "Tokyo at bytes 7-12");
1254 assert_eq!(positions[2], (13, 19), "Second '東京' at bytes 13-19");
1255 }
1256}