use super::coref::{CorefChain, Mention};
use super::coref_metrics::{b_cubed_score, ceaf_e_score, ceaf_m_score, lea_score, muc_score};
use super::types::{CorefDocStats, DocumentScale, MetricDivergence};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
pub struct Scores {
pub precision: f64,
pub recall: f64,
pub f1: f64,
}
impl Scores {
pub fn from_tuple((precision, recall, f1): (f64, f64, f64)) -> Self {
Self {
precision,
recall,
f1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BookScaleConfig {
pub window_size: usize,
pub window_overlap: usize,
pub long_chain_threshold: usize,
pub short_chain_threshold: usize,
pub divergence_threshold: f64,
pub performance_drop_threshold: f64,
}
impl Default for BookScaleConfig {
fn default() -> Self {
Self {
window_size: 1500,
window_overlap: 200,
long_chain_threshold: 10, short_chain_threshold: 2, divergence_threshold: 0.30, performance_drop_threshold: 0.15, }
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CorefEvalScores {
pub muc: Scores,
pub b_cubed: Scores,
pub ceaf_e: Scores,
pub ceaf_m: Scores,
pub lea: Scores,
pub conll_f1: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BookScaleAnalysis {
pub full_doc_eval: CorefEvalScores,
pub windowed_eval: Option<WindowedEvaluation>,
pub stratified: StratifiedEvaluation,
pub doc_stats: CorefDocStats,
pub reliability: MetricReliability,
pub scale: DocumentScale,
pub diagnostics: BookScaleDiagnostics,
}
impl BookScaleAnalysis {
pub fn has_scale_issues(&self) -> bool {
self.diagnostics.has_issues()
}
pub fn diagnostic_report(&self) -> String {
let mut report = String::new();
report.push_str("=== Book-Scale Coreference Analysis ===\n\n");
report.push_str(&format!("Document Scale: {}\n", self.scale));
report.push_str(&format!(
"Document Length: {} chars ({} mentions in {} chains)\n\n",
self.doc_stats.doc_length, self.doc_stats.mention_count, self.doc_stats.chain_count
));
report.push_str("Full-Document Metrics:\n");
report.push_str(&format!(
" MUC: {:.1}%\n",
self.full_doc_eval.muc.f1 * 100.0
));
report.push_str(&format!(
" B³: {:.1}%\n",
self.full_doc_eval.b_cubed.f1 * 100.0
));
report.push_str(&format!(
" CEAF-e: {:.1}%\n",
self.full_doc_eval.ceaf_e.f1 * 100.0
));
report.push_str(&format!(
" CoNLL: {:.1}%\n\n",
self.full_doc_eval.conll_f1 * 100.0
));
if let Some(ref windowed) = self.windowed_eval {
report.push_str("Windowed vs Full-Document Comparison:\n");
report.push_str(&format!(
" Windowed CoNLL: {:.1}%\n",
windowed.avg_conll_f1 * 100.0
));
report.push_str(&format!(
" Full-Doc CoNLL: {:.1}%\n",
self.full_doc_eval.conll_f1 * 100.0
));
report.push_str(&format!(
" Performance Drop: {:.1} F1 points\n\n",
windowed.performance_drop * 100.0
));
}
report.push_str("Chain-Length Stratified Evaluation:\n");
report.push_str(&format!(
" Long chains (>10): {:.1}% F1 ({} chains)\n",
self.stratified.long_chains.f1 * 100.0,
self.stratified.long_chain_count
));
report.push_str(&format!(
" Short chains (2-10): {:.1}% F1 ({} chains)\n",
self.stratified.short_chains.f1 * 100.0,
self.stratified.short_chain_count
));
report.push_str(&format!(
" Singletons (1): {:.1}% F1 ({} chains)\n\n",
self.stratified.singletons.f1 * 100.0,
self.stratified.singleton_count
));
report.push_str("Metric Reliability:\n");
report.push_str(&format!(
" MUC: {} ({})\n",
self.reliability.muc_reliability, self.reliability.muc_note
));
report.push_str(&format!(
" B³: {} ({})\n",
self.reliability.b_cubed_reliability, self.reliability.b_cubed_note
));
report.push_str(&format!(
" CEAF-e: {} ({})\n",
self.reliability.ceaf_e_reliability, self.reliability.ceaf_e_note
));
report.push_str(&format!(
" LEA: {} ({})\n\n",
self.reliability.lea_reliability, self.reliability.lea_note
));
if self.has_scale_issues() {
report.push_str("⚠️ ISSUES DETECTED:\n");
if self.diagnostics.high_metric_divergence {
report
.push_str(" • High metric divergence - MUC and CEAF disagree significantly\n");
}
if self.diagnostics.large_performance_drop {
report.push_str(
" • Large windowed→full performance drop - long-range dependencies failing\n",
);
}
if self.diagnostics.long_chain_dominance {
report.push_str(" • Long chains dominate - main characters skewing metrics\n");
}
if self.diagnostics.singleton_neglect {
report.push_str(" • Singleton neglect - minor entities being ignored\n");
}
report.push_str("\nRECOMMENDATIONS:\n");
for rec in &self.diagnostics.recommendations {
report.push_str(&format!(" → {}\n", rec));
}
} else {
report.push_str("✓ No significant scale issues detected.\n");
}
report
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WindowedEvaluation {
pub num_windows: usize,
pub window_size: usize,
pub avg_conll_f1: f64,
pub std_conll_f1: f64,
pub performance_drop: f64,
pub window_evals: Vec<CorefEvalScores>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StratifiedEvaluation {
pub long_chains: Scores,
pub short_chains: Scores,
pub singletons: Scores,
pub long_chain_count: usize,
pub short_chain_count: usize,
pub singleton_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricReliability {
pub muc_reliability: ReliabilityLevel,
pub muc_note: String,
pub b_cubed_reliability: ReliabilityLevel,
pub b_cubed_note: String,
pub ceaf_e_reliability: ReliabilityLevel,
pub ceaf_e_note: String,
pub lea_reliability: ReliabilityLevel,
pub lea_note: String,
}
impl Default for MetricReliability {
fn default() -> Self {
Self {
muc_reliability: ReliabilityLevel::Medium,
muc_note: "May be inflated at scale".to_string(),
b_cubed_reliability: ReliabilityLevel::Medium,
b_cubed_note: "Moderate reliability".to_string(),
ceaf_e_reliability: ReliabilityLevel::Medium,
ceaf_e_note: "May collapse at scale".to_string(),
lea_reliability: ReliabilityLevel::High,
lea_note: "Most stable across scales".to_string(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ReliabilityLevel {
High,
Medium,
Low,
Unreliable,
}
impl std::fmt::Display for ReliabilityLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReliabilityLevel::High => write!(f, "HIGH"),
ReliabilityLevel::Medium => write!(f, "MEDIUM"),
ReliabilityLevel::Low => write!(f, "LOW"),
ReliabilityLevel::Unreliable => write!(f, "UNRELIABLE"),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BookScaleDiagnostics {
pub high_metric_divergence: bool,
pub large_performance_drop: bool,
pub long_chain_dominance: bool,
pub singleton_neglect: bool,
pub recommendations: Vec<String>,
}
impl BookScaleDiagnostics {
pub fn has_issues(&self) -> bool {
self.high_metric_divergence
|| self.large_performance_drop
|| self.long_chain_dominance
|| self.singleton_neglect
}
}
pub struct BookScaleAnalyzer {
config: BookScaleConfig,
}
impl Default for BookScaleAnalyzer {
fn default() -> Self {
Self::new(BookScaleConfig::default())
}
}
impl BookScaleAnalyzer {
pub fn new(config: BookScaleConfig) -> Self {
Self { config }
}
pub fn analyze(
&self,
predicted: &[CorefChain],
gold: &[CorefChain],
doc_length: usize,
) -> BookScaleAnalysis {
let full_doc_eval = self.evaluate_chains(predicted, gold);
let mut doc_stats = CorefDocStats::from_chains(gold);
doc_stats.doc_length = doc_length;
let scale = doc_stats.scale_classification();
let windowed_eval = if doc_length > self.config.window_size * 2 {
Some(self.compute_windowed_eval(predicted, gold, doc_length))
} else {
None
};
let stratified = self.compute_stratified_eval(predicted, gold);
let reliability = self.assess_reliability(&full_doc_eval, &stratified, scale);
let diagnostics =
self.generate_diagnostics(&full_doc_eval, windowed_eval.as_ref(), &stratified, scale);
BookScaleAnalysis {
full_doc_eval,
windowed_eval,
stratified,
doc_stats,
reliability,
scale,
diagnostics,
}
}
fn evaluate_chains(&self, predicted: &[CorefChain], gold: &[CorefChain]) -> CorefEvalScores {
let muc = Scores::from_tuple(muc_score(predicted, gold));
let b_cubed = Scores::from_tuple(b_cubed_score(predicted, gold));
let ceaf_e = Scores::from_tuple(ceaf_e_score(predicted, gold));
let ceaf_m = Scores::from_tuple(ceaf_m_score(predicted, gold));
let lea = Scores::from_tuple(lea_score(predicted, gold));
let conll_f1 = (muc.f1 + b_cubed.f1 + ceaf_e.f1) / 3.0;
CorefEvalScores {
muc,
b_cubed,
ceaf_e,
ceaf_m,
lea,
conll_f1,
}
}
fn compute_windowed_eval(
&self,
predicted: &[CorefChain],
gold: &[CorefChain],
doc_length: usize,
) -> WindowedEvaluation {
let step = self
.config
.window_size
.saturating_sub(self.config.window_overlap);
let mut window_evals = Vec::new();
let mut offset = 0;
while offset < doc_length {
let window_end = (offset + self.config.window_size).min(doc_length);
let pred_window = self.filter_to_window(predicted, offset, window_end);
let gold_window = self.filter_to_window(gold, offset, window_end);
if !pred_window.is_empty() || !gold_window.is_empty() {
let eval = self.evaluate_chains(&pred_window, &gold_window);
window_evals.push(eval);
}
if window_end >= doc_length {
break;
}
offset += step.max(1);
}
let conll_scores: Vec<f64> = window_evals.iter().map(|e| e.conll_f1).collect();
let avg_conll_f1 = if !conll_scores.is_empty() {
conll_scores.iter().sum::<f64>() / conll_scores.len() as f64
} else {
0.0
};
let std_conll_f1 = if conll_scores.len() > 1 {
let variance = conll_scores
.iter()
.map(|x| (x - avg_conll_f1).powi(2))
.sum::<f64>()
/ (conll_scores.len() - 1) as f64;
variance.sqrt()
} else {
0.0
};
let full_doc_eval = self.evaluate_chains(predicted, gold);
let performance_drop = avg_conll_f1 - full_doc_eval.conll_f1;
WindowedEvaluation {
num_windows: window_evals.len(),
window_size: self.config.window_size,
avg_conll_f1,
std_conll_f1,
performance_drop,
window_evals,
}
}
fn filter_to_window(&self, chains: &[CorefChain], start: usize, end: usize) -> Vec<CorefChain> {
chains
.iter()
.filter_map(|chain| {
let filtered_mentions: Vec<Mention> = chain
.mentions
.iter()
.filter(|m| m.start >= start && m.end <= end)
.cloned()
.collect();
if filtered_mentions.is_empty() {
None
} else {
let mut new_chain = CorefChain::new(filtered_mentions);
new_chain.cluster_id = chain.cluster_id;
new_chain.entity_type = chain.entity_type.clone();
Some(new_chain)
}
})
.collect()
}
fn compute_stratified_eval(
&self,
predicted: &[CorefChain],
gold: &[CorefChain],
) -> StratifiedEvaluation {
let (pred_long, pred_short, pred_singleton) = self.stratify_chains(predicted);
let (gold_long, gold_short, gold_singleton) = self.stratify_chains(gold);
let long_chains = if !pred_long.is_empty() || !gold_long.is_empty() {
Scores::from_tuple(muc_score(&pred_long, &gold_long))
} else {
Scores::default()
};
let short_chains = if !pred_short.is_empty() || !gold_short.is_empty() {
Scores::from_tuple(muc_score(&pred_short, &gold_short))
} else {
Scores::default()
};
let singletons = if !pred_singleton.is_empty() || !gold_singleton.is_empty() {
Scores::from_tuple(b_cubed_score(&pred_singleton, &gold_singleton))
} else {
Scores::default()
};
StratifiedEvaluation {
long_chains,
short_chains,
singletons,
long_chain_count: gold_long.len(),
short_chain_count: gold_short.len(),
singleton_count: gold_singleton.len(),
}
}
fn stratify_chains(
&self,
chains: &[CorefChain],
) -> (Vec<CorefChain>, Vec<CorefChain>, Vec<CorefChain>) {
let mut long = Vec::new();
let mut short = Vec::new();
let mut singleton = Vec::new();
for chain in chains {
let len = chain.len();
if len > self.config.long_chain_threshold {
long.push(chain.clone());
} else if len >= self.config.short_chain_threshold {
short.push(chain.clone());
} else {
singleton.push(chain.clone());
}
}
(long, short, singleton)
}
fn assess_reliability(
&self,
eval: &CorefEvalScores,
_stratified: &StratifiedEvaluation,
scale: DocumentScale,
) -> MetricReliability {
let divergence =
MetricDivergence::from_scores(eval.muc.f1, eval.b_cubed.f1, eval.ceaf_e.f1);
let (muc_rel, muc_note) = if divergence.muc_ceaf_divergence > 0.40 {
(
ReliabilityLevel::Low,
"Severely inflated due to long chains".to_string(),
)
} else if divergence.muc_ceaf_divergence > 0.25 {
(
ReliabilityLevel::Medium,
"May be inflated at this scale".to_string(),
)
} else {
(ReliabilityLevel::High, "Reliable at this scale".to_string())
};
let (ceaf_rel, ceaf_note) = match scale {
DocumentScale::BookScale => (
ReliabilityLevel::Low,
"Known to collapse at book scale".to_string(),
),
DocumentScale::Long => (
ReliabilityLevel::Medium,
"May underestimate at this length".to_string(),
),
_ => (ReliabilityLevel::High, "Reliable at this scale".to_string()),
};
let (b3_rel, b3_note) = if divergence.muc_b3_divergence > 0.30 {
(
ReliabilityLevel::Medium,
"Moderate divergence from MUC".to_string(),
)
} else {
(ReliabilityLevel::High, "Stable metric".to_string())
};
let (lea_rel, lea_note) = (
ReliabilityLevel::High,
"Most stable across document scales".to_string(),
);
MetricReliability {
muc_reliability: muc_rel,
muc_note,
b_cubed_reliability: b3_rel,
b_cubed_note: b3_note,
ceaf_e_reliability: ceaf_rel,
ceaf_e_note: ceaf_note,
lea_reliability: lea_rel,
lea_note,
}
}
fn generate_diagnostics(
&self,
eval: &CorefEvalScores,
windowed: Option<&WindowedEvaluation>,
stratified: &StratifiedEvaluation,
scale: DocumentScale,
) -> BookScaleDiagnostics {
let mut diagnostics = BookScaleDiagnostics::default();
let divergence = (eval.muc.f1 - eval.ceaf_e.f1).abs();
if divergence > self.config.divergence_threshold {
diagnostics.high_metric_divergence = true;
diagnostics
.recommendations
.push("Use LEA or stratified metrics instead of CoNLL F1".to_string());
}
if let Some(w) = windowed {
if w.performance_drop > self.config.performance_drop_threshold {
diagnostics.large_performance_drop = true;
diagnostics.recommendations.push(
"Consider incremental/streaming coref approach (Longdoc-style)".to_string(),
);
}
}
let total_chains =
stratified.long_chain_count + stratified.short_chain_count + stratified.singleton_count;
if total_chains > 0 {
let _long_chain_ratio = stratified.long_chain_count as f64 / total_chains as f64;
if stratified.long_chains.f1 > stratified.short_chains.f1 + 0.20 {
diagnostics.long_chain_dominance = true;
diagnostics
.recommendations
.push("Report per-chain-length metrics separately".to_string());
}
}
if stratified.singleton_count > 0 && stratified.singletons.f1 < 0.50 {
diagnostics.singleton_neglect = true;
diagnostics
.recommendations
.push("System may be ignoring minor entities".to_string());
}
match scale {
DocumentScale::BookScale => {
diagnostics
.recommendations
.push("Consider BOOKCOREF-style windowed+grouped evaluation".to_string());
}
DocumentScale::Long => {
diagnostics
.recommendations
.push("Monitor for metric divergence as length increases".to_string());
}
_ => {}
}
diagnostics
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerBookEvaluation {
pub book_id: String,
pub title: Option<String>,
pub author: Option<String>,
pub token_count: usize,
pub full_doc: CorefEvalScores,
pub windowed: Option<WindowedEvaluation>,
pub scale: DocumentScale,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiBookReport {
pub books: Vec<PerBookEvaluation>,
pub aggregate: AggregateStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregateStats {
pub total_books: usize,
pub mean_conll_f1: f64,
pub std_conll_f1: f64,
pub mean_performance_drop: f64,
pub books_with_issues: usize,
}
impl MultiBookReport {
pub fn from_books(books: Vec<PerBookEvaluation>) -> Self {
let total_books = books.len();
let conll_scores: Vec<f64> = books.iter().map(|b| b.full_doc.conll_f1).collect();
let mean_conll_f1 = if !conll_scores.is_empty() {
conll_scores.iter().sum::<f64>() / conll_scores.len() as f64
} else {
0.0
};
let std_conll_f1 = if conll_scores.len() > 1 {
let variance = conll_scores
.iter()
.map(|x| (x - mean_conll_f1).powi(2))
.sum::<f64>()
/ (conll_scores.len() - 1) as f64;
variance.sqrt()
} else {
0.0
};
let performance_drops: Vec<f64> = books
.iter()
.filter_map(|b| b.windowed.as_ref().map(|w| w.performance_drop))
.collect();
let mean_performance_drop = if !performance_drops.is_empty() {
performance_drops.iter().sum::<f64>() / performance_drops.len() as f64
} else {
0.0
};
let books_with_issues = books
.iter()
.filter(|b| {
let divergence = (b.full_doc.muc.f1 - b.full_doc.ceaf_e.f1).abs();
divergence > 0.30
|| b.windowed
.as_ref()
.map(|w| w.performance_drop > 0.15)
.unwrap_or(false)
})
.count();
let aggregate = AggregateStats {
total_books,
mean_conll_f1,
std_conll_f1,
mean_performance_drop,
books_with_issues,
};
Self { books, aggregate }
}
pub fn format_table(&self) -> String {
let mut table = String::new();
table.push_str(&format!(
"{:<30} {:>8} {:>8} {:>8} {:>8} {:>8}\n",
"Book", "Tokens", "MUC", "B³", "CEAF", "CoNLL"
));
table.push_str(&format!("{}\n", "-".repeat(78)));
for book in &self.books {
let title = book
.title
.as_deref()
.unwrap_or(&book.book_id)
.chars()
.take(28)
.collect::<String>();
table.push_str(&format!(
"{:<30} {:>8} {:>7.1}% {:>7.1}% {:>7.1}% {:>7.1}%\n",
title,
book.token_count,
book.full_doc.muc.f1 * 100.0,
book.full_doc.b_cubed.f1 * 100.0,
book.full_doc.ceaf_e.f1 * 100.0,
book.full_doc.conll_f1 * 100.0,
));
}
table.push_str(&format!("{}\n", "-".repeat(78)));
table.push_str(&format!(
"{:<30} {:>8} {:>7.1}% ±{:.1}\n",
"MEAN",
"",
self.aggregate.mean_conll_f1 * 100.0,
self.aggregate.std_conll_f1 * 100.0
));
table
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_chain(mentions: Vec<(&str, usize, usize)>) -> CorefChain {
let m: Vec<Mention> = mentions
.into_iter()
.map(|(text, start, end)| Mention::new(text, start, end))
.collect();
CorefChain::new(m)
}
#[test]
fn test_stratify_chains() {
let config = BookScaleConfig::default();
let analyzer = BookScaleAnalyzer::new(config);
let chains = vec![
make_chain(vec![("a", 0, 1)]), make_chain(vec![("b", 0, 1), ("c", 2, 3), ("d", 4, 5)]), make_chain((0..15).map(|i| ("x", i * 10, i * 10 + 1)).collect()), ];
let (long, short, single) = analyzer.stratify_chains(&chains);
assert_eq!(single.len(), 1);
assert_eq!(short.len(), 1);
assert_eq!(long.len(), 1);
}
#[test]
fn test_reliability_assessment() {
let config = BookScaleConfig::default();
let analyzer = BookScaleAnalyzer::new(config);
let eval = CorefEvalScores {
muc: Scores {
precision: 0.9,
recall: 0.9,
f1: 0.9,
},
b_cubed: Scores {
precision: 0.7,
recall: 0.7,
f1: 0.7,
},
ceaf_e: Scores {
precision: 0.4,
recall: 0.4,
f1: 0.4,
},
ceaf_m: Scores::default(),
lea: Scores::default(),
conll_f1: 0.67,
};
let stratified = StratifiedEvaluation::default();
let reliability = analyzer.assess_reliability(&eval, &stratified, DocumentScale::BookScale);
assert!(matches!(
reliability.muc_reliability,
ReliabilityLevel::Low | ReliabilityLevel::Medium
));
}
#[test]
fn test_diagnostics_generation() {
let config = BookScaleConfig::default();
let analyzer = BookScaleAnalyzer::new(config);
let eval = CorefEvalScores {
muc: Scores {
precision: 0.93,
recall: 0.93,
f1: 0.93,
},
b_cubed: Scores {
precision: 0.62,
recall: 0.62,
f1: 0.62,
},
ceaf_e: Scores {
precision: 0.33,
recall: 0.33,
f1: 0.33,
},
ceaf_m: Scores::default(),
lea: Scores::default(),
conll_f1: 0.63,
};
let windowed = WindowedEvaluation {
num_windows: 10,
window_size: 1500,
avg_conll_f1: 0.78,
std_conll_f1: 0.05,
performance_drop: 0.15,
window_evals: vec![],
};
let stratified = StratifiedEvaluation::default();
let diagnostics = analyzer.generate_diagnostics(
&eval,
Some(&windowed),
&stratified,
DocumentScale::BookScale,
);
assert!(diagnostics.high_metric_divergence);
assert!(diagnostics.has_issues());
}
#[test]
fn test_multi_book_report() {
let books = vec![
PerBookEvaluation {
book_id: "animal_farm".to_string(),
title: Some("Animal Farm".to_string()),
author: Some("George Orwell".to_string()),
token_count: 29853,
full_doc: CorefEvalScores {
muc: Scores {
precision: 0.9,
recall: 0.9,
f1: 0.9,
},
b_cubed: Scores {
precision: 0.6,
recall: 0.6,
f1: 0.6,
},
ceaf_e: Scores {
precision: 0.5,
recall: 0.5,
f1: 0.5,
},
ceaf_m: Scores::default(),
lea: Scores::default(),
conll_f1: 0.67,
},
windowed: None,
scale: DocumentScale::Long,
},
PerBookEvaluation {
book_id: "pride_prejudice".to_string(),
title: Some("Pride and Prejudice".to_string()),
author: Some("Jane Austen".to_string()),
token_count: 121869,
full_doc: CorefEvalScores {
muc: Scores {
precision: 0.85,
recall: 0.85,
f1: 0.85,
},
b_cubed: Scores {
precision: 0.55,
recall: 0.55,
f1: 0.55,
},
ceaf_e: Scores {
precision: 0.35,
recall: 0.35,
f1: 0.35,
},
ceaf_m: Scores::default(),
lea: Scores::default(),
conll_f1: 0.58,
},
windowed: None,
scale: DocumentScale::BookScale,
},
];
let report = MultiBookReport::from_books(books);
assert_eq!(report.aggregate.total_books, 2);
assert!(report.aggregate.mean_conll_f1 > 0.5);
let table = report.format_table();
assert!(table.contains("Animal Farm"));
assert!(table.contains("Pride and Prejudice"));
}
#[test]
fn test_document_scale_classification() {
assert_eq!(DocumentScale::from_tokens(100), DocumentScale::Short);
assert_eq!(DocumentScale::from_tokens(2000), DocumentScale::Short);
assert_eq!(DocumentScale::from_tokens(5000), DocumentScale::Medium);
assert_eq!(DocumentScale::from_tokens(30000), DocumentScale::Long);
assert_eq!(DocumentScale::from_tokens(100000), DocumentScale::BookScale);
}
#[test]
fn test_empty_chains_stratification() {
let config = BookScaleConfig::default();
let analyzer = BookScaleAnalyzer::new(config);
let chains: Vec<CorefChain> = vec![];
let (long, short, single) = analyzer.stratify_chains(&chains);
assert!(long.is_empty());
assert!(short.is_empty());
assert!(single.is_empty());
}
#[test]
fn test_all_singletons() {
let config = BookScaleConfig::default();
let analyzer = BookScaleAnalyzer::new(config);
let chains = vec![
make_chain(vec![("a", 0, 1)]),
make_chain(vec![("b", 10, 11)]),
make_chain(vec![("c", 20, 21)]),
];
let (long, short, single) = analyzer.stratify_chains(&chains);
assert!(long.is_empty());
assert!(short.is_empty());
assert_eq!(single.len(), 3);
}
#[test]
fn test_all_long_chains() {
let config = BookScaleConfig::default();
let analyzer = BookScaleAnalyzer::new(config);
let chains = vec![
make_chain((0..20).map(|i| ("x", i * 10, i * 10 + 1)).collect()),
make_chain((0..25).map(|i| ("y", i * 10 + 5, i * 10 + 6)).collect()),
];
let (long, short, single) = analyzer.stratify_chains(&chains);
assert_eq!(long.len(), 2);
assert!(short.is_empty());
assert!(single.is_empty());
}
#[test]
fn test_scores_default() {
let scores = Scores::default();
assert!((scores.precision - 0.0).abs() < 0.001);
assert!((scores.recall - 0.0).abs() < 0.001);
assert!((scores.f1 - 0.0).abs() < 0.001);
}
#[test]
fn test_coref_eval_scores_conll_average() {
let eval = CorefEvalScores {
muc: Scores {
precision: 0.8,
recall: 0.8,
f1: 0.8,
},
b_cubed: Scores {
precision: 0.7,
recall: 0.7,
f1: 0.7,
},
ceaf_e: Scores {
precision: 0.6,
recall: 0.6,
f1: 0.6,
},
ceaf_m: Scores::default(),
lea: Scores::default(),
conll_f1: 0.7, };
let expected_conll = (0.8 + 0.7 + 0.6) / 3.0;
assert!((eval.conll_f1 - expected_conll).abs() < 0.001);
}
#[test]
fn test_windowed_evaluation_performance_drop() {
let windowed = WindowedEvaluation {
num_windows: 5,
window_size: 1000,
avg_conll_f1: 0.80,
std_conll_f1: 0.03,
performance_drop: 0.15,
window_evals: vec![],
};
assert!(windowed.performance_drop > 0.0);
assert_eq!(windowed.num_windows, 5);
}
#[test]
fn test_diagnostics_no_issues_for_short_doc() {
let config = BookScaleConfig::default();
let analyzer = BookScaleAnalyzer::new(config);
let eval = CorefEvalScores {
muc: Scores {
precision: 0.85,
recall: 0.85,
f1: 0.85,
},
b_cubed: Scores {
precision: 0.82,
recall: 0.82,
f1: 0.82,
},
ceaf_e: Scores {
precision: 0.80,
recall: 0.80,
f1: 0.80,
},
ceaf_m: Scores::default(),
lea: Scores::default(),
conll_f1: 0.82,
};
let stratified = StratifiedEvaluation::default();
let diagnostics = analyzer.generate_diagnostics(
&eval,
None, &stratified,
DocumentScale::Short,
);
let _ = diagnostics.high_metric_divergence;
let _ = diagnostics.has_issues();
}
#[test]
fn test_multi_book_report_empty() {
let books: Vec<PerBookEvaluation> = vec![];
let report = MultiBookReport::from_books(books);
assert_eq!(report.aggregate.total_books, 0);
assert!(report.books.is_empty());
}
#[test]
fn test_per_book_evaluation_scale() {
let book = PerBookEvaluation {
book_id: "test".to_string(),
title: Some("Test Book".to_string()),
author: None,
token_count: 200000,
full_doc: CorefEvalScores::default(),
windowed: None,
scale: DocumentScale::BookScale,
};
assert!(book.token_count > 100000);
assert_eq!(book.scale, DocumentScale::BookScale);
}
}