use anno::{Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd)]
#[serde(transparent)]
pub struct MetricValue(f64);
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub struct MetricWithVariance {
pub mean: f64,
pub std_dev: f64,
pub ci_95: f64,
pub min: f64,
pub max: f64,
pub n: usize,
}
impl MetricWithVariance {
pub fn from_samples(samples: &[f64]) -> Self {
if samples.is_empty() {
return Self {
mean: 0.0,
std_dev: 0.0,
ci_95: 0.0,
min: 0.0,
max: 0.0,
n: 0,
};
}
let n = samples.len();
let mean = samples.iter().sum::<f64>() / n as f64;
let min = samples.iter().cloned().fold(f64::INFINITY, f64::min);
let max = samples.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let std_dev = if n > 1 {
let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
variance.sqrt()
} else {
0.0
};
let t_value = if n >= 30 {
1.96
} else {
2.0 + 0.1 / (n as f64).sqrt()
};
let ci_95 = if n > 1 {
t_value * std_dev / (n as f64).sqrt()
} else {
0.0
};
Self {
mean,
std_dev,
ci_95,
min,
max,
n,
}
}
pub fn format_with_ci(&self) -> String {
if self.n == 0 {
return "N/A".to_string();
}
format!("{:.1}% ± {:.1}%", self.mean * 100.0, self.ci_95 * 100.0)
}
pub fn format_with_range(&self) -> String {
if self.n == 0 {
return "N/A".to_string();
}
format!(
"{:.1}% ({:.1}%-{:.1}%)",
self.mean * 100.0,
self.min * 100.0,
self.max * 100.0
)
}
pub fn coefficient_of_variation(&self) -> f64 {
if self.mean.abs() < 1e-10 {
0.0
} else {
self.std_dev / self.mean
}
}
}
impl Default for MetricWithVariance {
fn default() -> Self {
Self {
mean: 0.0,
std_dev: 0.0,
ci_95: 0.0,
min: 0.0,
max: 0.0,
n: 0,
}
}
}
impl std::fmt::Display for MetricWithVariance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.format_with_ci())
}
}
impl MetricValue {
pub fn new(value: f64) -> Self {
MetricValue(value.clamp(0.0, 1.0))
}
pub fn try_new(value: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&value) {
return Err(Error::InvalidInput(format!(
"MetricValue must be in [0.0, 1.0], got {}",
value
)));
}
Ok(MetricValue(value))
}
#[inline]
pub fn get(&self) -> f64 {
self.0
}
}
impl Default for MetricValue {
fn default() -> Self {
MetricValue(0.0)
}
}
impl std::fmt::Display for MetricValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:.4}", self.0)
}
}
impl From<f64> for MetricValue {
fn from(value: f64) -> Self {
MetricValue::new(value)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GoalCheckResult {
pub passed: bool,
pub checks: HashMap<String, GoalCheck>,
pub summary: Option<String>,
}
impl GoalCheckResult {
#[must_use]
pub fn new() -> Self {
Self {
passed: true,
checks: HashMap::new(),
summary: None,
}
}
pub fn add_check(&mut self, name: impl Into<String>, check: GoalCheck) {
if !check.passed {
self.passed = false;
}
self.checks.insert(name.into(), check);
}
pub fn add_failure(&mut self, name: impl Into<String>, actual: f64, threshold: f64) {
self.add_check(name, GoalCheck::fail(threshold, actual));
}
pub fn add_success(&mut self, name: impl Into<String>, actual: f64, threshold: f64) {
self.add_check(name, GoalCheck::pass(threshold, actual));
}
pub fn passed_count(&self) -> usize {
self.checks.values().filter(|c| c.passed).count()
}
pub fn failed_count(&self) -> usize {
self.checks.values().filter(|c| !c.passed).count()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoalCheck {
pub passed: bool,
pub threshold: f64,
pub actual: f64,
pub message: Option<String>,
}
impl GoalCheck {
pub fn new(passed: bool, threshold: f64, actual: f64) -> Self {
Self {
passed,
threshold,
actual,
message: None,
}
}
pub fn pass(threshold: f64, actual: f64) -> Self {
Self::new(true, threshold, actual)
}
pub fn fail(threshold: f64, actual: f64) -> Self {
Self::new(false, threshold, actual)
}
#[must_use]
pub fn with_message(mut self, msg: impl Into<String>) -> Self {
self.message = Some(msg.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabelShift {
pub overlap_ratio: f64,
pub familiarity: f64,
pub true_zero_shot_types: Vec<String>,
pub transfer_difficulty: String,
}
impl LabelShift {
#[must_use]
pub fn is_inflated(&self) -> bool {
self.overlap_ratio > 0.8 || self.familiarity > 0.85
}
#[must_use]
pub fn true_zero_shot_count(&self) -> usize {
self.true_zero_shot_types.len()
}
#[must_use]
pub fn from_type_sets(train_types: &[String], eval_types: &[String]) -> Self {
let train_set: std::collections::HashSet<_> = train_types.iter().collect();
let eval_set: std::collections::HashSet<_> = eval_types.iter().collect();
let overlap_count = eval_set.intersection(&train_set).count();
let overlap_ratio = if eval_types.is_empty() {
0.0
} else {
overlap_count as f64 / eval_types.len() as f64
};
let true_zero_shot_types: Vec<String> = eval_set
.difference(&train_set)
.map(|s| (*s).clone())
.collect();
let familiarity = compute_string_based_familiarity(train_types, eval_types);
let transfer_difficulty = if overlap_ratio > 0.8 || familiarity > 0.85 {
"low"
} else if overlap_ratio > 0.4 || familiarity > 0.5 {
"medium"
} else {
"high"
}
.to_string();
Self {
overlap_ratio,
familiarity,
true_zero_shot_types,
transfer_difficulty,
}
}
#[must_use]
pub fn from_type_sets_with_embeddings<F>(
train_types: &[String],
eval_types: &[String],
embedding_fn: F,
) -> Self
where
F: Fn(&str) -> Option<Vec<f32>>,
{
let mut result = Self::from_type_sets(train_types, eval_types);
if let Some(familiarity) =
compute_embedding_based_familiarity(train_types, eval_types, &embedding_fn)
{
result.familiarity = familiarity;
}
result
}
}
impl std::fmt::Display for LabelShift {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LabelShift(overlap={:.0}%, familiarity={:.2}, zero-shot={}, difficulty={})",
self.overlap_ratio * 100.0,
self.familiarity,
self.true_zero_shot_types.len(),
self.transfer_difficulty
)
}
}
pub use anno::metrics::types::CorefChainStats;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum DocumentScale {
#[default]
Short,
Medium,
Long,
BookScale,
}
impl DocumentScale {
#[must_use]
pub fn from_tokens(token_count: usize) -> Self {
match token_count {
0..=2000 => Self::Short,
2001..=10000 => Self::Medium,
10001..=50000 => Self::Long,
_ => Self::BookScale,
}
}
#[must_use]
pub fn is_book_scale(&self) -> bool {
matches!(self, Self::BookScale)
}
#[must_use]
pub fn metrics_may_be_unreliable(&self) -> bool {
matches!(self, Self::Long | Self::BookScale)
}
#[must_use]
pub fn expected_degradation(&self) -> f64 {
match self {
Self::Short => 0.0,
Self::Medium => 0.05,
Self::Long => 0.10,
Self::BookScale => 0.15,
}
}
}
impl std::fmt::Display for DocumentScale {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Short => write!(f, "Short (<2k tokens)"),
Self::Medium => write!(f, "Medium (2k-10k tokens)"),
Self::Long => write!(f, "Long (10k-50k tokens)"),
Self::BookScale => write!(f, "Book-scale (>50k tokens)"),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub struct MetricDivergence {
pub muc_f1: f64,
pub b3_f1: f64,
pub ceaf_e_f1: f64,
pub muc_ceaf_divergence: f64,
pub muc_b3_divergence: f64,
pub b3_ceaf_divergence: f64,
}
impl MetricDivergence {
#[must_use]
pub fn from_scores(muc_f1: f64, b3_f1: f64, ceaf_e_f1: f64) -> Self {
Self {
muc_f1,
b3_f1,
ceaf_e_f1,
muc_ceaf_divergence: (muc_f1 - ceaf_e_f1).abs(),
muc_b3_divergence: (muc_f1 - b3_f1).abs(),
b3_ceaf_divergence: (b3_f1 - ceaf_e_f1).abs(),
}
}
#[must_use]
pub fn has_high_divergence(&self) -> bool {
self.muc_ceaf_divergence > 0.20
}
#[must_use]
pub fn muc_likely_inflated(&self) -> bool {
self.muc_f1 > self.ceaf_e_f1 + 0.15
}
#[must_use]
pub fn ceaf_likely_collapsed(&self) -> bool {
self.ceaf_e_f1 < self.b3_f1 - 0.15 && self.ceaf_e_f1 < self.muc_f1 - 0.20
}
#[must_use]
pub fn most_reliable_metric(&self) -> &'static str {
if self.muc_likely_inflated() && self.ceaf_likely_collapsed() {
"B³ (MUC inflated, CEAF-e collapsed)"
} else if self.muc_likely_inflated() {
"B³ or CEAF-e (MUC inflated)"
} else if self.ceaf_likely_collapsed() {
"MUC or B³ (CEAF-e collapsed)"
} else {
"CoNLL F1 (metrics agree)"
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub struct CorefDocStats {
pub doc_length: usize,
pub chain_count: usize,
pub mention_count: usize,
pub avg_chain_length: f64,
pub max_chain_length: usize,
pub avg_entity_spread: usize,
pub max_entity_spread: usize,
pub median_entity_spread: usize,
pub pronoun_ratio: f64,
pub proper_ratio: f64,
pub nominal_ratio: f64,
pub singleton_ratio: f64,
}
impl CorefDocStats {
#[must_use]
pub fn from_chains(chains: &[crate::eval::coref::CorefChain]) -> Self {
if chains.is_empty() {
return Self::default();
}
let chain_count = chains.len();
let mention_count: usize = chains.iter().map(|c| c.mentions.len()).sum();
let avg_chain_length = mention_count as f64 / chain_count as f64;
let max_chain_length = chains.iter().map(|c| c.mentions.len()).max().unwrap_or(0);
let singleton_count = chains.iter().filter(|c| c.mentions.len() == 1).count();
let singleton_ratio = singleton_count as f64 / chain_count as f64;
let mut spreads: Vec<usize> = Vec::with_capacity(chain_count);
for chain in chains {
if chain.mentions.len() <= 1 {
spreads.push(0);
continue;
}
let first_start = chain.mentions.iter().map(|m| m.start).min().unwrap_or(0);
let last_end = chain.mentions.iter().map(|m| m.end).max().unwrap_or(0);
let spread = last_end.saturating_sub(first_start);
spreads.push(spread);
}
let avg_entity_spread = if !spreads.is_empty() {
spreads.iter().sum::<usize>() / spreads.len()
} else {
0
};
let max_entity_spread = spreads.iter().copied().max().unwrap_or(0);
spreads.sort_unstable();
let median_entity_spread = if spreads.is_empty() {
0
} else {
spreads[spreads.len() / 2]
};
let mut pronoun_count = 0usize;
let mut proper_count = 0usize;
let mut nominal_count = 0usize;
for chain in chains {
for mention in &chain.mentions {
let text_lower = mention.text.to_lowercase();
let is_pronoun = matches!(
text_lower.as_str(),
"he" | "she"
| "it"
| "they"
| "him"
| "her"
| "them"
| "his"
| "hers"
| "its"
| "their"
| "i"
| "me"
| "we"
| "us"
| "you"
);
if is_pronoun {
pronoun_count += 1;
} else if mention
.text
.chars()
.next()
.is_some_and(|c| c.is_uppercase())
{
proper_count += 1;
} else {
nominal_count += 1;
}
}
}
let total_mentions = mention_count.max(1) as f64;
let pronoun_ratio = pronoun_count as f64 / total_mentions;
let proper_ratio = proper_count as f64 / total_mentions;
let nominal_ratio = nominal_count as f64 / total_mentions;
Self {
doc_length: 0, chain_count,
mention_count,
avg_chain_length,
max_chain_length,
avg_entity_spread,
max_entity_spread,
median_entity_spread,
pronoun_ratio,
proper_ratio,
nominal_ratio,
singleton_ratio,
}
}
#[must_use]
pub fn scale_classification(&self) -> DocumentScale {
DocumentScale::from_tokens(self.doc_length)
}
#[must_use]
pub fn has_book_scale_spread(&self) -> bool {
self.avg_entity_spread > 5000 || self.max_entity_spread > 20000
}
#[must_use]
pub fn format_summary(&self) -> String {
format!(
"Chains: {}, Mentions: {}, Avg length: {:.1}, Spread: avg={} max={}",
self.chain_count,
self.mention_count,
self.avg_chain_length,
self.avg_entity_spread,
self.max_entity_spread,
)
}
}
fn compute_string_based_familiarity(train_types: &[String], eval_types: &[String]) -> f64 {
if eval_types.is_empty() {
return 0.0;
}
let mut total_similarity = 0.0;
let mut counts = std::collections::HashMap::<String, usize>::new();
for eval_type in eval_types {
*counts.entry(eval_type.clone()).or_insert(0) += 1;
}
let total_eval_count = eval_types.len() as f64;
for (eval_type, freq) in counts {
let max_sim = train_types
.iter()
.map(|train_type| string_similarity(&eval_type, train_type))
.fold(0.0, f64::max);
let weight = freq as f64 / total_eval_count;
total_similarity += max_sim * weight;
}
total_similarity
}
fn compute_embedding_based_familiarity<F>(
train_types: &[String],
eval_types: &[String],
embedding_fn: &F,
) -> Option<f64>
where
F: Fn(&str) -> Option<Vec<f32>>,
{
if eval_types.is_empty() {
return Some(0.0);
}
let train_embeddings: Vec<(String, Vec<f32>)> = train_types
.iter()
.filter_map(|t| embedding_fn(t).map(|e| (t.clone(), e)))
.collect();
if train_embeddings.is_empty() {
return None; }
let mut counts = std::collections::HashMap::<String, usize>::new();
for eval_type in eval_types {
*counts.entry(eval_type.clone()).or_insert(0) += 1;
}
let total_eval_count = eval_types.len() as f64;
let mut total_similarity = 0.0;
for (eval_type, freq) in counts {
if let Some(eval_emb) = embedding_fn(&eval_type) {
let max_sim = train_embeddings
.iter()
.map(|(_, train_emb)| cosine_similarity(&eval_emb, train_emb))
.fold(0.0, f64::max);
let weight = freq as f64 / total_eval_count;
total_similarity += max_sim * weight;
} else {
let max_sim = train_types
.iter()
.map(|train_type| string_similarity(&eval_type, train_type))
.fold(0.0, f64::max);
let weight = freq as f64 / total_eval_count;
total_similarity += max_sim * weight;
}
}
Some(total_similarity)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
dot_product as f64
}
fn string_similarity(a: &str, b: &str) -> f64 {
let a_lower = a.to_lowercase();
let b_lower = b.to_lowercase();
if a_lower == b_lower {
return 1.0;
}
if a_lower.contains(&b_lower) || b_lower.contains(&a_lower) {
return 0.8;
}
let max_len = a_lower.len().max(b_lower.len());
if max_len == 0 {
return 1.0;
}
let distance = levenshtein_distance(&a_lower, &b_lower);
1.0 - (distance as f64 / max_len as f64)
}
fn levenshtein_distance(a: &str, b: &str) -> usize {
anno::edit_distance::levenshtein(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_familiarity_computation() {
let train_types = vec![
"person".to_string(),
"organization".to_string(),
"location".to_string(),
];
let eval_types = vec![
"PERSON".to_string(), "ORG".to_string(), "DISEASE".to_string(), ];
let shift = LabelShift::from_type_sets(&train_types, &eval_types);
assert!(shift.familiarity > 0.0, "Should have non-zero familiarity");
assert!(
!shift.true_zero_shot_types.is_empty(),
"Should have at least 1 true zero-shot type"
);
assert!(shift.true_zero_shot_types.contains(&"DISEASE".to_string()));
}
#[test]
fn test_familiarity_inflation_detection() {
let train_types = vec![
"person".to_string(),
"organization".to_string(),
"location".to_string(),
];
let eval_types = vec![
"PERSON".to_string(),
"ORGANIZATION".to_string(),
"LOCATION".to_string(),
];
let shift = LabelShift::from_type_sets(&train_types, &eval_types);
assert!(shift.familiarity > 0.5, "Should have high familiarity");
}
#[test]
fn test_label_shift_zero_shot_types() {
let train_types = vec!["person".to_string()];
let eval_types = vec![
"person".to_string(),
"disease".to_string(),
"drug".to_string(),
];
let shift = LabelShift::from_type_sets(&train_types, &eval_types);
assert_eq!(shift.true_zero_shot_types.len(), 2);
assert!(shift.true_zero_shot_types.contains(&"disease".to_string()));
assert!(shift.true_zero_shot_types.contains(&"drug".to_string()));
}
#[test]
fn test_metric_value_clamping() {
assert_eq!(MetricValue::new(0.5).get(), 0.5);
assert_eq!(MetricValue::new(-0.5).get(), 0.0);
assert_eq!(MetricValue::new(1.5).get(), 1.0);
}
#[test]
fn test_metric_value_try_new() {
assert!(MetricValue::try_new(0.5).is_ok());
assert!(MetricValue::try_new(-0.1).is_err());
assert!(MetricValue::try_new(1.1).is_err());
}
#[test]
fn test_goal_check_result() {
let mut result = GoalCheckResult::new();
assert!(result.passed);
result.add_check("precision", GoalCheck::pass(0.8, 0.85));
assert!(result.passed);
result.add_check("recall", GoalCheck::fail(0.9, 0.75));
assert!(!result.passed);
assert_eq!(result.passed_count(), 1);
assert_eq!(result.failed_count(), 1);
}
#[test]
fn test_metric_with_variance_from_samples() {
let samples = vec![0.85, 0.87, 0.82, 0.88, 0.84];
let m = MetricWithVariance::from_samples(&samples);
assert!((m.mean - 0.852).abs() < 0.001);
assert_eq!(m.n, 5);
assert!((m.min - 0.82).abs() < 0.001);
assert!((m.max - 0.88).abs() < 0.001);
assert!(m.std_dev > 0.0);
assert!(m.ci_95 > 0.0);
}
#[test]
fn test_metric_with_variance_empty() {
let m = MetricWithVariance::from_samples(&[]);
assert_eq!(m.n, 0);
assert_eq!(m.mean, 0.0);
assert_eq!(m.format_with_ci(), "N/A");
}
#[test]
fn test_metric_with_variance_single() {
let m = MetricWithVariance::from_samples(&[0.9]);
assert!((m.mean - 0.9).abs() < 0.001);
assert_eq!(m.std_dev, 0.0);
assert_eq!(m.ci_95, 0.0);
assert_eq!(m.n, 1);
}
#[test]
fn test_metric_with_variance_format() {
let samples = vec![0.85, 0.87, 0.82, 0.88, 0.84];
let m = MetricWithVariance::from_samples(&samples);
let formatted = m.format_with_ci();
assert!(formatted.contains("%"));
assert!(formatted.contains("±"));
let range = m.format_with_range();
assert!(range.contains("82.0%"));
assert!(range.contains("88.0%"));
}
#[test]
fn test_document_scale_classification() {
assert_eq!(DocumentScale::from_tokens(500), DocumentScale::Short);
assert_eq!(DocumentScale::from_tokens(2000), DocumentScale::Short);
assert_eq!(DocumentScale::from_tokens(2001), DocumentScale::Medium);
assert_eq!(DocumentScale::from_tokens(5000), DocumentScale::Medium);
assert_eq!(DocumentScale::from_tokens(10000), DocumentScale::Medium);
assert_eq!(DocumentScale::from_tokens(10001), DocumentScale::Long);
assert_eq!(DocumentScale::from_tokens(30000), DocumentScale::Long);
assert_eq!(DocumentScale::from_tokens(50000), DocumentScale::Long);
assert_eq!(DocumentScale::from_tokens(50001), DocumentScale::BookScale);
assert_eq!(DocumentScale::from_tokens(100000), DocumentScale::BookScale);
}
#[test]
fn test_document_scale_is_book_scale() {
assert!(!DocumentScale::Short.is_book_scale());
assert!(!DocumentScale::Medium.is_book_scale());
assert!(!DocumentScale::Long.is_book_scale());
assert!(DocumentScale::BookScale.is_book_scale());
}
#[test]
fn test_document_scale_metrics_reliability() {
assert!(!DocumentScale::Short.metrics_may_be_unreliable());
assert!(!DocumentScale::Medium.metrics_may_be_unreliable());
assert!(DocumentScale::Long.metrics_may_be_unreliable());
assert!(DocumentScale::BookScale.metrics_may_be_unreliable());
}
#[test]
fn test_document_scale_expected_degradation() {
assert!((DocumentScale::Short.expected_degradation() - 0.0).abs() < 0.001);
assert!((DocumentScale::Medium.expected_degradation() - 0.05).abs() < 0.001);
assert!((DocumentScale::Long.expected_degradation() - 0.10).abs() < 0.001);
assert!((DocumentScale::BookScale.expected_degradation() - 0.15).abs() < 0.001);
}
#[test]
fn test_document_scale_display() {
assert!(DocumentScale::Short.to_string().contains("Short"));
assert!(DocumentScale::BookScale.to_string().contains("Book-scale"));
}
#[test]
fn test_metric_divergence_computation() {
let divergence = MetricDivergence::from_scores(0.90, 0.65, 0.45);
assert!((divergence.muc_f1 - 0.90).abs() < 0.001);
assert!((divergence.b3_f1 - 0.65).abs() < 0.001);
assert!((divergence.ceaf_e_f1 - 0.45).abs() < 0.001);
assert!((divergence.muc_ceaf_divergence - 0.45).abs() < 0.001);
}
#[test]
fn test_metric_divergence_high_divergence_detection() {
let high = MetricDivergence::from_scores(0.90, 0.70, 0.50);
assert!(high.has_high_divergence());
let low = MetricDivergence::from_scores(0.80, 0.75, 0.70);
assert!(!low.has_high_divergence());
}
#[test]
fn test_metric_divergence_muc_inflation() {
let inflated = MetricDivergence::from_scores(0.90, 0.70, 0.50);
assert!(inflated.muc_likely_inflated());
let not_inflated = MetricDivergence::from_scores(0.80, 0.75, 0.70);
assert!(!not_inflated.muc_likely_inflated());
}
#[test]
fn test_metric_divergence_ceaf_collapse() {
let collapsed = MetricDivergence::from_scores(0.90, 0.70, 0.40);
assert!(collapsed.ceaf_likely_collapsed());
let not_collapsed = MetricDivergence::from_scores(0.80, 0.75, 0.70);
assert!(!not_collapsed.ceaf_likely_collapsed());
}
#[test]
fn test_metric_divergence_recommendation() {
let both_bad = MetricDivergence::from_scores(0.90, 0.65, 0.40);
assert!(both_bad.most_reliable_metric().contains("B³"));
let agree = MetricDivergence::from_scores(0.75, 0.73, 0.71);
assert!(agree.most_reliable_metric().contains("CoNLL"));
}
#[test]
fn test_coref_doc_stats_default() {
let stats = CorefDocStats::default();
assert_eq!(stats.chain_count, 0);
assert_eq!(stats.mention_count, 0);
assert_eq!(stats.avg_entity_spread, 0);
assert_eq!(stats.max_entity_spread, 0);
}
#[test]
fn test_coref_doc_stats_scale_classification() {
let mut stats = CorefDocStats {
doc_length: 1000,
..Default::default()
};
assert_eq!(stats.scale_classification(), DocumentScale::Short);
stats.doc_length = 5000;
assert_eq!(stats.scale_classification(), DocumentScale::Medium);
stats.doc_length = 30000;
assert_eq!(stats.scale_classification(), DocumentScale::Long);
stats.doc_length = 100000;
assert_eq!(stats.scale_classification(), DocumentScale::BookScale);
}
#[test]
fn test_coref_doc_stats_book_scale_spread() {
let mut stats = CorefDocStats {
avg_entity_spread: 1000,
max_entity_spread: 5000,
..Default::default()
};
assert!(!stats.has_book_scale_spread());
stats.avg_entity_spread = 6000;
stats.max_entity_spread = 10000;
assert!(stats.has_book_scale_spread());
stats.avg_entity_spread = 2000;
stats.max_entity_spread = 25000;
assert!(stats.has_book_scale_spread());
}
#[test]
fn test_coref_doc_stats_format_summary() {
let stats = CorefDocStats {
chain_count: 159,
mention_count: 13178,
avg_chain_length: 82.9,
avg_entity_spread: 17529,
max_entity_spread: 115369,
..Default::default()
};
let summary = stats.format_summary();
assert!(summary.contains("159"));
assert!(summary.contains("13178"));
assert!(summary.contains("17529"));
assert!(summary.contains("115369"));
}
#[test]
fn test_coref_doc_stats_from_chains() {
use crate::eval::coref::{CorefChain, Mention};
let chains = vec![
CorefChain::new(vec![
Mention::new("John", 0, 4),
Mention::new("he", 20, 22),
Mention::new("him", 50, 53),
]),
CorefChain::new(vec![
Mention::new("Mary", 5, 9),
Mention::new("she", 30, 33),
]),
CorefChain::new(vec![Mention::new("London", 60, 66)]),
];
let stats = CorefDocStats::from_chains(&chains);
assert_eq!(stats.chain_count, 3);
assert_eq!(stats.mention_count, 6);
assert!((stats.avg_chain_length - 2.0).abs() < 0.01);
assert_eq!(stats.max_chain_length, 3);
assert!(stats.avg_entity_spread > 0);
assert!(stats.max_entity_spread >= 53);
assert!((stats.singleton_ratio - 0.333).abs() < 0.01);
}
#[test]
fn test_coref_doc_stats_mention_type_ratios() {
use crate::eval::coref::{CorefChain, Mention};
let chains = vec![
CorefChain::new(vec![
Mention::new("John", 0, 4), Mention::new("he", 10, 12), Mention::new("him", 20, 23), ]),
CorefChain::new(vec![
Mention::new("Mary", 30, 34), Mention::new("she", 40, 43), ]),
];
let stats = CorefDocStats::from_chains(&chains);
assert!(stats.pronoun_ratio > 0.5, "Should have majority pronouns");
assert!(stats.proper_ratio > 0.3, "Should have some proper nouns");
assert_eq!(stats.mention_count, 5);
}
}