use anno::{EntityCategory, EntityType};
use anno::{Error, Model, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum EvalTask {
NER {
labels: Vec<String>,
mode: EvalMode,
},
RelationExtraction {
relations: Vec<String>,
require_entity_match: bool,
},
Coreference {
metrics: Vec<CorefMetric>,
},
DiscontinuousNER {
labels: Vec<String>,
},
EventExtraction {
event_types: Vec<String>,
argument_roles: Vec<String>,
},
}
impl Default for EvalTask {
fn default() -> Self {
EvalTask::NER {
labels: vec![
"PER".to_string(),
"ORG".to_string(),
"LOC".to_string(),
"MISC".to_string(),
],
mode: EvalMode::Strict,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CorefMetric {
MUC,
BCubed,
CEAFe,
CEAFm,
LEA,
BLANC,
CoNLL,
}
pub use bio_adapter::BioScheme;
pub use modes::EvalConfig as ModeConfig;
pub use modes::EvalMode;
pub mod cluster_encoder {
pub use anno::metrics::cluster_encoder::*;
}
pub mod coref_metrics {
pub use anno::metrics::coref_metrics::*;
}
#[cfg(feature = "discourse")]
pub mod abstract_anaphora;
pub mod advanced_evaluator;
pub mod advanced_harness;
pub mod analysis;
pub mod backend_factory;
pub mod benchmark;
pub mod bio_adapter;
pub mod book_scale;
pub mod cdcr;
pub mod coref;
pub mod coref_loader;
pub mod cross_context_eval;
pub mod coref_resolver {
pub use anno::metrics::coref_resolver::*;
}
pub mod dataset;
pub mod dataset_metadata;
pub mod dataset_registry;
pub mod dataset_spec;
pub mod datasets;
pub mod discontinuous;
#[cfg(feature = "discourse")]
pub mod discourse_deixis;
pub mod evaluator;
pub mod harness;
pub mod history;
pub mod incremental_coref;
pub mod inter_doc_coref;
pub mod loader;
pub mod metrics;
pub mod modes;
pub mod ner_metrics;
pub mod neural_cluster_encoder;
pub mod prelude;
pub mod relation;
pub mod report;
pub mod sampling;
pub mod shell_nouns;
pub mod synthetic;
pub mod synthetic_gen;
pub mod task_evaluator;
#[cfg(feature = "eval-profiling")]
pub mod profiling;
pub mod task_mapping;
pub mod types;
pub mod validation;
pub mod visual;
#[cfg(feature = "eval-bias")]
pub mod bias_config;
#[cfg(feature = "eval-bias")]
pub mod demographic_bias;
#[cfg(feature = "eval-bias")]
pub mod gender_bias;
#[cfg(feature = "eval-bias")]
pub mod length_bias;
#[cfg(feature = "eval-bias")]
pub mod temporal_bias;
#[cfg(feature = "eval")]
pub mod active_learning;
#[cfg(feature = "eval")]
pub mod calibration;
#[cfg(feature = "eval")]
pub mod dataset_comparison;
#[cfg(feature = "eval")]
pub mod dataset_quality;
#[cfg(feature = "eval")]
pub mod drift;
#[cfg(feature = "eval")]
pub mod ensemble;
#[cfg(feature = "eval")]
pub mod error_analysis;
#[cfg(feature = "eval")]
pub mod few_shot;
#[cfg(feature = "eval")]
pub mod learning_curve;
#[cfg(feature = "eval")]
pub mod long_tail;
#[cfg(feature = "eval")]
pub mod low_resource;
#[cfg(feature = "eval")]
pub mod ood_detection;
#[cfg(feature = "eval")]
pub mod robustness;
#[cfg(feature = "eval")]
pub mod threshold_analysis;
#[cfg(feature = "eval")]
pub mod annotator;
#[cfg(feature = "eval")]
pub mod bridging;
#[cfg(feature = "eval")]
pub mod multi_run;
#[cfg(feature = "eval")]
pub mod ranking;
#[cfg(all(feature = "eval", test))]
mod property_tests;
pub use datasets::GoldEntity;
pub use dataset_registry::{AnnotationScheme, DataFormat, DatasetId as RegistryDatasetId};
pub use datasets::DatasetMetadata;
pub use loader::{DatasetLoader, LoadableDatasetId, LoadedDataset};
#[cfg(test)]
mod registry_exports;
pub use dataset::{AnnotatedExample, DatasetStats, Difficulty, Domain, NERDataset};
pub use evaluator::*;
pub use harness::{
BackendAggregateResult, BackendDatasetResult, BackendRegistry, DatasetStatsSummary, EvalConfig,
EvalHarness, EvalResults,
};
pub use metrics::*;
pub use types::{
CorefChainStats, CorefDocStats, DocumentScale, GoalCheck, GoalCheckResult, LabelShift,
MetricDivergence, MetricValue, MetricWithVariance,
};
pub use validation::*;
pub use coref::{CorefChain, CorefDocument, Mention, MentionType};
pub use coref_loader::{
adversarial_coref_examples, synthetic_coref_dataset, CorefLoader, GapExample,
};
pub use coref_metrics::{
b_cubed_score, blanc_score, ceaf_e_score, ceaf_m_score, compare_systems, conll_f1, lea_score,
muc_score, AggregateCorefEvaluation, CorefEvaluation, CorefScores, SignificanceTest,
};
pub use book_scale::{
BookScaleAnalysis, BookScaleAnalyzer, BookScaleConfig, BookScaleDiagnostics, CorefEvalScores,
MetricReliability, MultiBookReport, PerBookEvaluation, ReliabilityLevel, Scores,
StratifiedEvaluation, WindowedEvaluation,
};
pub use coref_resolver::{CorefConfig, CoreferenceResolver, SimpleCorefResolver};
pub use inter_doc_coref::InterDocCorefMetrics;
pub use cdcr::{
comprehensive_cdcr_dataset,
financial_news_dataset,
political_news_dataset,
science_news_dataset,
sports_news_dataset,
tech_news_dataset,
CDCRConfig,
CDCRMetrics,
CDCRResolver,
CrossDocCluster,
Document,
LSHBlocker,
MentionRef,
};
pub use discontinuous::{
evaluate_discontinuous_ner, DiscontinuousEvalConfig, DiscontinuousGold,
DiscontinuousNERMetrics, TypeMetrics as DiscontinuousTypeMetrics,
};
pub use relation::{
evaluate_relations, RelationEvalConfig, RelationGold, RelationMetrics, RelationPrediction,
RelationTypeMetrics,
};
pub use advanced_evaluator::{
evaluator_for_task, DiscontinuousEvaluator, EvalResults as AdvancedEvalResults,
RelationEvaluator, TaskEvaluator,
};
pub use visual::{
evaluate_visual_ner, synthetic_visual_examples, BoundingBox, VisualEvalConfig, VisualGold,
VisualNERMetrics, VisualPrediction, VisualTypeMetrics,
};
pub use advanced_harness::{
evaluate_discontinuous_gold_vs_gold, evaluate_discontinuous_synthetic,
evaluate_relations_gold_vs_gold, evaluate_relations_synthetic, evaluate_visual_gold_vs_gold,
synthetic_dataset_stats, AdvancedTaskResults, ModelResult, SyntheticDatasetStats,
};
#[cfg(feature = "eval-bias")]
pub use gender_bias::{
create_comprehensive_bias_templates, create_neopronoun_templates, create_winobias_templates,
occupation_stereotype, GenderBiasEvaluator, GenderBiasResults, OccupationBiasMetrics,
PronounGender, StereotypeType, WinoBiasExample,
};
#[cfg(feature = "eval-bias")]
pub use bias_config::{
BiasDatasetConfig, DistributionValidation, FrequencyWeightedResults, StatisticalBiasResults,
};
#[cfg(feature = "eval-bias")]
pub use demographic_bias::{
create_diverse_location_dataset, create_diverse_name_dataset, DemographicBiasEvaluator,
DemographicBiasResults, Ethnicity, Gender, LocationExample, LocationType, NameExample,
NameFrequency, NameResult, Region, RegionalBiasResults, Script,
};
#[cfg(feature = "eval-bias")]
pub use temporal_bias::{
create_temporal_name_dataset, Decade, TemporalBiasEvaluator, TemporalBiasResults,
TemporalGender, TemporalNameExample,
};
#[cfg(feature = "eval-bias")]
pub use length_bias::{
create_length_varied_dataset, EntityLengthEvaluator, LengthBiasResults, LengthBucket,
LengthTestExample, WordCountBucket,
};
#[cfg(feature = "eval")]
pub use calibration::{
calibration_grade, confidence_entropy, confidence_gap_grade, confidence_variance,
CalibrationEvaluator, CalibrationResults, EntropyFilter, ReliabilityBin, ThresholdMetrics,
};
#[cfg(feature = "eval")]
pub use robustness::{
robustness_grade, Perturbation, PerturbationMetrics, RobustnessEvaluator, RobustnessResults,
};
#[cfg(feature = "eval")]
pub use ood_detection::{
ood_rate_grade, OODAnalysisResults, OODConfig, OODDetector, OODStatus, VocabCoverageStats,
};
#[cfg(feature = "eval")]
pub use dataset_quality::{
check_leakage, entity_imbalance_ratio, DatasetQualityAnalyzer, DifficultyMetrics,
QualityReport, ReliabilityMetrics, ValidityMetrics,
};
#[cfg(feature = "eval")]
pub use learning_curve::{
suggested_train_sizes, CurveFitParams, DataPoint, LearningCurveAnalysis, LearningCurveAnalyzer,
SampleEfficiencyMetrics,
};
#[cfg(feature = "eval")]
pub use ensemble::{
agreement_grade, kappa_interpretation, DisagreementDetail, EnsembleAnalysisResults,
EnsembleAnalyzer, ModelPrediction, SingleExampleAnalysis,
};
#[cfg(feature = "eval")]
pub use dataset_comparison::{
compare_datasets, compute_stats, estimate_difficulty, DatasetComparison,
DatasetStats as ComparisonStats, DifficultyEstimate, LengthStats,
};
#[cfg(feature = "eval")]
pub use drift::{
ConfidenceDrift, DistributionDrift, DriftConfig, DriftDetector, DriftReport, DriftWindow,
VocabularyDrift,
};
#[cfg(feature = "eval")]
pub use active_learning::{
entities_to_candidates, estimate_budget, export_annotation_priority, rank_for_annotation,
select_for_annotation, ActiveLearner, Candidate, SamplingStrategy, ScoreStats, SelectionResult,
};
#[cfg(feature = "eval")]
pub use error_analysis::{
EntityInfo, ErrorAnalyzer, ErrorCategory, ErrorInstance, ErrorPattern, ErrorReport,
PredictedEntity, TypeErrorStats,
};
#[cfg(feature = "eval")]
pub use threshold_analysis::{
format_threshold_table, interpret_curve, PredictionWithConfidence, ThresholdAnalyzer,
ThresholdCurve, ThresholdPoint,
};
pub use report::{
BiasSummary, CalibrationSummary, CoreMetrics, DataQualitySummary, DemographicBiasMetrics,
ErrorSummary, EvalReport, GenderBiasMetrics, LengthBiasMetrics, Priority, Recommendation,
RecommendationCategory, ReportBuilder, SimpleGoldEntity, TestCase,
TypeMetrics as ReportTypeMetrics,
};
pub mod unified_evaluator;
pub use unified_evaluator::{EvalMetadata, EvalSystem, UnifiedEvalResults};
#[cfg(feature = "eval-bias")]
pub use unified_evaluator::BiasEvalResults;
#[cfg(feature = "eval")]
pub use unified_evaluator::CalibrationEvalResults;
#[cfg(feature = "eval")]
pub use unified_evaluator::DataQualityEvalResults;
#[cfg(feature = "eval")]
pub use unified_evaluator::StandardEvalResults;
pub mod backend_name;
pub use backend_name::BackendName;
pub mod config_builder;
#[cfg(feature = "eval-bias")]
pub use config_builder::BiasDatasetConfigBuilder;
#[cfg(feature = "eval")]
pub use config_builder::TaskEvalConfigBuilder;
#[cfg(feature = "eval")]
pub use few_shot::{
simulate_few_shot_task, FewShotEvaluator, FewShotGold, FewShotPrediction, FewShotResults,
FewShotTask, FewShotTaskResults, SupportExample,
};
#[cfg(feature = "eval")]
pub use long_tail::{
format_long_tail_results, EntityFrequency, FrequencyBucket, FrequencySplit, LongTailAnalyzer,
LongTailResults, TypePerformance,
};
pub use analysis::{
build_confusion_matrix, compare_ner_systems, ConfusionMatrix, ErrorAnalysis, ErrorType,
NERError, NERSignificanceTest,
};
pub use sampling::{multi_seed_eval, stratified_sample, stratified_sample_ner};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TypeMetrics {
pub precision: f64,
pub recall: f64,
pub f1: f64,
pub found: usize,
pub expected: usize,
pub correct: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NEREvaluationResults {
pub precision: f64,
pub recall: f64,
pub f1: f64,
#[serde(default)]
pub macro_f1: Option<f64>,
#[serde(default)]
pub weighted_f1: Option<f64>,
pub per_type: HashMap<String, TypeMetrics>,
pub tokens_per_second: f64,
pub found: usize,
pub expected: usize,
#[serde(default)]
pub metadata: Option<EvaluationMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct EvaluationMetadata {
pub dataset_name: Option<String>,
pub dataset_format: Option<String>,
pub dataset_version: Option<String>,
pub num_test_cases: usize,
pub total_gold_entities: Option<usize>,
pub timestamp: Option<String>,
pub model_info: Option<String>,
pub model_version: Option<String>,
pub matching_mode: Option<String>,
pub anno_version: Option<String>,
}
pub fn entity_type_to_string(et: &EntityType) -> String {
et.as_label().to_string()
}
pub fn entity_type_matches(a: &EntityType, b: &EntityType) -> bool {
if a == b {
return true;
}
let a_label = a.as_label().to_uppercase();
let b_label = b.as_label().to_uppercase();
if a_label == b_label {
return true;
}
matches!(
(a_label.as_str(), b_label.as_str()),
("PERSON", "PER") | ("PER", "PERSON")
| ("ORGANIZATION", "ORG") | ("ORG", "ORGANIZATION")
| ("ORGANIZATION", "CORPORATION") | ("CORPORATION", "ORGANIZATION")
| ("ORG", "CORPORATION") | ("CORPORATION", "ORG")
| ("ORGANIZATION", "COMPANY") | ("COMPANY", "ORGANIZATION")
| ("LOCATION", "LOC") | ("LOC", "LOCATION")
| ("LOCATION", "GPE") | ("GPE", "LOCATION")
| ("LOC", "GPE") | ("GPE", "LOC")
| ("MISC", "MISCELLANEOUS") | ("MISCELLANEOUS", "MISC")
| ("MISC", "OTHER") | ("OTHER", "MISC")
)
}
pub fn load_conll2003<P: AsRef<Path>>(path: P) -> Result<Vec<(String, Vec<GoldEntity>)>> {
let content = std::fs::read_to_string(path.as_ref()).map_err(Error::Io)?;
let mut test_cases: Vec<(String, Vec<GoldEntity>)> = Vec::new();
let mut current_text = String::new();
let mut current_entities: Vec<GoldEntity> = Vec::new();
let mut char_offset = 0;
for line in content.lines() {
if line.trim().is_empty() {
if !current_text.is_empty() {
test_cases.push((current_text.clone(), current_entities.clone()));
}
current_text.clear();
current_entities.clear();
char_offset = 0;
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 4 {
continue; }
let word = parts[0];
let ner_tag = parts[3];
if !current_text.is_empty() {
current_text.push(' ');
char_offset += 1;
}
let word_start = char_offset;
current_text.push_str(word);
char_offset += word.chars().count();
let word_end = char_offset;
if ner_tag != "O" {
let (prefix, entity_type_str) = if let Some(dash_pos) = ner_tag.find('-') {
(&ner_tag[..dash_pos], &ner_tag[dash_pos + 1..])
} else {
continue;
};
let entity_type = match entity_type_str {
"PER" => EntityType::Person,
"ORG" => EntityType::Organization,
"LOC" => EntityType::Location,
"MISC" => EntityType::custom("misc", EntityCategory::Misc),
"DATE" => EntityType::Date,
"MONEY" => EntityType::Money,
"PERCENT" => EntityType::Percent,
_ => continue,
};
if prefix == "B" {
current_entities.push(GoldEntity::with_span(
word,
entity_type,
word_start,
word_end,
));
} else if prefix == "I" {
if let Some(last) = current_entities.last_mut() {
if entity_type_matches(&last.entity_type, &entity_type) {
last.text.push(' ');
last.text.push_str(word);
last.end = word_end;
} else {
current_entities.push(GoldEntity::with_span(
word,
entity_type,
word_start,
word_end,
));
}
}
}
}
}
if !current_text.is_empty() {
test_cases.push((current_text, current_entities));
}
for (text, entities) in &test_cases {
let validation_result = validation::validate_ground_truth_entities(text, entities, false);
if !validation_result.is_valid {
return Err(Error::InvalidInput(format!(
"Invalid entities in CoNLL dataset: {}",
validation_result.errors.join("; ")
)));
}
}
Ok(test_cases)
}
pub fn evaluate_ner_model(
model: &dyn Model,
test_cases: &[(String, Vec<GoldEntity>)],
) -> Result<NEREvaluationResults> {
evaluate_ner_model_with_mapper(model, test_cases, None)
}
pub fn evaluate_ner_model_with_mapper(
model: &dyn Model,
test_cases: &[(String, Vec<GoldEntity>)],
type_mapper: Option<&anno::TypeMapper>,
) -> Result<NEREvaluationResults> {
let evaluator = evaluator::StandardNEREvaluator::new();
if test_cases.is_empty() {
return Ok(NEREvaluationResults {
precision: 0.0,
recall: 0.0,
f1: 0.0,
macro_f1: None,
weighted_f1: None,
per_type: HashMap::new(),
tokens_per_second: 0.0,
found: 0,
expected: 0,
metadata: Some(EvaluationMetadata {
num_test_cases: 0,
total_gold_entities: Some(0),
timestamp: Some(chrono::Utc::now().to_rfc3339()),
anno_version: Some(env!("CARGO_PKG_VERSION").to_string()),
..Default::default()
}),
});
}
let mut query_metrics = Vec::new();
for (i, (text, ground_truth)) in test_cases.iter().enumerate() {
let test_case_id = format!("test_case_{}", i);
let normalized_truth: Vec<GoldEntity>;
let truth_ref = if let Some(mapper) = type_mapper {
normalized_truth = ground_truth
.iter()
.map(|e| GoldEntity {
text: e.text.clone(),
entity_type: mapper.normalize(e.entity_type.as_label()),
original_label: e.original_label.clone(), start: e.start,
end: e.end,
})
.collect();
&normalized_truth
} else {
ground_truth
};
let metrics = evaluator.evaluate_test_case(model, text, truth_ref, Some(&test_case_id))?;
query_metrics.push(metrics);
}
let aggregate = evaluator.aggregate(&query_metrics)?;
let macro_f1 = if aggregate.per_type.is_empty() {
None
} else {
let sum: f64 = aggregate.per_type.values().map(|m| m.f1).sum();
Some(sum / aggregate.per_type.len() as f64)
};
let weighted_f1 = if aggregate.per_type.is_empty() || aggregate.total_expected == 0 {
None
} else {
let weighted_sum: f64 = aggregate
.per_type
.values()
.map(|m| m.f1 * m.expected as f64)
.sum();
Some(weighted_sum / aggregate.total_expected as f64)
};
Ok(NEREvaluationResults {
precision: aggregate.precision.get(),
recall: aggregate.recall.get(),
f1: aggregate.f1.get(),
macro_f1,
weighted_f1,
per_type: aggregate.per_type,
tokens_per_second: aggregate.tokens_per_second,
found: aggregate.total_found,
expected: aggregate.total_expected,
metadata: Some(EvaluationMetadata {
num_test_cases: aggregate.num_test_cases,
total_gold_entities: Some(aggregate.total_expected),
timestamp: Some(chrono::Utc::now().to_rfc3339()),
anno_version: Some(env!("CARGO_PKG_VERSION").to_string()),
..Default::default()
}),
})
}
pub fn compare_ner_models(
models: &[(&str, &dyn Model)],
test_cases: &[(String, Vec<GoldEntity>)],
) -> Result<HashMap<String, NEREvaluationResults>> {
let mut results = HashMap::new();
for (name, model) in models {
log::info!("Evaluating {}...", name);
let result = evaluate_ner_model(*model, test_cases)?;
results.insert(name.to_string(), result);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_entity_type_to_string() {
assert_eq!(entity_type_to_string(&EntityType::Person), "PER");
assert_eq!(entity_type_to_string(&EntityType::Organization), "ORG");
assert_eq!(entity_type_to_string(&EntityType::Location), "LOC");
}
#[test]
fn test_entity_type_matches() {
assert!(entity_type_matches(
&EntityType::Person,
&EntityType::Person
));
assert!(!entity_type_matches(
&EntityType::Person,
&EntityType::Organization
));
}
}