use clap::{ArgAction, Parser, ValueEnum};
#[cfg(feature = "eval")]
use anno_eval::eval::loader::DatasetId;
#[cfg(feature = "eval")]
use anno_eval::eval::task_evaluator::TaskEvaluator;
#[cfg(feature = "eval")]
use anno_eval::eval::task_mapping::Task;
#[derive(Parser, Debug)]
pub struct BenchmarkArgs {
#[arg(long, value_enum)]
pub profile: Option<BenchmarkProfile>,
#[arg(long, value_delimiter = ',')]
pub seeds: Option<Vec<u64>>,
#[arg(long, default_value_t = true, action = ArgAction::Set)]
pub split_heavy_backends: bool,
#[arg(short, long, value_delimiter = ',')]
pub tasks: Option<Vec<String>>,
#[arg(short, long, value_delimiter = ',')]
pub datasets: Option<Vec<String>>,
#[arg(short, long, value_delimiter = ',')]
pub backends: Option<Vec<String>>,
#[arg(short, long)]
pub max_examples: Option<usize>,
#[arg(long)]
pub seed: Option<u64>,
#[arg(long)]
pub cached_only: bool,
#[arg(short, long)]
pub output: Option<String>,
#[arg(long, value_name = "PATH")]
pub output_json: Option<String>,
}
#[derive(Debug, Clone, Copy, ValueEnum)]
pub enum BenchmarkProfile {
NerStandard,
NerZeroshotMultilingual,
CorefStandard,
RelationStandard,
}
fn suffix_path(p: &str, suffix: &str) -> String {
match p.rsplit_once('.') {
Some((stem, ext)) => format!("{stem}{suffix}.{ext}"),
None => format!("{p}{suffix}"),
}
}
fn is_heavy_backend(name: &str) -> bool {
matches!(
name,
"gliner_onnx" | "gliner_multitask" | "gliner_poly" | "gliner_pii" | "gliner_relex"
)
}
fn profile_defaults(profile: BenchmarkProfile) -> (Vec<String>, Vec<String>, Vec<String>) {
match profile {
BenchmarkProfile::NerStandard => (
vec!["ner".to_string()],
vec![
"WikiGold".to_string(),
"Wnut17".to_string(),
"CoNLL2003Sample".to_string(),
],
vec![
"bert_onnx".to_string(),
"stacked".to_string(),
"heuristic".to_string(),
],
),
BenchmarkProfile::NerZeroshotMultilingual => (
vec!["ner".to_string()],
vec![
"WikiANN".to_string(),
"MasakhaNER".to_string(),
"MultiNERD".to_string(),
"MultiCoNERv2".to_string(),
],
vec!["gliner_onnx".to_string(), "nuner".to_string()],
),
BenchmarkProfile::CorefStandard => (
vec!["coref".to_string()],
vec!["GAP".to_string()],
vec!["coref_resolver".to_string()],
),
BenchmarkProfile::RelationStandard => (
vec!["relation".to_string()],
vec!["DocRED".to_string()],
vec!["tplinker".to_string()],
),
}
}
pub fn run(args: BenchmarkArgs) -> Result<(), String> {
#[cfg(not(feature = "eval"))]
{
let _ = args;
Err("Benchmark command requires --features eval".to_string())
}
#[cfg(feature = "eval")]
{
println!("=== Comprehensive Task-Dataset-Backend Evaluation ===\n");
if let Some(seeds) = &args.seeds {
if seeds.is_empty() {
return Err("--seeds was provided but empty".to_string());
}
for seed in seeds {
let per = BenchmarkArgs {
profile: args.profile,
seeds: None,
split_heavy_backends: args.split_heavy_backends,
tasks: args.tasks.clone(),
datasets: args.datasets.clone(),
backends: args.backends.clone(),
max_examples: args.max_examples,
seed: Some(*seed),
cached_only: args.cached_only,
output: args
.output
.as_ref()
.map(|p| suffix_path(p, &format!("-seed{seed}"))),
output_json: args
.output_json
.as_ref()
.map(|p| suffix_path(p, &format!("-seed{seed}"))),
};
run(per)?;
}
return Ok(());
}
let (profile_tasks, profile_datasets, profile_backends) = match args.profile {
Some(p) => profile_defaults(p),
None => (Vec::new(), Vec::new(), Vec::new()),
};
let task_strs = args.tasks.clone().unwrap_or(profile_tasks);
let dataset_strs = args.datasets.clone().unwrap_or(profile_datasets);
let backend_strs = args.backends.clone().unwrap_or(profile_backends);
let tasks = if !task_strs.is_empty() {
let mut parsed = Vec::new();
for t in task_strs {
match t.to_lowercase().as_str() {
"ner" | "ner_task" => parsed.push(Task::NER),
"coref" | "coreference" | "intradoc_coref" => parsed.push(Task::IntraDocCoref),
"relation" | "relation_extraction" => parsed.push(Task::RelationExtraction),
other => {
return Err(format!(
"Unknown task: {}. Use: ner, coref, relation",
other
));
}
}
}
parsed
} else {
Task::all().to_vec()
};
let datasets = if !dataset_strs.is_empty() {
let mut parsed = Vec::new();
for d in dataset_strs {
let dataset_id: DatasetId = d
.parse()
.map_err(|e| format!("Invalid dataset '{}': {}", d, e))?;
parsed.push(dataset_id);
}
parsed
} else {
vec![] };
let backends = backend_strs;
if args.split_heavy_backends && backends.len() > 1 {
let (heavy, light): (Vec<String>, Vec<String>) =
backends.into_iter().partition(|b| is_heavy_backend(b));
for hb in &heavy {
let exe = std::env::current_exe()
.map_err(|e| format!("Failed to locate current executable: {}", e))?;
let mut cmd = std::process::Command::new(exe);
cmd.arg("benchmark");
if let Some(profile) = args.profile {
cmd.arg("--profile")
.arg(profile.to_possible_value().unwrap().get_name());
}
cmd.arg("--tasks")
.arg(tasks.iter().map(|t| t.code()).collect::<Vec<_>>().join(","));
if !datasets.is_empty() {
cmd.arg("--datasets").arg(
datasets
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join(","),
);
}
cmd.arg("--backends").arg(hb);
if let Some(max) = args.max_examples {
cmd.arg("--max-examples").arg(max.to_string());
}
if let Some(seed) = args.seed {
cmd.arg("--seed").arg(seed.to_string());
}
if args.cached_only {
cmd.arg("--cached-only");
}
if let Some(out) = &args.output {
cmd.arg("--output").arg(suffix_path(out, &format!("-{hb}")));
}
if let Some(outj) = &args.output_json {
cmd.arg("--output-json")
.arg(suffix_path(outj, &format!("-{hb}")));
}
cmd.arg("--split-heavy-backends=false");
let status = cmd
.status()
.map_err(|e| format!("Failed to spawn benchmark subprocess: {}", e))?;
if !status.success() {
return Err(format!("Heavy backend subprocess failed for '{hb}'"));
}
}
if light.is_empty() {
return Ok(());
}
if args.output_json.is_none() {
let per = BenchmarkArgs {
profile: args.profile,
seeds: None,
split_heavy_backends: false,
tasks: Some(tasks.iter().map(|t| t.code().to_string()).collect()),
datasets: if datasets.is_empty() {
None
} else {
Some(datasets.iter().map(|d| d.to_string()).collect())
},
backends: Some(light),
max_examples: args.max_examples,
seed: args.seed,
cached_only: args.cached_only,
output: args.output.clone(),
output_json: args.output_json.clone(),
};
return run(per);
}
use anno_eval::eval::config_builder::TaskEvalConfigBuilder;
use anno_eval::eval::task_evaluator::{
ComprehensiveEvalResults, EvalSummary, TaskEvalResult,
};
use std::collections::HashSet;
fn summarize(results: &[TaskEvalResult]) -> EvalSummary {
let skipped = results.iter().filter(|r| r.is_skipped()).count();
let failed = results
.iter()
.filter(|r| !r.success && !r.is_skipped())
.count();
let mut tasks: Vec<Task> = Vec::new();
let mut datasets: Vec<DatasetId> = Vec::new();
let mut backends: Vec<String> = Vec::new();
for r in results {
if !tasks.contains(&r.task) {
tasks.push(r.task);
}
if !datasets.contains(&r.dataset) {
datasets.push(r.dataset);
}
if !backends.contains(&r.backend) {
backends.push(r.backend.clone());
}
}
EvalSummary {
total_combinations: results.len(),
successful: results.iter().filter(|r| r.success).count(),
failed,
skipped,
tasks,
datasets,
backends,
}
}
let evaluator =
TaskEvaluator::new().map_err(|e| format!("Failed to create evaluator: {}", e))?;
let mut builder = TaskEvalConfigBuilder::new()
.with_tasks(tasks.clone())
.with_datasets(datasets.clone())
.with_backends(light)
.require_cached(args.cached_only)
.with_confidence_intervals(true)
.with_familiarity(true);
if let Some(max) = args.max_examples {
if max > 0 {
builder = builder.with_max_examples(max);
}
}
if let Some(seed) = args.seed {
builder = builder.with_seed(seed);
}
let config = builder.build();
println!("Running comprehensive evaluation...");
println!("Tasks: {:?}", config.tasks);
if !config.datasets.is_empty() {
println!("Datasets: {:?}", config.datasets);
} else {
println!("Datasets: all suitable datasets");
}
if !config.backends.is_empty() {
println!("Backends: {:?}", config.backends);
} else {
println!("Backends: all compatible backends");
}
if let Some(max) = config.max_examples {
println!("Max examples per dataset: {}", max);
}
if let Some(seed) = config.seed {
println!("Random seed: {}", seed);
}
println!();
let mut combined = evaluator
.evaluate_all(config)
.map_err(|e| format!("Evaluation failed: {}", e))?;
let json_path = args
.output_json
.as_ref()
.expect("checked output_json Some above");
let mut seen: HashSet<(Task, DatasetId, String)> = combined
.results
.iter()
.map(|r| (r.task, r.dataset, r.backend.clone()))
.collect();
for hb in &heavy {
let p = suffix_path(json_path, &format!("-{hb}"));
let raw = std::fs::read_to_string(&p)
.map_err(|e| format!("Failed to read heavy JSON artifact {}: {}", p, e))?;
let parsed: ComprehensiveEvalResults = serde_json::from_str(&raw)
.map_err(|e| format!("Failed to parse heavy JSON artifact {}: {}", p, e))?;
for r in parsed.results {
let key = (r.task, r.dataset, r.backend.clone());
if seen.insert(key) {
combined.results.push(r);
}
}
}
combined.summary = summarize(&combined.results);
println!("=== Evaluation Summary ===");
println!(
"Total combinations: {}",
combined.summary.total_combinations
);
println!("Successful: {}", combined.summary.successful);
println!(
"Skipped (feature not available): {}",
combined.summary.skipped
);
println!("Failed (actual errors): {}", combined.summary.failed);
println!("\nTasks evaluated: {}", combined.summary.tasks.len());
println!("Datasets used: {}", combined.summary.datasets.len());
println!("Backends tested: {}", combined.summary.backends.len());
println!();
let json = serde_json::to_string_pretty(&combined)
.map_err(|e| format!("Failed to serialize results as JSON: {}", e))?;
std::fs::write(json_path, json)
.map_err(|e| format!("Failed to write JSON report to {}: {}", json_path, e))?;
println!("JSON saved to: {}", json_path);
let report = combined.to_markdown();
if let Some(output_path) = &args.output {
std::fs::write(output_path, &report)
.map_err(|e| format!("Failed to write report to {}: {}", output_path, e))?;
println!("Report saved to: {}", output_path);
} else {
println!("=== Markdown Report ===");
println!("{}", report);
}
return Ok(());
}
let evaluator =
TaskEvaluator::new().map_err(|e| format!("Failed to create evaluator: {}", e))?;
use anno_eval::eval::config_builder::TaskEvalConfigBuilder;
let mut builder = TaskEvalConfigBuilder::new()
.with_tasks(tasks)
.with_datasets(datasets)
.with_backends(backends)
.require_cached(args.cached_only)
.with_confidence_intervals(true)
.with_familiarity(true);
if let Some(max) = args.max_examples {
if max > 0 {
builder = builder.with_max_examples(max);
}
}
if let Some(seed) = args.seed {
builder = builder.with_seed(seed);
}
let config = builder.build();
println!("Running comprehensive evaluation...");
println!("Tasks: {:?}", config.tasks);
if !config.datasets.is_empty() {
println!("Datasets: {:?}", config.datasets);
} else {
println!("Datasets: all suitable datasets");
}
if !config.backends.is_empty() {
println!("Backends: {:?}", config.backends);
} else {
println!("Backends: all compatible backends");
}
if let Some(max) = config.max_examples {
println!("Max examples per dataset: {}", max);
}
if let Some(seed) = config.seed {
println!("Random seed: {}", seed);
}
println!();
let results = evaluator
.evaluate_all(config)
.map_err(|e| format!("Evaluation failed: {}", e))?;
println!("=== Evaluation Summary ===");
println!("Total combinations: {}", results.summary.total_combinations);
println!("Successful: {}", results.summary.successful);
println!(
"Skipped (feature not available): {}",
results.summary.skipped
);
println!("Failed (actual errors): {}", results.summary.failed);
println!("\nTasks evaluated: {}", results.summary.tasks.len());
println!("Datasets used: {}", results.summary.datasets.len());
println!("Backends tested: {}", results.summary.backends.len());
println!();
let report = results.to_markdown();
if let Some(json_path) = &args.output_json {
let json = serde_json::to_string_pretty(&results)
.map_err(|e| format!("Failed to serialize results as JSON: {}", e))?;
std::fs::write(json_path, json)
.map_err(|e| format!("Failed to write JSON report to {}: {}", json_path, e))?;
println!("JSON saved to: {}", json_path);
}
if let Some(output_path) = &args.output {
std::fs::write(output_path, &report)
.map_err(|e| format!("Failed to write report to {}: {}", output_path, e))?;
println!("Report saved to: {}", output_path);
} else {
println!("=== Markdown Report ===");
println!("{}", report);
}
Ok(())
}
}