use crate::eval::synthetic::AnnotatedExample;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetStats {
pub num_examples: usize,
pub num_entities: usize,
pub type_distribution: HashMap<String, f64>,
pub avg_entities_per_example: f64,
pub vocab_size: usize,
pub entity_length_stats: LengthStats,
pub unique_entity_texts: usize,
pub entity_diversity: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LengthStats {
pub mean: f64,
pub median: f64,
pub std_dev: f64,
pub min: usize,
pub max: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetComparison {
pub stats_a: DatasetStats,
pub stats_b: DatasetStats,
pub type_divergence: f64,
pub vocab_overlap: f64,
pub entity_text_overlap: f64,
pub types_only_in_a: Vec<String>,
pub types_only_in_b: Vec<String>,
pub estimated_domain_gap: f64,
pub recommendations: Vec<String>,
}
pub fn compute_stats(examples: &[AnnotatedExample]) -> DatasetStats {
if examples.is_empty() {
return DatasetStats {
num_examples: 0,
num_entities: 0,
type_distribution: HashMap::new(),
avg_entities_per_example: 0.0,
vocab_size: 0,
entity_length_stats: LengthStats {
mean: 0.0,
median: 0.0,
std_dev: 0.0,
min: 0,
max: 0,
},
unique_entity_texts: 0,
entity_diversity: 1.0,
};
}
let mut type_counts: HashMap<String, usize> = HashMap::new();
let mut vocab: HashSet<String> = HashSet::new();
let mut entity_texts: HashSet<String> = HashSet::new();
let mut entity_lengths: Vec<usize> = Vec::new();
let mut total_entities = 0;
for example in examples {
for token in example.text.split_whitespace() {
vocab.insert(token.to_lowercase());
}
for entity in &example.entities {
total_entities += 1;
*type_counts
.entry(entity.entity_type.to_string())
.or_insert(0) += 1;
entity_texts.insert(entity.text.to_lowercase());
let token_count = entity.text.split_whitespace().count().max(1);
entity_lengths.push(token_count);
}
}
let type_distribution: HashMap<String, f64> = type_counts
.iter()
.map(|(t, c)| (t.clone(), *c as f64 / total_entities.max(1) as f64))
.collect();
let entity_length_stats = if entity_lengths.is_empty() {
LengthStats {
mean: 0.0,
median: 0.0,
std_dev: 0.0,
min: 0,
max: 0,
}
} else {
let mut sorted = entity_lengths.clone();
sorted.sort_unstable();
let mean = entity_lengths.iter().sum::<usize>() as f64 / entity_lengths.len() as f64;
let median = sorted[sorted.len() / 2] as f64;
let variance = entity_lengths
.iter()
.map(|&l| (l as f64 - mean).powi(2))
.sum::<f64>()
/ entity_lengths.len() as f64;
let std_dev = variance.sqrt();
LengthStats {
mean,
median,
std_dev,
min: *sorted.first().unwrap_or(&0),
max: *sorted.last().unwrap_or(&0),
}
};
DatasetStats {
num_examples: examples.len(),
num_entities: total_entities,
type_distribution,
avg_entities_per_example: total_entities as f64 / examples.len() as f64,
vocab_size: vocab.len(),
entity_length_stats,
unique_entity_texts: entity_texts.len(),
entity_diversity: entity_texts.len() as f64 / total_entities.max(1) as f64,
}
}
pub fn compare_datasets(a: &[AnnotatedExample], b: &[AnnotatedExample]) -> DatasetComparison {
let stats_a = compute_stats(a);
let stats_b = compute_stats(b);
let vocab_a: HashSet<String> = a
.iter()
.flat_map(|e| e.text.split_whitespace().map(|t| t.to_lowercase()))
.collect();
let vocab_b: HashSet<String> = b
.iter()
.flat_map(|e| e.text.split_whitespace().map(|t| t.to_lowercase()))
.collect();
let entities_a: HashSet<String> = a
.iter()
.flat_map(|e| e.entities.iter().map(|ent| ent.text.to_lowercase()))
.collect();
let entities_b: HashSet<String> = b
.iter()
.flat_map(|e| e.entities.iter().map(|ent| ent.text.to_lowercase()))
.collect();
let vocab_intersection = vocab_a.intersection(&vocab_b).count();
let vocab_union = vocab_a.union(&vocab_b).count();
let vocab_overlap = if vocab_union == 0 {
1.0
} else {
vocab_intersection as f64 / vocab_union as f64
};
let entity_intersection = entities_a.intersection(&entities_b).count();
let entity_union = entities_a.union(&entities_b).count();
let entity_text_overlap = if entity_union == 0 {
1.0
} else {
entity_intersection as f64 / entity_union as f64
};
let type_divergence =
jensen_shannon_divergence(&stats_a.type_distribution, &stats_b.type_distribution);
let types_a: HashSet<&String> = stats_a.type_distribution.keys().collect();
let types_b: HashSet<&String> = stats_b.type_distribution.keys().collect();
let types_only_in_a: Vec<String> = types_a.difference(&types_b).map(|s| (*s).clone()).collect();
let types_only_in_b: Vec<String> = types_b.difference(&types_a).map(|s| (*s).clone()).collect();
let estimated_domain_gap =
0.4 * type_divergence + 0.3 * (1.0 - vocab_overlap) + 0.3 * (1.0 - entity_text_overlap);
let recommendations = generate_recommendations(
type_divergence,
vocab_overlap,
entity_text_overlap,
&types_only_in_a,
&types_only_in_b,
);
DatasetComparison {
stats_a,
stats_b,
type_divergence,
vocab_overlap,
entity_text_overlap,
types_only_in_a,
types_only_in_b,
estimated_domain_gap,
recommendations,
}
}
fn jensen_shannon_divergence(p: &HashMap<String, f64>, q: &HashMap<String, f64>) -> f64 {
let all_keys: HashSet<&String> = p.keys().chain(q.keys()).collect();
if all_keys.is_empty() {
return 0.0;
}
let mut m: HashMap<&String, f64> = HashMap::new();
for k in &all_keys {
let p_val = p.get(*k).copied().unwrap_or(0.0);
let q_val = q.get(*k).copied().unwrap_or(0.0);
m.insert(*k, (p_val + q_val) / 2.0);
}
let kl_p_m: f64 = all_keys
.iter()
.map(|k| {
let p_val = p.get(*k).copied().unwrap_or(0.0);
let m_val = m.get(k).copied().unwrap_or(1e-10);
if p_val > 0.0 {
p_val * (p_val / m_val).ln()
} else {
0.0
}
})
.sum();
let kl_q_m: f64 = all_keys
.iter()
.map(|k| {
let q_val = q.get(*k).copied().unwrap_or(0.0);
let m_val = m.get(k).copied().unwrap_or(1e-10);
if q_val > 0.0 {
q_val * (q_val / m_val).ln()
} else {
0.0
}
})
.sum();
((kl_p_m + kl_q_m) / 2.0) / 2.0_f64.ln()
}
fn generate_recommendations(
type_div: f64,
vocab_overlap: f64,
entity_overlap: f64,
types_only_a: &[String],
types_only_b: &[String],
) -> Vec<String> {
let mut recs = Vec::new();
if type_div > 0.5 {
recs.push("High type distribution divergence - consider domain adaptation".into());
} else if type_div > 0.2 {
recs.push("Moderate type divergence - transfer learning may require fine-tuning".into());
}
if vocab_overlap < 0.3 {
recs.push("Low vocabulary overlap - domains use different terminology".into());
}
if entity_overlap < 0.1 {
recs.push("Very few shared entities - gazetteer transfer unlikely to help".into());
}
if !types_only_a.is_empty() {
recs.push(format!(
"Types in source only: {:?} - target may not need these",
types_only_a
));
}
if !types_only_b.is_empty() {
recs.push(format!(
"Types in target only: {:?} - source cannot help with these",
types_only_b
));
}
if recs.is_empty() {
recs.push("Datasets appear compatible for transfer learning".into());
}
recs
}
pub fn estimate_difficulty(stats: &DatasetStats) -> DifficultyEstimate {
let mut factors = Vec::new();
let mut score: f64 = 0.0;
let num_types = stats.type_distribution.len();
if num_types > 10 {
factors.push("Many entity types (>10)".into());
score += 0.2;
} else if num_types > 5 {
factors.push("Moderate entity types (5-10)".into());
score += 0.1;
}
if stats.entity_length_stats.mean > 3.0 {
factors.push("Long average entity length (>3 tokens)".into());
score += 0.2;
}
if stats.entity_length_stats.std_dev > 2.0 {
factors.push("High entity length variance".into());
score += 0.1;
}
if stats.entity_diversity > 0.9 {
factors.push("High entity diversity (few repeated entities)".into());
score += 0.2;
} else if stats.entity_diversity < 0.3 {
factors.push("Low entity diversity (model can memorize)".into());
score -= 0.1;
}
if stats.avg_entities_per_example < 1.0 {
factors.push("Few entities per example (<1 avg)".into());
score += 0.1;
}
let difficulty = match score {
s if s < 0.2 => EstimatedDifficulty::Easy,
s if s < 0.4 => EstimatedDifficulty::Medium,
s if s < 0.6 => EstimatedDifficulty::Hard,
_ => EstimatedDifficulty::VeryHard,
};
DifficultyEstimate {
difficulty,
score: score.clamp(0.0, 1.0),
factors,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EstimatedDifficulty {
Easy,
Medium,
Hard,
VeryHard,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DifficultyEstimate {
pub difficulty: EstimatedDifficulty,
pub score: f64,
pub factors: Vec<String>,
}
#[cfg(feature = "discourse")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscourseStats {
pub abstract_anaphor_count: usize,
pub event_trigger_count: usize,
pub shell_noun_count: usize,
pub avg_sentence_length: f64,
pub multi_sentence_examples: usize,
pub discourse_complexity: f64,
}
#[cfg(feature = "discourse")]
pub fn compute_discourse_stats(examples: &[AnnotatedExample]) -> DiscourseStats {
use crate::discourse::{classify_shell_noun, DiscourseScope, EventExtractor};
if examples.is_empty() {
return DiscourseStats {
abstract_anaphor_count: 0,
event_trigger_count: 0,
shell_noun_count: 0,
avg_sentence_length: 0.0,
multi_sentence_examples: 0,
discourse_complexity: 0.0,
};
}
let extractor = EventExtractor::default();
let mut abstract_anaphor_count = 0;
let mut event_trigger_count = 0;
let mut shell_noun_count = 0;
let mut total_sentences = 0;
let mut multi_sentence_examples = 0;
let anaphor_patterns = [
"this ", "that ", "these ", "those ", "this.", "that.", "this,", "that,", " it ", " it.",
" it,",
];
for example in examples {
let text_lower = example.text.to_lowercase();
for pattern in &anaphor_patterns {
abstract_anaphor_count += text_lower.matches(pattern).count();
}
let events = extractor.extract(&example.text);
event_trigger_count += events.len();
for word in example.text.split_whitespace() {
let word_clean = word.trim_matches(|c: char| !c.is_alphabetic());
if classify_shell_noun(word_clean).is_some() {
shell_noun_count += 1;
}
}
let scope = DiscourseScope::analyze(&example.text);
let num_sentences = scope.sentence_count().max(1);
total_sentences += num_sentences;
if num_sentences > 1 {
multi_sentence_examples += 1;
}
}
let avg_sentence_length = examples
.iter()
.map(|e| e.text.split_whitespace().count())
.sum::<usize>() as f64
/ total_sentences.max(1) as f64;
let complexity = ((abstract_anaphor_count as f64 / examples.len() as f64).min(1.0) * 0.3
+ (event_trigger_count as f64 / examples.len() as f64).min(1.0) * 0.3
+ (shell_noun_count as f64 / examples.len() as f64 / 2.0).min(1.0) * 0.2
+ (multi_sentence_examples as f64 / examples.len() as f64) * 0.2)
.clamp(0.0, 1.0);
DiscourseStats {
abstract_anaphor_count,
event_trigger_count,
shell_noun_count,
avg_sentence_length,
multi_sentence_examples,
discourse_complexity: complexity,
}
}
#[cfg(feature = "discourse")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtendedDatasetComparison {
pub basic: DatasetComparison,
pub discourse_a: DiscourseStats,
pub discourse_b: DiscourseStats,
pub discourse_gap: f64,
pub discourse_recommendations: Vec<String>,
}
#[cfg(feature = "discourse")]
pub fn compare_datasets_extended(
a: &[AnnotatedExample],
b: &[AnnotatedExample],
) -> ExtendedDatasetComparison {
let basic = compare_datasets(a, b);
let discourse_a = compute_discourse_stats(a);
let discourse_b = compute_discourse_stats(b);
let discourse_gap = (discourse_a.discourse_complexity - discourse_b.discourse_complexity).abs();
let mut discourse_recommendations = Vec::new();
if discourse_gap > 0.3 {
discourse_recommendations.push(
"Significant discourse complexity difference - models may struggle with transfer"
.into(),
);
}
if discourse_a.event_trigger_count > 0 && discourse_b.event_trigger_count == 0 {
discourse_recommendations.push(
"Source has event triggers but target doesn't - event extraction may not transfer"
.into(),
);
}
if discourse_a.abstract_anaphor_count > discourse_b.abstract_anaphor_count * 2 {
discourse_recommendations
.push("Source has more abstract anaphora - coreference may not generalize".into());
}
if discourse_a.multi_sentence_examples > 0 && discourse_b.multi_sentence_examples == 0 {
discourse_recommendations
.push("Target is single-sentence only - cross-sentence phenomena won't appear".into());
}
if discourse_recommendations.is_empty() {
discourse_recommendations
.push("Discourse characteristics are similar between datasets".into());
}
ExtendedDatasetComparison {
basic,
discourse_a,
discourse_b,
discourse_gap,
discourse_recommendations,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_example(text: &str, entities: Vec<(&str, &str)>) -> AnnotatedExample {
use crate::eval::datasets::GoldEntity;
use crate::eval::synthetic::{Difficulty, Domain};
use anno::{EntityCategory, EntityType};
let mut gold_entities = Vec::new();
for (entity_text, entity_type_str) in entities {
if let Some(start) = text.find(entity_text) {
let entity_type = match entity_type_str {
"PER" => EntityType::Person,
"ORG" => EntityType::Organization,
"LOC" => EntityType::Location,
_ => EntityType::custom(entity_type_str, EntityCategory::Misc),
};
gold_entities.push(GoldEntity::new(entity_text, entity_type, start));
}
}
AnnotatedExample {
text: text.to_string(),
entities: gold_entities,
domain: Domain::News,
difficulty: Difficulty::Easy,
}
}
#[test]
fn test_compute_stats_empty() {
let stats = compute_stats(&[]);
assert_eq!(stats.num_examples, 0);
assert_eq!(stats.num_entities, 0);
}
#[test]
fn test_compute_stats_basic() {
let examples = vec![
make_example(
"John works at Google.",
vec![("John", "PER"), ("Google", "ORG")],
),
make_example(
"Paris is in France.",
vec![("Paris", "LOC"), ("France", "LOC")],
),
];
let stats = compute_stats(&examples);
assert_eq!(stats.num_examples, 2);
assert_eq!(stats.num_entities, 4);
assert_eq!(stats.avg_entities_per_example, 2.0);
assert!(stats.type_distribution.contains_key("PER"));
assert!(stats.type_distribution.contains_key("ORG"));
assert!(stats.type_distribution.contains_key("LOC"));
}
#[test]
fn test_compare_identical_datasets() {
let examples = vec![make_example(
"John works at Google.",
vec![("John", "PER"), ("Google", "ORG")],
)];
let comparison = compare_datasets(&examples, &examples);
assert!(comparison.type_divergence < 0.01);
assert!((comparison.vocab_overlap - 1.0).abs() < 0.01);
assert!((comparison.entity_text_overlap - 1.0).abs() < 0.01);
}
#[test]
fn test_compare_different_datasets() {
let a = vec![make_example("John works at Google.", vec![("John", "PER")])];
let b = vec![make_example("Paris is beautiful.", vec![("Paris", "LOC")])];
let comparison = compare_datasets(&a, &b);
assert!(comparison.type_divergence > 0.5);
assert!(comparison.vocab_overlap < 0.5);
assert!((comparison.entity_text_overlap - 0.0).abs() < 0.01);
}
#[test]
fn test_jensen_shannon_identical() {
let mut p = HashMap::new();
p.insert("A".into(), 0.5);
p.insert("B".into(), 0.5);
let js = jensen_shannon_divergence(&p, &p);
assert!(js < 0.01);
}
#[test]
fn test_jensen_shannon_disjoint() {
let mut p = HashMap::new();
p.insert("A".into(), 1.0);
let mut q = HashMap::new();
q.insert("B".into(), 1.0);
let js = jensen_shannon_divergence(&p, &q);
assert!(js > 0.9);
}
#[test]
fn test_difficulty_estimation() {
let easy_examples = vec![
make_example("John works here.", vec![("John", "PER")]),
make_example("John went home.", vec![("John", "PER")]),
];
let hard_examples = vec![make_example(
"International Business Machines Corporation announced.",
vec![("International Business Machines Corporation", "ORG")],
)];
let easy_stats = compute_stats(&easy_examples);
let hard_stats = compute_stats(&hard_examples);
let easy_diff = estimate_difficulty(&easy_stats);
let hard_diff = estimate_difficulty(&hard_stats);
assert!(hard_diff.score >= easy_diff.score);
}
#[test]
#[cfg(feature = "discourse")]
fn test_discourse_stats_empty() {
let stats = compute_discourse_stats(&[]);
assert_eq!(stats.abstract_anaphor_count, 0);
assert_eq!(stats.event_trigger_count, 0);
assert_eq!(stats.shell_noun_count, 0);
}
#[test]
#[cfg(feature = "discourse")]
fn test_discourse_stats_with_anaphors() {
let examples = vec![
make_example(
"Russia invaded Ukraine. This caused inflation.",
vec![("Russia", "LOC")],
),
make_example(
"The merger was announced. That surprised investors.",
vec![],
),
];
let stats = compute_discourse_stats(&examples);
assert!(
stats.abstract_anaphor_count >= 2,
"Should detect abstract anaphors"
);
assert!(
stats.event_trigger_count >= 2,
"Should detect event triggers"
);
assert_eq!(stats.multi_sentence_examples, 2);
}
#[test]
#[cfg(feature = "discourse")]
fn test_discourse_stats_with_shell_nouns() {
let examples = vec![
make_example("This problem is serious.", vec![]),
make_example("The fact is clear.", vec![]),
make_example("The situation is complex.", vec![]),
];
let stats = compute_discourse_stats(&examples);
assert!(stats.shell_noun_count >= 3, "Should detect shell nouns");
}
#[test]
#[cfg(feature = "discourse")]
fn test_extended_comparison() {
let simple = vec![make_example(
"John works at Google.",
vec![("John", "PER"), ("Google", "ORG")],
)];
let complex = vec![
make_example(
"Russia invaded Ukraine in 2022. This caused a global energy crisis. The situation remains tense.",
vec![("Russia", "LOC"), ("Ukraine", "LOC")]
),
];
let comparison = compare_datasets_extended(&simple, &complex);
assert!(
comparison.discourse_b.discourse_complexity
> comparison.discourse_a.discourse_complexity,
"Complex dataset should have higher discourse complexity"
);
assert!(comparison.discourse_gap > 0.0);
}
#[test]
#[cfg(feature = "discourse")]
fn test_discourse_complexity_bounds() {
let examples = vec![
make_example(
"This problem happened. That event occurred. This situation developed. The fact emerged.",
vec![]
),
];
let stats = compute_discourse_stats(&examples);
assert!(stats.discourse_complexity >= 0.0);
assert!(stats.discourse_complexity <= 1.0);
}
}