1use std::borrow::Cow;
112use std::collections::HashMap;
113use std::sync::Arc;
114
115use crate::{Entity, EntityType, Model, Result};
116
117fn method_for_backend_id(backend_id: &str) -> anno_core::ExtractionMethod {
118 match backend_id {
119 "regex" => anno_core::ExtractionMethod::Pattern,
121 "heuristic" => anno_core::ExtractionMethod::Heuristic,
122 "rule" => anno_core::ExtractionMethod::Heuristic,
124 _ => anno_core::ExtractionMethod::Neural,
126 }
127}
128
129#[derive(Debug, Clone, Copy)]
137pub struct BackendWeight {
138 pub overall: f64,
140 pub per_type: Option<TypeWeights>,
142}
143
144impl Default for BackendWeight {
145 fn default() -> Self {
146 Self {
147 overall: 0.5,
148 per_type: None,
149 }
150 }
151}
152
153#[derive(Debug, Clone, Copy, Default)]
158pub struct TypeWeights {
159 pub person: f64,
161 pub organization: f64,
163 pub location: f64,
165 pub date: f64,
167 pub money: f64,
169 pub other: f64,
171}
172
173impl TypeWeights {
174 fn get(&self, entity_type: &EntityType) -> f64 {
175 match entity_type {
176 EntityType::Person => self.person,
177 EntityType::Organization => self.organization,
178 EntityType::Location => self.location,
179 EntityType::Date => self.date,
180 EntityType::Money => self.money,
181 _ => self.other,
182 }
183 }
184}
185
186fn default_backend_weights() -> HashMap<&'static str, BackendWeight> {
188 let mut weights = HashMap::new();
189
190 weights.insert(
192 "regex",
193 BackendWeight {
194 overall: 0.98,
195 per_type: Some(TypeWeights {
196 date: 0.99,
197 money: 0.99,
198 person: 0.50, organization: 0.50,
200 location: 0.50,
201 other: 0.95, }),
203 },
204 );
205
206 weights.insert(
208 "gliner",
209 BackendWeight {
210 overall: 0.85,
211 per_type: Some(TypeWeights {
212 person: 0.90,
213 organization: 0.85,
214 location: 0.80,
215 date: 0.75,
216 money: 0.70,
217 other: 0.75,
218 }),
219 },
220 );
221 weights.insert(
222 "GLiNER-ONNX",
223 BackendWeight {
224 overall: 0.85,
225 per_type: Some(TypeWeights {
226 person: 0.90,
227 organization: 0.85,
228 location: 0.80,
229 date: 0.75,
230 money: 0.70,
231 other: 0.75,
232 }),
233 },
234 );
235
236 weights.insert(
238 "gliner-candle",
239 BackendWeight {
240 overall: 0.85,
241 per_type: None,
242 },
243 );
244
245 weights.insert(
247 "bert-ner-onnx",
248 BackendWeight {
249 overall: 0.80,
250 per_type: None,
251 },
252 );
253
254 weights.insert(
256 "heuristic",
257 BackendWeight {
258 overall: 0.60,
259 per_type: Some(TypeWeights {
260 person: 0.65, organization: 0.70, location: 0.55, date: 0.40, money: 0.40,
265 other: 0.50,
266 }),
267 },
268 );
269
270 weights
271}
272
273#[derive(Debug, Clone)]
279struct Candidate {
280 entity: Entity,
281 source: String,
282 backend_weight: f64,
283}
284
285#[derive(Debug, Clone, PartialEq, Eq, Hash)]
293struct SpanKey {
294 start: usize,
295 end: usize,
296}
297
298impl SpanKey {
299 fn from_entity(e: &Entity) -> Self {
300 Self {
301 start: e.start,
302 end: e.end,
303 }
304 }
305
306 fn overlaps(&self, other: &SpanKey) -> bool {
308 let overlap_start = self.start.max(other.start);
309 let overlap_end = self.end.min(other.end);
310
311 if overlap_start >= overlap_end {
312 return false;
313 }
314
315 let overlap = overlap_end - overlap_start;
316 let smaller_span = (self.end - self.start).min(other.end - other.start);
317
318 (overlap as f64 / smaller_span as f64) > 0.5
320 }
321}
322
323pub struct EnsembleNER {
368 backends: Vec<Arc<dyn Model + Send + Sync>>,
369 backend_ids: Vec<String>,
374 weights: HashMap<String, BackendWeight>,
375 agreement_bonus: f64,
376 min_confidence: f64,
377 name: String,
379 name_static: std::sync::OnceLock<&'static str>,
381}
382
383impl Default for EnsembleNER {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389impl EnsembleNER {
390 #[must_use]
392 pub fn new() -> Self {
393 let mut backends: Vec<Arc<dyn Model + Send + Sync>> = Vec::new();
394 let mut backend_ids: Vec<&'static str> = Vec::new();
395
396 backends.push(Arc::new(crate::RegexNER::new()));
398 backend_ids.push("regex");
399
400 #[cfg(feature = "onnx")]
402 {
403 use super::GLiNEROnnx;
404 use crate::DEFAULT_GLINER_MODEL;
405 if let Ok(gliner) = GLiNEROnnx::new(DEFAULT_GLINER_MODEL) {
406 backends.push(Arc::new(gliner));
407 backend_ids.push("gliner");
408 }
409 }
410
411 #[cfg(feature = "candle")]
413 {
414 use super::GLiNERCandle;
415 use crate::DEFAULT_GLINER_MODEL;
416 if let Ok(candle) = GLiNERCandle::from_pretrained(DEFAULT_GLINER_MODEL) {
417 backends.push(Arc::new(candle));
418 backend_ids.push("gliner-candle");
419 }
420 }
421
422 backends.push(Arc::new(crate::HeuristicNER::new()));
424 backend_ids.push("heuristic");
425
426 let name = format!("ensemble({})", backend_ids.join("|"));
429
430 let weights: HashMap<String, BackendWeight> = default_backend_weights()
432 .into_iter()
433 .map(|(k, v)| (k.to_string(), v))
434 .collect();
435
436 Self {
437 backends,
438 backend_ids: backend_ids.into_iter().map(str::to_string).collect(),
439 weights,
440 agreement_bonus: 0.10,
441 min_confidence: 0.30,
442 name,
443 name_static: std::sync::OnceLock::new(),
444 }
445 }
446
447 #[must_use]
449 pub fn with_backends(backends: Vec<Box<dyn Model + Send + Sync>>) -> Self {
450 let backend_ids: Vec<String> = backends.iter().map(|b| b.name().to_string()).collect();
452 let name = format!("ensemble({})", backend_ids.join("|"));
453
454 let backends: Vec<Arc<dyn Model + Send + Sync>> =
455 backends.into_iter().map(Arc::from).collect();
456
457 let weights: HashMap<String, BackendWeight> = default_backend_weights()
458 .into_iter()
459 .map(|(k, v)| (k.to_string(), v))
460 .collect();
461
462 Self {
463 backends,
464 backend_ids,
465 weights,
466 agreement_bonus: 0.10,
467 min_confidence: 0.30,
468 name,
469 name_static: std::sync::OnceLock::new(),
470 }
471 }
472
473 #[must_use]
475 pub fn with_weights(mut self, weights: HashMap<String, BackendWeight>) -> Self {
476 self.weights = weights;
477 self
478 }
479
480 #[must_use]
482 pub fn with_agreement_bonus(mut self, bonus: f64) -> Self {
483 self.agreement_bonus = bonus;
484 self
485 }
486
487 #[must_use]
489 pub fn with_min_confidence(mut self, min: f64) -> Self {
490 self.min_confidence = min;
491 self
492 }
493
494 fn get_weight(&self, backend_name: &str, entity_type: &EntityType) -> f64 {
496 if let Some(weight) = self.weights.get(backend_name) {
497 if let Some(ref type_weights) = weight.per_type {
498 type_weights.get(entity_type)
499 } else {
500 weight.overall
501 }
502 } else {
503 0.50
505 }
506 }
507
508 fn resolve_candidates(&self, candidates: Vec<Candidate>) -> Option<Entity> {
510 if candidates.is_empty() {
511 return None;
512 }
513
514 if candidates.len() == 1 {
515 let candidate = candidates
517 .into_iter()
518 .next()
519 .expect("candidates.len() == 1 guarantees next() is Some");
520 let mut entity = candidate.entity;
521 let original_prov = entity.provenance.clone();
522 let original_confidence = entity.confidence;
523 entity.confidence *= 0.95;
525 entity.provenance = Some(anno_core::Provenance {
527 source: std::borrow::Cow::Owned(format!("ensemble({})", candidate.source)),
528 method: original_prov
530 .as_ref()
531 .map(|p| p.method)
532 .unwrap_or_else(|| method_for_backend_id(&candidate.source)),
533 pattern: original_prov.as_ref().and_then(|p| p.pattern.clone()),
534 raw_confidence: original_prov
535 .as_ref()
536 .and_then(|p| p.raw_confidence)
537 .or(Some(original_confidence)),
538 model_version: None,
539 timestamp: None,
540 });
541 return Some(entity);
542 }
543
544 let mut type_votes: HashMap<String, Vec<&Candidate>> = HashMap::new();
546 for c in &candidates {
547 let type_key = c.entity.entity_type.as_label().to_string();
548 type_votes.entry(type_key).or_default().push(c);
549 }
550
551 let mut best_type: Option<(String, f64, usize, Vec<&Candidate>)> = None;
561 for (type_key, type_candidates) in &type_votes {
562 let weighted_sum: f64 = type_candidates
563 .iter()
564 .map(|c| c.backend_weight * c.entity.confidence)
565 .sum();
566 let count = type_candidates.len();
567
568 let should_replace = match &best_type {
569 None => true,
570 Some((best_key, best_sum, best_count, _)) => {
571 if weighted_sum > *best_sum {
572 true
573 } else if weighted_sum < *best_sum {
574 false
575 } else if count > *best_count {
576 true
577 } else if count < *best_count {
578 false
579 } else {
580 type_key < best_key
581 }
582 }
583 };
584
585 if should_replace {
586 best_type = Some((
587 type_key.clone(),
588 weighted_sum,
589 count,
590 type_candidates.clone(),
591 ));
592 }
593 }
594
595 let (_type_key, weighted_sum, _count, winning_candidates) = best_type?;
596
597 let num_sources = winning_candidates.len();
599 let total_weight: f64 = winning_candidates.iter().map(|c| c.backend_weight).sum();
600
601 let base_confidence = if total_weight > 0.0 {
602 weighted_sum / total_weight
603 } else {
604 0.5
605 };
606
607 let agreement_bonus = if num_sources >= 3 {
609 self.agreement_bonus * 1.5
610 } else if num_sources >= 2 {
611 self.agreement_bonus
612 } else {
613 0.0
614 };
615
616 let final_confidence = (base_confidence + agreement_bonus).min(1.0);
617
618 let best_candidate = winning_candidates.iter().max_by(|a, b| {
621 a.entity
622 .confidence
623 .partial_cmp(&b.entity.confidence)
624 .unwrap_or(std::cmp::Ordering::Equal)
625 })?;
626
627 let sources: Vec<String> = winning_candidates
628 .iter()
629 .map(|c| c.source.clone())
630 .collect();
631
632 let total_candidates = candidates.len() as f32;
637 let num_winners = winning_candidates.len() as f32;
638
639 let linkage = if total_candidates > 0.0 {
641 (num_winners / total_candidates).min(1.0)
642 } else {
643 0.5
644 };
645
646 let type_score = final_confidence as f32;
648
649 let reference_span = (best_candidate.entity.start, best_candidate.entity.end);
652 let span_agreement_count = winning_candidates
653 .iter()
654 .filter(|c| c.entity.start == reference_span.0 && c.entity.end == reference_span.1)
655 .count();
656 let boundary = if num_winners > 0.0 {
657 (span_agreement_count as f32 / num_winners).min(1.0)
658 } else {
659 1.0
660 };
661
662 let mut entity = best_candidate.entity.clone();
663 entity.confidence = final_confidence;
664 entity.hierarchical_confidence = Some(anno_core::HierarchicalConfidence::new(
665 linkage, type_score, boundary,
666 ));
667 entity.provenance = Some(anno_core::Provenance {
668 source: Cow::Owned(format!("ensemble({})", sources.join("+"))),
669 method: anno_core::ExtractionMethod::Consensus,
670 pattern: None,
671 raw_confidence: Some(base_confidence),
672 model_version: None,
673 timestamp: None,
674 });
675
676 Some(entity)
677 }
678}
679
680impl Model for EnsembleNER {
681 fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
682 if self.backends.is_empty() {
683 return Ok(Vec::new());
684 }
685
686 let mut all_candidates: Vec<Candidate> = Vec::new();
688
689 for (i, backend) in self.backends.iter().enumerate() {
690 let backend_id = self
691 .backend_ids
692 .get(i)
693 .cloned()
694 .unwrap_or_else(|| backend.name().to_string());
695
696 match backend.extract_entities(text, language) {
697 Ok(entities) => {
698 for entity in entities {
699 let weight = self.get_weight(&backend_id, &entity.entity_type);
700 all_candidates.push(Candidate {
701 entity,
702 source: backend_id.clone(),
703 backend_weight: weight,
704 });
705 }
706 }
707 Err(e) => {
708 log::debug!(
710 "EnsembleNER: Backend {} (id={}) failed: {}",
711 backend.name(),
712 backend_id,
713 e
714 );
715 }
716 }
717 }
718
719 if all_candidates.is_empty() {
720 return Ok(Vec::new());
721 }
722
723 let mut span_groups: Vec<Vec<Candidate>> = Vec::new();
725
726 for candidate in all_candidates {
727 let span = SpanKey::from_entity(&candidate.entity);
728
729 let mut found_group = false;
731 for group in &mut span_groups {
732 if let Some(first) = group.first() {
733 let existing_span = SpanKey::from_entity(&first.entity);
734 if span.overlaps(&existing_span) {
735 group.push(candidate.clone());
736 found_group = true;
737 break;
738 }
739 }
740 }
741
742 if !found_group {
743 span_groups.push(vec![candidate]);
744 }
745 }
746
747 let mut results: Vec<Entity> = Vec::new();
749
750 for group in span_groups {
751 if let Some(entity) = self.resolve_candidates(group) {
752 if entity.confidence >= self.min_confidence {
753 results.push(entity);
754 }
755 }
756 }
757
758 results.sort_by_key(|e| (e.start, e.end));
760
761 Ok(results)
762 }
763
764 fn supported_types(&self) -> Vec<EntityType> {
765 let mut types: Vec<EntityType> = Vec::new();
767 for backend in &self.backends {
768 for t in backend.supported_types() {
769 if !types.contains(&t) {
770 types.push(t);
771 }
772 }
773 }
774 types
775 }
776
777 fn is_available(&self) -> bool {
778 self.backends.iter().any(|b| b.is_available())
780 }
781
782 fn name(&self) -> &'static str {
783 self.name_static
785 .get_or_init(|| Box::leak(self.name.clone().into_boxed_str()))
786 }
787
788 fn description(&self) -> &'static str {
789 "Ensemble NER: weighted voting across multiple backends"
790 }
791}
792
793impl crate::NamedEntityCapable for EnsembleNER {}
795
796impl crate::BatchCapable for EnsembleNER {
797 fn optimal_batch_size(&self) -> Option<usize> {
798 Some(8) }
800}
801
802impl crate::StreamingCapable for EnsembleNER {
803 fn recommended_chunk_size(&self) -> usize {
804 8192
805 }
806}
807
808#[derive(Debug, Clone)]
814pub struct WeightTrainingExample {
815 pub text: String,
817 pub gold_type: EntityType,
819 pub start: usize,
821 pub end: usize,
823 pub predictions: Vec<(String, EntityType, f64)>,
825}
826
827#[derive(Debug, Clone, Default)]
829pub struct BackendStats {
830 pub correct: usize,
832 pub total: usize,
834 pub per_type: HashMap<String, (usize, usize)>,
836}
837
838impl BackendStats {
839 pub fn precision(&self) -> f64 {
841 if self.total == 0 {
842 0.0
843 } else {
844 self.correct as f64 / self.total as f64
845 }
846 }
847
848 pub fn type_precision(&self, entity_type: &str) -> f64 {
850 if let Some((correct, total)) = self.per_type.get(entity_type) {
851 if *total == 0 {
852 0.0
853 } else {
854 *correct as f64 / *total as f64
855 }
856 } else {
857 0.0
858 }
859 }
860}
861
862pub struct WeightLearner {
885 backend_stats: HashMap<String, BackendStats>,
887 smoothing: f64,
889}
890
891impl Default for WeightLearner {
892 fn default() -> Self {
893 Self::new()
894 }
895}
896
897impl WeightLearner {
898 #[must_use]
900 pub fn new() -> Self {
901 Self {
902 backend_stats: HashMap::new(),
903 smoothing: 1.0, }
905 }
906
907 #[must_use]
909 pub fn with_smoothing(mut self, smoothing: f64) -> Self {
910 self.smoothing = smoothing;
911 self
912 }
913
914 pub fn add_example(&mut self, example: &WeightTrainingExample) {
916 for (backend_name, predicted_type, _confidence) in &example.predictions {
917 let stats = self.backend_stats.entry(backend_name.clone()).or_default();
918
919 stats.total += 1;
920 let correct = *predicted_type == example.gold_type;
921 if correct {
922 stats.correct += 1;
923 }
924
925 let type_key = example.gold_type.as_label().to_string();
927 let type_stats = stats.per_type.entry(type_key).or_insert((0, 0));
928 type_stats.1 += 1;
929 if correct {
930 type_stats.0 += 1;
931 }
932 }
933 }
934
935 pub fn add_from_backends(
939 &mut self,
940 text: &str,
941 gold_entities: &[Entity],
942 backends: &[(&str, &dyn Model)],
943 ) {
944 let mut backend_preds: HashMap<String, Vec<Entity>> = HashMap::new();
946 for (name, backend) in backends {
947 if let Ok(entities) = backend.extract_entities(text, None) {
948 backend_preds.insert(name.to_string(), entities);
949 }
950 }
951
952 for gold in gold_entities {
954 let mut example = WeightTrainingExample {
955 text: gold.text.clone(),
956 gold_type: gold.entity_type.clone(),
957 start: gold.start,
958 end: gold.end,
959 predictions: Vec::new(),
960 };
961
962 for (backend_name, entities) in &backend_preds {
963 for pred in entities {
965 if pred.start == gold.start && pred.end == gold.end {
966 example.predictions.push((
967 backend_name.clone(),
968 pred.entity_type.clone(),
969 pred.confidence,
970 ));
971 break;
972 }
973 }
974 }
975
976 if !example.predictions.is_empty() {
977 self.add_example(&example);
978 }
979 }
980 }
981
982 pub fn learn_weights(&self) -> HashMap<String, BackendWeight> {
986 let mut weights = HashMap::new();
987
988 for (backend_name, stats) in &self.backend_stats {
989 let smoothed_precision = (stats.correct as f64 + self.smoothing)
991 / (stats.total as f64 + 2.0 * self.smoothing);
992
993 let mut type_weights = TypeWeights::default();
995 for (type_key, (correct, total)) in &stats.per_type {
996 let type_precision =
997 (*correct as f64 + self.smoothing) / (*total as f64 + 2.0 * self.smoothing);
998
999 match type_key.as_str() {
1000 "PER" | "PERSON" => type_weights.person = type_precision,
1001 "ORG" | "ORGANIZATION" => type_weights.organization = type_precision,
1002 "LOC" | "LOCATION" | "GPE" => type_weights.location = type_precision,
1003 "DATE" => type_weights.date = type_precision,
1004 "MONEY" => type_weights.money = type_precision,
1005 _ => type_weights.other = type_precision,
1006 }
1007 }
1008
1009 weights.insert(
1010 backend_name.clone(),
1011 BackendWeight {
1012 overall: smoothed_precision,
1013 per_type: Some(type_weights),
1014 },
1015 );
1016 }
1017
1018 weights
1019 }
1020
1021 pub fn get_stats(&self, backend_name: &str) -> Option<&BackendStats> {
1023 self.backend_stats.get(backend_name)
1024 }
1025
1026 pub fn backend_names(&self) -> Vec<&String> {
1028 self.backend_stats.keys().collect()
1029 }
1030}
1031
1032#[cfg(test)]
1037mod tests {
1038 use super::*;
1039 use anno_core::ExtractionMethod;
1040
1041 fn fast_ensemble() -> EnsembleNER {
1042 EnsembleNER::with_backends(vec![
1044 Box::new(crate::RegexNER::new()),
1045 Box::new(crate::HeuristicNER::new()),
1046 ])
1047 }
1048
1049 #[test]
1050 fn test_new_backend_ids_have_weights() {
1051 let ner = EnsembleNER::new();
1052
1053 assert!(
1055 !ner.backend_ids.is_empty(),
1056 "EnsembleNER::new() should have at least one backend"
1057 );
1058
1059 for id in &ner.backend_ids {
1060 assert!(
1061 ner.weights.contains_key(id),
1062 "EnsembleNER::new(): missing weight for backend id={:?}. This usually means the ensemble's advertised IDs drifted from default_backend_weights keys.",
1063 id
1064 );
1065 }
1066 }
1067
1068 #[test]
1069 fn test_ensemble_basic() {
1070 let ner = fast_ensemble();
1071 let entities = ner
1072 .extract_entities("Tim Cook is the CEO of Apple Inc.", None)
1073 .unwrap();
1074
1075 assert!(!entities.is_empty(), "Should extract entities");
1077
1078 for e in &entities {
1080 assert!(
1081 e.provenance.is_some(),
1082 "All entities should have provenance"
1083 );
1084 }
1085 }
1086
1087 #[test]
1088 fn test_span_overlap() {
1089 let span1 = SpanKey { start: 0, end: 10 };
1093 let span2 = SpanKey { start: 3, end: 15 }; let span3 = SpanKey { start: 20, end: 30 };
1095
1096 assert!(span1.overlaps(&span2), "Overlapping spans should match");
1097 assert!(
1098 !span1.overlaps(&span3),
1099 "Non-overlapping spans should not match"
1100 );
1101 }
1102
1103 #[test]
1104 fn test_backend_weights() {
1105 let weights = default_backend_weights();
1106
1107 assert!(weights["regex"].overall > 0.9);
1109
1110 assert!(weights["gliner"].overall > 0.8);
1112
1113 assert!(weights["heuristic"].overall < 0.7);
1115 }
1116
1117 #[test]
1118 fn test_type_specific_weights() {
1119 let weights = default_backend_weights();
1120
1121 let pattern_date = weights["regex"].per_type.as_ref().unwrap().date;
1123 let heuristic_date = weights["heuristic"].per_type.as_ref().unwrap().date;
1124 assert!(pattern_date > heuristic_date);
1125
1126 let heuristic_org = weights["heuristic"].per_type.as_ref().unwrap().organization;
1128 assert!(heuristic_org > 0.6);
1129 }
1130
1131 #[test]
1132 fn test_agreement_bonus() {
1133 let ner = fast_ensemble().with_agreement_bonus(0.15);
1134 assert!((ner.agreement_bonus - 0.15).abs() < 0.001);
1135 }
1136
1137 #[test]
1138 fn test_weight_learner_basic() {
1139 let mut learner = WeightLearner::new();
1140
1141 learner.add_example(&WeightTrainingExample {
1143 text: "Apple".to_string(),
1144 gold_type: EntityType::Organization,
1145 start: 0,
1146 end: 5,
1147 predictions: vec![
1148 ("heuristic".to_string(), EntityType::Organization, 0.8),
1149 ("gliner".to_string(), EntityType::Organization, 0.9),
1150 ],
1151 });
1152
1153 learner.add_example(&WeightTrainingExample {
1154 text: "Paris".to_string(),
1155 gold_type: EntityType::Location,
1156 start: 0,
1157 end: 5,
1158 predictions: vec![
1159 ("heuristic".to_string(), EntityType::Person, 0.6), ("gliner".to_string(), EntityType::Location, 0.85),
1161 ],
1162 });
1163
1164 let weights = learner.learn_weights();
1166
1167 let gliner_weight = weights.get("gliner").map(|w| w.overall).unwrap_or(0.0);
1169 let heuristic_weight = weights.get("heuristic").map(|w| w.overall).unwrap_or(0.0);
1170
1171 assert!(
1172 gliner_weight > heuristic_weight,
1173 "GLiNER should have higher weight (was {} vs {})",
1174 gliner_weight,
1175 heuristic_weight
1176 );
1177 }
1178
1179 #[test]
1180 fn test_backend_stats() {
1181 let mut stats = BackendStats {
1182 correct: 8,
1183 total: 10,
1184 ..Default::default()
1185 };
1186 stats.per_type.insert("PER".to_string(), (5, 6));
1187
1188 assert!((stats.precision() - 0.8).abs() < 0.01);
1189 assert!((stats.type_precision("PER") - 0.833).abs() < 0.01);
1190 assert!((stats.type_precision("ORG") - 0.0).abs() < 0.01); }
1192
1193 #[test]
1198 fn test_empty_text() {
1199 let ner = fast_ensemble();
1200 let entities = ner.extract_entities("", None).unwrap();
1201 assert!(entities.is_empty());
1202 }
1203
1204 #[test]
1205 fn test_whitespace_only_text() {
1206 let ner = fast_ensemble();
1207 let entities = ner.extract_entities(" \t\n ", None).unwrap();
1208 assert!(entities.is_empty());
1209 }
1210
1211 #[test]
1212 fn test_resolve_candidates_tie_break_is_order_independent() {
1213 let ner = fast_ensemble();
1214 let span_text = "Apple";
1215 let span = (0, 5);
1216
1217 let e_person = Entity::new(span_text, EntityType::Person, span.0, span.1, 0.5);
1218 let e_org = Entity::new(span_text, EntityType::Organization, span.0, span.1, 0.5);
1219
1220 let c1 = Candidate {
1221 entity: e_person,
1222 source: "heuristic".to_string(),
1223 backend_weight: 1.0,
1224 };
1225 let c2 = Candidate {
1226 entity: e_org,
1227 source: "heuristic".to_string(),
1228 backend_weight: 1.0,
1229 };
1230
1231 let out_a = ner
1232 .resolve_candidates(vec![c1.clone(), c2.clone()])
1233 .expect("should resolve");
1234 let out_b = ner
1235 .resolve_candidates(vec![c2, c1])
1236 .expect("should resolve");
1237
1238 assert_eq!(
1239 out_a.entity_type, out_b.entity_type,
1240 "tie resolution should not depend on candidate order"
1241 );
1242
1243 let key_a = out_a.entity_type.as_label().to_string();
1244 let person_key = EntityType::Person.as_label().to_string();
1245 let org_key = EntityType::Organization.as_label().to_string();
1246 let expected = std::cmp::min(person_key, org_key);
1247 assert_eq!(
1248 key_a, expected,
1249 "tie-break should choose lexicographically smallest type label"
1250 );
1251 }
1252
1253 #[test]
1254 fn test_single_source_preserves_underlying_method_and_pattern() {
1255 let ner = EnsembleNER::with_backends(vec![Box::new(crate::RegexNER::new())]);
1258 let text = "Contact test@email.com on 2024-01-15";
1259 let entities = ner.extract_entities(text, None).expect("extract");
1260 assert!(!entities.is_empty());
1261
1262 let email = entities
1263 .iter()
1264 .find(|e| e.text == "test@email.com")
1265 .expect("email entity should exist");
1266 let prov = email.provenance.as_ref().expect("provenance");
1267
1268 assert_eq!(prov.method, ExtractionMethod::Pattern);
1269 assert!(
1270 prov.pattern.is_some(),
1271 "expected to preserve regex pattern name"
1272 );
1273 }
1274
1275 #[test]
1276 fn test_nested_single_source_preserves_inner_method() {
1277 let inner = EnsembleNER::with_backends(vec![Box::new(crate::HeuristicNER::new())]);
1280 let outer = EnsembleNER::with_backends(vec![Box::new(inner)]);
1281
1282 let text = "John Smith visited Paris.";
1283 let entities = outer.extract_entities(text, None).expect("extract");
1284 assert!(!entities.is_empty());
1285
1286 for e in &entities {
1287 let prov = e.provenance.as_ref().expect("provenance");
1288 assert_eq!(
1289 prov.method,
1290 ExtractionMethod::Heuristic,
1291 "expected outer to preserve inner method"
1292 );
1293 }
1294 }
1295
1296 #[test]
1297 fn test_span_key_self_overlap() {
1298 let span = SpanKey { start: 0, end: 10 };
1299 assert!(span.overlaps(&span), "Span should overlap with itself");
1300 }
1301
1302 #[test]
1303 fn test_span_key_adjacent_no_overlap() {
1304 let span1 = SpanKey { start: 0, end: 10 };
1305 let span2 = SpanKey { start: 10, end: 20 };
1306 assert!(!span1.overlaps(&span2), "Adjacent spans should not overlap");
1307 }
1308
1309 #[test]
1310 fn test_span_key_contained() {
1311 let outer = SpanKey { start: 0, end: 20 };
1312 let inner = SpanKey { start: 5, end: 15 };
1313 assert!(outer.overlaps(&inner), "Contained spans should overlap");
1314 assert!(inner.overlaps(&outer), "Overlap should be symmetric");
1315 }
1316
1317 #[test]
1318 fn test_backend_stats_empty() {
1319 let stats = BackendStats::default();
1320 assert!((stats.precision() - 0.0).abs() < 0.001);
1321 assert!((stats.type_precision("ANY") - 0.0).abs() < 0.001);
1322 }
1323
1324 #[test]
1325 fn test_weight_learner_empty() {
1326 let learner = WeightLearner::new();
1327 let weights = learner.learn_weights();
1328 let _ = weights.len();
1331 }
1332
1333 #[test]
1334 fn test_ensemble_with_language() {
1335 let ner = fast_ensemble();
1336
1337 let entities = ner
1339 .extract_entities("Tim Cook is the CEO of Apple.", Some("en"))
1340 .unwrap();
1341
1342 assert!(
1344 !entities.is_empty(),
1345 "Should find entities with language hint"
1346 );
1347 }
1348
1349 #[test]
1350 fn test_type_weights_structure() {
1351 let weights = TypeWeights {
1352 person: 0.9,
1353 location: 0.85,
1354 organization: 0.88,
1355 date: 0.95,
1356 money: 0.8,
1357 other: 0.7,
1358 };
1359
1360 assert!(weights.person > 0.0);
1361 assert!(weights.date > weights.other);
1362 }
1363
1364 #[test]
1365 fn test_backend_weight_structure() {
1366 let weight = BackendWeight {
1367 overall: 0.85,
1368 per_type: Some(TypeWeights {
1369 person: 0.9,
1370 location: 0.88,
1371 organization: 0.87,
1372 date: 0.92,
1373 money: 0.85,
1374 other: 0.75,
1375 }),
1376 };
1377
1378 assert!(weight.overall > 0.0);
1379 assert!(weight.per_type.is_some());
1380 }
1381
1382 #[test]
1383 fn test_unicode_extraction() {
1384 let ner = EnsembleNER::new();
1385 let entities = ner
1386 .extract_entities("東京で会議がありました。", None)
1387 .unwrap();
1388
1389 for e in &entities {
1391 assert!(e.confidence >= 0.0 && e.confidence <= 1.0);
1392 }
1393 }
1394
1395 #[test]
1396 fn test_ensemble_provenance_tracking() {
1397 let ner = EnsembleNER::new();
1398 let entities = ner
1399 .extract_entities("Barack Obama visited Paris yesterday.", None)
1400 .unwrap();
1401
1402 for e in &entities {
1403 assert!(
1405 e.provenance.is_some(),
1406 "Entity '{}' ({:?}) at {}..{} has no provenance",
1407 e.text,
1408 e.entity_type,
1409 e.start,
1410 e.end
1411 );
1412 let prov = e.provenance.as_ref().unwrap();
1413 assert!(!prov.source.is_empty());
1415 }
1416 }
1417}