use std::sync::Arc;
use entelix_agents::Agent;
use entelix_core::ir::{ContentPart, Message, Role, SystemPrompt};
use entelix_core::{ExecutionContext, Result};
use entelix_graph::{CompiledGraph, StateGraph};
use entelix_memory::{Document as RetrievedDocument, RetrievalQuery, Retriever};
use entelix_runnable::{Runnable, RunnableLambda};
use crate::corrective::grader::{GradeVerdict, RetrievalGrader};
use crate::corrective::rewriter::QueryRewriter;
pub const DEFAULT_MIN_CORRECT_FRACTION: f32 = 0.5;
pub const DEFAULT_RETRIEVAL_TOP_K: usize = 5;
pub const DEFAULT_MAX_REWRITE_ATTEMPTS: u32 = 3;
pub const DEFAULT_GENERATOR_SYSTEM_PROMPT: &str = "\
You are a helpful assistant. Answer the user's question using only the supplied retrieved \
documents as your evidence base. If the documents don't contain enough information to \
answer with confidence, say so explicitly. Never fabricate facts that the documents do \
not support.";
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct CragConfig {
min_correct_fraction: f32,
retrieval_top_k: usize,
max_rewrite_attempts: u32,
generator_system_prompt: SystemPrompt,
}
impl CragConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_min_correct_fraction(mut self, fraction: f32) -> Self {
self.min_correct_fraction = fraction;
self
}
#[must_use]
pub const fn with_retrieval_top_k(mut self, top_k: usize) -> Self {
self.retrieval_top_k = top_k;
self
}
#[must_use]
pub const fn with_max_rewrite_attempts(mut self, max: u32) -> Self {
self.max_rewrite_attempts = max;
self
}
#[must_use]
pub fn with_generator_system_prompt(mut self, prompt: SystemPrompt) -> Self {
self.generator_system_prompt = prompt;
self
}
#[must_use]
pub const fn min_correct_fraction(&self) -> f32 {
self.min_correct_fraction
}
#[must_use]
pub const fn retrieval_top_k(&self) -> usize {
self.retrieval_top_k
}
#[must_use]
pub const fn max_rewrite_attempts(&self) -> u32 {
self.max_rewrite_attempts
}
#[must_use]
pub const fn generator_system_prompt(&self) -> &SystemPrompt {
&self.generator_system_prompt
}
}
impl Default for CragConfig {
fn default() -> Self {
Self {
min_correct_fraction: DEFAULT_MIN_CORRECT_FRACTION,
retrieval_top_k: DEFAULT_RETRIEVAL_TOP_K,
max_rewrite_attempts: DEFAULT_MAX_REWRITE_ATTEMPTS,
generator_system_prompt: SystemPrompt::text(DEFAULT_GENERATOR_SYSTEM_PROMPT),
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct CorrectiveRagState {
pub original_query: String,
pub query: String,
pub previous_attempts: Vec<String>,
pub graded: Vec<(RetrievedDocument, GradeVerdict)>,
pub correct_documents: Vec<RetrievedDocument>,
pub attempt: u32,
pub answer: Option<String>,
}
impl CorrectiveRagState {
#[must_use]
pub fn from_query(query: impl Into<String>) -> Self {
let query = query.into();
Self {
original_query: query.clone(),
query,
previous_attempts: Vec::new(),
graded: Vec::new(),
correct_documents: Vec::new(),
attempt: 0,
answer: None,
}
}
}
pub fn build_corrective_rag_graph<Ret, G, R, M>(
retriever: Arc<Ret>,
grader: G,
rewriter: R,
generator: M,
config: CragConfig,
) -> Result<CompiledGraph<CorrectiveRagState>>
where
Ret: Retriever + ?Sized + 'static,
G: RetrievalGrader + 'static,
R: QueryRewriter + 'static,
M: Runnable<Vec<Message>, Message> + 'static,
{
let config = Arc::new(config);
let generator = Arc::new(generator);
let grader = Arc::new(grader);
let rewriter = Arc::new(rewriter);
StateGraph::<CorrectiveRagState>::new()
.add_node(NODE_RETRIEVE, make_retriever_node(retriever, &config))
.add_node(NODE_GRADE, make_grader_node(grader))
.add_node(NODE_REWRITE, make_rewriter_node(rewriter))
.add_node(NODE_GENERATE, make_generator_node(generator, &config))
.set_entry_point(NODE_RETRIEVE)
.add_finish_point(NODE_GENERATE)
.add_edge(NODE_RETRIEVE, NODE_GRADE)
.add_edge(NODE_REWRITE, NODE_RETRIEVE)
.add_conditional_edges(
NODE_GRADE,
{
let config = Arc::clone(&config);
move |state: &CorrectiveRagState| route_after_grade(state, &config)
},
[(NODE_REWRITE, NODE_REWRITE), (NODE_GENERATE, NODE_GENERATE)],
)
.compile()
}
fn make_retriever_node<Ret>(
retriever: Arc<Ret>,
config: &Arc<CragConfig>,
) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
where
Ret: Retriever + ?Sized + 'static,
{
let config = Arc::clone(config);
RunnableLambda::new(
move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
let retriever = Arc::clone(&retriever);
let top_k = config.retrieval_top_k;
async move {
let query = RetrievalQuery::new(state.query.clone(), top_k);
let docs = retriever.retrieve(query, &ctx).await?;
state.graded = docs
.into_iter()
.map(|d| (d, GradeVerdict::Ambiguous))
.collect();
state.correct_documents.clear();
Ok::<_, _>(state)
}
},
)
}
fn make_grader_node<G>(
grader: Arc<G>,
) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
where
G: RetrievalGrader + ?Sized + 'static,
{
RunnableLambda::new(
move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
let grader = Arc::clone(&grader);
async move {
let mut verdicts = Vec::with_capacity(state.graded.len());
for (doc, _) in std::mem::take(&mut state.graded) {
let verdict = grader.grade(&state.query, &doc, &ctx).await?;
verdicts.push((doc, verdict));
}
state.correct_documents = verdicts
.iter()
.filter(|(_, v)| matches!(v, GradeVerdict::Correct))
.map(|(d, _)| d.clone())
.collect();
state.graded = verdicts;
Ok::<_, _>(state)
}
},
)
}
fn make_rewriter_node<R>(
rewriter: Arc<R>,
) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
where
R: QueryRewriter + ?Sized + 'static,
{
RunnableLambda::new(
move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
let rewriter = Arc::clone(&rewriter);
async move {
let new_query = rewriter
.rewrite(&state.original_query, &state.previous_attempts, &ctx)
.await?;
state.previous_attempts.push(state.query.clone());
state.query = new_query;
state.attempt = state.attempt.saturating_add(1);
Ok::<_, _>(state)
}
},
)
}
fn make_generator_node<M>(
generator: Arc<M>,
config: &Arc<CragConfig>,
) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
where
M: Runnable<Vec<Message>, Message> + ?Sized + 'static,
{
let config = Arc::clone(config);
RunnableLambda::new(
move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
let generator = Arc::clone(&generator);
let config = Arc::clone(&config);
async move {
let messages = build_generator_prompt(
&state.original_query,
&state.correct_documents,
&state.graded,
&config,
);
let reply = generator.invoke(messages, &ctx).await?;
let text = extract_text(&reply);
state.answer = Some(text);
Ok::<_, _>(state)
}
},
)
}
pub const CORRECTIVE_RAG_AGENT_NAME: &str = "corrective-rag";
const NODE_RETRIEVE: &str = "retrieve";
const NODE_GRADE: &str = "grade";
const NODE_REWRITE: &str = "rewrite";
const NODE_GENERATE: &str = "generate";
pub fn create_corrective_rag_agent<Ret, G, R, M>(
retriever: Arc<Ret>,
grader: G,
rewriter: R,
generator: M,
config: CragConfig,
) -> Result<Agent<CorrectiveRagState>>
where
Ret: Retriever + ?Sized + 'static,
G: RetrievalGrader + 'static,
R: QueryRewriter + 'static,
M: Runnable<Vec<Message>, Message> + 'static,
{
let graph = build_corrective_rag_graph(retriever, grader, rewriter, generator, config)?;
Agent::builder()
.with_name(CORRECTIVE_RAG_AGENT_NAME)
.with_runnable(graph)
.build()
}
fn route_after_grade(state: &CorrectiveRagState, config: &CragConfig) -> String {
let total = state.graded.len();
let correct = state.correct_documents.len();
let fraction = if total == 0 {
0.0_f32
} else {
#[allow(clippy::cast_precision_loss)]
let n = total as f32;
#[allow(clippy::cast_precision_loss)]
let k = correct as f32;
k / n
};
let threshold = config.min_correct_fraction.clamp(0.0, 1.0);
let budget_remaining = state.attempt < config.max_rewrite_attempts;
if fraction >= threshold && correct > 0 {
NODE_GENERATE.to_owned()
} else if budget_remaining {
NODE_REWRITE.to_owned()
} else {
NODE_GENERATE.to_owned()
}
}
fn build_generator_prompt(
original_query: &str,
correct_documents: &[RetrievedDocument],
fallback_graded: &[(RetrievedDocument, GradeVerdict)],
config: &CragConfig,
) -> Vec<Message> {
let evidence: String = if correct_documents.is_empty() {
let mut docs: Vec<&RetrievedDocument> = fallback_graded
.iter()
.filter_map(|(d, v)| matches!(v, GradeVerdict::Ambiguous).then_some(d))
.collect();
if docs.is_empty() {
docs = fallback_graded.iter().map(|(d, _)| d).collect();
}
format_documents(&docs)
} else {
let docs: Vec<&RetrievedDocument> = correct_documents.iter().collect();
format_documents(&docs)
};
let mut messages: Vec<Message> = config
.generator_system_prompt
.blocks()
.iter()
.map(|b| Message::new(Role::System, vec![ContentPart::text(b.text.clone())]))
.collect();
messages.push(Message::new(
Role::User,
vec![
ContentPart::text(format!("<query>\n{original_query}\n</query>")),
ContentPart::text(format!("<documents>\n{evidence}\n</documents>")),
],
));
messages
}
fn format_documents(docs: &[&RetrievedDocument]) -> String {
use std::fmt::Write as _;
let mut out = String::new();
for (idx, doc) in docs.iter().enumerate() {
if idx > 0 {
out.push_str("\n\n");
}
write!(&mut out, "[{idx}] {}", doc.content).expect("writing to String never fails");
}
out
}
fn extract_text(message: &Message) -> String {
let mut buf = String::new();
for part in &message.content {
if let ContentPart::Text { text, .. } = part {
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(text);
}
}
buf.trim().to_owned()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corrective::grader::RetrievalGrader;
use crate::corrective::rewriter::QueryRewriter;
use async_trait::async_trait;
use entelix_memory::Document as RetrievedDocument;
use std::sync::Mutex;
struct StaticRetriever {
docs_by_query: std::collections::HashMap<String, Vec<RetrievedDocument>>,
observed_queries: Mutex<Vec<String>>,
}
impl StaticRetriever {
fn new() -> Self {
Self {
docs_by_query: std::collections::HashMap::new(),
observed_queries: Mutex::new(Vec::new()),
}
}
fn with(mut self, query: &str, docs: Vec<RetrievedDocument>) -> Self {
self.docs_by_query.insert(query.to_owned(), docs);
self
}
fn observed(&self) -> Vec<String> {
self.observed_queries.lock().unwrap().clone()
}
}
#[async_trait]
impl Retriever for StaticRetriever {
async fn retrieve(
&self,
query: RetrievalQuery,
_ctx: &ExecutionContext,
) -> Result<Vec<RetrievedDocument>> {
self.observed_queries
.lock()
.unwrap()
.push(query.text.clone());
Ok(self
.docs_by_query
.get(&query.text)
.cloned()
.unwrap_or_default())
}
}
struct ScriptedGrader {
verdicts: std::collections::HashMap<String, GradeVerdict>,
}
impl ScriptedGrader {
fn new(map: &[(&str, GradeVerdict)]) -> Self {
Self {
verdicts: map.iter().map(|(k, v)| ((*k).to_owned(), *v)).collect(),
}
}
}
#[async_trait]
impl RetrievalGrader for ScriptedGrader {
fn name(&self) -> &'static str {
"scripted-grader"
}
async fn grade(
&self,
_query: &str,
doc: &RetrievedDocument,
_ctx: &ExecutionContext,
) -> Result<GradeVerdict> {
Ok(self
.verdicts
.get(&doc.content)
.copied()
.unwrap_or(GradeVerdict::Ambiguous))
}
}
struct ScriptedRewriter {
replies: Mutex<Vec<String>>,
}
impl ScriptedRewriter {
fn new(replies: &[&str]) -> Self {
Self {
replies: Mutex::new(replies.iter().map(|s| (*s).to_owned()).rev().collect()),
}
}
}
#[async_trait]
impl QueryRewriter for ScriptedRewriter {
fn name(&self) -> &'static str {
"scripted-rewriter"
}
async fn rewrite(
&self,
_original: &str,
_previous: &[String],
_ctx: &ExecutionContext,
) -> Result<String> {
Ok(self
.replies
.lock()
.unwrap()
.pop()
.unwrap_or_else(|| "<exhausted>".to_owned()))
}
}
#[derive(Clone)]
struct CapturingGenerator {
observed: Arc<Mutex<Vec<Vec<Message>>>>,
reply: String,
}
impl CapturingGenerator {
fn new(reply: &str) -> Self {
Self {
observed: Arc::new(Mutex::new(Vec::new())),
reply: reply.to_owned(),
}
}
fn observed(&self) -> Vec<Vec<Message>> {
self.observed.lock().unwrap().clone()
}
}
#[async_trait]
impl Runnable<Vec<Message>, Message> for CapturingGenerator {
async fn invoke(&self, input: Vec<Message>, _ctx: &ExecutionContext) -> Result<Message> {
self.observed.lock().unwrap().push(input);
Ok(Message::new(
Role::Assistant,
vec![ContentPart::text(self.reply.clone())],
))
}
}
fn doc(content: &str) -> RetrievedDocument {
RetrievedDocument::new(content)
}
#[tokio::test]
async fn happy_path_generates_directly_when_all_correct() {
let retriever = Arc::new(StaticRetriever::new().with(
"what is alpha?",
vec![
doc("alpha is the first letter"),
doc("alpha is also a particle"),
],
));
let grader = ScriptedGrader::new(&[
("alpha is the first letter", GradeVerdict::Correct),
("alpha is also a particle", GradeVerdict::Correct),
]);
let rewriter = ScriptedRewriter::new(&["never used"]);
let generator = CapturingGenerator::new("Alpha is the first letter.");
let agent = create_corrective_rag_agent(
Arc::clone(&retriever),
grader,
rewriter,
generator.clone(),
CragConfig::new(),
)
.unwrap();
let final_state = agent
.execute(
CorrectiveRagState::from_query("what is alpha?"),
&ExecutionContext::new(),
)
.await
.unwrap()
.state;
assert_eq!(
final_state.answer.as_deref(),
Some("Alpha is the first letter.")
);
assert_eq!(
final_state.attempt, 0,
"no rewrite when retrieval is correct"
);
assert_eq!(retriever.observed(), vec!["what is alpha?".to_owned()]);
}
#[tokio::test]
async fn incorrect_retrieval_triggers_rewrite_and_loops_back() {
let retriever = Arc::new(
StaticRetriever::new()
.with("alpha?", vec![doc("totally off-topic")])
.with(
"what is alpha letter?",
vec![doc("alpha is the first letter")],
),
);
let grader = ScriptedGrader::new(&[
("totally off-topic", GradeVerdict::Incorrect),
("alpha is the first letter", GradeVerdict::Correct),
]);
let rewriter = ScriptedRewriter::new(&["what is alpha letter?"]);
let generator = CapturingGenerator::new("Final.");
let agent = create_corrective_rag_agent(
Arc::clone(&retriever),
grader,
rewriter,
generator.clone(),
CragConfig::new(),
)
.unwrap();
let final_state = agent
.execute(
CorrectiveRagState::from_query("alpha?"),
&ExecutionContext::new(),
)
.await
.unwrap()
.state;
assert_eq!(final_state.attempt, 1, "exactly one rewrite happened");
assert_eq!(
retriever.observed(),
vec!["alpha?".to_owned(), "what is alpha letter?".to_owned()]
);
assert_eq!(final_state.answer.as_deref(), Some("Final."));
}
#[tokio::test]
async fn rewrite_budget_caps_loop_and_surrenders_to_generate() {
let retriever = Arc::new(
StaticRetriever::new()
.with("q0", vec![doc("bad-0")])
.with("q1", vec![doc("bad-1")])
.with("q2", vec![doc("bad-2")])
.with("q3", vec![doc("bad-3")]),
);
let grader = ScriptedGrader::new(&[
("bad-0", GradeVerdict::Incorrect),
("bad-1", GradeVerdict::Incorrect),
("bad-2", GradeVerdict::Incorrect),
("bad-3", GradeVerdict::Incorrect),
]);
let rewriter = ScriptedRewriter::new(&["q1", "q2", "q3"]);
let generator = CapturingGenerator::new("Surrendered.");
let agent = create_corrective_rag_agent(
Arc::clone(&retriever),
grader,
rewriter,
generator.clone(),
CragConfig::new().with_max_rewrite_attempts(2),
)
.unwrap();
let final_state = agent
.execute(
CorrectiveRagState::from_query("q0"),
&ExecutionContext::new(),
)
.await
.unwrap()
.state;
assert_eq!(final_state.attempt, 2, "rewrite budget = 2");
assert_eq!(
retriever.observed(),
vec!["q0".to_owned(), "q1".to_owned(), "q2".to_owned()]
);
assert_eq!(final_state.answer.as_deref(), Some("Surrendered."));
}
#[tokio::test]
async fn empty_retrieval_loops_to_rewrite_when_budget_remains() {
let retriever = Arc::new(
StaticRetriever::new()
.with("q0", vec![])
.with("q1", vec![doc("alpha is the first letter")]),
);
let grader = ScriptedGrader::new(&[("alpha is the first letter", GradeVerdict::Correct)]);
let rewriter = ScriptedRewriter::new(&["q1"]);
let generator = CapturingGenerator::new("Answered.");
let agent = create_corrective_rag_agent(
Arc::clone(&retriever),
grader,
rewriter,
generator.clone(),
CragConfig::new(),
)
.unwrap();
let final_state = agent
.execute(
CorrectiveRagState::from_query("q0"),
&ExecutionContext::new(),
)
.await
.unwrap()
.state;
assert_eq!(final_state.attempt, 1);
assert_eq!(final_state.answer.as_deref(), Some("Answered."));
}
#[tokio::test]
async fn generator_sees_only_correct_documents_when_mixed_batch() {
let retriever = Arc::new(StaticRetriever::new().with(
"alpha?",
vec![
doc("alpha is the first letter"),
doc("alpha is unrelated stuff"),
doc("more about alpha letter"),
],
));
let grader = ScriptedGrader::new(&[
("alpha is the first letter", GradeVerdict::Correct),
("alpha is unrelated stuff", GradeVerdict::Incorrect),
("more about alpha letter", GradeVerdict::Correct),
]);
let rewriter = ScriptedRewriter::new(&["unused"]);
let generator = CapturingGenerator::new("Answered.");
let agent = create_corrective_rag_agent(
Arc::clone(&retriever),
grader,
rewriter,
generator.clone(),
CragConfig::new(),
)
.unwrap();
let final_state = agent
.execute(
CorrectiveRagState::from_query("alpha?"),
&ExecutionContext::new(),
)
.await
.unwrap()
.state;
assert_eq!(final_state.attempt, 0, "2/3 correct = above 0.5 → generate");
let prompt = generator.observed();
assert_eq!(prompt.len(), 1);
let user_msg = prompt[0]
.iter()
.rfind(|m| matches!(m.role, Role::User))
.unwrap();
let docs_part = user_msg
.content
.iter()
.find_map(|p| match p {
ContentPart::Text { text, .. } if text.contains("documents") => Some(text.clone()),
_ => None,
})
.unwrap();
assert!(docs_part.contains("alpha is the first letter"));
assert!(docs_part.contains("more about alpha letter"));
assert!(!docs_part.contains("alpha is unrelated stuff"));
}
#[test]
fn config_defaults_match_published_constants() {
let cfg = CragConfig::default();
assert!((cfg.min_correct_fraction() - DEFAULT_MIN_CORRECT_FRACTION).abs() < f32::EPSILON);
assert_eq!(cfg.retrieval_top_k(), DEFAULT_RETRIEVAL_TOP_K);
assert_eq!(cfg.max_rewrite_attempts(), DEFAULT_MAX_REWRITE_ATTEMPTS);
}
#[test]
fn min_correct_fraction_clamped_during_routing() {
let cfg = CragConfig::new()
.with_min_correct_fraction(2.5)
.with_max_rewrite_attempts(0);
let state = CorrectiveRagState {
original_query: "q".into(),
query: "q".into(),
previous_attempts: vec![],
graded: vec![(doc("d"), GradeVerdict::Correct)],
correct_documents: vec![doc("d")],
attempt: 0,
answer: None,
};
assert_eq!(route_after_grade(&state, &cfg), "generate");
}
}