use super::*;
#[test]
fn test_crf_basic() {
let ner = CrfNER::new();
let entities = ner
.extract_entities("John Smith works at Google in California", None)
.unwrap();
assert!(!entities.is_empty(), "Expected some entities, got none");
}
#[test]
fn test_word_shape() {
assert_eq!(CrfNER::word_shape("John"), "Xx");
assert_eq!(CrfNER::word_shape("USA"), "X");
assert_eq!(CrfNER::word_shape("hello"), "x");
assert_eq!(CrfNER::word_shape("123"), "0");
assert_eq!(CrfNER::word_shape("Hello123"), "Xx0");
}
#[test]
fn test_tokenize() {
let tokens = CrfNER::tokenize("Hello world");
assert_eq!(tokens, vec!["Hello", "world"]);
}
#[test]
fn test_empty_input() {
let ner = CrfNER::new();
let entities = ner.extract_entities("", None).unwrap();
assert!(entities.is_empty());
}
#[test]
fn test_gazetteer_lookup() {
let ner = CrfNER::new();
assert!(ner.gazetteers[&EntityType::Person].contains(&"John".to_string()));
assert!(ner.gazetteers[&EntityType::Location].contains(&"California".to_string()));
assert!(ner.gazetteers[&EntityType::Organization].contains(&"Google".to_string()));
}
#[test]
fn test_viterbi_returns_valid_labels() {
let ner = CrfNER::new();
let tokens = vec!["John", "works", "at", "Google"];
let labels = ner.viterbi_decode(&tokens);
assert_eq!(labels.len(), tokens.len());
for label in &labels {
assert!(ner.labels.contains(label));
}
}
#[test]
fn test_common_verbs_not_in_entities() {
let ner = CrfNER::new();
let entities = ner
.extract_entities("John Smith works at Apple", None)
.unwrap();
let entity_texts: Vec<&str> = entities.iter().map(|e| e.text.as_str()).collect();
for entity_text in &entity_texts {
assert!(
!entity_text.contains("works"),
"Entity '{}' should not contain 'works'",
entity_text
);
}
}
#[test]
fn test_weights_for_common_words() {
#[cfg(feature = "bundled-crf-weights")]
{
return;
}
#[allow(unreachable_code)]
let ner = CrfNER::new();
assert!(
ner.weights.contains_key("word.lower=works:O"),
"Missing weight for word.lower=works:O"
);
assert!(
ner.weights.contains_key("word.lower=works:I-PER"),
"Missing weight for word.lower=works:I-PER"
);
let o_weight = *ner.weights.get("word.lower=works:O").unwrap();
let i_per_weight = *ner.weights.get("word.lower=works:I-PER").unwrap();
assert!(
o_weight > 0.0,
"O weight should be positive, got {}",
o_weight
);
assert!(
i_per_weight < 0.0,
"I-PER weight should be negative, got {}",
i_per_weight
);
}
#[test]
fn test_unicode_char_offsets() {
let ner = CrfNER::new();
let text = "北京 Beijing";
assert_eq!(text.len(), 14, "Expected 14 bytes");
assert_eq!(text.chars().count(), 10, "Expected 10 characters");
let entities = ner.extract_entities(text, None).unwrap();
let char_count = text.chars().count();
for entity in &entities {
assert!(
entity.start() <= entity.end(),
"Invalid span: start {} > end {}",
entity.start(),
entity.end()
);
assert!(
entity.end() <= char_count,
"Entity end {} exceeds char count {} for text {:?}",
entity.end(),
char_count,
text
);
let extracted: String = text
.chars()
.skip(entity.start())
.take(entity.end() - entity.start())
.collect();
assert!(
!extracted.is_empty() || entity.start() == entity.end(),
"Empty extraction for entity at {}..{} in {:?}",
entity.start(),
entity.end(),
text
);
}
}
#[test]
fn test_multilingual_inputs_no_panic_and_valid_spans() {
let ner = CrfNER::new();
let texts = [
"Marie Curie discovered radium in Paris.",
"習近平在北京會見了普京。",
"التقى محمد بن سلمان بالرئيس في الرياض",
"Путин встретился с Си Цзиньпином в Москве.",
"प्रधान मंत्री शर्मा दिल्ली में मिले।",
];
for text in texts {
let entities = ner.extract_entities(text, None).unwrap();
let char_count = text.chars().count();
for e in entities {
assert!(e.start() <= e.end());
assert!(e.end() <= char_count);
let _span: String = text
.chars()
.skip(e.start())
.take(e.end() - e.start())
.collect();
}
}
}
#[test]
fn test_duplicate_entity_offsets() {
let text = "Google bought Google for $1 billion.";
let tokens: Vec<&str> = text.split_whitespace().collect();
let positions = CrfNER::calculate_token_positions(text, &tokens);
assert_eq!(
positions[0],
(0, 6),
"First 'Google' should be at bytes 0-6"
);
assert_eq!(
positions[2],
(14, 20),
"Second 'Google' should be at bytes 14-20"
);
}
#[test]
fn test_token_positions_unicode() {
let text = "東京 Tokyo 東京";
let tokens: Vec<&str> = text.split_whitespace().collect();
let positions = CrfNER::calculate_token_positions(text, &tokens);
assert_eq!(positions[0], (0, 6), "First '東京' at bytes 0-6");
assert_eq!(positions[1], (7, 12), "Tokyo at bytes 7-12");
assert_eq!(positions[2], (13, 19), "Second '東京' at bytes 13-19");
}
fn minimal_crf(weights: HashMap<String, f64>) -> CrfNER {
CrfNER {
weights,
gazetteers: HashMap::new(),
labels: vec![
"O".to_string(),
"B-PER".to_string(),
"I-PER".to_string(),
"B-ORG".to_string(),
"I-ORG".to_string(),
"B-LOC".to_string(),
"I-LOC".to_string(),
"B-MISC".to_string(),
"I-MISC".to_string(),
],
templates: vec![],
}
}
#[test]
fn test_viterbi_empty_input() {
let ner = minimal_crf(HashMap::new());
let labels = ner.viterbi_decode(&[]);
assert!(labels.is_empty(), "Empty tokens should yield empty labels");
}
#[test]
fn test_viterbi_single_token() {
let ner = minimal_crf(HashMap::new());
let labels = ner.viterbi_decode(&["hello"]);
assert_eq!(labels.len(), 1);
assert_eq!(labels[0], "O");
}
#[test]
fn test_viterbi_strong_emission_overrides_o_bias() {
let mut w = HashMap::new();
w.insert("bias:B-PER".to_string(), 20.0);
let ner = minimal_crf(w);
let labels = ner.viterbi_decode(&["Alice"]);
assert_eq!(labels, vec!["B-PER"]);
}
#[test]
fn test_viterbi_bio_transition_constraint() {
let mut w = HashMap::new();
w.insert("BOS:O".to_string(), 10.0);
w.insert("word.lower=john:B-PER".to_string(), 8.0);
w.insert("word.lower=john:I-PER".to_string(), 6.0);
w.insert("word.lower=smith:I-PER".to_string(), 8.0);
w.insert("word.lower=smith:B-PER".to_string(), 4.0);
w.insert("trans:B-PER->I-PER".to_string(), 5.0);
w.insert("trans:O->I-PER".to_string(), -50.0);
let ner = minimal_crf(w);
let labels = ner.viterbi_decode(&["The", "John", "Smith"]);
assert_eq!(labels.len(), 3);
assert_eq!(
labels[1], "B-PER",
"Expected B-PER after O, not {}",
labels[1]
);
assert_eq!(
labels[2], "I-PER",
"Expected I-PER continuation, not {}",
labels[2]
);
}
#[test]
fn test_viterbi_cross_type_transition_blocked() {
let mut w = HashMap::new();
w.insert("bias:B-PER".to_string(), 15.0);
w.insert("word.lower=inc:I-ORG".to_string(), 5.0);
w.insert("word.lower=inc:O".to_string(), 1.0);
w.insert("trans:B-PER->I-ORG".to_string(), -50.0);
let ner = minimal_crf(w);
let labels = ner.viterbi_decode(&["Alice", "Inc"]);
assert_eq!(labels.len(), 2);
assert_ne!(
labels[1], "I-ORG",
"Cross-type transition B-PER -> I-ORG should be blocked"
);
}
#[test]
fn test_score_label_default_o_bias() {
let ner = minimal_crf(HashMap::new());
let features = vec!["some_unknown_feature".to_string()];
let o_score = ner.score_label(&features, "O");
let b_per_score = ner.score_label(&features, "B-PER");
assert!(
o_score > b_per_score,
"O should score higher than B-PER with no matching weights: O={}, B-PER={}",
o_score,
b_per_score
);
assert!(
(o_score - 0.5).abs() < 1e-9,
"O score should be 0.5, got {}",
o_score
);
assert!(
(b_per_score - 0.0).abs() < 1e-9,
"B-PER score should be 0.0, got {}",
b_per_score
);
}
#[test]
fn test_score_label_weight_accumulation() {
let mut w = HashMap::new();
w.insert("feat_a:B-PER".to_string(), 2.0);
w.insert("feat_b:B-PER".to_string(), 3.0);
w.insert("feat_c".to_string(), 4.0);
let ner = minimal_crf(w);
let features = vec![
"feat_a".to_string(),
"feat_b".to_string(),
"feat_c".to_string(),
];
let score = ner.score_label(&features, "B-PER");
assert!(
(score - 7.0).abs() < 1e-9,
"Expected score 7.0, got {}",
score
);
}
#[test]
fn test_extract_features_bos_eos() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["Hello", "world"];
let feats_first = ner.extract_features(&tokens, 0, "O");
assert!(
feats_first.contains(&"BOS".to_string()),
"First token should have BOS feature"
);
assert!(
!feats_first.contains(&"EOS".to_string()),
"First token should not have EOS"
);
let feats_last = ner.extract_features(&tokens, 1, "O");
assert!(
feats_last.contains(&"EOS".to_string()),
"Last token should have EOS feature"
);
assert!(
!feats_last.contains(&"BOS".to_string()),
"Last token should not have BOS"
);
}
#[test]
fn test_extract_features_single_token_bos_and_eos() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["Only"];
let feats = ner.extract_features(&tokens, 0, "O");
assert!(feats.contains(&"BOS".to_string()), "Single token needs BOS");
assert!(feats.contains(&"EOS".to_string()), "Single token needs EOS");
}
#[test]
fn test_extract_features_word_identity() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["John"];
let feats = ner.extract_features(&tokens, 0, "O");
assert!(feats.contains(&"bias".to_string()), "Missing bias feature");
assert!(
feats.contains(&"word.lower=john".to_string()),
"Missing word.lower feature"
);
assert!(
feats.contains(&"word.shape=Xx".to_string()),
"Missing word.shape feature"
);
assert!(
feats.contains(&"word.istitle=True".to_string()),
"Missing word.istitle feature"
);
assert!(
feats.contains(&"word.isupper=False".to_string()),
"Missing word.isupper feature"
);
}
#[test]
fn test_labels_to_entities_simple() {
let ner = minimal_crf(HashMap::new());
let text = "John Smith works at Google";
let tokens: Vec<&str> = text.split_whitespace().collect();
let labels = vec![
"B-PER".to_string(),
"I-PER".to_string(),
"O".to_string(),
"O".to_string(),
"B-ORG".to_string(),
];
let entities = ner.labels_to_entities(text, &tokens, &labels);
assert_eq!(
entities.len(),
2,
"Expected 2 entities, got {}",
entities.len()
);
assert_eq!(entities[0].text, "John Smith");
assert_eq!(entities[0].entity_type, EntityType::Person);
assert_eq!(entities[0].start(), 0);
assert_eq!(entities[0].end(), 10);
assert_eq!(entities[1].text, "Google");
assert_eq!(entities[1].entity_type, EntityType::Organization);
assert_eq!(entities[1].start(), 20);
assert_eq!(entities[1].end(), 26);
}
#[test]
fn test_labels_to_entities_trailing_entity() {
let ner = minimal_crf(HashMap::new());
let text = "lives in Paris";
let tokens: Vec<&str> = text.split_whitespace().collect();
let labels = vec!["O".to_string(), "O".to_string(), "B-LOC".to_string()];
let entities = ner.labels_to_entities(text, &tokens, &labels);
assert_eq!(entities.len(), 1);
assert_eq!(entities[0].text, "Paris");
assert_eq!(entities[0].entity_type, EntityType::Location);
}
#[test]
fn test_labels_to_entities_all_outside() {
let ner = minimal_crf(HashMap::new());
let text = "nothing special here";
let tokens: Vec<&str> = text.split_whitespace().collect();
let labels = vec!["O".to_string(), "O".to_string(), "O".to_string()];
let entities = ner.labels_to_entities(text, &tokens, &labels);
assert!(
entities.is_empty(),
"All-O labels should produce no entities"
);
}
#[test]
fn test_word_shape_compression() {
assert_eq!(CrfNER::word_shape("AAAA"), "X");
assert_eq!(CrfNER::word_shape("aaaa"), "x");
assert_eq!(CrfNER::word_shape("AaBb"), "XxXx");
assert_eq!(CrfNER::word_shape("123-456"), "0-0");
assert_eq!(CrfNER::word_shape("O'Brien"), "X'Xx");
}
#[test]
fn test_viterbi_global_optimality() {
let mut w = HashMap::new();
w.insert("bias:O".to_string(), 5.0);
w.insert("bias:B-PER".to_string(), 3.0);
w.insert("word.lower=b:B-PER".to_string(), 1.0);
w.insert("trans:B-PER->B-PER".to_string(), 10.0);
let ner = minimal_crf(w);
let labels = ner.viterbi_decode(&["a", "b"]);
assert_eq!(labels.len(), 2);
assert_eq!(
labels,
vec!["B-PER", "B-PER"],
"Viterbi should find the globally optimal path, not the greedy one"
);
}
#[test]
fn test_default_weights_bio_constraints() {
let w = CrfNER::default_weights();
for tag in ["I-PER", "I-ORG", "I-LOC", "I-MISC"] {
let key = format!("trans:O->{}", tag);
let val = w.get(&key).copied().unwrap_or(0.0);
assert!(
val < -5.0,
"O -> {} should be heavily penalized, got {}",
tag,
val
);
}
for (b, i) in [("B-PER", "I-PER"), ("B-ORG", "I-ORG"), ("B-LOC", "I-LOC")] {
let key = format!("trans:{}->{}", b, i);
let val = w.get(&key).copied().unwrap_or(0.0);
assert!(val > 0.0, "{} -> {} should be positive, got {}", b, i, val);
}
for (b, i) in [("B-PER", "I-ORG"), ("B-ORG", "I-LOC"), ("B-LOC", "I-PER")] {
let key = format!("trans:{}->{}", b, i);
let val = w.get(&key).copied().unwrap_or(0.0);
assert!(
val < -5.0,
"{} -> {} should be heavily penalized, got {}",
b,
i,
val
);
}
}
#[test]
fn test_extract_features_prefix_suffix() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["Johnson"];
let feats = ner.extract_features(&tokens, 0, "O");
assert!(
feats.contains(&"prefix2=jo".to_string()),
"Missing prefix2 feature, got: {:?}",
feats
);
assert!(
feats.contains(&"prefix3=joh".to_string()),
"Missing prefix3 feature"
);
assert!(
feats.contains(&"suffix2=on".to_string()),
"Missing suffix2 feature"
);
assert!(
feats.contains(&"suffix3=son".to_string()),
"Missing suffix3 feature"
);
}
#[test]
fn test_extract_features_no_prefix_suffix_for_short_word() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["I"];
let feats = ner.extract_features(&tokens, 0, "O");
let has_prefix = feats.iter().any(|f| f.starts_with("prefix"));
let has_suffix = feats.iter().any(|f| f.starts_with("suffix"));
assert!(
!has_prefix,
"Single-char word should have no prefix feature"
);
assert!(
!has_suffix,
"Single-char word should have no suffix feature"
);
}
#[test]
fn test_extract_features_context_words() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["Dr", "John", "Smith"];
let feats = ner.extract_features(&tokens, 1, "O");
assert!(
feats.contains(&"-1:word.lower=dr".to_string()),
"Missing -1:word.lower feature"
);
assert!(
feats.contains(&"-1:word.istitle=True".to_string()),
"Missing -1:word.istitle for 'Dr'"
);
assert!(
feats.contains(&"-1:word.isupper=False".to_string()),
"Dr is not all-uppercase"
);
assert!(
feats.contains(&"+1:word.lower=smith".to_string()),
"Missing +1:word.lower feature"
);
assert!(
feats.contains(&"+1:word.istitle=True".to_string()),
"Missing +1:word.istitle feature"
);
}
#[test]
fn test_extract_features_digit_word() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["2024"];
let feats = ner.extract_features(&tokens, 0, "O");
assert!(
feats.contains(&"word.isdigit=True".to_string()),
"All-digit word should have isdigit=True"
);
assert!(
feats.contains(&"word.isupper=False".to_string()),
"Digits are not uppercase"
);
}
#[test]
fn test_extract_features_mixed_word_not_digit() {
let ner = minimal_crf(HashMap::new());
let tokens = vec!["Room42"];
let feats = ner.extract_features(&tokens, 0, "O");
assert!(
feats.contains(&"word.isdigit=False".to_string()),
"Mixed word should have isdigit=False"
);
}
#[test]
fn test_extract_features_gazetteer_match() {
let ner = CrfNER::new(); let tokens = vec!["John", "works"];
let feats = ner.extract_features(&tokens, 0, "O");
assert!(
feats.contains(&"gaz:PER".to_string()),
"John should match Person gazetteer, features: {:?}",
feats
);
}
#[test]
fn test_extract_features_no_gazetteer_match() {
let ner = CrfNER::new();
let tokens = vec!["works"];
let feats = ner.extract_features(&tokens, 0, "O");
let has_gaz = feats.iter().any(|f| f.starts_with("gaz:"));
assert!(
!has_gaz,
"'works' should not match any gazetteer, features: {:?}",
feats
);
}
#[test]
fn test_tokenize_whitespace_variants() {
assert_eq!(
CrfNER::tokenize(" Hello world "),
vec!["Hello", "world"]
);
assert!(CrfNER::tokenize("").is_empty());
assert!(CrfNER::tokenize(" ").is_empty());
assert_eq!(CrfNER::tokenize("single"), vec!["single"]);
}
#[test]
fn test_tokenize_tabs_newlines() {
assert_eq!(
CrfNER::tokenize("Hello\tworld\nfoo"),
vec!["Hello", "world", "foo"]
);
}
#[test]
fn test_labels_to_entities_consecutive_b_tags() {
let ner = minimal_crf(HashMap::new());
let text = "John Mary works";
let tokens: Vec<&str> = text.split_whitespace().collect();
let labels = vec!["B-PER".to_string(), "B-PER".to_string(), "O".to_string()];
let entities = ner.labels_to_entities(text, &tokens, &labels);
assert_eq!(
entities.len(),
2,
"Two consecutive B-PER should yield 2 entities"
);
assert_eq!(entities[0].text, "John");
assert_eq!(entities[1].text, "Mary");
}
#[test]
fn test_labels_to_entities_misc_type() {
let ner = minimal_crf(HashMap::new());
let text = "World Cup";
let tokens: Vec<&str> = text.split_whitespace().collect();
let labels = vec!["B-MISC".to_string(), "I-MISC".to_string()];
let entities = ner.labels_to_entities(text, &tokens, &labels);
assert_eq!(entities.len(), 1);
assert_eq!(entities[0].text, "World Cup");
assert_eq!(
entities[0].entity_type,
EntityType::custom("MISC", EntityCategory::Misc)
);
}
#[test]
fn test_labels_to_entities_empty() {
let ner = minimal_crf(HashMap::new());
let entities = ner.labels_to_entities("", &[], &[]);
assert!(entities.is_empty());
}
#[test]
fn test_calculate_token_positions_empty() {
let positions = CrfNER::calculate_token_positions("", &[]);
assert!(positions.is_empty());
}
#[test]
fn test_calculate_token_positions_single() {
let positions = CrfNER::calculate_token_positions("Hello", &["Hello"]);
assert_eq!(positions, vec![(0, 5)]);
}
#[test]
fn test_calculate_token_positions_with_punctuation() {
let text = "Hello, world!";
let tokens: Vec<&str> = text.split_whitespace().collect();
let positions = CrfNER::calculate_token_positions(text, &tokens);
assert_eq!(positions[0], (0, 6)); assert_eq!(positions[1], (7, 13)); }
#[test]
fn test_word_shape_empty() {
assert_eq!(CrfNER::word_shape(""), "");
}
#[test]
fn test_word_shape_unicode_letters() {
let shape = CrfNER::word_shape("Москва");
assert_eq!(shape, "Xx", "Titlecase Cyrillic should be Xx");
}
#[test]
fn test_viterbi_longer_sequence_length_invariant() {
let ner = CrfNER::new();
let tokens: Vec<&str> = "The quick brown fox jumps over the lazy dog near London"
.split_whitespace()
.collect();
let labels = ner.viterbi_decode(&tokens);
assert_eq!(
labels.len(),
tokens.len(),
"Viterbi must return exactly one label per token"
);
for label in &labels {
assert!(
ner.labels.contains(label),
"Label '{}' not in label set",
label
);
}
}
#[test]
fn test_viterbi_common_words_are_outside() {
let ner = CrfNER::new_heuristic();
let tokens = vec!["the", "quick", "brown", "fox"];
let labels = ner.viterbi_decode(&tokens);
for (tok, label) in tokens.iter().zip(labels.iter()) {
assert_eq!(
label, "O",
"Common lowercase word '{}' should be O, got '{}'",
tok, label
);
}
}
#[test]
fn test_score_label_both_keys_present() {
let mut w = HashMap::new();
w.insert("myfeat:O".to_string(), 3.0);
w.insert("myfeat".to_string(), 2.0); let ner = minimal_crf(w);
let features = vec!["myfeat".to_string()];
let score = ner.score_label(&features, "O");
assert!((score - 4.5).abs() < 1e-9, "Expected 4.5, got {}", score);
}
#[test]
fn test_new_heuristic_uses_default_weights() {
let heuristic = CrfNER::new_heuristic();
let default_w = CrfNER::default_weights();
assert_eq!(
heuristic.weights.len(),
default_w.len(),
"Heuristic model should have same number of weights as default_weights()"
);
for (key, val) in &default_w {
let got = heuristic.weights.get(key).copied().unwrap_or(f64::NAN);
assert!(
(got - val).abs() < 1e-12,
"Weight mismatch for '{}': expected {}, got {}",
key,
val,
got
);
}
}
#[test]
fn sentence_boundary_detection() {
let text = "Max Planck Institute respectively. Doudna said something.";
let boundaries = super::sentence_boundary_offsets(text);
assert!(
!boundaries.is_empty(),
"Should detect boundary at period before 'Doudna'"
);
}
#[test]
fn crf_no_cross_sentence_span() {
let mut entities = vec![Entity::new(
"Institute respectively. Doudna",
EntityType::Person,
10,
40,
0.7,
)];
let text = "Max Planck Institute respectively. Doudna said something else here.";
super::clip_entities_at_sentence_boundaries(text, &mut entities);
for e in &entities {
assert!(
e.end() <= 34,
"Entity should not cross sentence boundary: {:?}",
e
);
assert!(
!e.text.contains("Doudna"),
"Entity text should not contain 'Doudna': {:?}",
e
);
}
}