1use std::sync::Arc;
51
52use entelix_agents::Agent;
53use entelix_core::ir::{ContentPart, Message, Role, SystemPrompt};
54use entelix_core::{ExecutionContext, Result};
55use entelix_graph::{CompiledGraph, StateGraph};
56use entelix_memory::{Document as RetrievedDocument, RetrievalQuery, Retriever};
57use entelix_runnable::{Runnable, RunnableLambda};
58
59use crate::corrective::grader::{GradeVerdict, RetrievalGrader};
60use crate::corrective::rewriter::QueryRewriter;
61
62pub const DEFAULT_MIN_CORRECT_FRACTION: f32 = 0.5;
69
70pub const DEFAULT_RETRIEVAL_TOP_K: usize = 5;
74
75pub const DEFAULT_MAX_REWRITE_ATTEMPTS: u32 = 3;
80
81pub const DEFAULT_GENERATOR_SYSTEM_PROMPT: &str = "\
85You are a helpful assistant. Answer the user's question using only the supplied retrieved \
86documents as your evidence base. If the documents don't contain enough information to \
87answer with confidence, say so explicitly. Never fabricate facts that the documents do \
88not support.";
89
90#[derive(Clone, Debug)]
93#[non_exhaustive]
94pub struct CragConfig {
95 min_correct_fraction: f32,
96 retrieval_top_k: usize,
97 max_rewrite_attempts: u32,
98 generator_system_prompt: SystemPrompt,
99}
100
101impl CragConfig {
102 #[must_use]
105 pub fn new() -> Self {
106 Self::default()
107 }
108
109 #[must_use]
115 pub const fn with_min_correct_fraction(mut self, fraction: f32) -> Self {
116 self.min_correct_fraction = fraction;
117 self
118 }
119
120 #[must_use]
122 pub const fn with_retrieval_top_k(mut self, top_k: usize) -> Self {
123 self.retrieval_top_k = top_k;
124 self
125 }
126
127 #[must_use]
132 pub const fn with_max_rewrite_attempts(mut self, max: u32) -> Self {
133 self.max_rewrite_attempts = max;
134 self
135 }
136
137 #[must_use]
140 pub fn with_generator_system_prompt(mut self, prompt: SystemPrompt) -> Self {
141 self.generator_system_prompt = prompt;
142 self
143 }
144
145 #[must_use]
147 pub const fn min_correct_fraction(&self) -> f32 {
148 self.min_correct_fraction
149 }
150
151 #[must_use]
153 pub const fn retrieval_top_k(&self) -> usize {
154 self.retrieval_top_k
155 }
156
157 #[must_use]
159 pub const fn max_rewrite_attempts(&self) -> u32 {
160 self.max_rewrite_attempts
161 }
162
163 #[must_use]
165 pub const fn generator_system_prompt(&self) -> &SystemPrompt {
166 &self.generator_system_prompt
167 }
168}
169
170impl Default for CragConfig {
171 fn default() -> Self {
172 Self {
173 min_correct_fraction: DEFAULT_MIN_CORRECT_FRACTION,
174 retrieval_top_k: DEFAULT_RETRIEVAL_TOP_K,
175 max_rewrite_attempts: DEFAULT_MAX_REWRITE_ATTEMPTS,
176 generator_system_prompt: SystemPrompt::text(DEFAULT_GENERATOR_SYSTEM_PROMPT),
177 }
178 }
179}
180
181#[derive(Clone, Debug)]
191#[non_exhaustive]
192pub struct CorrectiveRagState {
193 pub original_query: String,
195 pub query: String,
198 pub previous_attempts: Vec<String>,
202 pub graded: Vec<(RetrievedDocument, GradeVerdict)>,
204 pub correct_documents: Vec<RetrievedDocument>,
208 pub attempt: u32,
211 pub answer: Option<String>,
213}
214
215impl CorrectiveRagState {
216 #[must_use]
218 pub fn from_query(query: impl Into<String>) -> Self {
219 let query = query.into();
220 Self {
221 original_query: query.clone(),
222 query,
223 previous_attempts: Vec::new(),
224 graded: Vec::new(),
225 correct_documents: Vec::new(),
226 attempt: 0,
227 answer: None,
228 }
229 }
230}
231
232pub fn build_corrective_rag_graph<Ret, G, R, M>(
237 retriever: Arc<Ret>,
238 grader: G,
239 rewriter: R,
240 generator: M,
241 config: CragConfig,
242) -> Result<CompiledGraph<CorrectiveRagState>>
243where
244 Ret: Retriever + ?Sized + 'static,
245 G: RetrievalGrader + 'static,
246 R: QueryRewriter + 'static,
247 M: Runnable<Vec<Message>, Message> + 'static,
248{
249 let config = Arc::new(config);
250 let generator = Arc::new(generator);
251 let grader = Arc::new(grader);
252 let rewriter = Arc::new(rewriter);
253
254 StateGraph::<CorrectiveRagState>::new()
255 .add_node(NODE_RETRIEVE, make_retriever_node(retriever, &config))
256 .add_node(NODE_GRADE, make_grader_node(grader))
257 .add_node(NODE_REWRITE, make_rewriter_node(rewriter))
258 .add_node(NODE_GENERATE, make_generator_node(generator, &config))
259 .set_entry_point(NODE_RETRIEVE)
260 .add_finish_point(NODE_GENERATE)
261 .add_edge(NODE_RETRIEVE, NODE_GRADE)
262 .add_edge(NODE_REWRITE, NODE_RETRIEVE)
263 .add_conditional_edges(
264 NODE_GRADE,
265 {
266 let config = Arc::clone(&config);
267 move |state: &CorrectiveRagState| route_after_grade(state, &config)
268 },
269 [(NODE_REWRITE, NODE_REWRITE), (NODE_GENERATE, NODE_GENERATE)],
270 )
271 .compile()
272}
273
274fn make_retriever_node<Ret>(
275 retriever: Arc<Ret>,
276 config: &Arc<CragConfig>,
277) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
278where
279 Ret: Retriever + ?Sized + 'static,
280{
281 let config = Arc::clone(config);
282 RunnableLambda::new(
283 move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
284 let retriever = Arc::clone(&retriever);
285 let top_k = config.retrieval_top_k;
286 async move {
287 let query = RetrievalQuery::new(state.query.clone(), top_k);
288 let docs = retriever.retrieve(query, &ctx).await?;
289 state.graded = docs
290 .into_iter()
291 .map(|d| (d, GradeVerdict::Ambiguous))
292 .collect();
293 state.correct_documents.clear();
294 Ok::<_, _>(state)
295 }
296 },
297 )
298}
299
300fn make_grader_node<G>(
301 grader: Arc<G>,
302) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
303where
304 G: RetrievalGrader + ?Sized + 'static,
305{
306 RunnableLambda::new(
307 move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
308 let grader = Arc::clone(&grader);
309 async move {
310 let mut verdicts = Vec::with_capacity(state.graded.len());
311 for (doc, _) in std::mem::take(&mut state.graded) {
312 let verdict = grader.grade(&state.query, &doc, &ctx).await?;
313 verdicts.push((doc, verdict));
314 }
315 state.correct_documents = verdicts
316 .iter()
317 .filter(|(_, v)| matches!(v, GradeVerdict::Correct))
318 .map(|(d, _)| d.clone())
319 .collect();
320 state.graded = verdicts;
321 Ok::<_, _>(state)
322 }
323 },
324 )
325}
326
327fn make_rewriter_node<R>(
328 rewriter: Arc<R>,
329) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
330where
331 R: QueryRewriter + ?Sized + 'static,
332{
333 RunnableLambda::new(
334 move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
335 let rewriter = Arc::clone(&rewriter);
336 async move {
337 let new_query = rewriter
338 .rewrite(&state.original_query, &state.previous_attempts, &ctx)
339 .await?;
340 state.previous_attempts.push(state.query.clone());
341 state.query = new_query;
342 state.attempt = state.attempt.saturating_add(1);
343 Ok::<_, _>(state)
344 }
345 },
346 )
347}
348
349fn make_generator_node<M>(
350 generator: Arc<M>,
351 config: &Arc<CragConfig>,
352) -> impl Runnable<CorrectiveRagState, CorrectiveRagState> + 'static
353where
354 M: Runnable<Vec<Message>, Message> + ?Sized + 'static,
355{
356 let config = Arc::clone(config);
357 RunnableLambda::new(
358 move |mut state: CorrectiveRagState, ctx: ExecutionContext| {
359 let generator = Arc::clone(&generator);
360 let config = Arc::clone(&config);
361 async move {
362 let messages = build_generator_prompt(
363 &state.original_query,
364 &state.correct_documents,
365 &state.graded,
366 &config,
367 );
368 let reply = generator.invoke(messages, &ctx).await?;
369 let text = extract_text(&reply);
370 state.answer = Some(text);
371 Ok::<_, _>(state)
372 }
373 },
374 )
375}
376
377pub const CORRECTIVE_RAG_AGENT_NAME: &str = "corrective-rag";
380
381const NODE_RETRIEVE: &str = "retrieve";
386const NODE_GRADE: &str = "grade";
387const NODE_REWRITE: &str = "rewrite";
388const NODE_GENERATE: &str = "generate";
389
390pub fn create_corrective_rag_agent<Ret, G, R, M>(
407 retriever: Arc<Ret>,
408 grader: G,
409 rewriter: R,
410 generator: M,
411 config: CragConfig,
412) -> Result<Agent<CorrectiveRagState>>
413where
414 Ret: Retriever + ?Sized + 'static,
415 G: RetrievalGrader + 'static,
416 R: QueryRewriter + 'static,
417 M: Runnable<Vec<Message>, Message> + 'static,
418{
419 let graph = build_corrective_rag_graph(retriever, grader, rewriter, generator, config)?;
420 Agent::builder()
421 .with_name(CORRECTIVE_RAG_AGENT_NAME)
422 .with_runnable(graph)
423 .build()
424}
425
426fn route_after_grade(state: &CorrectiveRagState, config: &CragConfig) -> String {
437 let total = state.graded.len();
438 let correct = state.correct_documents.len();
439 let fraction = if total == 0 {
440 0.0_f32
441 } else {
442 #[allow(clippy::cast_precision_loss)]
446 let n = total as f32;
447 #[allow(clippy::cast_precision_loss)]
448 let k = correct as f32;
449 k / n
450 };
451 let threshold = config.min_correct_fraction.clamp(0.0, 1.0);
452 let budget_remaining = state.attempt < config.max_rewrite_attempts;
453 if fraction >= threshold && correct > 0 {
454 NODE_GENERATE.to_owned()
455 } else if budget_remaining {
456 NODE_REWRITE.to_owned()
457 } else {
458 NODE_GENERATE.to_owned()
459 }
460}
461
462fn build_generator_prompt(
468 original_query: &str,
469 correct_documents: &[RetrievedDocument],
470 fallback_graded: &[(RetrievedDocument, GradeVerdict)],
471 config: &CragConfig,
472) -> Vec<Message> {
473 let evidence: String = if correct_documents.is_empty() {
474 let mut docs: Vec<&RetrievedDocument> = fallback_graded
478 .iter()
479 .filter_map(|(d, v)| matches!(v, GradeVerdict::Ambiguous).then_some(d))
480 .collect();
481 if docs.is_empty() {
482 docs = fallback_graded.iter().map(|(d, _)| d).collect();
483 }
484 format_documents(&docs)
485 } else {
486 let docs: Vec<&RetrievedDocument> = correct_documents.iter().collect();
487 format_documents(&docs)
488 };
489
490 let mut messages: Vec<Message> = config
491 .generator_system_prompt
492 .blocks()
493 .iter()
494 .map(|b| Message::new(Role::System, vec![ContentPart::text(b.text.clone())]))
495 .collect();
496 messages.push(Message::new(
497 Role::User,
498 vec![
499 ContentPart::text(format!("<query>\n{original_query}\n</query>")),
500 ContentPart::text(format!("<documents>\n{evidence}\n</documents>")),
501 ],
502 ));
503 messages
504}
505
506fn format_documents(docs: &[&RetrievedDocument]) -> String {
510 use std::fmt::Write as _;
511 let mut out = String::new();
512 for (idx, doc) in docs.iter().enumerate() {
513 if idx > 0 {
514 out.push_str("\n\n");
515 }
516 write!(&mut out, "[{idx}] {}", doc.content).expect("writing to String never fails");
517 }
518 out
519}
520
521fn extract_text(message: &Message) -> String {
525 let mut buf = String::new();
526 for part in &message.content {
527 if let ContentPart::Text { text, .. } = part {
528 if !buf.is_empty() {
529 buf.push('\n');
530 }
531 buf.push_str(text);
532 }
533 }
534 buf.trim().to_owned()
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540 use crate::corrective::grader::RetrievalGrader;
541 use crate::corrective::rewriter::QueryRewriter;
542 use async_trait::async_trait;
543 use entelix_memory::Document as RetrievedDocument;
544 use std::sync::Mutex;
545
546 struct StaticRetriever {
549 docs_by_query: std::collections::HashMap<String, Vec<RetrievedDocument>>,
550 observed_queries: Mutex<Vec<String>>,
551 }
552
553 impl StaticRetriever {
554 fn new() -> Self {
555 Self {
556 docs_by_query: std::collections::HashMap::new(),
557 observed_queries: Mutex::new(Vec::new()),
558 }
559 }
560
561 fn with(mut self, query: &str, docs: Vec<RetrievedDocument>) -> Self {
562 self.docs_by_query.insert(query.to_owned(), docs);
563 self
564 }
565
566 fn observed(&self) -> Vec<String> {
567 self.observed_queries.lock().unwrap().clone()
568 }
569 }
570
571 #[async_trait]
572 impl Retriever for StaticRetriever {
573 async fn retrieve(
574 &self,
575 query: RetrievalQuery,
576 _ctx: &ExecutionContext,
577 ) -> Result<Vec<RetrievedDocument>> {
578 self.observed_queries
579 .lock()
580 .unwrap()
581 .push(query.text.clone());
582 Ok(self
583 .docs_by_query
584 .get(&query.text)
585 .cloned()
586 .unwrap_or_default())
587 }
588 }
589
590 struct ScriptedGrader {
593 verdicts: std::collections::HashMap<String, GradeVerdict>,
594 }
595
596 impl ScriptedGrader {
597 fn new(map: &[(&str, GradeVerdict)]) -> Self {
598 Self {
599 verdicts: map.iter().map(|(k, v)| ((*k).to_owned(), *v)).collect(),
600 }
601 }
602 }
603
604 #[async_trait]
605 impl RetrievalGrader for ScriptedGrader {
606 fn name(&self) -> &'static str {
607 "scripted-grader"
608 }
609 async fn grade(
610 &self,
611 _query: &str,
612 doc: &RetrievedDocument,
613 _ctx: &ExecutionContext,
614 ) -> Result<GradeVerdict> {
615 Ok(self
616 .verdicts
617 .get(&doc.content)
618 .copied()
619 .unwrap_or(GradeVerdict::Ambiguous))
620 }
621 }
622
623 struct ScriptedRewriter {
625 replies: Mutex<Vec<String>>,
626 }
627
628 impl ScriptedRewriter {
629 fn new(replies: &[&str]) -> Self {
630 Self {
631 replies: Mutex::new(replies.iter().map(|s| (*s).to_owned()).rev().collect()),
632 }
633 }
634 }
635
636 #[async_trait]
637 impl QueryRewriter for ScriptedRewriter {
638 fn name(&self) -> &'static str {
639 "scripted-rewriter"
640 }
641 async fn rewrite(
642 &self,
643 _original: &str,
644 _previous: &[String],
645 _ctx: &ExecutionContext,
646 ) -> Result<String> {
647 Ok(self
648 .replies
649 .lock()
650 .unwrap()
651 .pop()
652 .unwrap_or_else(|| "<exhausted>".to_owned()))
653 }
654 }
655
656 #[derive(Clone)]
662 struct CapturingGenerator {
663 observed: Arc<Mutex<Vec<Vec<Message>>>>,
664 reply: String,
665 }
666
667 impl CapturingGenerator {
668 fn new(reply: &str) -> Self {
669 Self {
670 observed: Arc::new(Mutex::new(Vec::new())),
671 reply: reply.to_owned(),
672 }
673 }
674 fn observed(&self) -> Vec<Vec<Message>> {
675 self.observed.lock().unwrap().clone()
676 }
677 }
678
679 #[async_trait]
680 impl Runnable<Vec<Message>, Message> for CapturingGenerator {
681 async fn invoke(&self, input: Vec<Message>, _ctx: &ExecutionContext) -> Result<Message> {
682 self.observed.lock().unwrap().push(input);
683 Ok(Message::new(
684 Role::Assistant,
685 vec![ContentPart::text(self.reply.clone())],
686 ))
687 }
688 }
689
690 fn doc(content: &str) -> RetrievedDocument {
691 RetrievedDocument::new(content)
692 }
693
694 #[tokio::test]
695 async fn happy_path_generates_directly_when_all_correct() {
696 let retriever = Arc::new(StaticRetriever::new().with(
697 "what is alpha?",
698 vec![
699 doc("alpha is the first letter"),
700 doc("alpha is also a particle"),
701 ],
702 ));
703 let grader = ScriptedGrader::new(&[
704 ("alpha is the first letter", GradeVerdict::Correct),
705 ("alpha is also a particle", GradeVerdict::Correct),
706 ]);
707 let rewriter = ScriptedRewriter::new(&["never used"]);
708 let generator = CapturingGenerator::new("Alpha is the first letter.");
709
710 let agent = create_corrective_rag_agent(
711 Arc::clone(&retriever),
712 grader,
713 rewriter,
714 generator.clone(),
715 CragConfig::new(),
716 )
717 .unwrap();
718 let final_state = agent
719 .execute(
720 CorrectiveRagState::from_query("what is alpha?"),
721 &ExecutionContext::new(),
722 )
723 .await
724 .unwrap()
725 .state;
726
727 assert_eq!(
728 final_state.answer.as_deref(),
729 Some("Alpha is the first letter.")
730 );
731 assert_eq!(
732 final_state.attempt, 0,
733 "no rewrite when retrieval is correct"
734 );
735 assert_eq!(retriever.observed(), vec!["what is alpha?".to_owned()]);
736 }
737
738 #[tokio::test]
739 async fn incorrect_retrieval_triggers_rewrite_and_loops_back() {
740 let retriever = Arc::new(
741 StaticRetriever::new()
742 .with("alpha?", vec![doc("totally off-topic")])
743 .with(
744 "what is alpha letter?",
745 vec![doc("alpha is the first letter")],
746 ),
747 );
748 let grader = ScriptedGrader::new(&[
749 ("totally off-topic", GradeVerdict::Incorrect),
750 ("alpha is the first letter", GradeVerdict::Correct),
751 ]);
752 let rewriter = ScriptedRewriter::new(&["what is alpha letter?"]);
753 let generator = CapturingGenerator::new("Final.");
754
755 let agent = create_corrective_rag_agent(
756 Arc::clone(&retriever),
757 grader,
758 rewriter,
759 generator.clone(),
760 CragConfig::new(),
761 )
762 .unwrap();
763 let final_state = agent
764 .execute(
765 CorrectiveRagState::from_query("alpha?"),
766 &ExecutionContext::new(),
767 )
768 .await
769 .unwrap()
770 .state;
771
772 assert_eq!(final_state.attempt, 1, "exactly one rewrite happened");
773 assert_eq!(
774 retriever.observed(),
775 vec!["alpha?".to_owned(), "what is alpha letter?".to_owned()]
776 );
777 assert_eq!(final_state.answer.as_deref(), Some("Final."));
778 }
779
780 #[tokio::test]
781 async fn rewrite_budget_caps_loop_and_surrenders_to_generate() {
782 let retriever = Arc::new(
787 StaticRetriever::new()
788 .with("q0", vec![doc("bad-0")])
789 .with("q1", vec![doc("bad-1")])
790 .with("q2", vec![doc("bad-2")])
791 .with("q3", vec![doc("bad-3")]),
792 );
793 let grader = ScriptedGrader::new(&[
794 ("bad-0", GradeVerdict::Incorrect),
795 ("bad-1", GradeVerdict::Incorrect),
796 ("bad-2", GradeVerdict::Incorrect),
797 ("bad-3", GradeVerdict::Incorrect),
798 ]);
799 let rewriter = ScriptedRewriter::new(&["q1", "q2", "q3"]);
800 let generator = CapturingGenerator::new("Surrendered.");
801
802 let agent = create_corrective_rag_agent(
803 Arc::clone(&retriever),
804 grader,
805 rewriter,
806 generator.clone(),
807 CragConfig::new().with_max_rewrite_attempts(2),
808 )
809 .unwrap();
810 let final_state = agent
811 .execute(
812 CorrectiveRagState::from_query("q0"),
813 &ExecutionContext::new(),
814 )
815 .await
816 .unwrap()
817 .state;
818
819 assert_eq!(final_state.attempt, 2, "rewrite budget = 2");
820 assert_eq!(
823 retriever.observed(),
824 vec!["q0".to_owned(), "q1".to_owned(), "q2".to_owned()]
825 );
826 assert_eq!(final_state.answer.as_deref(), Some("Surrendered."));
827 }
828
829 #[tokio::test]
830 async fn empty_retrieval_loops_to_rewrite_when_budget_remains() {
831 let retriever = Arc::new(
832 StaticRetriever::new()
833 .with("q0", vec![])
834 .with("q1", vec![doc("alpha is the first letter")]),
835 );
836 let grader = ScriptedGrader::new(&[("alpha is the first letter", GradeVerdict::Correct)]);
837 let rewriter = ScriptedRewriter::new(&["q1"]);
838 let generator = CapturingGenerator::new("Answered.");
839
840 let agent = create_corrective_rag_agent(
841 Arc::clone(&retriever),
842 grader,
843 rewriter,
844 generator.clone(),
845 CragConfig::new(),
846 )
847 .unwrap();
848 let final_state = agent
849 .execute(
850 CorrectiveRagState::from_query("q0"),
851 &ExecutionContext::new(),
852 )
853 .await
854 .unwrap()
855 .state;
856 assert_eq!(final_state.attempt, 1);
857 assert_eq!(final_state.answer.as_deref(), Some("Answered."));
858 }
859
860 #[tokio::test]
861 async fn generator_sees_only_correct_documents_when_mixed_batch() {
862 let retriever = Arc::new(StaticRetriever::new().with(
863 "alpha?",
864 vec![
865 doc("alpha is the first letter"),
866 doc("alpha is unrelated stuff"),
867 doc("more about alpha letter"),
868 ],
869 ));
870 let grader = ScriptedGrader::new(&[
871 ("alpha is the first letter", GradeVerdict::Correct),
872 ("alpha is unrelated stuff", GradeVerdict::Incorrect),
873 ("more about alpha letter", GradeVerdict::Correct),
874 ]);
875 let rewriter = ScriptedRewriter::new(&["unused"]);
876 let generator = CapturingGenerator::new("Answered.");
877
878 let agent = create_corrective_rag_agent(
879 Arc::clone(&retriever),
880 grader,
881 rewriter,
882 generator.clone(),
883 CragConfig::new(),
884 )
885 .unwrap();
886 let final_state = agent
887 .execute(
888 CorrectiveRagState::from_query("alpha?"),
889 &ExecutionContext::new(),
890 )
891 .await
892 .unwrap()
893 .state;
894
895 assert_eq!(final_state.attempt, 0, "2/3 correct = above 0.5 → generate");
896 let prompt = generator.observed();
897 assert_eq!(prompt.len(), 1);
898 let user_msg = prompt[0]
901 .iter()
902 .rfind(|m| matches!(m.role, Role::User))
903 .unwrap();
904 let docs_part = user_msg
905 .content
906 .iter()
907 .find_map(|p| match p {
908 ContentPart::Text { text, .. } if text.contains("documents") => Some(text.clone()),
909 _ => None,
910 })
911 .unwrap();
912 assert!(docs_part.contains("alpha is the first letter"));
913 assert!(docs_part.contains("more about alpha letter"));
914 assert!(!docs_part.contains("alpha is unrelated stuff"));
915 }
916
917 #[test]
918 fn config_defaults_match_published_constants() {
919 let cfg = CragConfig::default();
920 assert!((cfg.min_correct_fraction() - DEFAULT_MIN_CORRECT_FRACTION).abs() < f32::EPSILON);
921 assert_eq!(cfg.retrieval_top_k(), DEFAULT_RETRIEVAL_TOP_K);
922 assert_eq!(cfg.max_rewrite_attempts(), DEFAULT_MAX_REWRITE_ATTEMPTS);
923 }
924
925 #[test]
926 fn min_correct_fraction_clamped_during_routing() {
927 let cfg = CragConfig::new()
930 .with_min_correct_fraction(2.5)
931 .with_max_rewrite_attempts(0);
932 let state = CorrectiveRagState {
933 original_query: "q".into(),
934 query: "q".into(),
935 previous_attempts: vec![],
936 graded: vec![(doc("d"), GradeVerdict::Correct)],
937 correct_documents: vec![doc("d")],
938 attempt: 0,
939 answer: None,
940 };
941 assert_eq!(route_after_grade(&state, &cfg), "generate");
944 }
945}