Skip to main content

entelix_rag/corrective/
recipe.rs

1//! `create_corrective_rag_agent` — CRAG (Yan et al. 2024)
2//! topology assembled into a runnable [`Agent`] over the
3//! primitives this module ships.
4//!
5//! ## Topology
6//!
7//! ```text
8//!         ┌─────────┐    ┌──────┐    ┌────────┐
9//! query ─▶│retrieve │───▶│grade │───▶│decide  │
10//!         └─────────┘    └──────┘    └───┬────┘
11//!              ▲                         │
12//!              │                  ┌──────┴──────┐
13//!              │                  ▼             ▼
14//!         ┌────┴────┐         ┌────────┐   ┌─────────┐
15//!         │rewrite  │◀────────│ retry  │   │generate │──▶ END
16//!         └─────────┘         └────────┘   └─────────┘
17//! ```
18//!
19//! Retrieve runs against the operator's [`Retriever`]; every
20//! result is graded by the [`RetrievalGrader`]; `decide` (a pure
21//! state-based router, no LLM call) routes on the fraction of
22//! [`GradeVerdict::Correct`] verdicts vs the configured
23//! [`CragConfig::min_correct_fraction`] threshold; `rewrite`
24//! re-issues the query through the [`QueryRewriter`] and loops back
25//! to retrieve; `generate` produces the final answer over the
26//! operator-supplied generator model. Configurable knobs
27//! ([`CragConfig`]) control the routing threshold and the
28//! rewrite-loop attempt cap.
29//!
30//! The CRAG paper's full three-way decision (Correct vs Ambiguous
31//! vs Incorrect, with a web-search branch on Incorrect) is
32//! intentionally collapsed here to a Correct-vs-not-Correct
33//! routing — the SDK ships no built-in web-search primitive, and
34//! operators wanting the third branch wire it as a fallback inside
35//! their custom [`Retriever`] (try local KB → escalate to web on
36//! empty / low-confidence hits).
37//!
38//! ## When to reach for this recipe
39//!
40//! Use this when the corpus is messy enough that naive
41//! retrieve-then-generate produces low-quality grounded answers
42//! and the operator wants the LLM-driven self-correction loop the
43//! CRAG paper describes. For corpora where retrieval quality is
44//! already high (well-curated technical docs, reference KB), the
45//! plain `IngestionPipeline` + manual `Retriever::retrieve` +
46//! `ChatModel` composition is cheaper and simpler — corrective
47//! routing only pays off when retrieval failures are common
48//! enough to amortise the grader's per-document cost.
49
50use std::sync::Arc;
51
52use entelix_agents::Agent;
53use entelix_core::ir::{ContentPart, Message, Role, SystemPrompt};
54use entelix_core::{ExecutionContext, Result};
55use entelix_graph::{CompiledGraph, StateGraph};
56use entelix_memory::{Document as RetrievedDocument, RetrievalQuery, Retriever};
57use entelix_runnable::{Runnable, RunnableLambda};
58
59use crate::corrective::grader::{GradeVerdict, RetrievalGrader};
60use crate::corrective::rewriter::QueryRewriter;
61
62/// Default minimum fraction of retrieved documents that must
63/// grade [`GradeVerdict::Correct`] for the recipe to skip rewriting
64/// and proceed directly to generation. `0.5` matches the CRAG
65/// paper's mid-confidence threshold — operators tuning for higher
66/// retrieval precision raise it; tuning for lower model spend
67/// (fewer rewrites at the cost of weaker grounding) lower it.
68pub const DEFAULT_MIN_CORRECT_FRACTION: f32 = 0.5;
69
70/// Default top-k passed into the retriever on every retrieval
71/// pass. Operator-overridable via
72/// [`CragConfig::with_retrieval_top_k`].
73pub const DEFAULT_RETRIEVAL_TOP_K: usize = 5;
74
75/// Default cap on rewrite-loop attempts before the recipe
76/// surrenders and generates over whatever was retrieved last.
77/// `3` is the CRAG paper's reported sweet spot (retrieval rarely
78/// improves beyond the third rewrite).
79pub const DEFAULT_MAX_REWRITE_ATTEMPTS: u32 = 3;
80
81/// Default system prompt the generator node prepends to every
82/// answer-generation call. Vendor-neutral, focused on grounded
83/// answer style.
84pub const DEFAULT_GENERATOR_SYSTEM_PROMPT: &str = "\
85You are a helpful assistant. Answer the user's question using only the supplied retrieved \
86documents as your evidence base. If the documents don't contain enough information to \
87answer with confidence, say so explicitly. Never fabricate facts that the documents do \
88not support.";
89
90/// Operator-tunable knobs for the corrective-RAG recipe. Construct
91/// via [`Self::new`] or [`Self::default`]; chain `with_*` setters.
92#[derive(Clone, Debug)]
93#[non_exhaustive]
94pub struct CragConfig {
95    min_correct_fraction: f32,
96    retrieval_top_k: usize,
97    max_rewrite_attempts: u32,
98    generator_system_prompt: SystemPrompt,
99}
100
101impl CragConfig {
102    /// Build with the default thresholds + retrieval top-k +
103    /// system prompt.
104    #[must_use]
105    pub fn new() -> Self {
106        Self::default()
107    }
108
109    /// Override the minimum fraction of `Correct` verdicts the
110    /// recipe needs before generating without a rewrite. Values
111    /// are clamped to `[0.0, 1.0]` at decision time — operators
112    /// supplying out-of-range values get the clamped value, no
113    /// error.
114    #[must_use]
115    pub const fn with_min_correct_fraction(mut self, fraction: f32) -> Self {
116        self.min_correct_fraction = fraction;
117        self
118    }
119
120    /// Override the retrieval top-k.
121    #[must_use]
122    pub const fn with_retrieval_top_k(mut self, top_k: usize) -> Self {
123        self.retrieval_top_k = top_k;
124        self
125    }
126
127    /// Override the rewrite-loop attempt cap. After this many
128    /// rewrites, the recipe generates over whatever the last
129    /// retrieval returned (even if every document graded
130    /// `Incorrect`) — surrender beats infinite loop.
131    #[must_use]
132    pub const fn with_max_rewrite_attempts(mut self, max: u32) -> Self {
133        self.max_rewrite_attempts = max;
134        self
135    }
136
137    /// Override the system prompt the generator node uses. Default
138    /// is [`DEFAULT_GENERATOR_SYSTEM_PROMPT`].
139    #[must_use]
140    pub fn with_generator_system_prompt(mut self, prompt: SystemPrompt) -> Self {
141        self.generator_system_prompt = prompt;
142        self
143    }
144
145    /// Effective minimum-correct fraction.
146    #[must_use]
147    pub const fn min_correct_fraction(&self) -> f32 {
148        self.min_correct_fraction
149    }
150
151    /// Effective retrieval top-k.
152    #[must_use]
153    pub const fn retrieval_top_k(&self) -> usize {
154        self.retrieval_top_k
155    }
156
157    /// Effective rewrite attempt cap.
158    #[must_use]
159    pub const fn max_rewrite_attempts(&self) -> u32 {
160        self.max_rewrite_attempts
161    }
162
163    /// Borrow the configured generator system prompt.
164    #[must_use]
165    pub const fn generator_system_prompt(&self) -> &SystemPrompt {
166        &self.generator_system_prompt
167    }
168}
169
170impl Default for CragConfig {
171    fn default() -> Self {
172        Self {
173            min_correct_fraction: DEFAULT_MIN_CORRECT_FRACTION,
174            retrieval_top_k: DEFAULT_RETRIEVAL_TOP_K,
175            max_rewrite_attempts: DEFAULT_MAX_REWRITE_ATTEMPTS,
176            generator_system_prompt: SystemPrompt::text(DEFAULT_GENERATOR_SYSTEM_PROMPT),
177        }
178    }
179}
180
181/// State the corrective-RAG graph drives across nodes. Carries
182/// the original + current query, the rewrite history, the last
183/// retrieval batch + verdicts, the surviving correct subset, and
184/// the terminal answer.
185///
186/// `attempt` counts rewrite passes — `0` is the original query,
187/// `n > 0` is the n-th rewrite. Compared against
188/// [`CragConfig::max_rewrite_attempts`] before each rewrite to
189/// short-circuit the loop.
190#[derive(Clone, Debug)]
191#[non_exhaustive]
192pub struct CorrectiveRagState {
193    /// Original user query. Untouched across the run.
194    pub original_query: String,
195    /// Current query being retrieved against. Updated by each
196    /// rewrite pass.
197    pub query: String,
198    /// Every prior failed query attempt, oldest first. The
199    /// rewriter sees this so it doesn't re-emit a previous
200    /// attempt.
201    pub previous_attempts: Vec<String>,
202    /// Per-document grade from the most recent grading pass.
203    pub graded: Vec<(RetrievedDocument, GradeVerdict)>,
204    /// Filtered subset of `graded` containing only documents
205    /// with [`GradeVerdict::Correct`]. The generator node sees
206    /// this as evidence base.
207    pub correct_documents: Vec<RetrievedDocument>,
208    /// Number of rewrite attempts performed so far. `0` before
209    /// the first rewrite.
210    pub attempt: u32,
211    /// Final answer text. `None` until the generator runs.
212    pub answer: Option<String>,
213}
214
215impl CorrectiveRagState {
216    /// Seed a fresh state with the user's query.
217    #[must_use]
218    pub fn from_query(query: impl Into<String>) -> Self {
219        let query = query.into();
220        Self {
221            original_query: query.clone(),
222            query,
223            previous_attempts: Vec::new(),
224            graded: Vec::new(),
225            correct_documents: Vec::new(),
226            attempt: 0,
227            answer: None,
228        }
229    }
230}
231
232/// Compile the corrective-RAG graph from operator-supplied
233/// primitives. Use this when you need to embed the graph as a
234/// node in a larger [`StateGraph`]; for a ready-to-execute
235/// agent, prefer [`create_corrective_rag_agent`].
236pub fn build_corrective_rag_graph<Ret, G, R, M>(
237    retriever: Arc<Ret>,
238    grader: G,
239    rewriter: R,
240    generator: M,
241    config: CragConfig,
242) -> Result<CompiledGraph<CorrectiveRagState>>
243where
244    Ret: Retriever + ?Sized + 'static,
245    G: RetrievalGrader + 'static,
246    R: QueryRewriter + 'static,
247    M: Runnable<Vec<Message>, Message> + 'static,
248{
249    let config = Arc::new(config);
250    let generator = Arc::new(generator);
251    let grader = Arc::new(grader);
252    let rewriter = Arc::new(rewriter);
253
254    StateGraph::<CorrectiveRagState>::new()
255        .add_node(NODE_RETRIEVE, make_retriever_node(retriever, &config))
256        .add_node(NODE_GRADE, make_grader_node(grader))
257        .add_node(NODE_REWRITE, make_rewriter_node(rewriter))
258        .add_node(NODE_GENERATE, make_generator_node(generator, &config))
259        .set_entry_point(NODE_RETRIEVE)
260        .add_finish_point(NODE_GENERATE)
261        .add_edge(NODE_RETRIEVE, NODE_GRADE)
262        .add_edge(NODE_REWRITE, NODE_RETRIEVE)
263        .add_conditional_edges(
264            NODE_GRADE,
265            {
266                let config = Arc::clone(&config);
267                move |state: &CorrectiveRagState| route_after_grade(state, &config)
268            },
269            [(NODE_REWRITE, NODE_REWRITE), (NODE_GENERATE, NODE_GENERATE)],
270        )
271        .compile()
272}
273
274fn make_retriever_node<Ret>(
275    retriever: Arc<Ret>,
276    config: &Arc<CragConfig>,
277) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
278where
279    Ret: Retriever + ?Sized + 'static,
280{
281    let config = Arc::clone(config);
282    RunnableLambda::new(
283        move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
284            let retriever = Arc::clone(&retriever);
285            let top_k = config.retrieval_top_k;
286            async move {
287                let query = RetrievalQuery::new(state.query.clone(), top_k);
288                let docs = retriever.retrieve(query, &ctx).await?;
289                state.graded = docs
290                    .into_iter()
291                    .map(|d| (d, GradeVerdict::Ambiguous))
292                    .collect();
293                state.correct_documents.clear();
294                Ok::<_, _>(state)
295            }
296        },
297    )
298}
299
300fn make_grader_node<G>(
301    grader: Arc<G>,
302) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
303where
304    G: RetrievalGrader + ?Sized + 'static,
305{
306    RunnableLambda::new(
307        move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
308            let grader = Arc::clone(&grader);
309            async move {
310                let mut verdicts = Vec::with_capacity(state.graded.len());
311                for (doc, _) in std::mem::take(&mut state.graded) {
312                    let verdict = grader.grade(&state.query, &doc, &ctx).await?;
313                    verdicts.push((doc, verdict));
314                }
315                state.correct_documents = verdicts
316                    .iter()
317                    .filter(|(_, v)| matches!(v, GradeVerdict::Correct))
318                    .map(|(d, _)| d.clone())
319                    .collect();
320                state.graded = verdicts;
321                Ok::<_, _>(state)
322            }
323        },
324    )
325}
326
327fn make_rewriter_node<R>(
328    rewriter: Arc<R>,
329) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
330where
331    R: QueryRewriter + ?Sized + 'static,
332{
333    RunnableLambda::new(
334        move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
335            let rewriter = Arc::clone(&rewriter);
336            async move {
337                let new_query = rewriter
338                    .rewrite(&state.original_query, &state.previous_attempts, &ctx)
339                    .await?;
340                state.previous_attempts.push(state.query.clone());
341                state.query = new_query;
342                state.attempt = state.attempt.saturating_add(1);
343                Ok::<_, _>(state)
344            }
345        },
346    )
347}
348
349fn make_generator_node<M>(
350    generator: Arc<M>,
351    config: &Arc<CragConfig>,
352) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
353where
354    M: Runnable<Vec<Message>, Message> + ?Sized + 'static,
355{
356    let config = Arc::clone(config);
357    RunnableLambda::new(
358        move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
359            let generator = Arc::clone(&generator);
360            let config = Arc::clone(&config);
361            async move {
362                let messages = build_generator_prompt(
363                    &state.original_query,
364                    &state.correct_documents,
365                    &state.graded,
366                    &config,
367                );
368                let reply = generator.invoke(messages, &ctx).await?;
369                let text = extract_text(&reply);
370                state.answer = Some(text);
371                Ok::<_, _>(state)
372            }
373        },
374    )
375}
376
377/// Stable agent name surfaced on every emitted
378/// [`entelix_agents::AgentEvent`] and OTel `entelix.agent.run` span.
379pub const CORRECTIVE_RAG_AGENT_NAME: &str = "corrective-rag";
380
381/// Single source of truth for the CRAG graph node identifiers —
382/// `add_node` / `add_edge` / `add_conditional_edges` / the router
383/// closure all reference these constants so a typo surfaces at
384/// compile time rather than as a silent dispatch miss.
385const NODE_RETRIEVE: &str = "retrieve";
386const NODE_GRADE: &str = "grade";
387const NODE_REWRITE: &str = "rewrite";
388const NODE_GENERATE: &str = "generate";
389
390/// Build a ready-to-execute corrective-RAG [`Agent`]. Wraps
391/// [`build_corrective_rag_graph`] in the standard `Agent<S>` shape
392/// so the full lifecycle (`AgentEvent` stream, sink fan-out,
393/// observer hooks, supervisor handoff) integrates uniformly with
394/// every other recipe (`create_react_agent`,
395/// `create_supervisor_agent`, `create_chat_agent`).
396///
397/// `Agent::execute` returns
398/// [`AgentRunResult<CorrectiveRagState>`](entelix_agents::AgentRunResult)
399/// — the standard envelope carrying the terminal state plus a
400/// frozen `RunBudget` snapshot. Seed the input via
401/// [`CorrectiveRagState::from_query`].
402///
403/// Operators embedding the CRAG graph as a node in a larger
404/// `StateGraph<S>` reach for [`build_corrective_rag_graph`]
405/// directly and skip the `Agent` wrapper.
406pub fn create_corrective_rag_agent<Ret, G, R, M>(
407    retriever: Arc<Ret>,
408    grader: G,
409    rewriter: R,
410    generator: M,
411    config: CragConfig,
412) -> Result<Agent<CorrectiveRagState>>
413where
414    Ret: Retriever + ?Sized + 'static,
415    G: RetrievalGrader + 'static,
416    R: QueryRewriter + 'static,
417    M: Runnable<Vec<Message>, Message> + 'static,
418{
419    let graph = build_corrective_rag_graph(retriever, grader, rewriter, generator, config)?;
420    Agent::builder()
421        .with_name(CORRECTIVE_RAG_AGENT_NAME)
422        .with_runnable(graph)
423        .build()
424}
425
426/// Route after a `grade` pass — decide whether to proceed to
427/// `generate` or loop back through `rewrite`.
428///
429/// Routing rules (CRAG paper):
430/// - No documents at all → `rewrite` (unless attempt cap hit).
431/// - Fraction of `Correct` verdicts ≥ `min_correct_fraction` →
432///   `generate` (we have enough grounded evidence).
433/// - Otherwise → `rewrite` (if budget remains) or `generate`
434///   (when the rewrite budget is exhausted; surrender beats
435///   infinite loop).
436fn route_after_grade(state: &CorrectiveRagState, config: &CragConfig) -> String {
437    let total = state.graded.len();
438    let correct = state.correct_documents.len();
439    let fraction = if total == 0 {
440        0.0_f32
441    } else {
442        // Cast is benign — chunk counts above ~16M are
443        // pathological for an in-flight retrieval batch and the
444        // operator's top_k caps it well below f32 precision loss.
445        #[allow(clippy::cast_precision_loss)]
446        let n = total as f32;
447        #[allow(clippy::cast_precision_loss)]
448        let k = correct as f32;
449        k / n
450    };
451    let threshold = config.min_correct_fraction.clamp(0.0, 1.0);
452    let budget_remaining = state.attempt < config.max_rewrite_attempts;
453    if fraction >= threshold && correct > 0 {
454        NODE_GENERATE.to_owned()
455    } else if budget_remaining {
456        NODE_REWRITE.to_owned()
457    } else {
458        NODE_GENERATE.to_owned()
459    }
460}
461
462/// Build the user message the generator node sends to the model.
463/// Three text parts: the original user query, the correct
464/// documents (one block each), and a fallback context list when
465/// every retrieval graded poorly so the generator at least sees
466/// what the corpus surfaced.
467fn build_generator_prompt(
468    original_query: &str,
469    correct_documents: &[RetrievedDocument],
470    fallback_graded: &[(RetrievedDocument, GradeVerdict)],
471    config: &CragConfig,
472) -> Vec<Message> {
473    let evidence: String = if correct_documents.is_empty() {
474        // No `Correct` verdicts — generator runs over whatever
475        // graded best. Prefer `Ambiguous` over `Incorrect`; that
476        // ordering is the CRAG paper's degraded-mode behaviour.
477        let mut docs: Vec<&RetrievedDocument> = fallback_graded
478            .iter()
479            .filter_map(|(d, v)| matches!(v, GradeVerdict::Ambiguous).then_some(d))
480            .collect();
481        if docs.is_empty() {
482            docs = fallback_graded.iter().map(|(d, _)| d).collect();
483        }
484        format_documents(&docs)
485    } else {
486        let docs: Vec<&RetrievedDocument> = correct_documents.iter().collect();
487        format_documents(&docs)
488    };
489
490    let mut messages: Vec<Message> = config
491        .generator_system_prompt
492        .blocks()
493        .iter()
494        .map(|b| Message::new(Role::System, vec![ContentPart::text(b.text.clone())]))
495        .collect();
496    messages.push(Message::new(
497        Role::User,
498        vec![
499            ContentPart::text(format!("<query>\n{original_query}\n</query>")),
500            ContentPart::text(format!("<documents>\n{evidence}\n</documents>")),
501        ],
502    ));
503    messages
504}
505
506/// Render a slice of documents as one concatenated text block,
507/// each delimited by a blank line so the model can tell them
508/// apart without the recipe imposing JSON structure.
509fn format_documents(docs: &[&RetrievedDocument]) -> String {
510    use std::fmt::Write as _;
511    let mut out = String::new();
512    for (idx, doc) in docs.iter().enumerate() {
513        if idx > 0 {
514            out.push_str("\n\n");
515        }
516        write!(&mut out, "[{idx}] {}", doc.content).expect("writing to String never fails");
517    }
518    out
519}
520
521/// Pull the assistant reply text out of the model's [`Message`].
522/// Concatenates every [`ContentPart::Text`] part with single
523/// newline separators; non-text parts are skipped.
524fn extract_text(message: &Message) -> String {
525    let mut buf = String::new();
526    for part in &message.content {
527        if let ContentPart::Text { text, .. } = part {
528            if !buf.is_empty() {
529                buf.push('\n');
530            }
531            buf.push_str(text);
532        }
533    }
534    buf.trim().to_owned()
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use crate::corrective::grader::RetrievalGrader;
541    use crate::corrective::rewriter::QueryRewriter;
542    use async_trait::async_trait;
543    use entelix_memory::Document as RetrievedDocument;
544    use std::sync::Mutex;
545
546    /// Static retriever: returns a pre-canned doc set per query
547    /// and tracks every query it saw.
548    struct StaticRetriever {
549        docs_by_query: std::collections::HashMap<String, Vec<RetrievedDocument>>,
550        observed_queries: Mutex<Vec<String>>,
551    }
552
553    impl StaticRetriever {
554        fn new() -> Self {
555            Self {
556                docs_by_query: std::collections::HashMap::new(),
557                observed_queries: Mutex::new(Vec::new()),
558            }
559        }
560
561        fn with(mut self, query: &str, docs: Vec<RetrievedDocument>) -> Self {
562            self.docs_by_query.insert(query.to_owned(), docs);
563            self
564        }
565
566        fn observed(&self) -> Vec<String> {
567            self.observed_queries.lock().unwrap().clone()
568        }
569    }
570
571    #[async_trait]
572    impl Retriever for StaticRetriever {
573        async fn retrieve(
574            &self,
575            query: RetrievalQuery,
576            _ctx: &ExecutionContext,
577        ) -> Result<Vec<RetrievedDocument>> {
578            self.observed_queries
579                .lock()
580                .unwrap()
581                .push(query.text.clone());
582            Ok(self
583                .docs_by_query
584                .get(&query.text)
585                .cloned()
586                .unwrap_or_default())
587        }
588    }
589
590    /// Verdict-scripted grader: returns pre-canned verdicts per
591    /// document content.
592    struct ScriptedGrader {
593        verdicts: std::collections::HashMap<String, GradeVerdict>,
594    }
595
596    impl ScriptedGrader {
597        fn new(map: &[(&str, GradeVerdict)]) -> Self {
598            Self {
599                verdicts: map.iter().map(|(k, v)| ((*k).to_owned(), *v)).collect(),
600            }
601        }
602    }
603
604    #[async_trait]
605    impl RetrievalGrader for ScriptedGrader {
606        fn name(&self) -> &'static str {
607            "scripted-grader"
608        }
609        async fn grade(
610            &self,
611            _query: &str,
612            doc: &RetrievedDocument,
613            _ctx: &ExecutionContext,
614        ) -> Result<GradeVerdict> {
615            Ok(self
616                .verdicts
617                .get(&doc.content)
618                .copied()
619                .unwrap_or(GradeVerdict::Ambiguous))
620        }
621    }
622
623    /// Rewriter that returns the next pre-canned string per call.
624    struct ScriptedRewriter {
625        replies: Mutex<Vec<String>>,
626    }
627
628    impl ScriptedRewriter {
629        fn new(replies: &[&str]) -> Self {
630            Self {
631                replies: Mutex::new(replies.iter().map(|s| (*s).to_owned()).rev().collect()),
632            }
633        }
634    }
635
636    #[async_trait]
637    impl QueryRewriter for ScriptedRewriter {
638        fn name(&self) -> &'static str {
639            "scripted-rewriter"
640        }
641        async fn rewrite(
642            &self,
643            _original: &str,
644            _previous: &[String],
645            _ctx: &ExecutionContext,
646        ) -> Result<String> {
647            Ok(self
648                .replies
649                .lock()
650                .unwrap()
651                .pop()
652                .unwrap_or_else(|| "<exhausted>".to_owned()))
653        }
654    }
655
656    /// Generator that records the messages it received and replies
657    /// with a fixed answer. Cheap to clone — the observed-prompts
658    /// log lives under an `Arc<Mutex<...>>` so the test can keep
659    /// one handle for inspection while the recipe moves another
660    /// into the graph.
661    #[derive(Clone)]
662    struct CapturingGenerator {
663        observed: Arc<Mutex<Vec<Vec<Message>>>>,
664        reply: String,
665    }
666
667    impl CapturingGenerator {
668        fn new(reply: &str) -> Self {
669            Self {
670                observed: Arc::new(Mutex::new(Vec::new())),
671                reply: reply.to_owned(),
672            }
673        }
674        fn observed(&self) -> Vec<Vec<Message>> {
675            self.observed.lock().unwrap().clone()
676        }
677    }
678
679    #[async_trait]
680    impl Runnable<Vec<Message>, Message> for CapturingGenerator {
681        async fn invoke(&self, input: Vec<Message>, _ctx: &ExecutionContext) -> Result<Message> {
682            self.observed.lock().unwrap().push(input);
683            Ok(Message::new(
684                Role::Assistant,
685                vec![ContentPart::text(self.reply.clone())],
686            ))
687        }
688    }
689
690    fn doc(content: &str) -> RetrievedDocument {
691        RetrievedDocument::new(content)
692    }
693
694    #[tokio::test]
695    async fn happy_path_generates_directly_when_all_correct() {
696        let retriever = Arc::new(StaticRetriever::new().with(
697            "what is alpha?",
698            vec![
699                doc("alpha is the first letter"),
700                doc("alpha is also a particle"),
701            ],
702        ));
703        let grader = ScriptedGrader::new(&[
704            ("alpha is the first letter", GradeVerdict::Correct),
705            ("alpha is also a particle", GradeVerdict::Correct),
706        ]);
707        let rewriter = ScriptedRewriter::new(&["never used"]);
708        let generator = CapturingGenerator::new("Alpha is the first letter.");
709
710        let agent = create_corrective_rag_agent(
711            Arc::clone(&retriever),
712            grader,
713            rewriter,
714            generator.clone(),
715            CragConfig::new(),
716        )
717        .unwrap();
718        let final_state = agent
719            .execute(
720                CorrectiveRagState::from_query("what is alpha?"),
721                &ExecutionContext::new(),
722            )
723            .await
724            .unwrap()
725            .state;
726
727        assert_eq!(
728            final_state.answer.as_deref(),
729            Some("Alpha is the first letter.")
730        );
731        assert_eq!(
732            final_state.attempt, 0,
733            "no rewrite when retrieval is correct"
734        );
735        assert_eq!(retriever.observed(), vec!["what is alpha?".to_owned()]);
736    }
737
738    #[tokio::test]
739    async fn incorrect_retrieval_triggers_rewrite_and_loops_back() {
740        let retriever = Arc::new(
741            StaticRetriever::new()
742                .with("alpha?", vec![doc("totally off-topic")])
743                .with(
744                    "what is alpha letter?",
745                    vec![doc("alpha is the first letter")],
746                ),
747        );
748        let grader = ScriptedGrader::new(&[
749            ("totally off-topic", GradeVerdict::Incorrect),
750            ("alpha is the first letter", GradeVerdict::Correct),
751        ]);
752        let rewriter = ScriptedRewriter::new(&["what is alpha letter?"]);
753        let generator = CapturingGenerator::new("Final.");
754
755        let agent = create_corrective_rag_agent(
756            Arc::clone(&retriever),
757            grader,
758            rewriter,
759            generator.clone(),
760            CragConfig::new(),
761        )
762        .unwrap();
763        let final_state = agent
764            .execute(
765                CorrectiveRagState::from_query("alpha?"),
766                &ExecutionContext::new(),
767            )
768            .await
769            .unwrap()
770            .state;
771
772        assert_eq!(final_state.attempt, 1, "exactly one rewrite happened");
773        assert_eq!(
774            retriever.observed(),
775            vec!["alpha?".to_owned(), "what is alpha letter?".to_owned()]
776        );
777        assert_eq!(final_state.answer.as_deref(), Some("Final."));
778    }
779
780    #[tokio::test]
781    async fn rewrite_budget_caps_loop_and_surrenders_to_generate() {
782        // Every retrieval grades Incorrect, every rewrite produces
783        // a different (still-bad) query. The recipe must cap at
784        // max_rewrite_attempts and generate over whatever was
785        // last retrieved.
786        let retriever = Arc::new(
787            StaticRetriever::new()
788                .with("q0", vec![doc("bad-0")])
789                .with("q1", vec![doc("bad-1")])
790                .with("q2", vec![doc("bad-2")])
791                .with("q3", vec![doc("bad-3")]),
792        );
793        let grader = ScriptedGrader::new(&[
794            ("bad-0", GradeVerdict::Incorrect),
795            ("bad-1", GradeVerdict::Incorrect),
796            ("bad-2", GradeVerdict::Incorrect),
797            ("bad-3", GradeVerdict::Incorrect),
798        ]);
799        let rewriter = ScriptedRewriter::new(&["q1", "q2", "q3"]);
800        let generator = CapturingGenerator::new("Surrendered.");
801
802        let agent = create_corrective_rag_agent(
803            Arc::clone(&retriever),
804            grader,
805            rewriter,
806            generator.clone(),
807            CragConfig::new().with_max_rewrite_attempts(2),
808        )
809        .unwrap();
810        let final_state = agent
811            .execute(
812                CorrectiveRagState::from_query("q0"),
813                &ExecutionContext::new(),
814            )
815            .await
816            .unwrap()
817            .state;
818
819        assert_eq!(final_state.attempt, 2, "rewrite budget = 2");
820        // Retrievals: original (q0), rewrite-1 (q1), rewrite-2 (q2).
821        // Generator runs after the third retrieval — we never saw q3.
822        assert_eq!(
823            retriever.observed(),
824            vec!["q0".to_owned(), "q1".to_owned(), "q2".to_owned()]
825        );
826        assert_eq!(final_state.answer.as_deref(), Some("Surrendered."));
827    }
828
829    #[tokio::test]
830    async fn empty_retrieval_loops_to_rewrite_when_budget_remains() {
831        let retriever = Arc::new(
832            StaticRetriever::new()
833                .with("q0", vec![])
834                .with("q1", vec![doc("alpha is the first letter")]),
835        );
836        let grader = ScriptedGrader::new(&[("alpha is the first letter", GradeVerdict::Correct)]);
837        let rewriter = ScriptedRewriter::new(&["q1"]);
838        let generator = CapturingGenerator::new("Answered.");
839
840        let agent = create_corrective_rag_agent(
841            Arc::clone(&retriever),
842            grader,
843            rewriter,
844            generator.clone(),
845            CragConfig::new(),
846        )
847        .unwrap();
848        let final_state = agent
849            .execute(
850                CorrectiveRagState::from_query("q0"),
851                &ExecutionContext::new(),
852            )
853            .await
854            .unwrap()
855            .state;
856        assert_eq!(final_state.attempt, 1);
857        assert_eq!(final_state.answer.as_deref(), Some("Answered."));
858    }
859
860    #[tokio::test]
861    async fn generator_sees_only_correct_documents_when_mixed_batch() {
862        let retriever = Arc::new(StaticRetriever::new().with(
863            "alpha?",
864            vec![
865                doc("alpha is the first letter"),
866                doc("alpha is unrelated stuff"),
867                doc("more about alpha letter"),
868            ],
869        ));
870        let grader = ScriptedGrader::new(&[
871            ("alpha is the first letter", GradeVerdict::Correct),
872            ("alpha is unrelated stuff", GradeVerdict::Incorrect),
873            ("more about alpha letter", GradeVerdict::Correct),
874        ]);
875        let rewriter = ScriptedRewriter::new(&["unused"]);
876        let generator = CapturingGenerator::new("Answered.");
877
878        let agent = create_corrective_rag_agent(
879            Arc::clone(&retriever),
880            grader,
881            rewriter,
882            generator.clone(),
883            CragConfig::new(),
884        )
885        .unwrap();
886        let final_state = agent
887            .execute(
888                CorrectiveRagState::from_query("alpha?"),
889                &ExecutionContext::new(),
890            )
891            .await
892            .unwrap()
893            .state;
894
895        assert_eq!(final_state.attempt, 0, "2/3 correct = above 0.5 → generate");
896        let prompt = generator.observed();
897        assert_eq!(prompt.len(), 1);
898        // Last user message has documents — verify the Incorrect
899        // doc is filtered out and only the two Correct ones land.
900        let user_msg = prompt[0]
901            .iter()
902            .rfind(|m| matches!(m.role, Role::User))
903            .unwrap();
904        let docs_part = user_msg
905            .content
906            .iter()
907            .find_map(|p| match p {
908                ContentPart::Text { text, .. } if text.contains("documents") => Some(text.clone()),
909                _ => None,
910            })
911            .unwrap();
912        assert!(docs_part.contains("alpha is the first letter"));
913        assert!(docs_part.contains("more about alpha letter"));
914        assert!(!docs_part.contains("alpha is unrelated stuff"));
915    }
916
917    #[test]
918    fn config_defaults_match_published_constants() {
919        let cfg = CragConfig::default();
920        assert!((cfg.min_correct_fraction() - DEFAULT_MIN_CORRECT_FRACTION).abs() < f32::EPSILON);
921        assert_eq!(cfg.retrieval_top_k(), DEFAULT_RETRIEVAL_TOP_K);
922        assert_eq!(cfg.max_rewrite_attempts(), DEFAULT_MAX_REWRITE_ATTEMPTS);
923    }
924
925    #[test]
926    fn min_correct_fraction_clamped_during_routing() {
927        // Out-of-range fractions don't crash routing — they
928        // clamp to [0, 1] before the comparison.
929        let cfg = CragConfig::new()
930            .with_min_correct_fraction(2.5)
931            .with_max_rewrite_attempts(0);
932        let state = CorrectiveRagState {
933            original_query: "q".into(),
934            query: "q".into(),
935            previous_attempts: vec![],
936            graded: vec![(doc("d"), GradeVerdict::Correct)],
937            correct_documents: vec![doc("d")],
938            attempt: 0,
939            answer: None,
940        };
941        // 1/1 = 1.0; even with threshold > 1.0 the clamp brings
942        // it to 1.0 and 1.0 ≥ 1.0 so we route to generate.
943        assert_eq!(route_after_grade(&state, &cfg), "generate");
944    }
945}