1use crate::linking::candidate::CandidateSource;
48use crate::linking::linker::LinkedEntity;
49use crate::{Entity, EntityType, Result};
50use anno_core::{CorefChain, Mention as CorefMention};
51use serde::{Deserialize, Serialize};
52use std::collections::HashMap;
53
54use super::factors::{
55 CorefLinkFactor, CorefLinkWeights, CorefNerFactor, CorefNerWeights, Factor, LinkNerFactor,
56 LinkNerWeights, UnaryCorefFactor, UnaryLinkFactor, UnaryNerFactor, WikipediaKnowledgeStore,
57};
58use super::inference::{BeliefPropagation, InferenceConfig, Marginals};
59use std::sync::Arc;
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67pub struct VariableId {
68 pub mention_idx: usize,
70 pub var_type: VariableType,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum VariableType {
77 Antecedent,
79 SemanticType,
81 EntityLink,
83}
84
85#[derive(Debug, Clone)]
87pub enum JointVariable {
88 Antecedent {
90 mention_idx: usize,
92 candidates: Vec<usize>,
94 },
95 SemanticType {
97 mention_idx: usize,
99 types: Vec<EntityType>,
101 },
102 EntityLink {
104 mention_idx: usize,
106 candidates: Vec<String>,
108 },
109}
110
111impl JointVariable {
112 pub fn id(&self) -> VariableId {
114 match self {
115 JointVariable::Antecedent { mention_idx, .. } => VariableId {
116 mention_idx: *mention_idx,
117 var_type: VariableType::Antecedent,
118 },
119 JointVariable::SemanticType { mention_idx, .. } => VariableId {
120 mention_idx: *mention_idx,
121 var_type: VariableType::SemanticType,
122 },
123 JointVariable::EntityLink { mention_idx, .. } => VariableId {
124 mention_idx: *mention_idx,
125 var_type: VariableType::EntityLink,
126 },
127 }
128 }
129
130 pub fn domain_size(&self) -> usize {
132 match self {
133 JointVariable::Antecedent { candidates, .. } => candidates.len() + 1, JointVariable::SemanticType { types, .. } => types.len(),
135 JointVariable::EntityLink { candidates, .. } => candidates.len() + 1, }
137 }
138}
139
140#[derive(Debug, Clone)]
142pub enum VariableDomain {
143 Antecedent(Vec<AntecedentValue>),
145 SemanticType(Vec<EntityType>),
147 EntityLink(Vec<LinkValue>),
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153pub enum AntecedentValue {
154 Mention(usize),
156 NewCluster,
158}
159
160#[derive(Debug, Clone, PartialEq, Eq, Hash)]
162pub enum LinkValue {
163 KbId(String),
165 Nil,
167}
168
169#[derive(Debug, Clone, Default)]
175pub struct Assignment {
176 pub antecedents: HashMap<usize, AntecedentValue>,
178 pub types: HashMap<usize, EntityType>,
180 pub links: HashMap<usize, LinkValue>,
182}
183
184impl Assignment {
185 pub fn get_antecedent(&self, mention_idx: usize) -> Option<AntecedentValue> {
187 self.antecedents.get(&mention_idx).copied()
188 }
189
190 pub fn get_type(&self, mention_idx: usize) -> Option<EntityType> {
192 self.types.get(&mention_idx).cloned()
193 }
194
195 pub fn get_link(&self, mention_idx: usize) -> Option<&LinkValue> {
197 self.links.get(&mention_idx)
198 }
199
200 pub fn set_antecedent(&mut self, mention_idx: usize, value: AntecedentValue) {
202 self.antecedents.insert(mention_idx, value);
203 }
204
205 pub fn set_type(&mut self, mention_idx: usize, value: EntityType) {
207 self.types.insert(mention_idx, value);
208 }
209
210 pub fn set_link(&mut self, mention_idx: usize, value: LinkValue) {
212 self.links.insert(mention_idx, value);
213 }
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
222pub enum MentionKind {
223 Proper,
225 Nominal,
227 Pronominal,
229}
230
231impl MentionKind {
232 pub fn from_text(text: &str) -> Self {
234 let lower = text.to_lowercase();
235 let pronouns = [
236 "he",
237 "she",
238 "it",
239 "they",
240 "him",
241 "her",
242 "them",
243 "his",
244 "hers",
245 "its",
246 "their",
247 "himself",
248 "herself",
249 "itself",
250 "themselves",
251 "who",
252 "whom",
253 "which",
254 "that",
255 ];
256
257 if pronouns.contains(&lower.as_str()) {
258 MentionKind::Pronominal
259 } else if text.chars().next().is_some_and(|c| c.is_uppercase()) {
260 MentionKind::Proper
261 } else {
262 MentionKind::Nominal
263 }
264 }
265
266 pub fn is_proper_name(&self) -> bool {
268 matches!(self, MentionKind::Proper)
269 }
270
271 pub fn is_pronoun(&self) -> bool {
273 matches!(self, MentionKind::Pronominal)
274 }
275
276 pub fn is_nominal(&self) -> bool {
278 matches!(self, MentionKind::Nominal)
279 }
280}
281
282#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
291pub enum EventCorefRelation {
292 Identity,
294 ConceptInstance,
297 WholeSubevent,
300 SetMember,
303 TopicallyRelated,
305 NotRelated,
307 CannotDecide,
309}
310
311impl EventCorefRelation {
312 pub fn is_positive(&self) -> bool {
316 matches!(
317 self,
318 EventCorefRelation::Identity
319 | EventCorefRelation::ConceptInstance
320 | EventCorefRelation::WholeSubevent
321 | EventCorefRelation::SetMember
322 )
323 }
324
325 pub fn to_binary(&self) -> bool {
329 matches!(self, EventCorefRelation::Identity)
330 }
331
332 pub fn to_strict_binary(&self) -> bool {
336 self.is_positive()
337 }
338}
339
340#[derive(Debug, Clone)]
350pub struct DecontextualizedMention {
351 pub original_text: String,
353 pub decontextualized: String,
355 pub doc_id: String,
357 pub original_start: usize,
359 pub original_end: usize,
361 pub resolved_entities: Vec<(String, String)>, }
364
365impl DecontextualizedMention {
366 pub fn new(
368 original_text: impl Into<String>,
369 decontextualized: impl Into<String>,
370 doc_id: impl Into<String>,
371 original_start: usize,
372 original_end: usize,
373 ) -> Self {
374 Self {
375 original_text: original_text.into(),
376 decontextualized: decontextualized.into(),
377 doc_id: doc_id.into(),
378 original_start,
379 original_end,
380 resolved_entities: Vec::new(),
381 }
382 }
383
384 pub fn with_resolved(
386 mut self,
387 reference: impl Into<String>,
388 resolved: impl Into<String>,
389 ) -> Self {
390 self.resolved_entities
391 .push((reference.into(), resolved.into()));
392 self
393 }
394}
395
396#[derive(Debug, Clone)]
398pub struct EventMention {
399 pub id: String,
401 pub trigger: String,
403 pub event_type: Option<String>,
405 pub doc_id: String,
407 pub start: usize,
409 pub end: usize,
411 pub decontextualized: Option<DecontextualizedMention>,
413}
414
415#[derive(Debug, Clone)]
421pub struct JointMention {
422 pub idx: usize,
424 pub text: String,
426 pub head: String,
428 pub start: usize,
430 pub end: usize,
432 pub mention_kind: MentionKind,
434 pub entity_type: Option<EntityType>,
436 pub entity: Option<Entity>,
438}
439
440impl JointMention {
441 pub fn from_entity(idx: usize, entity: &Entity, text: &str) -> Self {
443 let mention_text = text
444 .chars()
445 .skip(entity.start)
446 .take(entity.end - entity.start)
447 .collect::<String>();
448
449 let head = mention_text
450 .split_whitespace()
451 .last()
452 .unwrap_or(&mention_text)
453 .to_string();
454
455 Self {
456 idx,
457 text: mention_text.clone(),
458 head,
459 start: entity.start,
460 end: entity.end,
461 mention_kind: MentionKind::from_text(&mention_text),
462 entity_type: Some(entity.entity_type.clone()),
463 entity: Some(entity.clone()),
464 }
465 }
466}
467
468#[derive(Debug, Clone)]
474pub struct JointConfig {
475 pub enable_link_ner: bool,
477 pub enable_coref_ner: bool,
479 pub enable_coref_link: bool,
481
482 pub max_iterations: usize,
484 pub convergence_threshold: f64,
486
487 pub pruning_threshold: f64,
489 pub max_antecedent_candidates: usize,
491
492 pub max_link_candidates: usize,
494
495 pub entity_types: Vec<EntityType>,
497}
498
499impl Default for JointConfig {
500 fn default() -> Self {
501 Self {
502 enable_link_ner: true,
503 enable_coref_ner: true,
504 enable_coref_link: true,
505
506 max_iterations: 5,
507 convergence_threshold: 1e-4,
508
509 pruning_threshold: 5.0, max_antecedent_candidates: 50,
511
512 max_link_candidates: 20,
513
514 entity_types: vec![
516 EntityType::Person,
517 EntityType::Organization,
518 EntityType::Location,
519 EntityType::Date,
520 EntityType::Time,
521 EntityType::Money,
522 EntityType::Percent,
523 EntityType::Other("MISC".to_string()),
524 ],
525 }
526 }
527}
528
529#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct JointResult {
536 pub entities: Vec<Entity>,
538 pub chains: Vec<CorefChain>,
540 pub links: Vec<LinkedEntity>,
542 pub confidences: Vec<f64>,
544}
545
546pub struct CoarsePruner {
556 pub threshold: f64,
558 pub max_candidates: usize,
560 pub string_match_weight: f64,
562 pub distance_weight: f64,
564}
565
566impl Default for CoarsePruner {
567 fn default() -> Self {
568 Self {
569 threshold: 5.0, max_candidates: 50,
571 string_match_weight: 2.0,
572 distance_weight: 0.1,
573 }
574 }
575}
576
577impl CoarsePruner {
578 pub fn prune_candidates(&self, mention_idx: usize, mentions: &[JointMention]) -> Vec<usize> {
580 if mention_idx == 0 {
581 return vec![];
582 }
583
584 let mention = &mentions[mention_idx];
585
586 let mut scored: Vec<(usize, f64)> = (0..mention_idx)
588 .map(|ante_idx| {
589 let score = self.score_pair(mention, &mentions[ante_idx], mention_idx - ante_idx);
590 (ante_idx, score)
591 })
592 .collect();
593
594 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
596
597 if scored.is_empty() {
598 return vec![];
599 }
600
601 let best_score = scored[0].1;
603
604 scored
606 .into_iter()
607 .take_while(|(_, score)| best_score - *score <= self.threshold)
608 .take(self.max_candidates)
609 .map(|(idx, _)| idx)
610 .collect()
611 }
612
613 fn score_pair(
615 &self,
616 mention: &JointMention,
617 antecedent: &JointMention,
618 distance: usize,
619 ) -> f64 {
620 let mut score = 0.0;
621
622 let m_lower = mention.text.to_lowercase();
624 let a_lower = antecedent.text.to_lowercase();
625 let m_head = mention.head.to_lowercase();
626 let a_head = antecedent.head.to_lowercase();
627
628 if m_lower == a_lower {
630 score += self.string_match_weight * 1.0;
631 }
632 else if m_head == a_head {
634 score += self.string_match_weight * 0.6;
635 }
636 else if m_lower.contains(&a_lower) || a_lower.contains(&m_lower) {
638 score += self.string_match_weight * 0.3;
639 }
640
641 match (mention.mention_kind, antecedent.mention_kind) {
643 (MentionKind::Pronominal, MentionKind::Proper) => score += 0.5,
645 (a, b) if a == b => score += 0.3,
647 _ => {}
648 }
649
650 score -= self.distance_weight * (distance as f64 + 1.0).ln();
652
653 score
654 }
655}
656
657pub struct JointModel {
666 config: JointConfig,
667 pruner: CoarsePruner,
669 knowledge_store: Option<Arc<WikipediaKnowledgeStore>>,
671 ner_provider: Option<Arc<dyn NerScoreProvider>>,
673 coref_provider: Option<Arc<dyn CorefScoreProvider>>,
675 link_provider: Option<Arc<dyn LinkScoreProvider>>,
677}
678
679impl Default for JointModel {
680 fn default() -> Self {
681 Self::new(JointConfig::default()).expect("default config should always succeed")
682 }
683}
684
685impl JointModel {
686 pub fn new(config: JointConfig) -> Result<Self> {
688 let pruner = CoarsePruner {
689 threshold: config.pruning_threshold,
690 max_candidates: config.max_antecedent_candidates,
691 ..Default::default()
692 };
693
694 Ok(Self {
695 config,
696 pruner,
697 knowledge_store: None,
698 ner_provider: None,
699 coref_provider: None,
700 link_provider: None,
701 })
702 }
703
704 pub fn with_knowledge(mut self, store: Arc<WikipediaKnowledgeStore>) -> Self {
706 self.knowledge_store = Some(store);
707 self
708 }
709
710 pub fn with_ner_provider(mut self, provider: Arc<dyn NerScoreProvider>) -> Self {
712 self.ner_provider = Some(provider);
713 self
714 }
715
716 pub fn with_coref_provider(mut self, provider: Arc<dyn CorefScoreProvider>) -> Self {
718 self.coref_provider = Some(provider);
719 self
720 }
721
722 pub fn with_link_provider(mut self, provider: Arc<dyn LinkScoreProvider>) -> Self {
724 self.link_provider = Some(provider);
725 self
726 }
727
728 pub fn analyze(&self, text: &str, entities: &[Entity]) -> Result<JointResult> {
730 let mentions: Vec<JointMention> = entities
732 .iter()
733 .enumerate()
734 .map(|(i, e)| JointMention::from_entity(i, e, text))
735 .collect();
736
737 if mentions.is_empty() {
738 return Ok(JointResult {
739 entities: vec![],
740 chains: vec![],
741 links: vec![],
742 confidences: vec![],
743 });
744 }
745
746 let variables = self.build_variables(&mentions);
748
749 let factors = self.build_factors(&mentions, &variables);
751
752 let inference_config = InferenceConfig {
754 max_iterations: self.config.max_iterations,
755 convergence_threshold: self.config.convergence_threshold,
756 ..Default::default()
757 };
758 let mut bp = BeliefPropagation::new(factors, variables.clone(), inference_config);
759 let marginals = bp.run();
760
761 let (entities_out, chains, links, confidences) =
763 self.decode(&mentions, &variables, &marginals);
764
765 Ok(JointResult {
766 entities: entities_out,
767 chains,
768 links,
769 confidences,
770 })
771 }
772
773 fn build_variables(&self, mentions: &[JointMention]) -> Vec<JointVariable> {
775 let mut variables = Vec::new();
776
777 for (i, _mention) in mentions.iter().enumerate() {
778 if i > 0 {
780 let pruned = self.pruner.prune_candidates(i, mentions);
781 variables.push(JointVariable::Antecedent {
782 mention_idx: i,
783 candidates: pruned,
784 });
785 }
786
787 variables.push(JointVariable::SemanticType {
789 mention_idx: i,
790 types: self.config.entity_types.clone(),
791 });
792
793 let link_candidates: Vec<String> = vec![];
796 variables.push(JointVariable::EntityLink {
797 mention_idx: i,
798 candidates: link_candidates,
799 });
800 }
801
802 variables
803 }
804
805 fn build_factors(
807 &self,
808 mentions: &[JointMention],
809 _variables: &[JointVariable],
810 ) -> Vec<Box<dyn Factor>> {
811 let mut factors: Vec<Box<dyn Factor>> = Vec::new();
812
813 for mention in mentions {
814 let i = mention.idx;
815
816 let type_scores: Vec<(EntityType, f64)> = if let Some(ref provider) = self.ner_provider
818 {
819 provider.type_scores(mention, mention.text.as_str())
820 } else {
821 let original_type = mention.entity.as_ref().map(|e| &e.entity_type);
822 self.config
823 .entity_types
824 .iter()
825 .map(|t| {
826 let score = if original_type == Some(t) {
827 10.0 } else {
829 -5.0 };
831 (t.clone(), score)
832 })
833 .collect()
834 };
835 factors.push(Box::new(UnaryNerFactor::new(i, type_scores)));
836
837 if i > 0 {
839 let candidates: Vec<usize> =
840 (0..i).take(self.config.max_antecedent_candidates).collect();
841 let coref_scores: Vec<(AntecedentValue, f64)> =
842 if let Some(ref provider) = self.coref_provider {
843 let cand_refs: Vec<&JointMention> =
845 candidates.iter().map(|&idx| &mentions[idx]).collect();
846 provider.antecedent_scores(mention, &cand_refs, mention.text.as_str())
847 } else {
848 let mut scores: Vec<(AntecedentValue, f64)> = candidates
849 .iter()
850 .map(|&ante| {
851 let ante_mention = &mentions[ante];
852 let head_match = if mention.head.to_lowercase()
853 == ante_mention.head.to_lowercase()
854 {
855 2.0
856 } else {
857 0.0
858 };
859 let distance_penalty = -0.1 * (i - ante) as f64;
860 (
861 AntecedentValue::Mention(ante),
862 head_match + distance_penalty,
863 )
864 })
865 .collect();
866 scores.push((AntecedentValue::NewCluster, 0.0));
867 scores
868 };
869 factors.push(Box::new(UnaryCorefFactor::new(i, coref_scores)));
870 }
871
872 let link_candidates_raw = if let Some(ref provider) = self.link_provider {
874 provider.link_candidates(mention, mention.text.as_str())
875 } else {
876 vec![]
877 };
878 let link_candidates: Vec<(LinkValue, f64)> = link_candidates_raw
879 .into_iter()
880 .map(|(id, score)| {
881 let lv = if id == "NIL" {
882 LinkValue::Nil
883 } else {
884 LinkValue::KbId(id)
885 };
886 (lv, score)
887 })
888 .collect();
889 factors.push(Box::new(UnaryLinkFactor::new(i, link_candidates)));
890 }
891
892 for mention in mentions {
894 let i = mention.idx;
895
896 if i > 0 {
897 let candidates: Vec<usize> =
898 (0..i).take(self.config.max_antecedent_candidates).collect();
899
900 for &ante in &candidates {
901 if self.config.enable_coref_ner {
903 factors.push(Box::new(CorefNerFactor::new(
904 i,
905 ante,
906 CorefNerWeights::default(),
907 )));
908 }
909
910 if self.config.enable_coref_link {
912 let mut factor = CorefLinkFactor::new(i, ante, CorefLinkWeights::default());
913 if let Some(ref store) = self.knowledge_store {
914 factor = factor.with_knowledge(store.clone());
915 }
916 factors.push(Box::new(factor));
917 }
918 }
919 }
920
921 if self.config.enable_link_ner {
923 let mut factor = LinkNerFactor::new(i, LinkNerWeights::default());
924 if let Some(ref store) = self.knowledge_store {
925 factor = factor.with_knowledge(store.clone());
926 }
927 factors.push(Box::new(factor));
928 }
929 }
930
931 factors
932 }
933
934 fn decode(
936 &self,
937 mentions: &[JointMention],
938 variables: &[JointVariable],
939 marginals: &Marginals,
940 ) -> (Vec<Entity>, Vec<CorefChain>, Vec<LinkedEntity>, Vec<f64>) {
941 let mut entities = Vec::new();
942 let mut links = Vec::new();
943 let mut confidences = Vec::new();
944 let mut antecedents: HashMap<usize, AntecedentValue> = HashMap::new();
945
946 for var in variables {
947 let var_id = var.id();
948 if let Some(best_idx) = marginals.argmax(&var_id) {
949 let prob = marginals.prob(&var_id, best_idx).unwrap_or(0.0);
950
951 match var {
952 JointVariable::Antecedent {
953 mention_idx,
954 candidates,
955 } => {
956 let value = if best_idx < candidates.len() {
957 AntecedentValue::Mention(candidates[best_idx])
958 } else {
959 AntecedentValue::NewCluster
960 };
961 antecedents.insert(*mention_idx, value);
962 }
963 JointVariable::SemanticType {
964 mention_idx, types, ..
965 } => {
966 let m = &mentions[*mention_idx];
967 let (entity_type, conf) = if let Some(inferred_type) = types.get(best_idx) {
969 if prob > 0.3 {
971 (inferred_type.clone(), prob)
972 } else if let Some(original) = &m.entity {
973 (original.entity_type.clone(), original.confidence)
975 } else {
976 (inferred_type.clone(), prob)
977 }
978 } else if let Some(original) = &m.entity {
979 (original.entity_type.clone(), original.confidence)
981 } else {
982 continue;
984 };
985 entities.push(Entity::new(&m.text, entity_type, m.start, m.end, conf));
986 confidences.push(conf);
987 }
988 JointVariable::EntityLink {
989 mention_idx,
990 candidates,
991 } => {
992 let link_value = if best_idx < candidates.len() {
993 LinkValue::KbId(candidates[best_idx].clone())
994 } else {
995 LinkValue::Nil
996 };
997 if let LinkValue::KbId(kb_id) = link_value {
998 let m = &mentions[*mention_idx];
999 links.push(LinkedEntity {
1000 mention_text: m.text.clone(),
1001 start: m.start,
1002 end: m.end,
1003 kb_id: Some(kb_id),
1004 source: CandidateSource::Wikidata,
1005 label: None,
1006 iri: None,
1007 confidence: prob,
1008 is_nil: false,
1009 nil_reason: None,
1010 nil_action: None,
1011 alternatives: Vec::new(),
1012 });
1013 }
1014 }
1015 }
1016 }
1017 }
1018
1019 let chains = self.build_chains(&antecedents, mentions);
1021
1022 (entities, chains, links, confidences)
1023 }
1024
1025 fn build_chains(
1027 &self,
1028 antecedents: &HashMap<usize, AntecedentValue>,
1029 mentions: &[JointMention],
1030 ) -> Vec<CorefChain> {
1031 let n_mentions = mentions.len();
1032 let mut parent: Vec<usize> = (0..n_mentions).collect();
1034
1035 fn find(parent: &mut [usize], i: usize) -> usize {
1036 if parent[i] != i {
1037 parent[i] = find(parent, parent[i]);
1038 }
1039 parent[i]
1040 }
1041
1042 fn union(parent: &mut [usize], i: usize, j: usize) {
1043 let pi = find(parent, i);
1044 let pj = find(parent, j);
1045 if pi != pj {
1046 parent[pi] = pj;
1047 }
1048 }
1049
1050 for (&mention_idx, &ante_value) in antecedents {
1052 if let AntecedentValue::Mention(ante_idx) = ante_value {
1053 union(&mut parent, mention_idx, ante_idx);
1054 }
1055 }
1056
1057 let mut clusters: HashMap<usize, Vec<usize>> = HashMap::new();
1059 for i in 0..n_mentions {
1060 let root = find(&mut parent, i);
1061 clusters.entry(root).or_default().push(i);
1062 }
1063
1064 clusters
1066 .into_iter()
1067 .filter(|(_, members)| members.len() > 1) .enumerate()
1069 .map(|(chain_id, (_, mut members))| {
1070 members.sort();
1071 let coref_mentions: Vec<CorefMention> = members
1072 .iter()
1073 .map(|&idx| {
1074 let m = &mentions[idx];
1075 CorefMention {
1076 text: m.text.clone(),
1077 start: m.start,
1078 end: m.end,
1079 head_start: None,
1080 head_end: None,
1081 entity_type: m.entity.as_ref().map(|e| format!("{:?}", e.entity_type)),
1082 mention_type: None,
1083 }
1084 })
1085 .collect();
1086 CorefChain {
1087 cluster_id: Some(anno_core::CanonicalId::new(chain_id as u64)),
1088 mentions: coref_mentions,
1089 entity_type: None,
1090 }
1091 })
1092 .collect()
1093 }
1094
1095 pub fn config(&self) -> &JointConfig {
1097 &self.config
1098 }
1099
1100 pub fn extract_entities_from_mentions(
1105 &self,
1106 text: &str,
1107 mentions: &[JointMention],
1108 ) -> Result<Vec<Entity>> {
1109 let entities: Vec<Entity> = mentions.iter().filter_map(|m| m.entity.clone()).collect();
1110
1111 let result = self.analyze(text, &entities)?;
1112 Ok(result.entities)
1113 }
1114}
1115
1116impl crate::Model for JointModel {
1125 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
1126 let initial_entities = detect_mentions_heuristic(text);
1129 let result = self.analyze(text, &initial_entities)?;
1130 Ok(result.entities)
1131 }
1132
1133 fn supported_types(&self) -> Vec<EntityType> {
1134 self.config.entity_types.clone()
1135 }
1136
1137 fn is_available(&self) -> bool {
1138 true
1139 }
1140
1141 fn name(&self) -> &'static str {
1142 "joint-model"
1143 }
1144
1145 fn description(&self) -> &'static str {
1146 "Joint Entity Analysis: NER + Coreference + Entity Linking (Durrett & Klein 2014)"
1147 }
1148}
1149
1150impl anno_core::CoreferenceResolver for JointModel {
1152 fn resolve(&self, entities: &[Entity]) -> Vec<Entity> {
1153 if entities.is_empty() {
1154 return vec![];
1155 }
1156
1157 let max_end = entities.iter().map(|e| e.end).max().unwrap_or(0);
1160 let text = " ".repeat(max_end + 1);
1161
1162 match self.analyze(&text, entities) {
1163 Ok(result) => {
1164 let mut resolved = entities.to_vec();
1166
1167 for chain in &result.chains {
1168 let cluster_id = chain.cluster_id.unwrap_or(anno_core::CanonicalId::ZERO);
1169 for mention in &chain.mentions {
1170 for entity in &mut resolved {
1172 if entity.start == mention.start && entity.end == mention.end {
1173 entity.canonical_id = Some(cluster_id);
1174 }
1175 }
1176 }
1177 }
1178
1179 let mut next_id = anno_core::CanonicalId::new(result.chains.len() as u64);
1181 for entity in &mut resolved {
1182 if entity.canonical_id.is_none() {
1183 entity.canonical_id = Some(next_id);
1184 next_id += 1;
1185 }
1186 }
1187
1188 resolved
1189 }
1190 Err(_) => entities.to_vec(),
1191 }
1192 }
1193
1194 fn name(&self) -> &'static str {
1195 "joint-model-coref"
1196 }
1197}
1198
1199#[derive(Clone, Default)]
1220pub struct JointModelBuilder {
1221 config: JointConfig,
1222 knowledge_store: Option<Arc<WikipediaKnowledgeStore>>,
1223 ner_provider: Option<Arc<dyn NerScoreProvider>>,
1224 coref_provider: Option<Arc<dyn CorefScoreProvider>>,
1225 link_provider: Option<Arc<dyn LinkScoreProvider>>,
1226}
1227
1228impl JointModelBuilder {
1229 pub fn new() -> Self {
1231 Self::default()
1232 }
1233
1234 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
1236 self.config.max_iterations = max_iterations;
1237 self
1238 }
1239
1240 pub fn with_convergence_threshold(mut self, threshold: f64) -> Self {
1242 self.config.convergence_threshold = threshold;
1243 self
1244 }
1245
1246 pub fn with_pruning_threshold(mut self, threshold: f64) -> Self {
1248 self.config.pruning_threshold = threshold;
1249 self
1250 }
1251
1252 pub fn with_max_antecedent_candidates(mut self, max: usize) -> Self {
1254 self.config.max_antecedent_candidates = max;
1255 self
1256 }
1257
1258 pub fn with_max_link_candidates(mut self, max: usize) -> Self {
1260 self.config.max_link_candidates = max;
1261 self
1262 }
1263
1264 pub fn enable_link_ner(mut self, enable: bool) -> Self {
1266 self.config.enable_link_ner = enable;
1267 self
1268 }
1269
1270 pub fn enable_coref_ner(mut self, enable: bool) -> Self {
1272 self.config.enable_coref_ner = enable;
1273 self
1274 }
1275
1276 pub fn enable_coref_link(mut self, enable: bool) -> Self {
1278 self.config.enable_coref_link = enable;
1279 self
1280 }
1281
1282 pub fn with_entity_types(mut self, types: Vec<EntityType>) -> Self {
1284 self.config.entity_types = types;
1285 self
1286 }
1287
1288 pub fn with_knowledge(mut self, store: Arc<WikipediaKnowledgeStore>) -> Self {
1290 self.knowledge_store = Some(store);
1291 self
1292 }
1293
1294 pub fn with_ner_provider(mut self, provider: Arc<dyn NerScoreProvider>) -> Self {
1296 self.ner_provider = Some(provider);
1297 self
1298 }
1299
1300 pub fn with_coref_provider(mut self, provider: Arc<dyn CorefScoreProvider>) -> Self {
1302 self.coref_provider = Some(provider);
1303 self
1304 }
1305
1306 pub fn with_link_provider(mut self, provider: Arc<dyn LinkScoreProvider>) -> Self {
1308 self.link_provider = Some(provider);
1309 self
1310 }
1311
1312 pub fn build(self) -> Result<JointModel> {
1314 let mut model = JointModel::new(self.config)?;
1315 if let Some(store) = self.knowledge_store {
1316 model = model.with_knowledge(store);
1317 }
1318 if let Some(ner) = self.ner_provider {
1319 model = model.with_ner_provider(ner);
1320 }
1321 if let Some(coref) = self.coref_provider {
1322 model = model.with_coref_provider(coref);
1323 }
1324 if let Some(link) = self.link_provider {
1325 model = model.with_link_provider(link);
1326 }
1327 Ok(model)
1328 }
1329}
1330
1331pub trait NerScoreProvider: Send + Sync {
1339 fn type_scores(&self, mention: &JointMention, text: &str) -> Vec<(EntityType, f64)>;
1343}
1344
1345pub trait CorefScoreProvider: Send + Sync {
1349 fn antecedent_scores(
1353 &self,
1354 mention: &JointMention,
1355 candidates: &[&JointMention],
1356 text: &str,
1357 ) -> Vec<(AntecedentValue, f64)>;
1358}
1359
1360pub trait LinkScoreProvider: Send + Sync {
1364 fn link_candidates(&self, mention: &JointMention, text: &str) -> Vec<(String, f64)>;
1368}
1369
1370fn detect_mentions_heuristic(text: &str) -> Vec<Entity> {
1379 let mut entities = Vec::new();
1380
1381 let mut in_name = false;
1384 let mut name_start_char = 0;
1385 let mut char_pos = 0;
1386
1387 let chars: Vec<char> = text.chars().collect();
1388
1389 for c in &chars {
1390 if c.is_whitespace() || c.is_ascii_punctuation() {
1391 if in_name {
1392 let name_text: String = chars[name_start_char..char_pos].iter().collect();
1394
1395 if name_text.chars().count() > 1 {
1396 entities.push(Entity::new(
1397 &name_text,
1398 EntityType::Other("MENTION".to_string()),
1399 name_start_char,
1400 char_pos,
1401 0.5,
1402 ));
1403 }
1404 in_name = false;
1405 }
1406 } else if c.is_uppercase() && !in_name {
1407 in_name = true;
1409 name_start_char = char_pos;
1410 }
1411
1412 char_pos += 1;
1413 }
1414
1415 if in_name {
1417 let name_text: String = chars[name_start_char..char_pos].iter().collect();
1418
1419 if name_text.chars().count() > 1 {
1420 entities.push(Entity::new(
1421 &name_text,
1422 EntityType::Other("MENTION".to_string()),
1423 name_start_char,
1424 char_pos,
1425 0.5,
1426 ));
1427 }
1428 }
1429
1430 entities
1431}
1432
1433#[cfg(test)]
1434mod tests {
1435 use super::*;
1436
1437 #[test]
1438 fn test_variable_id() {
1439 let id = VariableId {
1440 mention_idx: 0,
1441 var_type: VariableType::Antecedent,
1442 };
1443 assert_eq!(id.mention_idx, 0);
1444 }
1445
1446 #[test]
1447 fn test_assignment() {
1448 let mut assignment = Assignment::default();
1449 assignment.set_antecedent(1, AntecedentValue::Mention(0));
1450 assignment.set_type(0, EntityType::Person);
1451 assignment.set_link(0, LinkValue::KbId("Q42".to_string()));
1452
1453 assert_eq!(
1454 assignment.get_antecedent(1),
1455 Some(AntecedentValue::Mention(0))
1456 );
1457 assert_eq!(assignment.get_type(0), Some(EntityType::Person));
1458 assert_eq!(
1459 assignment.get_link(0),
1460 Some(&LinkValue::KbId("Q42".to_string()))
1461 );
1462 }
1463
1464 #[test]
1465 fn test_joint_config_default() {
1466 let config = JointConfig::default();
1467 assert!(config.enable_link_ner);
1468 assert!(config.enable_coref_ner);
1469 assert!(config.enable_coref_link);
1470 assert_eq!(config.max_iterations, 5);
1471 }
1472
1473 #[test]
1474 fn test_joint_model_creation() {
1475 let model = JointModel::new(JointConfig::default());
1476 assert!(model.is_ok());
1477 }
1478
1479 #[test]
1480 fn test_mention_kind_detection() {
1481 assert_eq!(MentionKind::from_text("he"), MentionKind::Pronominal);
1482 assert_eq!(MentionKind::from_text("She"), MentionKind::Pronominal);
1483 assert_eq!(MentionKind::from_text("Barack Obama"), MentionKind::Proper);
1484 assert_eq!(
1485 MentionKind::from_text("the president"),
1486 MentionKind::Nominal
1487 );
1488 }
1489
1490 #[test]
1491 fn test_joint_model_analyze_empty() {
1492 let model = JointModel::new(JointConfig::default()).unwrap();
1493 let result = model.analyze("Hello world", &[]).unwrap();
1494
1495 assert!(result.entities.is_empty());
1496 assert!(result.chains.is_empty());
1497 }
1498
1499 #[test]
1500 fn test_joint_model_analyze_single_entity() {
1501 let model = JointModel::new(JointConfig::default()).unwrap();
1502 let entities = vec![Entity::new("Obama", EntityType::Person, 0, 5, 0.9)];
1503
1504 let result = model.analyze("Obama was here.", &entities).unwrap();
1505 assert!(!result.entities.is_empty());
1506 }
1507
1508 #[test]
1509 fn test_coarse_pruner() {
1510 let pruner = CoarsePruner::default();
1511
1512 let mentions = vec![
1513 JointMention {
1514 idx: 0,
1515 text: "Barack Obama".to_string(),
1516 head: "Obama".to_string(),
1517 start: 0,
1518 end: 12,
1519 mention_kind: MentionKind::Proper,
1520 entity_type: Some(EntityType::Person),
1521 entity: None,
1522 },
1523 JointMention {
1524 idx: 1,
1525 text: "France".to_string(),
1526 head: "France".to_string(),
1527 start: 21,
1528 end: 27,
1529 mention_kind: MentionKind::Proper,
1530 entity_type: Some(EntityType::Location),
1531 entity: None,
1532 },
1533 JointMention {
1534 idx: 2,
1535 text: "Obama".to_string(),
1536 head: "Obama".to_string(),
1537 start: 40,
1538 end: 45,
1539 mention_kind: MentionKind::Proper,
1540 entity_type: Some(EntityType::Person),
1541 entity: None,
1542 },
1543 ];
1544
1545 let candidates = pruner.prune_candidates(2, &mentions);
1546 assert!(!candidates.is_empty());
1548 assert!(candidates.contains(&0));
1550 }
1551
1552 #[test]
1557 fn test_event_coref_relation_is_positive() {
1558 assert!(EventCorefRelation::Identity.is_positive());
1560 assert!(EventCorefRelation::ConceptInstance.is_positive());
1561 assert!(EventCorefRelation::WholeSubevent.is_positive());
1562 assert!(EventCorefRelation::SetMember.is_positive());
1563
1564 assert!(!EventCorefRelation::TopicallyRelated.is_positive());
1566 assert!(!EventCorefRelation::NotRelated.is_positive());
1567 assert!(!EventCorefRelation::CannotDecide.is_positive());
1568 }
1569
1570 #[test]
1571 fn test_event_coref_relation_to_binary() {
1572 assert!(EventCorefRelation::Identity.to_binary());
1574 assert!(!EventCorefRelation::ConceptInstance.to_binary());
1575 assert!(!EventCorefRelation::WholeSubevent.to_binary());
1576 assert!(!EventCorefRelation::SetMember.to_binary());
1577 assert!(!EventCorefRelation::NotRelated.to_binary());
1578 }
1579
1580 #[test]
1581 fn test_event_coref_relation_to_strict_binary() {
1582 assert!(EventCorefRelation::Identity.to_strict_binary());
1584 assert!(EventCorefRelation::ConceptInstance.to_strict_binary());
1585 assert!(EventCorefRelation::WholeSubevent.to_strict_binary());
1586 assert!(EventCorefRelation::SetMember.to_strict_binary());
1587 assert!(!EventCorefRelation::NotRelated.to_strict_binary());
1588 assert!(!EventCorefRelation::TopicallyRelated.to_strict_binary());
1589 }
1590
1591 #[test]
1592 fn test_decontextualized_mention() {
1593 let mention = DecontextualizedMention::new("it", "Apple Inc.", "doc001", 10, 12)
1594 .with_resolved("it", "Apple Inc.");
1595
1596 assert_eq!(mention.original_text, "it");
1597 assert_eq!(mention.decontextualized, "Apple Inc.");
1598 assert_eq!(mention.doc_id, "doc001");
1599 assert_eq!(mention.resolved_entities.len(), 1);
1600 assert_eq!(
1601 mention.resolved_entities[0],
1602 ("it".to_string(), "Apple Inc.".to_string())
1603 );
1604 }
1605
1606 #[test]
1607 fn test_event_mention() {
1608 let event = EventMention {
1609 id: "e001".to_string(),
1610 trigger: "announced".to_string(),
1611 event_type: Some("Communication".to_string()),
1612 doc_id: "doc001".to_string(),
1613 start: 15,
1614 end: 24,
1615 decontextualized: Some(DecontextualizedMention::new(
1616 "The company announced it yesterday",
1617 "Apple Inc. announced the new iPhone on March 15, 2024",
1618 "doc001",
1619 0,
1620 34,
1621 )),
1622 };
1623
1624 assert_eq!(event.id, "e001");
1625 assert_eq!(event.trigger, "announced");
1626 assert!(event.decontextualized.is_some());
1627 let decon = event.decontextualized.unwrap();
1628 assert!(decon.decontextualized.contains("Apple Inc."));
1629 }
1630
1631 #[test]
1636 fn test_model_trait_implementation() {
1637 use crate::Model;
1638
1639 let model = JointModel::default();
1640
1641 assert_eq!(model.name(), "joint-model");
1643 assert!(model.description().contains("Durrett"));
1644 assert!(model.is_available());
1645
1646 let types = model.supported_types();
1647 assert!(!types.is_empty());
1648 }
1649
1650 #[test]
1651 fn test_model_extract_entities_simple() {
1652 use crate::Model;
1653
1654 let model = JointModel::default();
1655
1656 let text = "John Smith visited New York";
1658 let entities = model.extract_entities(text, None).unwrap();
1659
1660 let _ = entities;
1662 }
1663
1664 #[test]
1665 fn test_coref_resolver_trait_implementation() {
1666 use anno_core::CoreferenceResolver;
1667
1668 let model = JointModel::default();
1669
1670 assert_eq!(model.name(), "joint-model-coref");
1672
1673 let empty_result = model.resolve(&[]);
1675 assert!(empty_result.is_empty());
1676 }
1677
1678 #[test]
1679 fn test_coref_resolver_assigns_canonical_ids() {
1680 use anno_core::CoreferenceResolver;
1681
1682 let model = JointModel::default();
1683
1684 let entities = vec![
1685 Entity::new("John", EntityType::Person, 0, 4, 0.9),
1686 Entity::new("he", EntityType::Person, 10, 12, 0.8),
1687 Entity::new("Microsoft", EntityType::Organization, 20, 29, 0.95),
1688 ];
1689
1690 let resolved = model.resolve(&entities);
1691
1692 assert_eq!(resolved.len(), 3);
1694 for entity in &resolved {
1695 assert!(entity.canonical_id.is_some());
1696 }
1697 }
1698
1699 #[test]
1700 fn test_builder_default() {
1701 let model = JointModelBuilder::new().build().unwrap();
1702
1703 let config = model.config();
1705 assert_eq!(config.max_iterations, 5); assert!(config.enable_link_ner);
1707 assert!(config.enable_coref_ner);
1708 assert!(config.enable_coref_link);
1709 }
1710
1711 #[test]
1712 fn test_builder_fluent_api() {
1713 let model = JointModelBuilder::new()
1714 .with_max_iterations(50)
1715 .with_convergence_threshold(1e-6)
1716 .with_pruning_threshold(0.5)
1717 .with_max_antecedent_candidates(100)
1718 .with_max_link_candidates(20)
1719 .enable_link_ner(false)
1720 .enable_coref_ner(true)
1721 .enable_coref_link(false)
1722 .build()
1723 .unwrap();
1724
1725 let config = model.config();
1726 assert_eq!(config.max_iterations, 50);
1727 assert!((config.convergence_threshold - 1e-6).abs() < 1e-10);
1728 assert!((config.pruning_threshold - 0.5).abs() < 1e-10);
1729 assert_eq!(config.max_antecedent_candidates, 100);
1730 assert_eq!(config.max_link_candidates, 20);
1731 assert!(!config.enable_link_ner);
1732 assert!(config.enable_coref_ner);
1733 assert!(!config.enable_coref_link);
1734 }
1735
1736 #[test]
1737 fn test_builder_with_entity_types() {
1738 let custom_types = vec![EntityType::Person, EntityType::Organization];
1739
1740 let model = JointModelBuilder::new()
1741 .with_entity_types(custom_types.clone())
1742 .build()
1743 .unwrap();
1744
1745 assert_eq!(model.config().entity_types, custom_types);
1746 }
1747
1748 #[test]
1749 fn test_heuristic_mention_detection() {
1750 let text = "Barack Obama met Angela Merkel in Berlin";
1752 let entities = detect_mentions_heuristic(text);
1753
1754 assert!(!entities.is_empty());
1757
1758 for entity in &entities {
1760 assert!(entity.start < entity.end);
1761 assert!(entity.end <= text.chars().count());
1762 }
1763 }
1764
1765 #[test]
1766 fn test_heuristic_mention_detection_unicode() {
1767 let text = "François Müller visited München";
1769 let entities = detect_mentions_heuristic(text);
1770
1771 for entity in &entities {
1773 assert!(entity.start <= entity.end);
1774 let char_count = text.chars().count();
1775 assert!(entity.end <= char_count);
1776 }
1777 }
1778
1779 #[test]
1780 fn test_extract_entities_from_mentions() {
1781 let model = JointModel::default();
1782
1783 let text = "John Smith visited New York. He liked the city.";
1784 let mentions = vec![
1785 JointMention::from_entity(
1786 0,
1787 &Entity::new("John Smith", EntityType::Person, 0, 10, 0.9),
1788 text,
1789 ),
1790 JointMention::from_entity(
1791 1,
1792 &Entity::new("New York", EntityType::Location, 19, 27, 0.85),
1793 text,
1794 ),
1795 JointMention::from_entity(2, &Entity::new("He", EntityType::Person, 29, 31, 0.7), text),
1796 ];
1797
1798 let result = model.extract_entities_from_mentions(text, &mentions);
1799 assert!(result.is_ok());
1800
1801 let entities = result.unwrap();
1802 assert!(!entities.is_empty());
1804 }
1805}