use anno::{Identity, IdentityId, TrackRef};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct InterDocCorefMetrics {
pub cluster_purity: f64,
pub cluster_completeness: f64,
pub num_pred_identities: usize,
pub num_gold_identities: usize,
pub num_correct: usize,
pub num_total: usize,
}
impl InterDocCorefMetrics {
#[must_use]
pub fn compute(predicted: &[Identity], gold: &[Vec<TrackRef>]) -> Self {
if predicted.is_empty() && gold.is_empty() {
return Self::default();
}
let mut pred_map: HashMap<TrackRef, IdentityId> = HashMap::new();
for identity in predicted {
if let Some(anno::IdentitySource::CrossDocCoref { track_refs }) = &identity.source {
for track_ref in track_refs {
pred_map.insert(track_ref.clone(), identity.id);
}
}
}
let mut gold_map: HashMap<TrackRef, usize> = HashMap::new();
for (idx, cluster) in gold.iter().enumerate() {
for track_ref in cluster {
gold_map.insert(track_ref.clone(), idx);
}
}
let all_tracks: HashSet<_> = pred_map.keys().chain(gold_map.keys()).cloned().collect();
let num_total = all_tracks.len();
if num_total == 0 {
return Self::default();
}
let mut total_purity = 0.0;
let mut total_completeness = 0.0;
let mut num_correct = 0;
for identity in predicted {
if let Some(anno::IdentitySource::CrossDocCoref { track_refs }) = &identity.source {
if track_refs.is_empty() {
continue;
}
let mut gold_cluster_counts: HashMap<usize, usize> = HashMap::new();
for track_ref in track_refs {
if let Some(&gold_cluster) = gold_map.get(track_ref) {
*gold_cluster_counts.entry(gold_cluster).or_insert(0) += 1;
}
}
let max_count = gold_cluster_counts.values().max().copied().unwrap_or(0);
let purity = if track_refs.is_empty() {
0.0
} else {
max_count as f64 / track_refs.len() as f64
};
total_purity += purity * track_refs.len() as f64;
num_correct += max_count;
}
}
for cluster in gold.iter() {
if cluster.is_empty() {
continue;
}
let mut pred_identity_counts: HashMap<IdentityId, usize> = HashMap::new();
for track_ref in cluster {
if let Some(&pred_identity) = pred_map.get(track_ref) {
*pred_identity_counts.entry(pred_identity).or_insert(0) += 1;
}
}
let max_count = pred_identity_counts.values().max().copied().unwrap_or(0);
let completeness = if cluster.is_empty() {
0.0
} else {
max_count as f64 / cluster.len() as f64
};
total_completeness += completeness * cluster.len() as f64;
}
let cluster_purity = if num_total > 0 {
total_purity / num_total as f64
} else {
0.0
};
let cluster_completeness = if num_total > 0 {
total_completeness / num_total as f64
} else {
0.0
};
Self {
cluster_purity,
cluster_completeness,
num_pred_identities: predicted.len(),
num_gold_identities: gold.len(),
num_correct,
num_total,
}
}
#[must_use]
pub fn f1(&self) -> f64 {
if self.cluster_purity + self.cluster_completeness == 0.0 {
0.0
} else {
2.0 * self.cluster_purity * self.cluster_completeness
/ (self.cluster_purity + self.cluster_completeness)
}
}
}
impl Default for InterDocCorefMetrics {
fn default() -> Self {
Self {
cluster_purity: 0.0,
cluster_completeness: 0.0,
num_pred_identities: 0,
num_gold_identities: 0,
num_correct: 0,
num_total: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use anno::{GroundedDocument, Location, Signal, Track, TrackId};
fn create_test_corpus() -> (anno::Corpus, Vec<Vec<TrackRef>>) {
let mut corpus = anno::Corpus::new();
let mut doc1 = GroundedDocument::new("doc1", "Apple and Microsoft");
let s1 = doc1.add_signal(Signal::new(0, Location::text(0, 5), "Apple", "Org", 0.9));
let s2 = doc1.add_signal(Signal::new(
1,
Location::text(10, 19),
"Microsoft",
"Org",
0.9,
));
let mut track1 = Track::new(0, "Apple");
track1.add_signal(s1, 0);
let mut track2 = Track::new(1, "Microsoft");
track2.add_signal(s2, 0);
doc1.add_track(track1);
doc1.add_track(track2);
corpus.add_document(doc1);
let mut doc2 = GroundedDocument::new("doc2", "Apple Inc");
let s3 = doc2.add_signal(Signal::new(
0,
Location::text(0, 10),
"Apple Inc",
"Org",
0.9,
));
let mut track3 = Track::new(0, "Apple Inc");
track3.add_signal(s3, 0);
doc2.add_track(track3);
corpus.add_document(doc2);
let mut doc3 = GroundedDocument::new("doc3", "Microsoft Corp");
let s4 = doc3.add_signal(Signal::new(
0,
Location::text(0, 13),
"Microsoft Corp",
"Org",
0.9,
));
let mut track4 = Track::new(0, "Microsoft Corp");
track4.add_signal(s4, 0);
doc3.add_track(track4);
corpus.add_document(doc3);
use anno::coalesce::Resolver;
let resolver = Resolver::new().with_threshold(0.3).require_type_match(true);
let _identity_ids = resolver.resolve_inter_doc_coref(&mut corpus, None, None);
let gold = vec![
vec![
TrackRef {
doc_id: "doc1".to_string(),
track_id: TrackId::new(0),
},
TrackRef {
doc_id: "doc2".to_string(),
track_id: TrackId::new(0),
},
],
vec![
TrackRef {
doc_id: "doc1".to_string(),
track_id: TrackId::new(1),
},
TrackRef {
doc_id: "doc3".to_string(),
track_id: TrackId::new(0),
},
],
];
(corpus, gold)
}
#[test]
fn test_inter_doc_coref_metrics_basic() {
let (corpus, gold) = create_test_corpus();
let identity_ids: Vec<_> = corpus
.identities()
.values()
.filter(|id| matches!(id.source, Some(anno::IdentitySource::CrossDocCoref { .. })))
.map(|id| id.id)
.collect();
let predicted: Vec<_> = identity_ids
.iter()
.filter_map(|&id| corpus.get_identity(id))
.cloned()
.collect();
let metrics = InterDocCorefMetrics::compute(&predicted, &gold);
assert!(metrics.cluster_purity >= 0.0 && metrics.cluster_purity <= 1.0);
assert!(metrics.cluster_completeness >= 0.0 && metrics.cluster_completeness <= 1.0);
assert!(metrics.f1() >= 0.0 && metrics.f1() <= 1.0);
}
#[test]
fn test_inter_doc_coref_metrics_empty() {
let metrics = InterDocCorefMetrics::compute(&[], &[]);
assert_eq!(metrics.cluster_purity, 0.0);
assert_eq!(metrics.cluster_completeness, 0.0);
assert_eq!(metrics.f1(), 0.0);
}
}