1use super::Entity;
28use serde::{Deserialize, Serialize};
29use std::collections::{HashMap, HashSet};
30
31pub use super::types::MentionType;
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub struct Mention {
61 pub text: String,
63 pub start: usize,
65 pub end: usize,
67 pub head_start: Option<usize>,
69 pub head_end: Option<usize>,
71 pub entity_type: Option<String>,
73 pub mention_type: Option<MentionType>,
75}
76
77impl Mention {
78 #[must_use]
91 pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
92 Self {
93 text: text.into(),
94 start,
95 end,
96 head_start: None,
97 head_end: None,
98 entity_type: None,
99 mention_type: None,
100 }
101 }
102
103 #[must_use]
113 pub fn with_head(
114 text: impl Into<String>,
115 start: usize,
116 end: usize,
117 head_start: usize,
118 head_end: usize,
119 ) -> Self {
120 Self {
121 text: text.into(),
122 start,
123 end,
124 head_start: Some(head_start),
125 head_end: Some(head_end),
126 entity_type: None,
127 mention_type: None,
128 }
129 }
130
131 #[must_use]
140 pub fn with_type(
141 text: impl Into<String>,
142 start: usize,
143 end: usize,
144 mention_type: MentionType,
145 ) -> Self {
146 Self {
147 text: text.into(),
148 start,
149 end,
150 head_start: None,
151 head_end: None,
152 entity_type: None,
153 mention_type: Some(mention_type),
154 }
155 }
156
157 #[must_use]
159 pub fn overlaps(&self, other: &Mention) -> bool {
160 self.start < other.end && other.start < self.end
161 }
162
163 #[must_use]
165 pub fn span_matches(&self, other: &Mention) -> bool {
166 self.start == other.start && self.end == other.end
167 }
168
169 #[must_use]
171 pub fn len(&self) -> usize {
172 self.end.saturating_sub(self.start)
173 }
174
175 #[must_use]
177 pub fn is_empty(&self) -> bool {
178 self.len() == 0
179 }
180
181 #[must_use]
183 pub fn span_id(&self) -> (usize, usize) {
184 (self.start, self.end)
185 }
186}
187
188impl std::fmt::Display for Mention {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 write!(f, "\"{}\" [{}-{})", self.text, self.start, self.end)
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
217pub struct CorefChain {
218 pub mentions: Vec<Mention>,
220 pub cluster_id: Option<super::types::CanonicalId>,
222 pub entity_type: Option<String>,
224}
225
226impl CorefChain {
227 #[must_use]
238 pub fn new(mut mentions: Vec<Mention>) -> Self {
239 mentions.sort_by_key(|m| (m.start, m.end));
240 Self {
241 mentions,
242 cluster_id: None,
243 entity_type: None,
244 }
245 }
246
247 #[must_use]
249 pub fn with_id(
250 mut mentions: Vec<Mention>,
251 cluster_id: impl Into<super::types::CanonicalId>,
252 ) -> Self {
253 mentions.sort_by_key(|m| (m.start, m.end));
254 Self {
255 mentions,
256 cluster_id: Some(cluster_id.into()),
257 entity_type: None,
258 }
259 }
260
261 #[must_use]
263 pub fn singleton(mention: Mention) -> Self {
264 Self {
265 mentions: vec![mention],
266 cluster_id: None,
267 entity_type: None,
268 }
269 }
270
271 #[must_use]
273 pub fn len(&self) -> usize {
274 self.mentions.len()
275 }
276
277 #[must_use]
279 pub fn is_empty(&self) -> bool {
280 self.mentions.is_empty()
281 }
282
283 #[must_use]
285 pub fn is_singleton(&self) -> bool {
286 self.mentions.len() == 1
287 }
288
289 #[must_use]
301 pub fn links(&self) -> Vec<(&Mention, &Mention)> {
302 let mut links = Vec::new();
303 for i in 0..self.mentions.len() {
304 for j in (i + 1)..self.mentions.len() {
305 links.push((&self.mentions[i], &self.mentions[j]));
306 }
307 }
308 links
309 }
310
311 #[must_use]
316 pub fn link_count(&self) -> usize {
317 if self.mentions.len() <= 1 {
318 0
319 } else {
320 self.mentions.len() - 1
321 }
322 }
323
324 #[must_use]
326 pub fn all_pairs(&self) -> Vec<(&Mention, &Mention)> {
327 self.links() }
329
330 #[must_use]
332 pub fn contains_span(&self, start: usize, end: usize) -> bool {
333 self.mentions
334 .iter()
335 .any(|m| m.start == start && m.end == end)
336 }
337
338 #[must_use]
340 pub fn first(&self) -> Option<&Mention> {
341 self.mentions.first()
342 }
343
344 #[must_use]
346 pub fn mention_spans(&self) -> HashSet<(usize, usize)> {
347 self.mentions.iter().map(|m| m.span_id()).collect()
348 }
349
350 #[must_use]
355 pub fn canonical_mention(&self) -> Option<&Mention> {
356 let proper = self
358 .mentions
359 .iter()
360 .filter(|m| m.mention_type == Some(MentionType::Proper))
361 .max_by_key(|m| m.text.len());
362
363 if proper.is_some() {
364 return proper;
365 }
366
367 self.mentions.iter().max_by_key(|m| m.text.len())
369 }
370
371 #[must_use]
373 pub fn canonical_id(&self) -> Option<super::types::CanonicalId> {
374 self.cluster_id
375 }
376}
377
378impl std::fmt::Display for CorefChain {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 let mentions: Vec<String> = self
381 .mentions
382 .iter()
383 .map(|m| format!("\"{}\"", m.text))
384 .collect();
385 write!(f, "[{}]", mentions.join(", "))
386 }
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
397pub struct CorefDocument {
398 pub text: String,
400 pub doc_id: Option<String>,
402 pub chains: Vec<CorefChain>,
404 pub includes_singletons: bool,
406}
407
408impl CorefDocument {
409 #[must_use]
423 pub fn new(text: impl Into<String>, chains: Vec<CorefChain>) -> Self {
424 Self {
425 text: text.into(),
426 doc_id: None,
427 chains,
428 includes_singletons: false,
429 }
430 }
431
432 #[must_use]
434 pub fn with_id(
435 text: impl Into<String>,
436 doc_id: impl Into<String>,
437 chains: Vec<CorefChain>,
438 ) -> Self {
439 Self {
440 text: text.into(),
441 doc_id: Some(doc_id.into()),
442 chains,
443 includes_singletons: false,
444 }
445 }
446
447 #[must_use]
449 pub fn mention_count(&self) -> usize {
450 self.chains.iter().map(|c| c.len()).sum()
451 }
452
453 #[must_use]
455 pub fn chain_count(&self) -> usize {
456 self.chains.len()
457 }
458
459 #[must_use]
461 pub fn non_singleton_count(&self) -> usize {
462 self.chains.iter().filter(|c| !c.is_singleton()).count()
463 }
464
465 #[must_use]
467 pub fn all_mentions(&self) -> Vec<&Mention> {
468 let mut mentions: Vec<&Mention> = self.chains.iter().flat_map(|c| &c.mentions).collect();
469 mentions.sort_by_key(|m| (m.start, m.end));
470 mentions
471 }
472
473 #[must_use]
475 pub fn find_chain(&self, start: usize, end: usize) -> Option<&CorefChain> {
476 self.chains.iter().find(|c| c.contains_span(start, end))
477 }
478
479 #[must_use]
481 pub fn mention_to_chain_index(&self) -> HashMap<(usize, usize), usize> {
482 let mut index = HashMap::new();
483 for (chain_idx, chain) in self.chains.iter().enumerate() {
484 for mention in &chain.mentions {
485 index.insert(mention.span_id(), chain_idx);
486 }
487 }
488 index
489 }
490
491 #[must_use]
493 pub fn without_singletons(&self) -> Self {
494 Self {
495 text: self.text.clone(),
496 doc_id: self.doc_id.clone(),
497 chains: self
498 .chains
499 .iter()
500 .filter(|c| !c.is_singleton())
501 .cloned()
502 .collect(),
503 includes_singletons: false,
504 }
505 }
506}
507
508impl From<&Entity> for Mention {
513 fn from(entity: &Entity) -> Self {
514 Self {
515 text: entity.text.clone(),
516 start: entity.start,
517 end: entity.end,
518 head_start: None,
519 head_end: None,
520 entity_type: Some(entity.entity_type.as_label().to_string()),
521 mention_type: None,
522 }
523 }
524}
525
526#[must_use]
530pub fn entities_to_chains(entities: &[Entity]) -> Vec<CorefChain> {
531 let mut clusters: HashMap<u64, Vec<Mention>> = HashMap::new();
532 let mut singletons: Vec<Mention> = Vec::new();
533
534 for entity in entities {
535 let mention = Mention::from(entity);
536 if let Some(canonical_id) = entity.canonical_id {
537 clusters
538 .entry(canonical_id.get())
539 .or_default()
540 .push(mention);
541 } else {
542 singletons.push(mention);
543 }
544 }
545
546 let mut chains: Vec<CorefChain> = clusters
547 .into_iter()
548 .map(|(id, mentions)| CorefChain::with_id(mentions, id))
549 .collect();
550
551 for mention in singletons {
553 chains.push(CorefChain::singleton(mention));
554 }
555
556 chains
557}
558
559pub trait CoreferenceResolver: Send + Sync {
611 fn resolve(&self, entities: &[Entity]) -> Vec<Entity>;
623
624 fn resolve_to_chains(&self, entities: &[Entity]) -> Vec<CorefChain> {
630 let resolved = self.resolve(entities);
631 entities_to_chains(&resolved)
632 }
633
634 fn name(&self) -> &'static str;
638}
639
640#[cfg(test)]
645mod tests {
646 use super::*;
647
648 #[test]
649 fn test_mention_creation() {
650 let m = Mention::new("John", 0, 4);
651 assert_eq!(m.text, "John");
652 assert_eq!(m.start, 0);
653 assert_eq!(m.end, 4);
654 assert_eq!(m.len(), 4);
655 }
656
657 #[test]
658 fn test_mention_overlap() {
659 let m1 = Mention::new("John Smith", 0, 10);
660 let m2 = Mention::new("Smith", 5, 10);
661 let m3 = Mention::new("works", 11, 16);
662
663 assert!(m1.overlaps(&m2));
664 assert!(!m1.overlaps(&m3));
665 assert!(!m2.overlaps(&m3));
666 }
667
668 #[test]
669 fn test_chain_creation() {
670 let mentions = vec![
671 Mention::new("John", 0, 4),
672 Mention::new("he", 20, 22),
673 Mention::new("him", 40, 43),
674 ];
675 let chain = CorefChain::new(mentions);
676
677 assert_eq!(chain.len(), 3);
678 assert!(!chain.is_singleton());
679 assert_eq!(chain.link_count(), 2); }
681
682 #[test]
683 fn test_chain_links() {
684 let mentions = vec![
685 Mention::new("a", 0, 1),
686 Mention::new("b", 2, 3),
687 Mention::new("c", 4, 5),
688 ];
689 let chain = CorefChain::new(mentions);
690
691 assert_eq!(chain.all_pairs().len(), 3);
693 }
694
695 #[test]
696 fn test_singleton_chain() {
697 let m = Mention::new("entity", 0, 6);
698 let chain = CorefChain::singleton(m);
699
700 assert!(chain.is_singleton());
701 assert_eq!(chain.link_count(), 0);
702 assert!(chain.all_pairs().is_empty());
703 }
704
705 #[test]
706 fn test_document() {
707 let text = "John went to the store. He bought milk.";
708 let chain = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("He", 24, 26)]);
709 let doc = CorefDocument::new(text, vec![chain]);
710
711 assert_eq!(doc.mention_count(), 2);
712 assert_eq!(doc.chain_count(), 1);
713 assert_eq!(doc.non_singleton_count(), 1);
714 }
715
716 #[test]
717 fn test_mention_to_chain_index() {
718 let chain1 = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("he", 20, 22)]);
719 let chain2 = CorefChain::new(vec![
720 Mention::new("Mary", 5, 9),
721 Mention::new("she", 30, 33),
722 ]);
723 let doc = CorefDocument::new("text", vec![chain1, chain2]);
724
725 let index = doc.mention_to_chain_index();
726 assert_eq!(index.get(&(0, 4)), Some(&0));
727 assert_eq!(index.get(&(20, 22)), Some(&0));
728 assert_eq!(index.get(&(5, 9)), Some(&1));
729 assert_eq!(index.get(&(30, 33)), Some(&1));
730 }
731
732 #[test]
737 fn test_unicode_mention_offsets() {
738 let m = Mention::new("北京", 0, 2); assert_eq!(m.len(), 2);
742 assert_eq!(m.span_id(), (0, 2));
743 assert!(!m.is_empty());
744 }
745
746 #[test]
747 fn test_zero_length_mention() {
748 let m = Mention::new("", 5, 5);
750 assert!(m.is_empty());
751 assert_eq!(m.len(), 0);
752 assert_eq!(m.span_id(), (5, 5));
753 }
754
755 #[test]
756 fn test_empty_chain() {
757 let chain = CorefChain::new(vec![]);
758 assert!(chain.is_empty());
759 assert_eq!(chain.link_count(), 0);
760 assert!(chain.all_pairs().is_empty());
761 assert!(chain.first().is_none());
762 assert!(chain.canonical_mention().is_none());
763 }
764
765 #[test]
766 fn test_chain_sorting_out_of_order() {
767 let chain = CorefChain::new(vec![
769 Mention::new("c", 20, 21),
770 Mention::new("a", 0, 1),
771 Mention::new("b", 10, 11),
772 ]);
773 assert_eq!(chain.mentions[0].text, "a");
774 assert_eq!(chain.mentions[1].text, "b");
775 assert_eq!(chain.mentions[2].text, "c");
776 }
777
778 #[test]
779 fn test_chain_sorting_ties_broken_by_end() {
780 let chain = CorefChain::new(vec![
782 Mention::new("John Smith", 0, 10),
783 Mention::new("John", 0, 4),
784 ]);
785 assert_eq!(chain.mentions[0].text, "John");
786 assert_eq!(chain.mentions[1].text, "John Smith");
787 }
788
789 #[test]
790 fn test_entities_to_chains_grouped() {
791 use super::super::entity::EntityType;
792 use super::super::types::CanonicalId;
793
794 let e1 = super::super::Entity::new("John", EntityType::Person, 0, 4, 0.9)
795 .with_canonical_id(1_u64);
796 let e2 = super::super::Entity::new("he", EntityType::Person, 20, 22, 0.8)
797 .with_canonical_id(1_u64);
798 let e3 = super::super::Entity::new("Mary", EntityType::Person, 5, 9, 0.95)
799 .with_canonical_id(2_u64);
800
801 let chains = entities_to_chains(&[e1, e2, e3]);
802
803 assert_eq!(chains.len(), 2);
805
806 let chain1 = chains
808 .iter()
809 .find(|c| c.cluster_id == Some(CanonicalId::new(1)))
810 .expect("chain with id=1");
811 assert_eq!(chain1.len(), 2);
812
813 let chain2 = chains
815 .iter()
816 .find(|c| c.cluster_id == Some(CanonicalId::new(2)))
817 .expect("chain with id=2");
818 assert_eq!(chain2.len(), 1);
819 }
820
821 #[test]
822 fn test_entities_to_chains_singletons() {
823 use super::super::entity::EntityType;
824
825 let e1 = super::super::Entity::new("Paris", EntityType::Location, 0, 5, 0.9);
827 let e2 = super::super::Entity::new("London", EntityType::Location, 10, 16, 0.85);
828
829 let chains = entities_to_chains(&[e1, e2]);
830 assert_eq!(chains.len(), 2);
831 assert!(chains.iter().all(|c| c.is_singleton()));
832 }
833
834 #[test]
835 fn test_entities_to_chains_empty() {
836 let chains = entities_to_chains(&[]);
837 assert!(chains.is_empty());
838 }
839
840 #[test]
841 fn test_without_singletons_filters() {
842 let singleton = CorefChain::singleton(Mention::new("solo", 0, 4));
843 let multi = CorefChain::new(vec![
844 Mention::new("John", 10, 14),
845 Mention::new("he", 20, 22),
846 ]);
847 let doc = CorefDocument::new("text", vec![singleton, multi]);
848
849 let filtered = doc.without_singletons();
850 assert_eq!(filtered.chain_count(), 1);
851 assert_eq!(filtered.chains[0].len(), 2);
852 assert!(!filtered.includes_singletons);
853 }
854
855 #[test]
856 fn test_without_singletons_preserves_non_singletons() {
857 let c1 = CorefChain::new(vec![Mention::new("a", 0, 1), Mention::new("b", 2, 3)]);
858 let c2 = CorefChain::new(vec![
859 Mention::new("x", 10, 11),
860 Mention::new("y", 12, 13),
861 Mention::new("z", 14, 15),
862 ]);
863 let doc = CorefDocument::new("text", vec![c1.clone(), c2.clone()]);
864
865 let filtered = doc.without_singletons();
866 assert_eq!(filtered.chain_count(), 2);
867 }
868
869 #[test]
870 fn test_without_singletons_all_singletons() {
871 let s1 = CorefChain::singleton(Mention::new("a", 0, 1));
872 let s2 = CorefChain::singleton(Mention::new("b", 2, 3));
873 let doc = CorefDocument::new("text", vec![s1, s2]);
874
875 let filtered = doc.without_singletons();
876 assert!(filtered.chains.is_empty());
877 }
878
879 #[test]
880 fn test_overlaps_adjacent_non_overlapping() {
881 let m1 = Mention::new("hello", 0, 5);
883 let m2 = Mention::new("world", 5, 10);
884 assert!(!m1.overlaps(&m2));
885 assert!(!m2.overlaps(&m1));
886 }
887
888 #[test]
889 fn test_overlaps_nested() {
890 let outer = Mention::new("the big dog", 0, 10);
892 let inner = Mention::new("big", 2, 5);
893 assert!(outer.overlaps(&inner));
894 assert!(inner.overlaps(&outer));
895 }
896
897 #[test]
898 fn test_chain_with_id() {
899 let chain = CorefChain::with_id(
900 vec![Mention::new("John", 0, 4), Mention::new("he", 10, 12)],
901 42_u64,
902 );
903 assert_eq!(
904 chain.canonical_id(),
905 Some(super::super::types::CanonicalId::new(42))
906 );
907 assert_eq!(
908 chain.cluster_id,
909 Some(super::super::types::CanonicalId::new(42))
910 );
911 assert_eq!(chain.mentions[0].text, "John");
913 }
914}
915
916#[cfg(test)]
917mod proptests {
918 #![allow(clippy::unwrap_used)]
919 use super::*;
920 use proptest::prelude::*;
921
922 fn arb_mention(max_offset: usize) -> impl Strategy<Value = Mention> {
924 (0usize..max_offset, 1usize..500)
925 .prop_map(|(start, len)| Mention::new(format!("m_{}", start), start, start + len))
926 }
927
928 proptest! {
929 #[test]
931 fn mention_ordering_after_chain_construction(
932 mentions in proptest::collection::vec(arb_mention(10000), 1..20),
933 ) {
934 let chain = CorefChain::new(mentions);
935 for w in chain.mentions.windows(2) {
936 prop_assert!(
937 (w[0].start, w[0].end) <= (w[1].start, w[1].end),
938 "mentions must be sorted by (start, end): ({},{}) vs ({},{})",
939 w[0].start, w[0].end, w[1].start, w[1].end
940 );
941 }
942 }
943
944 #[test]
946 fn coref_chain_non_empty(
947 mentions in proptest::collection::vec(arb_mention(10000), 1..20),
948 ) {
949 let n = mentions.len();
950 let chain = CorefChain::new(mentions);
951 prop_assert!(!chain.is_empty());
952 prop_assert_eq!(chain.len(), n);
953 }
954
955 #[test]
957 fn coref_chain_singleton_has_one(start in 0usize..10000, len in 1usize..500) {
958 let m = Mention::new("x", start, start + len);
959 let chain = CorefChain::singleton(m);
960 prop_assert!(chain.is_singleton());
961 prop_assert_eq!(chain.len(), 1);
962 prop_assert_eq!(chain.link_count(), 0);
963 }
964
965 #[test]
967 fn mention_overlap_symmetric(
968 s1 in 0usize..10000, len1 in 1usize..500,
969 s2 in 0usize..10000, len2 in 1usize..500,
970 ) {
971 let m1 = Mention::new("a", s1, s1 + len1);
972 let m2 = Mention::new("b", s2, s2 + len2);
973 prop_assert_eq!(m1.overlaps(&m2), m2.overlaps(&m1));
974 }
975
976 #[test]
978 fn mention_serde_roundtrip(
979 start in 0usize..10000, len in 1usize..500,
980 ) {
981 let m = Mention::new(format!("mention_{}", start), start, start + len);
982 let json = serde_json::to_string(&m).unwrap();
983 let m2: Mention = serde_json::from_str(&json).unwrap();
984 prop_assert_eq!(&m, &m2);
985 }
986 }
987}