use std::collections::{BTreeMap, BTreeSet};
use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
use tracing::{debug, instrument, warn};
use crate::error::{Error, Result};
use crate::report::{MetricReport, ReliabilityReport};
use crate::skills::grader::{AsyncGrader, GraderOutcome};
use crate::skills::runner::AgentRunner;
use crate::skills::task::SkillTaskSet;
use crate::skills::transcript::Transcript;
pub struct SkillHarness<R: AgentRunner> {
runner: R,
graders: Vec<Box<dyn AsyncGrader>>,
trials: usize,
concurrency: usize,
pass_threshold: f64,
}
impl<R: AgentRunner> SkillHarness<R> {
pub fn new(runner: R, graders: Vec<Box<dyn AsyncGrader>>) -> Self {
Self {
runner,
graders,
trials: 1,
concurrency: 1,
pass_threshold: 1.0,
}
}
#[must_use]
pub fn with_trials(mut self, trials: usize) -> Self {
self.trials = trials.max(1);
self
}
#[must_use]
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency.max(1);
self
}
#[must_use]
pub fn with_pass_threshold(mut self, threshold: f64) -> Self {
self.pass_threshold = threshold;
self
}
#[must_use]
pub fn trials(&self) -> usize {
self.trials
}
#[instrument(skip_all, fields(suite = %tasks.id, n_tasks = tasks.tasks.len(), trials = self.trials))]
pub async fn run(&self, tasks: &SkillTaskSet) -> Result<SkillEvalReport> {
if tasks.is_empty() {
return Err(Error::Config("skill task set is empty".into()));
}
let registry: BTreeMap<&str, &dyn AsyncGrader> =
self.graders.iter().map(|g| (g.id(), g.as_ref())).collect();
if registry.len() != self.graders.len() {
return Err(Error::Config(
"duplicate grader id in registry; ids must be unique".into(),
));
}
let mut referenced: BTreeSet<&str> = BTreeSet::new();
for task in &tasks.tasks {
for gid in &task.graders {
if !registry.contains_key(gid.as_str()) {
return Err(Error::Config(format!(
"task {:?} references unknown grader {:?}",
task.id, gid
)));
}
referenced.insert(gid.as_str());
}
}
let work: Vec<(usize, usize)> = (0..tasks.tasks.len())
.flat_map(|ti| (0..self.trials).map(move |tr| (ti, tr)))
.collect();
let registry_ref = ®istry;
let outcomes_stream = stream::iter(work.into_iter().map(move |(ti, tr)| {
let task = tasks.tasks.get(ti);
async move {
let Some(task) = task else {
return Err(Error::Config(format!("task index {ti} out of bounds")));
};
debug!(task_id = %task.id, trial = tr, "running trial");
let transcript = self.runner.run(task, tr).await?;
let mut per_grader: BTreeMap<String, GraderOutcome> = BTreeMap::new();
for gid in &task.graders {
let Some(grader) = registry_ref.get(gid.as_str()) else {
return Err(Error::Config(format!(
"grader {gid:?} disappeared between validation and run"
)));
};
let outcome = grader.grade(task, &transcript).await;
per_grader.insert(gid.clone(), outcome);
}
Ok::<_, Error>(TrialRow {
task_id: task.id.clone(),
trial: tr,
transcript,
outcomes: per_grader,
})
}
}))
.buffer_unordered(self.concurrency)
.collect::<Vec<_>>()
.await;
let mut rows: Vec<TrialRow> = Vec::with_capacity(outcomes_stream.len());
for row in outcomes_stream {
rows.push(row?);
}
rows.sort_by(|a, b| {
a.trial
.cmp(&b.trial)
.then_with(|| a.task_id.cmp(&b.task_id))
});
let mut per_grader_reports: BTreeMap<String, Vec<MetricReport>> = BTreeMap::new();
for trial in 0..self.trials {
let trial_rows: Vec<&TrialRow> = rows.iter().filter(|r| r.trial == trial).collect();
for gid in &referenced {
let pairs: Vec<(String, f64)> = trial_rows
.iter()
.filter_map(|row| {
row.outcomes
.get(*gid)
.map(|o| (row.task_id.clone(), o.score))
})
.collect();
if pairs.is_empty() {
continue;
}
let report = MetricReport::from_per_query((*gid).to_string(), pairs);
per_grader_reports
.entry((*gid).to_string())
.or_default()
.push(report);
}
}
let mut reliability: BTreeMap<String, ReliabilityReport> = BTreeMap::new();
for (gid, trial_reports) in &per_grader_reports {
if trial_reports.len() < self.trials {
warn!(
grader = %gid,
have = trial_reports.len(),
want = self.trials,
"skipping reliability aggregation: not all trials produced this grader"
);
continue;
}
let first_task_set: BTreeSet<&str> = trial_reports
.first()
.map(|r| r.per_query.iter().map(|(q, _)| q.as_str()).collect())
.unwrap_or_default();
let consistent = trial_reports.iter().all(|r| {
let set: BTreeSet<&str> = r.per_query.iter().map(|(q, _)| q.as_str()).collect();
set == first_task_set
});
if !consistent {
warn!(grader = %gid, "skipping reliability: task coverage varies across trials");
continue;
}
match ReliabilityReport::from_metric_reports(
gid.clone(),
self.pass_threshold,
self.trials,
trial_reports,
) {
Ok(r) => {
reliability.insert(gid.clone(), r);
}
Err(err) => {
warn!(grader = %gid, %err, "reliability aggregation failed");
}
}
}
Ok(SkillEvalReport {
suite_id: tasks.id.clone(),
n_tasks: tasks.tasks.len(),
trials: self.trials,
pass_threshold: self.pass_threshold,
per_grader: per_grader_reports,
reliability,
trials_log: rows,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrialRow {
pub task_id: String,
pub trial: usize,
pub transcript: Transcript,
pub outcomes: BTreeMap<String, GraderOutcome>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillEvalReport {
pub suite_id: String,
pub n_tasks: usize,
pub trials: usize,
pub pass_threshold: f64,
pub per_grader: BTreeMap<String, Vec<MetricReport>>,
pub reliability: BTreeMap<String, ReliabilityReport>,
pub trials_log: Vec<TrialRow>,
}
impl SkillEvalReport {
#[must_use]
pub fn mean_pass_rate(&self, grader_id: &str) -> Option<f64> {
self.reliability.get(grader_id).map(|r| r.mean_pass_rate)
}
}