use crate::skills::grader::{Grader, GraderOutcome};
use crate::skills::task::SkillTask;
use crate::skills::transcript::Transcript;
#[derive(Debug, Clone)]
pub struct ContainsGrader {
id: String,
needle: String,
expect_present: bool,
}
impl ContainsGrader {
pub fn present(id: impl Into<String>, needle: impl Into<String>) -> Self {
Self {
id: id.into(),
needle: needle.into(),
expect_present: true,
}
}
pub fn absent(id: impl Into<String>, needle: impl Into<String>) -> Self {
Self {
id: id.into(),
needle: needle.into(),
expect_present: false,
}
}
}
impl Grader for ContainsGrader {
fn id(&self) -> &str {
&self.id
}
fn grade(&self, _task: &SkillTask, transcript: &Transcript) -> GraderOutcome {
let present = transcript.final_output.contains(&self.needle);
if present == self.expect_present {
GraderOutcome::pass(&self.id)
} else if self.expect_present {
GraderOutcome::fail(&self.id, format!("substring {:?} not found", self.needle))
} else {
GraderOutcome::fail(
&self.id,
format!("forbidden substring {:?} present", self.needle),
)
}
}
}
#[derive(Debug, Clone)]
pub struct ToolCallGrader {
id: String,
tool: String,
min_invocations: usize,
}
impl ToolCallGrader {
pub fn at_least_once(id: impl Into<String>, tool: impl Into<String>) -> Self {
Self {
id: id.into(),
tool: tool.into(),
min_invocations: 1,
}
}
pub fn at_least(id: impl Into<String>, tool: impl Into<String>, n: usize) -> Self {
Self {
id: id.into(),
tool: tool.into(),
min_invocations: n.max(1),
}
}
}
impl Grader for ToolCallGrader {
fn id(&self) -> &str {
&self.id
}
fn grade(&self, _task: &SkillTask, transcript: &Transcript) -> GraderOutcome {
let count = transcript
.tool_calls
.iter()
.filter(|call| call.name == self.tool)
.count();
if count >= self.min_invocations {
GraderOutcome::pass(&self.id)
} else {
GraderOutcome::fail(
&self.id,
format!(
"tool {:?} invoked {} time(s); required {}",
self.tool, count, self.min_invocations
),
)
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TranscriptBudget {
id: String,
pub max_turns: Option<usize>,
pub max_tool_calls: Option<usize>,
pub max_total_tokens: Option<u64>,
pub max_cost_usd: Option<f64>,
}
impl TranscriptBudget {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
..Self::default()
}
}
}
impl Grader for TranscriptBudget {
fn id(&self) -> &str {
&self.id
}
fn grade(&self, _task: &SkillTask, transcript: &Transcript) -> GraderOutcome {
let mut violations = Vec::new();
if let (Some(limit), Some(turns)) = (self.max_turns, transcript.turns)
&& turns > limit
{
violations.push(format!("turns {turns} > {limit}"));
}
if let Some(limit) = self.max_tool_calls {
let n = transcript.tool_calls.len();
if n > limit {
violations.push(format!("tool_calls {n} > {limit}"));
}
}
if let (Some(limit), Some(usage)) = (self.max_total_tokens, transcript.usage.as_ref()) {
let total = usage.total_tokens();
if total > limit {
violations.push(format!("total_tokens {total} > {limit}"));
}
}
if let (Some(limit), Some(cost)) = (
self.max_cost_usd,
transcript.usage.as_ref().and_then(|u| u.cost_usd),
) && cost > limit
{
violations.push(format!("cost_usd {cost} > {limit}"));
}
if violations.is_empty() {
GraderOutcome::pass(&self.id)
} else {
GraderOutcome::fail(&self.id, violations.join("; "))
}
}
}
#[derive(Debug, Clone)]
pub struct TriggerGrader {
id: String,
}
impl TriggerGrader {
pub fn new(id: impl Into<String>) -> Self {
Self { id: id.into() }
}
}
impl Default for TriggerGrader {
fn default() -> Self {
Self::new("trigger")
}
}
impl Grader for TriggerGrader {
fn id(&self) -> &str {
&self.id
}
fn grade(&self, task: &SkillTask, transcript: &Transcript) -> GraderOutcome {
match (
task.should_trigger.as_deref(),
transcript.skill_invoked.as_deref(),
) {
(Some(expected), Some(actual)) if expected == actual => GraderOutcome::pass(&self.id),
(Some(expected), Some(actual)) => GraderOutcome::fail(
&self.id,
format!("expected skill {expected:?}, got {actual:?}"),
),
(Some(expected), None) => GraderOutcome::fail(
&self.id,
format!("expected skill {expected:?}, runner did not record routing"),
),
(None, Some(actual)) => GraderOutcome::fail(
&self.id,
format!("expected no skill to fire, got {actual:?}"),
),
(None, None) => {
GraderOutcome::skipped(&self.id, "no expected skill and no observed routing")
}
}
}
}