use super::*;
use anno_core::{Confidence, ExtractionMethod};
fn fast_ensemble() -> EnsembleNER {
EnsembleNER::with_backends(vec![
Box::new(crate::RegexNER::new()),
Box::new(crate::HeuristicNER::new()),
])
}
struct FixedBackend {
name: &'static str,
entities: Vec<Entity>,
}
impl FixedBackend {
fn new(name: &'static str, entities: Vec<Entity>) -> Self {
Self { name, entities }
}
}
impl crate::sealed::Sealed for FixedBackend {}
impl crate::Model for FixedBackend {
fn name(&self) -> &'static str {
self.name
}
fn extract_entities(
&self,
_text: &str,
_language: Option<Language>,
) -> crate::Result<Vec<Entity>> {
Ok(self.entities.clone())
}
fn supported_types(&self) -> Vec<EntityType> {
self.entities
.iter()
.map(|e| e.entity_type.clone())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect()
}
fn is_available(&self) -> bool {
true
}
}
struct AlwaysErrBackend {
name: &'static str,
}
impl AlwaysErrBackend {
fn new(name: &'static str) -> Self {
Self { name }
}
}
impl crate::sealed::Sealed for AlwaysErrBackend {}
impl crate::Model for AlwaysErrBackend {
fn name(&self) -> &'static str {
self.name
}
fn extract_entities(
&self,
_text: &str,
_language: Option<Language>,
) -> crate::Result<Vec<Entity>> {
Err(crate::Error::ModelInit(format!(
"AlwaysErrBackend '{}' intentionally failed",
self.name
)))
}
fn supported_types(&self) -> Vec<EntityType> {
vec![]
}
fn is_available(&self) -> bool {
false
}
}
struct PanicBackend {
name: &'static str,
}
impl PanicBackend {
fn new(name: &'static str) -> Self {
Self { name }
}
}
impl crate::sealed::Sealed for PanicBackend {}
impl crate::Model for PanicBackend {
fn name(&self) -> &'static str {
self.name
}
fn extract_entities(
&self,
_text: &str,
_language: Option<Language>,
) -> crate::Result<Vec<Entity>> {
panic!("PanicBackend '{}' intentionally panicked", self.name);
}
fn supported_types(&self) -> Vec<EntityType> {
vec![]
}
fn is_available(&self) -> bool {
false
}
}
#[test]
fn test_new_backend_ids_have_weights() {
let ner = EnsembleNER::new();
assert!(
!ner.backend_ids.is_empty(),
"EnsembleNER::new() should have at least one backend"
);
for id in &ner.backend_ids {
assert!(
ner.weights.contains_key(id),
"EnsembleNER::new(): missing weight for backend id={:?}. This usually means the ensemble's advertised IDs drifted from default_backend_weights keys.",
id
);
}
}
#[test]
fn test_ensemble_basic() {
let ner = fast_ensemble();
let entities = ner
.extract_entities("Tim Cook is the CEO of Apple Inc.", None)
.unwrap();
assert!(!entities.is_empty(), "Should extract entities");
for e in &entities {
assert!(
e.provenance.is_some(),
"All entities should have provenance"
);
}
}
#[test]
fn test_span_overlap() {
let span1 = SpanKey { start: 0, end: 10 };
let span2 = SpanKey { start: 3, end: 15 }; let span3 = SpanKey { start: 20, end: 30 };
assert!(span1.overlaps(&span2), "Overlapping spans should match");
assert!(
!span1.overlaps(&span3),
"Non-overlapping spans should not match"
);
}
#[test]
fn test_backend_weights() {
let weights = default_backend_weights();
assert!(weights["regex"].overall > 0.9);
assert!(weights["gliner"].overall > 0.8);
assert!(weights["heuristic"].overall < 0.7);
}
#[test]
fn test_type_specific_weights() {
let weights = default_backend_weights();
let pattern_date = weights["regex"].per_type.as_ref().unwrap().date;
let heuristic_date = weights["heuristic"].per_type.as_ref().unwrap().date;
assert!(pattern_date > heuristic_date);
let heuristic_org = weights["heuristic"].per_type.as_ref().unwrap().organization;
assert!(heuristic_org > 0.6);
}
#[test]
fn test_agreement_bonus() {
let ner = fast_ensemble().with_agreement_bonus(0.15);
assert!((ner.agreement_bonus - 0.15).abs() < 0.001);
}
#[test]
fn test_weight_learner_basic() {
let mut learner = WeightLearner::new();
learner.add_example(&WeightTrainingExample {
text: "Apple".to_string(),
gold_type: EntityType::Organization,
start: 0,
end: 5,
predictions: vec![
(
"heuristic".to_string(),
EntityType::Organization,
Confidence::new(0.8),
),
(
"gliner".to_string(),
EntityType::Organization,
Confidence::new(0.9),
),
],
});
learner.add_example(&WeightTrainingExample {
text: "Paris".to_string(),
gold_type: EntityType::Location,
start: 0,
end: 5,
predictions: vec![
(
"heuristic".to_string(),
EntityType::Person,
Confidence::new(0.6),
), (
"gliner".to_string(),
EntityType::Location,
Confidence::new(0.85),
),
],
});
let weights = learner.learn_weights();
let gliner_weight = weights.get("gliner").map(|w| w.overall).unwrap_or(0.0);
let heuristic_weight = weights.get("heuristic").map(|w| w.overall).unwrap_or(0.0);
assert!(
gliner_weight > heuristic_weight,
"GLiNER should have higher weight (was {} vs {})",
gliner_weight,
heuristic_weight
);
}
#[test]
fn test_backend_stats() {
let mut stats = BackendStats {
correct: 8,
total: 10,
..Default::default()
};
stats.per_type.insert("PER".to_string(), (5, 6));
assert!((stats.precision() - 0.8).abs() < 0.01);
assert!((stats.type_precision("PER") - 0.833).abs() < 0.01);
assert!((stats.type_precision("ORG") - 0.0).abs() < 0.01); }
#[test]
fn test_empty_text() {
let ner = fast_ensemble();
let entities = ner.extract_entities("", None).unwrap();
assert!(entities.is_empty());
}
#[test]
fn test_whitespace_only_text() {
let ner = fast_ensemble();
let entities = ner.extract_entities(" \t\n ", None).unwrap();
assert!(entities.is_empty());
}
#[test]
fn test_resolve_candidates_tie_break_is_order_independent() {
let ner = fast_ensemble();
let span_text = "Apple";
let span = (0, 5);
let e_person = Entity::new(span_text, EntityType::Person, span.0, span.1, 0.5);
let e_org = Entity::new(span_text, EntityType::Organization, span.0, span.1, 0.5);
let c1 = Candidate {
entity: e_person,
source: "heuristic".to_string(),
backend_weight: 1.0,
};
let c2 = Candidate {
entity: e_org,
source: "heuristic".to_string(),
backend_weight: 1.0,
};
let out_a = ner
.resolve_candidates(vec![c1.clone(), c2.clone()])
.expect("should resolve");
let out_b = ner
.resolve_candidates(vec![c2, c1])
.expect("should resolve");
assert_eq!(
out_a.entity_type, out_b.entity_type,
"tie resolution should not depend on candidate order"
);
let key_a = out_a.entity_type.as_label().to_string();
let person_key = EntityType::Person.as_label().to_string();
let org_key = EntityType::Organization.as_label().to_string();
let expected = std::cmp::min(person_key, org_key);
assert_eq!(
key_a, expected,
"tie-break should choose lexicographically smallest type label"
);
}
#[test]
fn test_single_source_preserves_underlying_method_and_pattern() {
let ner = EnsembleNER::with_backends(vec![Box::new(crate::RegexNER::new())]);
let text = "Contact test@email.com on 2024-01-15";
let entities = ner.extract_entities(text, None).expect("extract");
assert!(!entities.is_empty());
let email = entities
.iter()
.find(|e| e.text == "test@email.com")
.expect("email entity should exist");
let prov = email.provenance.as_ref().expect("provenance");
assert_eq!(prov.method, ExtractionMethod::Pattern);
assert!(
prov.pattern.is_some(),
"expected to preserve regex pattern name"
);
}
#[test]
fn test_nested_single_source_preserves_inner_method() {
let inner = EnsembleNER::with_backends(vec![Box::new(crate::HeuristicNER::new())]);
let outer = EnsembleNER::with_backends(vec![Box::new(inner)]);
let text = "John Smith visited Paris.";
let entities = outer.extract_entities(text, None).expect("extract");
assert!(!entities.is_empty());
for e in &entities {
let prov = e.provenance.as_ref().expect("provenance");
assert_eq!(
prov.method,
ExtractionMethod::Heuristic,
"expected outer to preserve inner method"
);
}
}
#[test]
fn test_span_key_self_overlap() {
let span = SpanKey { start: 0, end: 10 };
assert!(span.overlaps(&span), "Span should overlap with itself");
}
#[test]
fn test_span_key_adjacent_no_overlap() {
let span1 = SpanKey { start: 0, end: 10 };
let span2 = SpanKey { start: 10, end: 20 };
assert!(!span1.overlaps(&span2), "Adjacent spans should not overlap");
}
#[test]
fn test_span_key_contained() {
let outer = SpanKey { start: 0, end: 20 };
let inner = SpanKey { start: 5, end: 15 };
assert!(outer.overlaps(&inner), "Contained spans should overlap");
assert!(inner.overlaps(&outer), "Overlap should be symmetric");
}
#[test]
fn test_backend_stats_empty() {
let stats = BackendStats::default();
assert!((stats.precision() - 0.0).abs() < 0.001);
assert!((stats.type_precision("ANY") - 0.0).abs() < 0.001);
}
#[test]
fn test_weight_learner_empty() {
let learner = WeightLearner::new();
let weights = learner.learn_weights();
assert!(
weights.is_empty(),
"empty learner should return empty weights"
);
}
#[test]
fn test_ensemble_with_language() {
let ner = fast_ensemble();
let entities = ner
.extract_entities("Tim Cook is the CEO of Apple.", Some(Language::English))
.unwrap();
assert!(
!entities.is_empty(),
"Should find entities with language hint"
);
}
#[test]
fn test_type_weights_structure() {
let weights = TypeWeights {
person: 0.9,
location: 0.85,
organization: 0.88,
date: 0.95,
money: 0.8,
other: 0.7,
};
assert!(weights.person > 0.0);
assert!(weights.date > weights.other);
}
#[test]
fn test_backend_weight_structure() {
let weight = BackendWeight {
overall: 0.85,
per_type: Some(TypeWeights {
person: 0.9,
location: 0.88,
organization: 0.87,
date: 0.92,
money: 0.85,
other: 0.75,
}),
};
assert!(weight.overall > 0.0);
assert!(weight.per_type.is_some());
}
#[test]
fn test_unicode_extraction() {
let ner = EnsembleNER::new();
let entities = ner
.extract_entities("東京で会議がありました。", None)
.unwrap();
for e in &entities {
assert!(e.confidence >= 0.0 && e.confidence <= 1.0);
}
}
#[test]
fn test_ensemble_provenance_tracking() {
let ner = EnsembleNER::new();
let entities = ner
.extract_entities("Barack Obama visited Paris yesterday.", None)
.unwrap();
for e in &entities {
assert!(
e.provenance.is_some(),
"Entity '{}' ({:?}) at {}..{} has no provenance",
e.text,
e.entity_type,
e.start(),
e.end()
);
let prov = e.provenance.as_ref().unwrap();
assert!(!prov.source.is_empty());
}
}
fn determinism_ensemble() -> EnsembleNER {
let backend_a = FixedBackend::new(
"backend-a",
vec![
Entity::new("aaaa", EntityType::Person, 0, 4, 0.80),
Entity::new("bbbb", EntityType::Organization, 5, 9, 0.75),
Entity::new("cccc", EntityType::Date, 10, 14, 0.90),
],
);
let backend_b = FixedBackend::new(
"backend-b",
vec![
Entity::new("dddd", EntityType::Person, 15, 19, 0.70),
Entity::new("eeee", EntityType::Location, 20, 24, 0.65),
],
);
let backend_c = FixedBackend::new(
"backend-c",
vec![Entity::new("ffff", EntityType::Money, 25, 29, 0.85)],
);
EnsembleNER::with_backends(vec![
Box::new(backend_a),
Box::new(backend_b),
Box::new(backend_c),
])
}
#[test]
fn test_parallel_execution_is_deterministic() {
let text = "aaaa bbbb cccc dddd eeee ffff";
let ner = determinism_ensemble();
let reference: Vec<(usize, usize, String, Confidence)> = ner
.extract_entities(text, None)
.expect("first run should succeed")
.into_iter()
.map(|e| {
(
e.start(),
e.end(),
e.entity_type.as_label().to_string(),
e.confidence,
)
})
.collect();
assert!(
!reference.is_empty(),
"determinism ensemble should produce at least one entity"
);
for run in 1..10_usize {
let result: Vec<(usize, usize, String, Confidence)> = ner
.extract_entities(text, None)
.unwrap_or_else(|e| panic!("run {} failed: {}", run, e))
.into_iter()
.map(|e| {
(
e.start(),
e.end(),
e.entity_type.as_label().to_string(),
e.confidence,
)
})
.collect();
assert_eq!(
result.len(),
reference.len(),
"run {} produced {} entities, expected {}",
run,
result.len(),
reference.len()
);
for (idx, (got, want)) in result.iter().zip(reference.iter()).enumerate() {
assert_eq!(
got, want,
"run {} entity[{}]: got {:?}, want {:?}",
run, idx, got, want
);
}
}
}
#[test]
fn test_parallel_determinism_with_overlapping_spans() {
let backend_high = FixedBackend::new(
"gliner", vec![Entity::new("Apple", EntityType::Organization, 0, 5, 0.80)],
);
let backend_low = FixedBackend::new(
"heuristic", vec![Entity::new("Apple", EntityType::Person, 0, 5, 0.80)],
);
let ner = EnsembleNER::with_backends(vec![Box::new(backend_high), Box::new(backend_low)]);
let text = "Apple";
let reference_type = {
let entities = ner.extract_entities(text, None).expect("first run");
assert_eq!(entities.len(), 1, "should resolve to exactly one entity");
entities[0].entity_type.clone()
};
for run in 1..10_usize {
let entities = ner
.extract_entities(text, None)
.unwrap_or_else(|e| panic!("run {} failed: {}", run, e));
assert_eq!(
entities.len(),
1,
"run {} produced {} entities, expected 1",
run,
entities.len()
);
assert_eq!(
entities[0].entity_type, reference_type,
"run {} resolved to {:?}, expected {:?}",
run, entities[0].entity_type, reference_type
);
}
}
#[test]
fn test_failing_backend_is_skipped_and_others_produce_results() {
let good_a = FixedBackend::new(
"good-a",
vec![Entity::new("Paris", EntityType::Location, 0, 5, 0.85)],
);
let bad = AlwaysErrBackend::new("always-err");
let good_b = FixedBackend::new(
"good-b",
vec![Entity::new("March", EntityType::Date, 6, 11, 0.90)],
);
let ner = EnsembleNER::with_backends(vec![Box::new(good_a), Box::new(bad), Box::new(good_b)]);
let entities = ner
.extract_entities("Paris March", None)
.expect("ensemble should not propagate backend errors");
assert_eq!(
entities.len(),
2,
"expected 2 entities from healthy backends, got: {:?}",
entities
.iter()
.map(|e| format!("{}:{:?}", e.text, e.entity_type))
.collect::<Vec<_>>()
);
let texts: Vec<&str> = entities.iter().map(|e| e.text.as_str()).collect();
assert!(
texts.contains(&"Paris"),
"expected 'Paris' in output, got {:?}",
texts
);
assert!(
texts.contains(&"March"),
"expected 'March' in output, got {:?}",
texts
);
}
#[test]
fn test_all_backends_fail_returns_empty() {
let ner = EnsembleNER::with_backends(vec![
Box::new(AlwaysErrBackend::new("err-1")),
Box::new(AlwaysErrBackend::new("err-2")),
]);
let result = ner.extract_entities("Anything at all", None);
assert!(
result.is_ok(),
"ensemble should return Ok even when all backends fail"
);
assert!(
result.unwrap().is_empty(),
"ensemble should return empty vec when all backends fail"
);
}
#[test]
fn test_single_failing_backend_with_single_good_backend() {
let ner = EnsembleNER::with_backends(vec![Box::new(AlwaysErrBackend::new("only-err"))]);
let result = ner.extract_entities("Tim Cook", None);
assert!(
result.is_ok(),
"single failing backend must not propagate as Err"
);
assert!(result.unwrap().is_empty());
}
#[test]
fn test_error_backend_does_not_affect_confidence_of_good_results() {
let text = "London";
let entity = Entity::new("London", EntityType::Location, 0, 6, 0.80);
let solo = EnsembleNER::with_backends(vec![Box::new(FixedBackend::new(
"solo",
vec![entity.clone()],
))]);
let with_err = EnsembleNER::with_backends(vec![
Box::new(FixedBackend::new("solo", vec![entity])),
Box::new(AlwaysErrBackend::new("noise")),
]);
let solo_result = solo.extract_entities(text, None).unwrap();
let err_result = with_err.extract_entities(text, None).unwrap();
assert_eq!(
solo_result.len(),
err_result.len(),
"entity count should be identical regardless of failing backend"
);
assert!(
(solo_result[0].confidence - err_result[0].confidence).abs() < 1e-9,
"confidence differed: solo={} with_err={}",
solo_result[0].confidence,
err_result[0].confidence
);
}
#[test]
fn test_panicking_backend_returns_inference_error() {
let ner = EnsembleNER::with_backends(vec![Box::new(PanicBackend::new("boom"))]);
let result = ner.extract_entities("any text", None);
match result {
Err(crate::Error::Inference(msg)) => {
assert!(
msg.contains("panicked"),
"expected panic-related message, got: {msg}"
);
}
Ok(v) => panic!("expected Err(Error::Inference), got Ok({v:?})"),
Err(other) => panic!("expected Err(Error::Inference), got {other:?}"),
}
}