use crate::eval::cdcr::{CrossDocCluster, Document};
use crate::eval::cluster_encoder::{ClusterEncoder, MergeScorer};
use crate::eval::coref::{CorefChain, Mention};
use crate::eval::coref_metrics::{conll_f1, CorefScores};
use crate::eval::neural_cluster_encoder::{
CrossContextConfig, UnifiedCrossContextResolver, WindowOutput,
};
use crate::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CrossContextBenchmark {
ECBPlus,
SciCo,
LitBank,
BookCoref,
AnimalFarm,
}
impl CrossContextBenchmark {
pub fn name(&self) -> &'static str {
match self {
Self::ECBPlus => "ECB+",
Self::SciCo => "SciCo",
Self::LitBank => "LitBank",
Self::BookCoref => "BookCoref",
Self::AnimalFarm => "Animal Farm",
}
}
pub fn is_cross_document(&self) -> bool {
matches!(self, Self::ECBPlus | Self::SciCo)
}
pub fn is_long_document(&self) -> bool {
matches!(self, Self::LitBank | Self::BookCoref | Self::AnimalFarm)
}
pub fn recommended_window_size(&self) -> usize {
match self {
Self::ECBPlus => 512, Self::SciCo => 512, Self::LitBank => 2000, Self::BookCoref => 4000, Self::AnimalFarm => 4000, }
}
pub fn xcore_sota_f1(&self) -> f64 {
match self {
Self::ECBPlus => 40.3,
Self::SciCo => 34.5,
Self::LitBank => 78.2,
Self::BookCoref => 65.0,
Self::AnimalFarm => 70.0,
}
}
pub fn all() -> &'static [Self] {
&[
Self::ECBPlus,
Self::SciCo,
Self::LitBank,
Self::BookCoref,
Self::AnimalFarm,
]
}
pub fn cross_document() -> &'static [Self] {
&[Self::ECBPlus, Self::SciCo]
}
pub fn long_document() -> &'static [Self] {
&[Self::LitBank, Self::BookCoref, Self::AnimalFarm]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossContextEvalConfig {
pub window_size: usize,
pub window_overlap: usize,
pub merge_threshold: f32,
pub use_gold_mentions: bool,
pub use_gold_clusters: bool,
pub max_docs_per_topic: usize,
pub seed: u64,
}
impl Default for CrossContextEvalConfig {
fn default() -> Self {
Self {
window_size: 4000,
window_overlap: 256,
merge_threshold: 0.5,
use_gold_mentions: false,
use_gold_clusters: false,
max_docs_per_topic: 0,
seed: 42,
}
}
}
impl CrossContextEvalConfig {
pub fn for_benchmark(benchmark: CrossContextBenchmark) -> Self {
Self {
window_size: benchmark.recommended_window_size(),
..Default::default()
}
}
pub fn oracle() -> Self {
Self {
use_gold_mentions: true,
use_gold_clusters: true,
..Default::default()
}
}
pub fn gold_clusters() -> Self {
Self {
use_gold_mentions: false,
use_gold_clusters: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossContextEvalResults {
pub benchmark: String,
pub config: CrossContextEvalConfig,
pub muc: CorefScores,
pub b_cubed: CorefScores,
pub ceaf_e: CorefScores,
pub lea: CorefScores,
pub conll_f1: f64,
pub num_contexts: usize,
pub num_gold_clusters: usize,
pub num_pred_clusters: usize,
pub avg_cluster_size: f64,
pub time_ms: f64,
pub per_topic: Option<HashMap<String, TopicResults>>,
pub per_document: Option<HashMap<String, DocumentResults>>,
}
impl CrossContextEvalResults {
pub fn summary(&self) -> String {
format!(
"{}: CoNLL F1 = {:.1}% (MUC: {:.1}, B³: {:.1}, CEAF: {:.1})",
self.benchmark,
self.conll_f1 * 100.0,
self.muc.f1 * 100.0,
self.b_cubed.f1 * 100.0,
self.ceaf_e.f1 * 100.0,
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopicResults {
pub topic_id: String,
pub num_documents: usize,
pub conll_f1: f64,
pub num_gold_clusters: usize,
pub num_pred_clusters: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentResults {
pub doc_id: String,
pub num_tokens: usize,
pub num_windows: usize,
pub conll_f1: f64,
pub num_gold_chains: usize,
pub num_pred_chains: usize,
}
#[derive(Debug, Clone)]
pub struct Topic {
pub id: String,
pub documents: Vec<Document>,
pub gold_clusters: Vec<CrossDocCluster>,
}
impl Topic {
pub fn new(id: &str) -> Self {
Self {
id: id.to_string(),
documents: Vec::new(),
gold_clusters: Vec::new(),
}
}
pub fn add_document(&mut self, doc: Document) {
self.documents.push(doc);
}
pub fn add_gold_cluster(&mut self, cluster: CrossDocCluster) {
self.gold_clusters.push(cluster);
}
}
#[derive(Debug, Clone)]
pub struct LongDocument {
pub id: String,
pub text: String,
pub gold_chains: Vec<CorefChain>,
pub windows: Option<Vec<WindowOutput>>,
}
impl LongDocument {
pub fn new(id: &str, text: &str, gold_chains: Vec<CorefChain>) -> Self {
Self {
id: id.to_string(),
text: text.to_string(),
gold_chains,
windows: None,
}
}
pub fn char_len(&self) -> usize {
self.text.chars().count()
}
pub fn approx_tokens(&self) -> usize {
self.text.len() / 5
}
}
pub fn evaluate_cross_document<E: ClusterEncoder + Clone, S: MergeScorer + Clone>(
topics: &[Topic],
encoder: E,
scorer: S,
config: &CrossContextEvalConfig,
) -> Result<CrossContextEvalResults> {
let start = std::time::Instant::now();
let resolver_config = CrossContextConfig {
window_size: config.window_size,
window_overlap: config.window_overlap,
merge_threshold: config.merge_threshold,
};
let resolver = UnifiedCrossContextResolver::new(encoder, scorer, resolver_config);
let mut all_gold_chains: Vec<CorefChain> = Vec::new();
let mut all_pred_chains: Vec<CorefChain> = Vec::new();
let mut per_topic = HashMap::new();
let mut total_gold_clusters = 0;
let mut total_pred_clusters = 0;
for topic in topics {
let topic_gold_chains =
cross_doc_clusters_to_chains(&topic.gold_clusters, &topic.documents);
total_gold_clusters += topic.gold_clusters.len();
let pred_clusters = resolver.resolve_documents(&topic.documents);
total_pred_clusters += pred_clusters.len();
let topic_pred_chains = cross_doc_clusters_to_chains(&pred_clusters, &topic.documents);
let topic_f1 = conll_f1(&topic_gold_chains, &topic_pred_chains);
per_topic.insert(
topic.id.clone(),
TopicResults {
topic_id: topic.id.clone(),
num_documents: topic.documents.len(),
conll_f1: topic_f1,
num_gold_clusters: topic.gold_clusters.len(),
num_pred_clusters: pred_clusters.len(),
},
);
all_gold_chains.extend(topic_gold_chains);
all_pred_chains.extend(topic_pred_chains);
}
let (muc_p, muc_r, muc_f1) =
crate::eval::coref_metrics::muc_score(&all_pred_chains, &all_gold_chains);
let (b3_p, b3_r, b3_f1) =
crate::eval::coref_metrics::b_cubed_score(&all_pred_chains, &all_gold_chains);
let (ceaf_p, ceaf_r, ceaf_f1) =
crate::eval::coref_metrics::ceaf_e_score(&all_pred_chains, &all_gold_chains);
let (lea_p, lea_r, lea_f1) =
crate::eval::coref_metrics::lea_score(&all_pred_chains, &all_gold_chains);
let conll = conll_f1(&all_gold_chains, &all_pred_chains);
let num_contexts: usize = topics.iter().map(|t| t.documents.len()).sum();
let total_mentions: usize = all_pred_chains.iter().map(|c| c.len()).sum();
let avg_cluster_size = if !all_pred_chains.is_empty() {
total_mentions as f64 / all_pred_chains.len() as f64
} else {
0.0
};
Ok(CrossContextEvalResults {
benchmark: "Cross-Document".to_string(),
config: config.clone(),
muc: CorefScores::from_tuple((muc_p, muc_r, muc_f1)),
b_cubed: CorefScores::from_tuple((b3_p, b3_r, b3_f1)),
ceaf_e: CorefScores::from_tuple((ceaf_p, ceaf_r, ceaf_f1)),
lea: CorefScores::from_tuple((lea_p, lea_r, lea_f1)),
conll_f1: conll,
num_contexts,
num_gold_clusters: total_gold_clusters,
num_pred_clusters: total_pred_clusters,
avg_cluster_size,
time_ms: start.elapsed().as_millis() as f64,
per_topic: Some(per_topic),
per_document: None,
})
}
pub fn evaluate_long_document<E: ClusterEncoder + Clone, S: MergeScorer + Clone>(
documents: &[LongDocument],
encoder: E,
scorer: S,
config: &CrossContextEvalConfig,
) -> Result<CrossContextEvalResults> {
let start = std::time::Instant::now();
let resolver_config = CrossContextConfig {
window_size: config.window_size,
window_overlap: config.window_overlap,
merge_threshold: config.merge_threshold,
};
let resolver = UnifiedCrossContextResolver::new(encoder, scorer, resolver_config);
let mut all_gold_chains: Vec<CorefChain> = Vec::new();
let mut all_pred_chains: Vec<CorefChain> = Vec::new();
let mut per_document = HashMap::new();
for doc in documents {
let windows = doc.windows.clone().unwrap_or_default();
if windows.is_empty() {
let single_window = WindowOutput::new(
0,
0,
doc.char_len(),
if config.use_gold_mentions {
doc.gold_chains.clone()
} else {
Vec::new()
},
);
let pred_chains = resolver.resolve_long_document_windows(&[single_window]);
let doc_f1 = conll_f1(&doc.gold_chains, &pred_chains);
per_document.insert(
doc.id.clone(),
DocumentResults {
doc_id: doc.id.clone(),
num_tokens: doc.approx_tokens(),
num_windows: 1,
conll_f1: doc_f1,
num_gold_chains: doc.gold_chains.len(),
num_pred_chains: pred_chains.len(),
},
);
all_gold_chains.extend(doc.gold_chains.clone());
all_pred_chains.extend(pred_chains);
continue;
}
let pred_chains = resolver.resolve_long_document_windows(&windows);
let doc_f1 = conll_f1(&doc.gold_chains, &pred_chains);
per_document.insert(
doc.id.clone(),
DocumentResults {
doc_id: doc.id.clone(),
num_tokens: doc.approx_tokens(),
num_windows: windows.len(),
conll_f1: doc_f1,
num_gold_chains: doc.gold_chains.len(),
num_pred_chains: pred_chains.len(),
},
);
all_gold_chains.extend(doc.gold_chains.clone());
all_pred_chains.extend(pred_chains);
}
let (muc_p, muc_r, muc_f1) =
crate::eval::coref_metrics::muc_score(&all_pred_chains, &all_gold_chains);
let (b3_p, b3_r, b3_f1) =
crate::eval::coref_metrics::b_cubed_score(&all_pred_chains, &all_gold_chains);
let (ceaf_p, ceaf_r, ceaf_f1) =
crate::eval::coref_metrics::ceaf_e_score(&all_pred_chains, &all_gold_chains);
let (lea_p, lea_r, lea_f1) =
crate::eval::coref_metrics::lea_score(&all_pred_chains, &all_gold_chains);
let conll = conll_f1(&all_gold_chains, &all_pred_chains);
let total_mentions: usize = all_pred_chains.iter().map(|c| c.len()).sum();
let avg_cluster_size = if !all_pred_chains.is_empty() {
total_mentions as f64 / all_pred_chains.len() as f64
} else {
0.0
};
Ok(CrossContextEvalResults {
benchmark: "Long-Document".to_string(),
config: config.clone(),
muc: CorefScores::from_tuple((muc_p, muc_r, muc_f1)),
b_cubed: CorefScores::from_tuple((b3_p, b3_r, b3_f1)),
ceaf_e: CorefScores::from_tuple((ceaf_p, ceaf_r, ceaf_f1)),
lea: CorefScores::from_tuple((lea_p, lea_r, lea_f1)),
conll_f1: conll,
num_contexts: documents.len(),
num_gold_clusters: all_gold_chains.len(),
num_pred_clusters: all_pred_chains.len(),
avg_cluster_size,
time_ms: start.elapsed().as_millis() as f64,
per_topic: None,
per_document: Some(per_document),
})
}
fn cross_doc_clusters_to_chains(
clusters: &[CrossDocCluster],
docs: &[Document],
) -> Vec<CorefChain> {
clusters
.iter()
.map(|cluster| {
let mentions: Vec<Mention> = cluster
.mentions
.iter()
.filter_map(|(doc_id, entity_idx)| {
let doc = docs.iter().find(|d| &d.id == doc_id)?;
let entity = doc.entities.get(*entity_idx)?;
Some(Mention {
text: entity.text.clone(),
start: entity.start(),
end: entity.end(),
head_start: None,
head_end: None,
entity_type: Some(entity.entity_type.as_label().to_string()),
mention_type: None,
})
})
.collect();
CorefChain::new(mentions)
})
.filter(|c| !c.is_empty())
.collect()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StepwiseAnalysis {
FullPipeline,
GoldMentions,
GoldMentionsAndClusters,
}
impl StepwiseAnalysis {
pub fn description(&self) -> &'static str {
match self {
Self::FullPipeline => "xCoRe (full pipeline)",
Self::GoldMentions => "xCoRe (gold mentions)",
Self::GoldMentionsAndClusters => "xCoRe (gold mentions & clusters)",
}
}
}
pub fn stepwise_error_analysis<E: ClusterEncoder + Clone, S: MergeScorer + Clone>(
benchmark: CrossContextBenchmark,
topics: &[Topic], documents: &[LongDocument], encoder: E,
scorer: S,
) -> Result<HashMap<StepwiseAnalysis, CrossContextEvalResults>> {
let mut results = HashMap::new();
for analysis in [
StepwiseAnalysis::FullPipeline,
StepwiseAnalysis::GoldMentions,
StepwiseAnalysis::GoldMentionsAndClusters,
] {
let config = match analysis {
StepwiseAnalysis::FullPipeline => CrossContextEvalConfig::for_benchmark(benchmark),
StepwiseAnalysis::GoldMentions => CrossContextEvalConfig::gold_clusters(),
StepwiseAnalysis::GoldMentionsAndClusters => CrossContextEvalConfig::oracle(),
};
let eval_result = if benchmark.is_cross_document() {
evaluate_cross_document(topics, encoder.clone(), scorer.clone(), &config)?
} else {
evaluate_long_document(documents, encoder.clone(), scorer.clone(), &config)?
};
results.insert(analysis, eval_result);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::cluster_encoder::{CosineMergeScorer, HeuristicClusterEncoder};
use anno::{Entity, EntityType};
#[test]
fn test_benchmark_properties() {
assert!(CrossContextBenchmark::ECBPlus.is_cross_document());
assert!(!CrossContextBenchmark::ECBPlus.is_long_document());
assert!(!CrossContextBenchmark::LitBank.is_cross_document());
assert!(CrossContextBenchmark::LitBank.is_long_document());
assert_eq!(CrossContextBenchmark::all().len(), 5);
assert_eq!(CrossContextBenchmark::cross_document().len(), 2);
assert_eq!(CrossContextBenchmark::long_document().len(), 3);
}
#[test]
fn test_benchmark_sota() {
assert!((CrossContextBenchmark::ECBPlus.xcore_sota_f1() - 40.3).abs() < 0.1);
assert!((CrossContextBenchmark::LitBank.xcore_sota_f1() - 78.2).abs() < 0.1);
}
#[test]
fn test_eval_config_default() {
let config = CrossContextEvalConfig::default();
assert_eq!(config.window_size, 4000);
assert!(!config.use_gold_mentions);
}
#[test]
fn test_eval_config_for_benchmark() {
let config = CrossContextEvalConfig::for_benchmark(CrossContextBenchmark::ECBPlus);
assert_eq!(config.window_size, 512);
}
#[test]
fn test_topic_creation() {
let mut topic = Topic::new("topic_1");
topic.add_document(Document::new("doc1", "Obama visited Paris."));
topic.add_document(Document::new("doc2", "The president met leaders."));
assert_eq!(topic.documents.len(), 2);
}
#[test]
fn test_long_document_creation() {
use anno::MentionType;
fn new_mention(text: &str, start: usize, end: usize) -> Mention {
Mention {
text: text.to_string(),
start,
end,
head_start: None,
head_end: None,
entity_type: None,
mention_type: Some(MentionType::Proper),
}
}
let chains = vec![CorefChain::new(vec![
new_mention("Obama", 0, 5),
new_mention("he", 100, 102),
])];
let doc = LongDocument::new(
"book1",
"Obama went to Paris. ".repeat(100).as_str(),
chains,
);
assert!(doc.approx_tokens() > 100);
assert_eq!(doc.gold_chains.len(), 1);
}
#[test]
fn test_evaluate_cross_document_empty() {
let encoder = HeuristicClusterEncoder::new(64);
let scorer = CosineMergeScorer::new();
let config = CrossContextEvalConfig::default();
let topics: Vec<Topic> = vec![];
let results = evaluate_cross_document(&topics, encoder, scorer, &config).unwrap();
assert_eq!(results.num_contexts, 0);
}
#[test]
fn test_evaluate_cross_document_single_topic() {
let encoder = HeuristicClusterEncoder::new(64);
let scorer = CosineMergeScorer::new();
let config = CrossContextEvalConfig::default();
let mut topic = Topic::new("topic_1");
topic.add_document(
Document::new("doc1", "Obama visited France.").with_entities(vec![Entity::new(
"Obama",
EntityType::Person,
0,
5,
0.9,
)]),
);
topic.add_document(
Document::new("doc2", "The president met Macron.").with_entities(vec![
Entity::new("The president", EntityType::Person, 0, 13, 0.8),
Entity::new("Macron", EntityType::Person, 18, 24, 0.9),
]),
);
let results = evaluate_cross_document(&[topic], encoder, scorer, &config).unwrap();
assert_eq!(results.num_contexts, 2);
assert!(results.per_topic.is_some());
}
#[test]
fn test_stepwise_analysis_types() {
assert_eq!(
StepwiseAnalysis::FullPipeline.description(),
"xCoRe (full pipeline)"
);
assert_eq!(
StepwiseAnalysis::GoldMentions.description(),
"xCoRe (gold mentions)"
);
}
#[test]
fn test_results_summary() {
let results = CrossContextEvalResults {
benchmark: "Test".to_string(),
config: CrossContextEvalConfig::default(),
muc: CorefScores::new(0.8, 0.7),
b_cubed: CorefScores::new(0.75, 0.65),
ceaf_e: CorefScores::new(0.7, 0.6),
lea: CorefScores::new(0.72, 0.62),
conll_f1: 0.70,
num_contexts: 10,
num_gold_clusters: 50,
num_pred_clusters: 45,
avg_cluster_size: 2.5,
time_ms: 100.0,
per_topic: None,
per_document: None,
};
let summary = results.summary();
assert!(summary.contains("70.0%"));
}
#[test]
fn test_evaluate_cross_document_with_synthetic_data() {
let encoder = HeuristicClusterEncoder::new(64);
let scorer = CosineMergeScorer::new();
let config = CrossContextEvalConfig::default();
let mut topic1 = Topic::new("politics");
topic1.add_document(
Document::new("doc1", "Obama visited France yesterday.").with_entities(vec![
Entity::new("Obama", EntityType::Person, 0, 5, 0.9),
Entity::new("France", EntityType::Location, 14, 20, 0.9),
]),
);
topic1.add_document(
Document::new("doc2", "The president arrived in Paris.").with_entities(vec![
Entity::new("The president", EntityType::Person, 0, 13, 0.8),
Entity::new("Paris", EntityType::Location, 25, 30, 0.9),
]),
);
topic1.add_document(
Document::new("doc3", "Barack Obama met Macron in France.").with_entities(vec![
Entity::new("Barack Obama", EntityType::Person, 0, 12, 0.95),
Entity::new("Macron", EntityType::Person, 17, 23, 0.9),
Entity::new("France", EntityType::Location, 27, 33, 0.9),
]),
);
let mut obama_cluster = crate::eval::cdcr::CrossDocCluster::new(0u64, "Obama");
obama_cluster.mentions = vec![
("doc1".to_string(), 0),
("doc2".to_string(), 0),
("doc3".to_string(), 0),
];
let mut france_cluster = crate::eval::cdcr::CrossDocCluster::new(1u64, "France");
france_cluster.mentions = vec![
("doc1".to_string(), 1),
("doc2".to_string(), 1),
("doc3".to_string(), 2),
];
topic1.add_gold_cluster(obama_cluster);
topic1.add_gold_cluster(france_cluster);
let mut topic2 = Topic::new("tech");
topic2.add_document(
Document::new("doc4", "Apple released new products.").with_entities(vec![Entity::new(
"Apple",
EntityType::Organization,
0,
5,
0.9,
)]),
);
topic2.add_document(
Document::new("doc5", "The company expanded in Asia.").with_entities(vec![
Entity::new("The company", EntityType::Organization, 0, 11, 0.8),
Entity::new("Asia", EntityType::Location, 24, 28, 0.9),
]),
);
topic2.add_document(
Document::new("doc6", "Apple Inc announced quarterly results.").with_entities(vec![
Entity::new("Apple Inc", EntityType::Organization, 0, 9, 0.9),
]),
);
let mut apple_cluster = crate::eval::cdcr::CrossDocCluster::new(0u64, "Apple");
apple_cluster.mentions = vec![
("doc4".to_string(), 0),
("doc5".to_string(), 0),
("doc6".to_string(), 0),
];
topic2.add_gold_cluster(apple_cluster);
let results = evaluate_cross_document(&[topic1, topic2], encoder, scorer, &config).unwrap();
assert!(results.conll_f1 >= 0.0 && results.conll_f1 <= 1.0);
assert!(results.muc.f1 >= 0.0 && results.muc.f1 <= 1.0);
assert!(results.b_cubed.f1 >= 0.0 && results.b_cubed.f1 <= 1.0);
assert!(results.ceaf_e.f1 >= 0.0 && results.ceaf_e.f1 <= 1.0);
assert_eq!(results.num_contexts, 6);
assert!(results.per_topic.is_some());
let per_topic = results.per_topic.as_ref().unwrap();
assert_eq!(per_topic.len(), 2);
assert!(per_topic.contains_key("politics"));
assert!(per_topic.contains_key("tech"));
assert_eq!(results.num_gold_clusters, 3);
assert!(results.num_pred_clusters > 0);
}
#[test]
fn test_evaluate_long_document_with_gold_mentions() {
let encoder = HeuristicClusterEncoder::new(64);
let scorer = CosineMergeScorer::new();
let config = CrossContextEvalConfig {
use_gold_mentions: true,
..CrossContextEvalConfig::default()
};
let chains = vec![
CorefChain::new(vec![
Mention::new("Obama", 0, 5),
Mention::new("he", 50, 52),
]),
CorefChain::new(vec![
Mention::new("France", 14, 20),
Mention::new("the country", 60, 71),
]),
];
let doc = LongDocument::new("long_doc", &"Obama visited France. ".repeat(10), chains);
let results = evaluate_long_document(&[doc], encoder, scorer, &config).unwrap();
assert_eq!(results.num_contexts, 1);
assert!(results.per_document.is_some());
let per_doc = results.per_document.as_ref().unwrap();
assert_eq!(per_doc.len(), 1);
assert!(per_doc.contains_key("long_doc"));
}
}