use std::sync::Arc;
use crate::ragas::{RagasInputs, RagasMetric, RagasScore};
use crate::skills::grader::{AsyncGrader, GraderOutcome};
use crate::skills::task::SkillTask;
use crate::skills::transcript::Transcript;
pub type RagasInputsFn = Arc<dyn Fn(&SkillTask, &Transcript) -> RagasInputs + Send + Sync>;
pub struct RagasJudgeGrader<M> {
id: String,
metric: M,
pass_threshold: f64,
inputs_fn: RagasInputsFn,
}
impl<M: RagasMetric> RagasJudgeGrader<M> {
pub fn new(id: impl Into<String>, metric: M) -> Self {
Self {
id: id.into(),
metric,
pass_threshold: 0.5,
inputs_fn: default_inputs_fn(),
}
}
#[must_use]
pub fn with_pass_threshold(mut self, threshold: f64) -> Self {
self.pass_threshold = threshold;
self
}
#[must_use]
pub fn with_inputs_fn<F>(mut self, f: F) -> Self
where
F: Fn(&SkillTask, &Transcript) -> RagasInputs + Send + Sync + 'static,
{
self.inputs_fn = Arc::new(f);
self
}
}
pub fn default_inputs_fn() -> RagasInputsFn {
Arc::new(|task: &SkillTask, transcript: &Transcript| RagasInputs {
query_id: task.id.clone(),
query: task.prompt.clone(),
answer: Some(transcript.final_output.clone()),
context: Vec::new(),
reference_answer: None,
})
}
impl<M: RagasMetric + 'static> AsyncGrader for RagasJudgeGrader<M> {
fn id(&self) -> &str {
&self.id
}
fn grade<'a>(
&'a self,
task: &'a SkillTask,
transcript: &'a Transcript,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = GraderOutcome> + Send + 'a>> {
let id = self.id.clone();
let threshold = self.pass_threshold;
let inputs = (self.inputs_fn)(task, transcript);
Box::pin(async move {
match self.metric.score(&inputs).await {
Ok(RagasScore {
value: Some(v),
rationales,
}) => {
let raw = v.clamp(0.0, 1.0);
let passed = raw >= threshold;
let score = if passed { 1.0 } else { 0.0 };
let mut notes = format!("judge_score={raw:.4}");
if !rationales.is_empty() {
notes.push_str("; ");
notes.push_str(&rationales.join("; "));
}
GraderOutcome {
id,
score,
passed,
notes,
}
}
Ok(RagasScore {
value: None,
rationales,
}) => GraderOutcome::skipped(id, rationales.join("; ")),
Err(err) => GraderOutcome::fail(id, format!("judge error: {err}")),
}
})
}
}