use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::ir::{ContentPart, Message, Role};
use entelix_core::{ExecutionContext, Result};
use entelix_memory::Document as RetrievedDocument;
use entelix_runnable::Runnable;
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum GradeVerdict {
Correct,
Ambiguous,
Incorrect,
}
impl GradeVerdict {
#[must_use]
pub const fn is_actionable(self) -> bool {
matches!(self, Self::Correct)
}
#[must_use]
pub const fn as_label(self) -> &'static str {
match self {
Self::Correct => "correct",
Self::Ambiguous => "ambiguous",
Self::Incorrect => "incorrect",
}
}
}
#[async_trait]
pub trait RetrievalGrader: Send + Sync {
fn name(&self) -> &'static str;
async fn grade(
&self,
query: &str,
document: &RetrievedDocument,
ctx: &ExecutionContext,
) -> Result<GradeVerdict>;
}
pub const DEFAULT_GRADER_INSTRUCTION: &str = "\
You are a retrieval grader. Given a user query and one retrieved document, decide whether \
the document answers the query. Reply with exactly one of: `correct` (the document directly \
answers), `ambiguous` (the document is on-topic but does not directly answer), or `incorrect` \
(the document does not answer or is off-topic). Reply with only the single label — no \
explanation, no quotes, no surrounding text.";
const LLM_GRADER_NAME: &str = "llm-retrieval-grader";
pub struct LlmRetrievalGraderBuilder<M> {
model: Arc<M>,
instruction: String,
}
impl<M> LlmRetrievalGraderBuilder<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
#[must_use]
pub fn with_instruction(mut self, instruction: impl Into<String>) -> Self {
self.instruction = instruction.into();
self
}
#[must_use]
pub fn build(self) -> LlmRetrievalGrader<M> {
LlmRetrievalGrader {
model: self.model,
instruction: Arc::from(self.instruction),
}
}
}
pub struct LlmRetrievalGrader<M> {
model: Arc<M>,
instruction: Arc<str>,
}
impl<M> LlmRetrievalGrader<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
#[must_use]
pub fn builder(model: Arc<M>) -> LlmRetrievalGraderBuilder<M> {
LlmRetrievalGraderBuilder {
model,
instruction: DEFAULT_GRADER_INSTRUCTION.to_owned(),
}
}
fn build_prompt(&self, query: &str, document: &RetrievedDocument) -> Vec<Message> {
vec![Message::new(
Role::User,
vec![
ContentPart::text(self.instruction.to_string()),
ContentPart::text(format!("<query>\n{query}\n</query>")),
ContentPart::text(format!("<document>\n{}\n</document>", document.content)),
],
)]
}
}
impl<M> Clone for LlmRetrievalGrader<M> {
fn clone(&self) -> Self {
Self {
model: Arc::clone(&self.model),
instruction: Arc::clone(&self.instruction),
}
}
}
impl<M> std::fmt::Debug for LlmRetrievalGrader<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmRetrievalGrader").finish_non_exhaustive()
}
}
#[async_trait]
impl<M> RetrievalGrader for LlmRetrievalGrader<M>
where
M: Runnable<Vec<Message>, Message> + 'static,
{
fn name(&self) -> &'static str {
LLM_GRADER_NAME
}
async fn grade(
&self,
query: &str,
document: &RetrievedDocument,
ctx: &ExecutionContext,
) -> Result<GradeVerdict> {
let prompt = self.build_prompt(query, document);
let reply = self.model.invoke(prompt, ctx).await?;
Ok(parse_verdict(&reply))
}
}
fn parse_verdict(message: &Message) -> GradeVerdict {
let mut text = String::new();
for part in &message.content {
if let ContentPart::Text { text: t, .. } = part {
text.push_str(t);
}
}
let lower = text.to_lowercase();
if lower.contains("incorrect") {
GradeVerdict::Incorrect
} else if lower.contains("ambiguous") {
GradeVerdict::Ambiguous
} else if lower.contains("correct") {
GradeVerdict::Correct
} else {
GradeVerdict::Ambiguous
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
fn doc(content: &str) -> RetrievedDocument {
RetrievedDocument::new(content)
}
fn assistant(text: &str) -> Message {
Message::new(Role::Assistant, vec![ContentPart::text(text)])
}
struct ScriptedModel {
script: Mutex<Vec<Result<Message>>>,
}
impl ScriptedModel {
fn new(replies: Vec<Message>) -> Self {
Self {
script: Mutex::new(replies.into_iter().map(Ok).rev().collect()),
}
}
}
#[async_trait]
impl Runnable<Vec<Message>, Message> for ScriptedModel {
async fn invoke(&self, _input: Vec<Message>, _ctx: &ExecutionContext) -> Result<Message> {
self.script.lock().unwrap().pop().expect("script exhausted")
}
}
#[test]
fn verdict_is_actionable_only_for_correct() {
assert!(GradeVerdict::Correct.is_actionable());
assert!(!GradeVerdict::Ambiguous.is_actionable());
assert!(!GradeVerdict::Incorrect.is_actionable());
}
#[test]
fn verdict_label_round_trips() {
assert_eq!(GradeVerdict::Correct.as_label(), "correct");
assert_eq!(GradeVerdict::Ambiguous.as_label(), "ambiguous");
assert_eq!(GradeVerdict::Incorrect.as_label(), "incorrect");
}
#[test]
fn parser_accepts_canonical_lowercase() {
assert_eq!(parse_verdict(&assistant("correct")), GradeVerdict::Correct);
assert_eq!(
parse_verdict(&assistant("ambiguous")),
GradeVerdict::Ambiguous
);
assert_eq!(
parse_verdict(&assistant("incorrect")),
GradeVerdict::Incorrect
);
}
#[test]
fn parser_tolerates_whitespace_punctuation_and_case() {
assert_eq!(parse_verdict(&assistant("Correct.")), GradeVerdict::Correct);
assert_eq!(
parse_verdict(&assistant(" AMBIGUOUS\n")),
GradeVerdict::Ambiguous
);
assert_eq!(
parse_verdict(&assistant("Verdict: incorrect")),
GradeVerdict::Incorrect
);
}
#[test]
fn parser_disambiguates_incorrect_from_correct() {
let reply = assistant("incorrect");
assert_eq!(parse_verdict(&reply), GradeVerdict::Incorrect);
}
#[test]
fn parser_degrades_unknown_reply_to_ambiguous() {
assert_eq!(
parse_verdict(&assistant("the document looks fine to me")),
GradeVerdict::Ambiguous
);
assert_eq!(parse_verdict(&assistant("")), GradeVerdict::Ambiguous);
}
#[tokio::test]
async fn grader_dispatches_through_model_and_returns_parsed_verdict() {
let model = Arc::new(ScriptedModel::new(vec![assistant("correct")]));
let grader = LlmRetrievalGrader::builder(model).build();
let verdict = grader
.grade(
"alpha?",
&doc("alpha is the first letter"),
&ExecutionContext::new(),
)
.await
.unwrap();
assert_eq!(verdict, GradeVerdict::Correct);
assert_eq!(grader.name(), LLM_GRADER_NAME);
}
#[tokio::test]
async fn grader_propagates_model_error() {
struct FailingModel;
#[async_trait]
impl Runnable<Vec<Message>, Message> for FailingModel {
async fn invoke(
&self,
_input: Vec<Message>,
_ctx: &ExecutionContext,
) -> Result<Message> {
Err(entelix_core::Error::provider_http(503, "transient"))
}
}
let grader = LlmRetrievalGrader::builder(Arc::new(FailingModel)).build();
let err = grader
.grade("query", &doc("text"), &ExecutionContext::new())
.await
.unwrap_err();
assert!(matches!(
err,
entelix_core::Error::Provider {
kind: entelix_core::ProviderErrorKind::Http(503),
..
}
));
}
}