use super::Entity;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
pub use super::types::MentionType;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct Mention {
pub text: String,
pub start: usize,
pub end: usize,
pub head_start: Option<usize>,
pub head_end: Option<usize>,
pub entity_type: Option<String>,
pub mention_type: Option<MentionType>,
}
impl Mention {
#[must_use]
pub fn new(text: impl Into<String>, start: usize, end: usize) -> Self {
Self {
text: text.into(),
start,
end,
head_start: None,
head_end: None,
entity_type: None,
mention_type: None,
}
}
#[must_use]
pub fn with_head(
text: impl Into<String>,
start: usize,
end: usize,
head_start: usize,
head_end: usize,
) -> Self {
Self {
text: text.into(),
start,
end,
head_start: Some(head_start),
head_end: Some(head_end),
entity_type: None,
mention_type: None,
}
}
#[must_use]
pub fn with_type(
text: impl Into<String>,
start: usize,
end: usize,
mention_type: MentionType,
) -> Self {
Self {
text: text.into(),
start,
end,
head_start: None,
head_end: None,
entity_type: None,
mention_type: Some(mention_type),
}
}
#[must_use]
pub fn overlaps(&self, other: &Mention) -> bool {
self.start < other.end && other.start < self.end
}
#[must_use]
pub fn span_matches(&self, other: &Mention) -> bool {
self.start == other.start && self.end == other.end
}
#[must_use]
pub fn len(&self) -> usize {
self.end.saturating_sub(self.start)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn span_id(&self) -> (usize, usize) {
(self.start, self.end)
}
#[must_use]
pub fn span_id_head(&self) -> (usize, usize) {
match (self.head_start, self.head_end) {
(Some(hs), Some(he)) => (hs, he),
_ => (self.start, self.end),
}
}
}
impl std::fmt::Display for Mention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "\"{}\" [{}-{})", self.text, self.start, self.end)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CorefChain {
pub mentions: Vec<Mention>,
pub cluster_id: Option<super::types::CanonicalId>,
pub entity_type: Option<String>,
}
impl CorefChain {
#[must_use]
pub fn new(mut mentions: Vec<Mention>) -> Self {
mentions.sort_by_key(|m| (m.start, m.end));
Self {
mentions,
cluster_id: None,
entity_type: None,
}
}
#[must_use]
pub fn with_id(
mut mentions: Vec<Mention>,
cluster_id: impl Into<super::types::CanonicalId>,
) -> Self {
mentions.sort_by_key(|m| (m.start, m.end));
Self {
mentions,
cluster_id: Some(cluster_id.into()),
entity_type: None,
}
}
#[must_use]
pub fn singleton(mention: Mention) -> Self {
Self {
mentions: vec![mention],
cluster_id: None,
entity_type: None,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.mentions.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.mentions.is_empty()
}
#[must_use]
pub fn is_singleton(&self) -> bool {
self.mentions.len() == 1
}
#[must_use]
pub fn links(&self) -> Vec<(&Mention, &Mention)> {
let mut links = Vec::new();
for i in 0..self.mentions.len() {
for j in (i + 1)..self.mentions.len() {
links.push((&self.mentions[i], &self.mentions[j]));
}
}
links
}
#[must_use]
pub fn link_count(&self) -> usize {
if self.mentions.len() <= 1 {
0
} else {
self.mentions.len() - 1
}
}
#[must_use]
pub fn all_pairs(&self) -> Vec<(&Mention, &Mention)> {
self.links() }
#[must_use]
pub fn contains_span(&self, start: usize, end: usize) -> bool {
self.mentions
.iter()
.any(|m| m.start == start && m.end == end)
}
#[must_use]
pub fn first(&self) -> Option<&Mention> {
self.mentions.first()
}
#[must_use]
pub fn mention_spans(&self) -> HashSet<(usize, usize)> {
self.mentions.iter().map(|m| m.span_id()).collect()
}
#[must_use]
pub fn canonical_mention(&self) -> Option<&Mention> {
let proper = self
.mentions
.iter()
.filter(|m| m.mention_type == Some(MentionType::Proper))
.max_by_key(|m| m.text.len());
if proper.is_some() {
return proper;
}
self.mentions.iter().max_by_key(|m| m.text.len())
}
#[must_use]
pub fn canonical_id(&self) -> Option<super::types::CanonicalId> {
self.cluster_id
}
}
impl std::fmt::Display for CorefChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mentions: Vec<String> = self
.mentions
.iter()
.map(|m| format!("\"{}\"", m.text))
.collect();
write!(f, "[{}]", mentions.join(", "))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorefDocument {
pub text: String,
pub doc_id: Option<String>,
pub chains: Vec<CorefChain>,
pub includes_singletons: bool,
}
impl CorefDocument {
#[must_use]
pub fn new(text: impl Into<String>, chains: Vec<CorefChain>) -> Self {
Self {
text: text.into(),
doc_id: None,
chains,
includes_singletons: false,
}
}
#[must_use]
pub fn with_id(
text: impl Into<String>,
doc_id: impl Into<String>,
chains: Vec<CorefChain>,
) -> Self {
Self {
text: text.into(),
doc_id: Some(doc_id.into()),
chains,
includes_singletons: false,
}
}
#[must_use]
pub fn mention_count(&self) -> usize {
self.chains.iter().map(|c| c.len()).sum()
}
#[must_use]
pub fn chain_count(&self) -> usize {
self.chains.len()
}
#[must_use]
pub fn non_singleton_count(&self) -> usize {
self.chains.iter().filter(|c| !c.is_singleton()).count()
}
#[must_use]
pub fn all_mentions(&self) -> Vec<&Mention> {
let mut mentions: Vec<&Mention> = self.chains.iter().flat_map(|c| &c.mentions).collect();
mentions.sort_by_key(|m| (m.start, m.end));
mentions
}
#[must_use]
pub fn find_chain(&self, start: usize, end: usize) -> Option<&CorefChain> {
self.chains.iter().find(|c| c.contains_span(start, end))
}
#[must_use]
pub fn mention_to_chain_index(&self) -> HashMap<(usize, usize), usize> {
let mut index = HashMap::new();
for (chain_idx, chain) in self.chains.iter().enumerate() {
for mention in &chain.mentions {
index.insert(mention.span_id(), chain_idx);
}
}
index
}
#[must_use]
pub fn without_singletons(&self) -> Self {
Self {
text: self.text.clone(),
doc_id: self.doc_id.clone(),
chains: self
.chains
.iter()
.filter(|c| !c.is_singleton())
.cloned()
.collect(),
includes_singletons: false,
}
}
}
impl From<&Entity> for Mention {
fn from(entity: &Entity) -> Self {
Self {
text: entity.text.clone(),
start: entity.start(),
end: entity.end(),
head_start: None,
head_end: None,
entity_type: Some(entity.entity_type.as_label().to_string()),
mention_type: entity.mention_type,
}
}
}
#[must_use]
pub fn entities_to_chains(entities: &[Entity]) -> Vec<CorefChain> {
let mut clusters: HashMap<u64, Vec<Mention>> = HashMap::new();
let mut singletons: Vec<Mention> = Vec::new();
for entity in entities {
let mention = Mention::from(entity);
if let Some(canonical_id) = entity.canonical_id {
clusters
.entry(canonical_id.get())
.or_default()
.push(mention);
} else {
singletons.push(mention);
}
}
let mut chains: Vec<CorefChain> = clusters
.into_iter()
.map(|(id, mentions)| CorefChain::with_id(mentions, id))
.collect();
for mention in singletons {
chains.push(CorefChain::singleton(mention));
}
chains
}
pub trait CoreferenceResolver: Send + Sync {
fn resolve(&self, entities: &[Entity]) -> Vec<Entity>;
fn resolve_to_chains(&self, entities: &[Entity]) -> Vec<CorefChain> {
let resolved = self.resolve(entities);
entities_to_chains(&resolved)
}
fn name(&self) -> &'static str;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mention_creation() {
let m = Mention::new("John", 0, 4);
assert_eq!(m.text, "John");
assert_eq!(m.start, 0);
assert_eq!(m.end, 4);
assert_eq!(m.len(), 4);
}
#[test]
fn test_mention_overlap() {
let m1 = Mention::new("John Smith", 0, 10);
let m2 = Mention::new("Smith", 5, 10);
let m3 = Mention::new("works", 11, 16);
assert!(m1.overlaps(&m2));
assert!(!m1.overlaps(&m3));
assert!(!m2.overlaps(&m3));
}
#[test]
fn test_chain_creation() {
let mentions = vec![
Mention::new("John", 0, 4),
Mention::new("he", 20, 22),
Mention::new("him", 40, 43),
];
let chain = CorefChain::new(mentions);
assert_eq!(chain.len(), 3);
assert!(!chain.is_singleton());
assert_eq!(chain.link_count(), 2); }
#[test]
fn test_chain_links() {
let mentions = vec![
Mention::new("a", 0, 1),
Mention::new("b", 2, 3),
Mention::new("c", 4, 5),
];
let chain = CorefChain::new(mentions);
assert_eq!(chain.all_pairs().len(), 3);
}
#[test]
fn test_singleton_chain() {
let m = Mention::new("entity", 0, 6);
let chain = CorefChain::singleton(m);
assert!(chain.is_singleton());
assert_eq!(chain.link_count(), 0);
assert!(chain.all_pairs().is_empty());
}
#[test]
fn test_document() {
let text = "John went to the store. He bought milk.";
let chain = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("He", 24, 26)]);
let doc = CorefDocument::new(text, vec![chain]);
assert_eq!(doc.mention_count(), 2);
assert_eq!(doc.chain_count(), 1);
assert_eq!(doc.non_singleton_count(), 1);
}
#[test]
fn test_mention_to_chain_index() {
let chain1 = CorefChain::new(vec![Mention::new("John", 0, 4), Mention::new("he", 20, 22)]);
let chain2 = CorefChain::new(vec![
Mention::new("Mary", 5, 9),
Mention::new("she", 30, 33),
]);
let doc = CorefDocument::new("text", vec![chain1, chain2]);
let index = doc.mention_to_chain_index();
assert_eq!(index.get(&(0, 4)), Some(&0));
assert_eq!(index.get(&(20, 22)), Some(&0));
assert_eq!(index.get(&(5, 9)), Some(&1));
assert_eq!(index.get(&(30, 33)), Some(&1));
}
#[test]
fn test_unicode_mention_offsets() {
let m = Mention::new("北京", 0, 2); assert_eq!(m.len(), 2);
assert_eq!(m.span_id(), (0, 2));
assert!(!m.is_empty());
}
#[test]
fn test_zero_length_mention() {
let m = Mention::new("", 5, 5);
assert!(m.is_empty());
assert_eq!(m.len(), 0);
assert_eq!(m.span_id(), (5, 5));
}
#[test]
fn test_empty_chain() {
let chain = CorefChain::new(vec![]);
assert!(chain.is_empty());
assert_eq!(chain.link_count(), 0);
assert!(chain.all_pairs().is_empty());
assert!(chain.first().is_none());
assert!(chain.canonical_mention().is_none());
}
#[test]
fn test_chain_sorting_out_of_order() {
let chain = CorefChain::new(vec![
Mention::new("c", 20, 21),
Mention::new("a", 0, 1),
Mention::new("b", 10, 11),
]);
assert_eq!(chain.mentions[0].text, "a");
assert_eq!(chain.mentions[1].text, "b");
assert_eq!(chain.mentions[2].text, "c");
}
#[test]
fn test_chain_sorting_ties_broken_by_end() {
let chain = CorefChain::new(vec![
Mention::new("John Smith", 0, 10),
Mention::new("John", 0, 4),
]);
assert_eq!(chain.mentions[0].text, "John");
assert_eq!(chain.mentions[1].text, "John Smith");
}
#[test]
fn test_entities_to_chains_grouped() {
use super::super::entity::EntityType;
use super::super::types::CanonicalId;
let e1 = super::super::Entity::new("John", EntityType::Person, 0, 4, 0.9)
.with_canonical_id(1_u64);
let e2 = super::super::Entity::new("he", EntityType::Person, 20, 22, 0.8)
.with_canonical_id(1_u64);
let e3 = super::super::Entity::new("Mary", EntityType::Person, 5, 9, 0.95)
.with_canonical_id(2_u64);
let chains = entities_to_chains(&[e1, e2, e3]);
assert_eq!(chains.len(), 2);
let chain1 = chains
.iter()
.find(|c| c.cluster_id == Some(CanonicalId::new(1)))
.expect("chain with id=1");
assert_eq!(chain1.len(), 2);
let chain2 = chains
.iter()
.find(|c| c.cluster_id == Some(CanonicalId::new(2)))
.expect("chain with id=2");
assert_eq!(chain2.len(), 1);
}
#[test]
fn test_entities_to_chains_singletons() {
use super::super::entity::EntityType;
let e1 = super::super::Entity::new("Paris", EntityType::Location, 0, 5, 0.9);
let e2 = super::super::Entity::new("London", EntityType::Location, 10, 16, 0.85);
let chains = entities_to_chains(&[e1, e2]);
assert_eq!(chains.len(), 2);
assert!(chains.iter().all(|c| c.is_singleton()));
}
#[test]
fn test_entities_to_chains_empty() {
let chains = entities_to_chains(&[]);
assert!(chains.is_empty());
}
#[test]
fn test_without_singletons_filters() {
let singleton = CorefChain::singleton(Mention::new("solo", 0, 4));
let multi = CorefChain::new(vec![
Mention::new("John", 10, 14),
Mention::new("he", 20, 22),
]);
let doc = CorefDocument::new("text", vec![singleton, multi]);
let filtered = doc.without_singletons();
assert_eq!(filtered.chain_count(), 1);
assert_eq!(filtered.chains[0].len(), 2);
assert!(!filtered.includes_singletons);
}
#[test]
fn test_without_singletons_preserves_non_singletons() {
let c1 = CorefChain::new(vec![Mention::new("a", 0, 1), Mention::new("b", 2, 3)]);
let c2 = CorefChain::new(vec![
Mention::new("x", 10, 11),
Mention::new("y", 12, 13),
Mention::new("z", 14, 15),
]);
let doc = CorefDocument::new("text", vec![c1.clone(), c2.clone()]);
let filtered = doc.without_singletons();
assert_eq!(filtered.chain_count(), 2);
}
#[test]
fn test_without_singletons_all_singletons() {
let s1 = CorefChain::singleton(Mention::new("a", 0, 1));
let s2 = CorefChain::singleton(Mention::new("b", 2, 3));
let doc = CorefDocument::new("text", vec![s1, s2]);
let filtered = doc.without_singletons();
assert!(filtered.chains.is_empty());
}
#[test]
fn test_overlaps_adjacent_non_overlapping() {
let m1 = Mention::new("hello", 0, 5);
let m2 = Mention::new("world", 5, 10);
assert!(!m1.overlaps(&m2));
assert!(!m2.overlaps(&m1));
}
#[test]
fn test_overlaps_nested() {
let outer = Mention::new("the big dog", 0, 10);
let inner = Mention::new("big", 2, 5);
assert!(outer.overlaps(&inner));
assert!(inner.overlaps(&outer));
}
#[test]
fn test_chain_with_id() {
let chain = CorefChain::with_id(
vec![Mention::new("John", 0, 4), Mention::new("he", 10, 12)],
42_u64,
);
assert_eq!(
chain.canonical_id(),
Some(super::super::types::CanonicalId::new(42))
);
assert_eq!(
chain.cluster_id,
Some(super::super::types::CanonicalId::new(42))
);
assert_eq!(chain.mentions[0].text, "John");
}
}
#[cfg(test)]
mod proptests {
#![allow(clippy::unwrap_used)]
use super::*;
use proptest::prelude::*;
fn arb_mention(max_offset: usize) -> impl Strategy<Value = Mention> {
(0usize..max_offset, 1usize..500)
.prop_map(|(start, len)| Mention::new(format!("m_{}", start), start, start + len))
}
proptest! {
#[test]
fn mention_ordering_after_chain_construction(
mentions in proptest::collection::vec(arb_mention(10000), 1..20),
) {
let chain = CorefChain::new(mentions);
for w in chain.mentions.windows(2) {
prop_assert!(
(w[0].start, w[0].end) <= (w[1].start, w[1].end),
"mentions must be sorted by (start, end): ({},{}) vs ({},{})",
w[0].start, w[0].end, w[1].start, w[1].end
);
}
}
#[test]
fn coref_chain_non_empty(
mentions in proptest::collection::vec(arb_mention(10000), 1..20),
) {
let n = mentions.len();
let chain = CorefChain::new(mentions);
prop_assert!(!chain.is_empty());
prop_assert_eq!(chain.len(), n);
}
#[test]
fn coref_chain_singleton_has_one(start in 0usize..10000, len in 1usize..500) {
let m = Mention::new("x", start, start + len);
let chain = CorefChain::singleton(m);
prop_assert!(chain.is_singleton());
prop_assert_eq!(chain.len(), 1);
prop_assert_eq!(chain.link_count(), 0);
}
#[test]
fn mention_overlap_symmetric(
s1 in 0usize..10000, len1 in 1usize..500,
s2 in 0usize..10000, len2 in 1usize..500,
) {
let m1 = Mention::new("a", s1, s1 + len1);
let m2 = Mention::new("b", s2, s2 + len2);
prop_assert_eq!(m1.overlaps(&m2), m2.overlaps(&m1));
}
#[test]
fn mention_serde_roundtrip(
start in 0usize..10000, len in 1usize..500,
) {
let m = Mention::new(format!("mention_{}", start), start, start + len);
let json = serde_json::to_string(&m).unwrap();
let m2: Mention = serde_json::from_str(&json).unwrap();
prop_assert_eq!(&m, &m2);
}
}
}