use std::pin::Pin;
use std::sync::Arc;
use rig::vector_store::request::Filter;
use rig::vector_store::{VectorSearchRequest, VectorStoreIndexDyn};
use serde_json::Value;
use crate::skills::grader::{AsyncGrader, GraderOutcome};
use crate::skills::task::SkillTask;
use crate::skills::transcript::Transcript;
pub type GroundednessQueryFn = Arc<dyn Fn(&SkillTask, &Transcript) -> String + Send + Sync>;
pub type GroundednessScorerFn = Arc<dyn Fn(&str, &[String]) -> f64 + Send + Sync>;
pub type DocumentExtractorFn = Arc<dyn Fn(&Value) -> String + Send + Sync>;
pub fn default_query_fn() -> GroundednessQueryFn {
Arc::new(|task, transcript| {
let trimmed = transcript.final_output.trim();
if trimmed.is_empty() {
task.prompt.clone()
} else {
transcript.final_output.clone()
}
})
}
pub fn default_document_extractor() -> DocumentExtractorFn {
Arc::new(|doc| {
if let Some(s) = doc.get("content").and_then(Value::as_str) {
return s.to_string();
}
doc.to_string()
})
}
pub fn default_scorer() -> GroundednessScorerFn {
Arc::new(|answer, contexts| {
let answer_tokens = tokenize_unique(answer);
if answer_tokens.is_empty() {
return 0.0;
}
let mut corpus = String::new();
for c in contexts {
corpus.push_str(c);
corpus.push(' ');
}
let corpus_tokens = tokenize_unique(&corpus);
let hits = answer_tokens
.iter()
.filter(|t| corpus_tokens.contains(*t))
.count();
hits as f64 / answer_tokens.len() as f64
})
}
fn tokenize_unique(s: &str) -> std::collections::BTreeSet<String> {
s.to_ascii_lowercase()
.split(|c: char| !c.is_ascii_alphanumeric())
.filter(|t| !t.is_empty())
.map(str::to_string)
.collect()
}
pub struct RetrievalGroundednessGrader {
id: String,
store: Arc<dyn VectorStoreIndexDyn>,
k: u64,
pass_threshold: f64,
query_fn: GroundednessQueryFn,
scorer: GroundednessScorerFn,
extractor: DocumentExtractorFn,
}
impl RetrievalGroundednessGrader {
pub fn new(id: impl Into<String>, store: Arc<dyn VectorStoreIndexDyn>) -> Self {
Self {
id: id.into(),
store,
k: 5,
pass_threshold: 0.5,
query_fn: default_query_fn(),
scorer: default_scorer(),
extractor: default_document_extractor(),
}
}
pub fn with_k(mut self, k: u64) -> Self {
self.k = k;
self
}
pub fn with_pass_threshold(mut self, threshold: f64) -> Self {
self.pass_threshold = threshold;
self
}
pub fn with_query_fn<F>(mut self, f: F) -> Self
where
F: Fn(&SkillTask, &Transcript) -> String + Send + Sync + 'static,
{
self.query_fn = Arc::new(f);
self
}
pub fn with_scorer<F>(mut self, f: F) -> Self
where
F: Fn(&str, &[String]) -> f64 + Send + Sync + 'static,
{
self.scorer = Arc::new(f);
self
}
pub fn with_document_extractor<F>(mut self, f: F) -> Self
where
F: Fn(&Value) -> String + Send + Sync + 'static,
{
self.extractor = Arc::new(f);
self
}
}
impl AsyncGrader for RetrievalGroundednessGrader {
fn id(&self) -> &str {
&self.id
}
fn grade<'a>(
&'a self,
task: &'a SkillTask,
transcript: &'a Transcript,
) -> Pin<Box<dyn std::future::Future<Output = GraderOutcome> + Send + 'a>> {
let id = self.id.clone();
let threshold = self.pass_threshold;
let k = self.k;
let store = self.store.clone();
let query = (self.query_fn)(task, transcript);
let scorer = self.scorer.clone();
let extractor = self.extractor.clone();
let answer = transcript.final_output.clone();
Box::pin(async move {
if query.trim().is_empty() {
return GraderOutcome::skipped(id, "empty retrieval query");
}
let req: VectorSearchRequest<Filter<Value>> = VectorSearchRequest::builder()
.query(query)
.samples(k)
.build();
let hits = match store.top_n(req).await {
Ok(hits) => hits,
Err(err) => {
return GraderOutcome::fail(id, format!("retrieval error: {err}"));
}
};
let contexts: Vec<String> = hits.iter().map(|(_, _, doc)| extractor(doc)).collect();
let raw = scorer(&answer, &contexts).clamp(0.0, 1.0);
let passed = raw >= threshold;
let score = if passed { 1.0 } else { 0.0 };
let notes = format!("grounded_score={raw:.4}; k={k}");
GraderOutcome {
id,
score,
passed,
notes,
}
})
}
}