Skip to main content

entelix_rag/corrective/
rewriter.rs

1//! Query rewriter — `(original_query, failed_attempts) → corrected_query`.
2//!
3//! When the [`crate::RetrievalGrader`] tells the corrective recipe
4//! the retrieved batch is mostly off-topic, the recipe asks the
5//! rewriter for a corrected query and re-runs retrieval. The
6//! rewriter sees the original query plus every prior failed
7//! attempt so it doesn't re-emit the same wording the failed
8//! retrieval already used.
9//!
10//! This module ships the *rewriter primitive* (trait + reference
11//! LLM impl). Operators with a heuristic rewriter (HyDE,
12//! query-expansion via thesaurus, BM25 keyword extraction) write
13//! their own [`QueryRewriter`] and bypass the LLM cost entirely.
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18use entelix_core::ir::{ContentPart, Message, Role};
19use entelix_core::{Error, ExecutionContext, Result};
20use entelix_runnable::Runnable;
21
22/// Async trait the corrective-RAG recipe calls when retrieval
23/// quality requires another attempt with a different query.
24/// Implementations may be LLM-driven, heuristic
25/// (query-expansion / synonym-bag), classifier-routed, or any
26/// hybrid — the recipe takes whatever string comes back and
27/// re-runs retrieval with it.
28#[async_trait]
29pub trait QueryRewriter: Send + Sync {
30    /// Stable rewriter identifier — surfaces in audit dashboards
31    /// alongside the per-attempt query string.
32    fn name(&self) -> &'static str;
33
34    /// Produce a corrected query. `original` is the user's
35    /// untouched first attempt; `previous_attempts` is every
36    /// failed retry since (in chronological order, oldest first)
37    /// so the rewriter can avoid emitting an attempt the recipe
38    /// has already failed on. An empty `previous_attempts` slice
39    /// means this is the first rewrite — the rewriter sees only
40    /// the original.
41    async fn rewrite(
42        &self,
43        original: &str,
44        previous_attempts: &[String],
45        ctx: &ExecutionContext,
46    ) -> Result<String>;
47}
48
49/// Default instruction prepended to every model call. Verbatim
50/// matches the CRAG-paper rewriter framing — the model produces
51/// one corrected query string, no surrounding explanation.
52pub const DEFAULT_REWRITER_INSTRUCTION: &str = "\
53You are a query rewriter. Given the user's original query and any prior failed attempts \
54(retrieval did not return useful results), produce a single corrected query that captures \
55the user's intent in different words. Reply with only the corrected query string — no \
56quotes, no explanation, no surrounding text.";
57
58/// Stable rewriter identifier for [`LlmQueryRewriter`].
59const LLM_REWRITER_NAME: &str = "llm-query-rewriter";
60
61/// Builder for [`LlmQueryRewriter`].
62pub struct LlmQueryRewriterBuilder<M> {
63    model: Arc<M>,
64    instruction: String,
65}
66
67impl<M> LlmQueryRewriterBuilder<M>
68where
69    M: Runnable<Vec<Message>, Message> + 'static,
70{
71    /// Override the operator-facing instruction. Default matches
72    /// [`DEFAULT_REWRITER_INSTRUCTION`] verbatim.
73    #[must_use]
74    pub fn with_instruction(mut self, instruction: impl Into<String>) -> Self {
75        self.instruction = instruction.into();
76        self
77    }
78
79    /// Finalise into a runnable rewriter.
80    #[must_use]
81    pub fn build(self) -> LlmQueryRewriter<M> {
82        LlmQueryRewriter {
83            model: self.model,
84            instruction: Arc::from(self.instruction),
85        }
86    }
87}
88
89/// Reference LLM-driven [`QueryRewriter`]. Asks the supplied
90/// `Runnable<Vec<Message>, Message>` model for a corrected
91/// query, then trims surrounding whitespace and quote marks.
92pub struct LlmQueryRewriter<M> {
93    model: Arc<M>,
94    instruction: Arc<str>,
95}
96
97impl<M> LlmQueryRewriter<M>
98where
99    M: Runnable<Vec<Message>, Message> + 'static,
100{
101    /// Start a builder bound to the supplied model.
102    #[must_use]
103    pub fn builder(model: Arc<M>) -> LlmQueryRewriterBuilder<M> {
104        LlmQueryRewriterBuilder {
105            model,
106            instruction: DEFAULT_REWRITER_INSTRUCTION.to_owned(),
107        }
108    }
109
110    /// Build the user message that frames one rewrite call. Three
111    /// text parts: instruction, original query, prior attempts.
112    /// Single-message shape so any
113    /// `Runnable<Vec<Message>, Message>` impl executes without
114    /// recipe-side wiring.
115    fn build_prompt(&self, original: &str, previous_attempts: &[String]) -> Vec<Message> {
116        let prior_block = if previous_attempts.is_empty() {
117            "(none)".to_owned()
118        } else {
119            previous_attempts
120                .iter()
121                .enumerate()
122                .map(|(idx, attempt)| format!("attempt {}: {attempt}", idx + 1))
123                .collect::<Vec<_>>()
124                .join("\n")
125        };
126        vec![Message::new(
127            Role::User,
128            vec![
129                ContentPart::text(self.instruction.to_string()),
130                ContentPart::text(format!("<original>\n{original}\n</original>")),
131                ContentPart::text(format!(
132                    "<failed_attempts>\n{prior_block}\n</failed_attempts>"
133                )),
134            ],
135        )]
136    }
137}
138
139impl<M> Clone for LlmQueryRewriter<M> {
140    fn clone(&self) -> Self {
141        Self {
142            model: Arc::clone(&self.model),
143            instruction: Arc::clone(&self.instruction),
144        }
145    }
146}
147
148impl<M> std::fmt::Debug for LlmQueryRewriter<M> {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("LlmQueryRewriter").finish_non_exhaustive()
151    }
152}
153
154#[async_trait]
155impl<M> QueryRewriter for LlmQueryRewriter<M>
156where
157    M: Runnable<Vec<Message>, Message> + 'static,
158{
159    fn name(&self) -> &'static str {
160        LLM_REWRITER_NAME
161    }
162
163    async fn rewrite(
164        &self,
165        original: &str,
166        previous_attempts: &[String],
167        ctx: &ExecutionContext,
168    ) -> Result<String> {
169        let prompt = self.build_prompt(original, previous_attempts);
170        let reply = self.model.invoke(prompt, ctx).await?;
171        let cleaned = clean_reply(&reply);
172        if cleaned.is_empty() {
173            return Err(Error::invalid_request(
174                "LlmQueryRewriter: model returned no text — rewrite failed",
175            ));
176        }
177        Ok(cleaned)
178    }
179}
180
181/// Strip surrounding whitespace + quote marks from the model's
182/// reply. Pulls every `Text` part out of the message and
183/// concatenates with single newlines; non-text parts (tool-use,
184/// image-output) are skipped — a rewriter that emits tool calls
185/// is a misconfiguration we silently degrade rather than fail
186/// on.
187fn clean_reply(message: &Message) -> String {
188    let mut buf = String::new();
189    for part in &message.content {
190        if let ContentPart::Text { text, .. } = part {
191            if !buf.is_empty() {
192                buf.push('\n');
193            }
194            buf.push_str(text);
195        }
196    }
197    let trimmed = buf.trim();
198    // Strip a single layer of surrounding quotes the model might
199    // emit despite the instruction. Done after `trim` so
200    // whitespace inside the quote pair survives.
201    let stripped = trimmed
202        .strip_prefix('"')
203        .and_then(|s| s.strip_suffix('"'))
204        .or_else(|| {
205            trimmed
206                .strip_prefix('\'')
207                .and_then(|s| s.strip_suffix('\''))
208        })
209        .unwrap_or(trimmed);
210    stripped.to_owned()
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use std::sync::Mutex;
217
218    fn assistant(text: &str) -> Message {
219        Message::new(Role::Assistant, vec![ContentPart::text(text)])
220    }
221
222    /// Scripted model — pops next reply per invocation, exposes
223    /// the prompts every call observed for prompt-shape pinning.
224    struct ScriptedModel {
225        script: Mutex<Vec<Result<Message>>>,
226        observed: Mutex<Vec<Vec<Message>>>,
227    }
228
229    impl ScriptedModel {
230        fn new(replies: Vec<Message>) -> Self {
231            Self {
232                script: Mutex::new(replies.into_iter().map(Ok).rev().collect()),
233                observed: Mutex::new(Vec::new()),
234            }
235        }
236        fn observed(&self) -> Vec<Vec<Message>> {
237            self.observed.lock().unwrap().clone()
238        }
239    }
240
241    #[async_trait]
242    impl Runnable<Vec<Message>, Message> for ScriptedModel {
243        async fn invoke(&self, input: Vec<Message>, _ctx: &ExecutionContext) -> Result<Message> {
244            self.observed.lock().unwrap().push(input);
245            self.script.lock().unwrap().pop().expect("script exhausted")
246        }
247    }
248
249    #[tokio::test]
250    async fn first_attempt_sees_only_original_query() {
251        let model = Arc::new(ScriptedModel::new(vec![assistant(
252            "alpha letter explanation",
253        )]));
254        let rewriter = LlmQueryRewriter::builder(Arc::clone(&model)).build();
255        let out = rewriter
256            .rewrite("what is alpha?", &[], &ExecutionContext::new())
257            .await
258            .unwrap();
259        assert_eq!(out, "alpha letter explanation");
260
261        // The prompt's third part records "(none)" when no prior
262        // attempts exist — pin that the rewriter signals an empty
263        // history rather than emitting an empty `<failed_attempts>`
264        // block that the model could mistake for a malformed prompt.
265        let prompts = model.observed();
266        let parts = &prompts[0][0].content;
267        let prior_text = match &parts[2] {
268            ContentPart::Text { text, .. } => text.clone(),
269            _ => panic!("third part must be Text"),
270        };
271        assert!(prior_text.contains("(none)"));
272    }
273
274    #[tokio::test]
275    async fn subsequent_attempts_carry_prior_history() {
276        let model = Arc::new(ScriptedModel::new(vec![assistant(
277            "what does alpha denote in linear algebra?",
278        )]));
279        let rewriter = LlmQueryRewriter::builder(Arc::clone(&model)).build();
280        let prior = vec!["alpha?".to_owned(), "alpha letter".to_owned()];
281        rewriter
282            .rewrite("alpha", &prior, &ExecutionContext::new())
283            .await
284            .unwrap();
285        let prompts = model.observed();
286        let prior_text = match &prompts[0][0].content[2] {
287            ContentPart::Text { text, .. } => text.clone(),
288            _ => panic!("third part must be Text"),
289        };
290        assert!(prior_text.contains("attempt 1: alpha?"));
291        assert!(prior_text.contains("attempt 2: alpha letter"));
292    }
293
294    #[tokio::test]
295    async fn double_quotes_stripped_from_reply() {
296        let model = Arc::new(ScriptedModel::new(vec![assistant(
297            "\"alpha definition with examples\"",
298        )]));
299        let rewriter = LlmQueryRewriter::builder(model).build();
300        let out = rewriter
301            .rewrite("alpha", &[], &ExecutionContext::new())
302            .await
303            .unwrap();
304        assert_eq!(out, "alpha definition with examples");
305    }
306
307    #[tokio::test]
308    async fn single_quotes_stripped_from_reply() {
309        let model = Arc::new(ScriptedModel::new(vec![assistant("'alpha primer'")]));
310        let rewriter = LlmQueryRewriter::builder(model).build();
311        let out = rewriter
312            .rewrite("alpha", &[], &ExecutionContext::new())
313            .await
314            .unwrap();
315        assert_eq!(out, "alpha primer");
316    }
317
318    #[tokio::test]
319    async fn whitespace_around_reply_trimmed() {
320        let model = Arc::new(ScriptedModel::new(vec![assistant("   alpha primer\n")]));
321        let rewriter = LlmQueryRewriter::builder(model).build();
322        let out = rewriter
323            .rewrite("alpha", &[], &ExecutionContext::new())
324            .await
325            .unwrap();
326        assert_eq!(out, "alpha primer");
327    }
328
329    #[tokio::test]
330    async fn empty_reply_surfaces_invalid_request_error() {
331        // A model that produces no usable text is a structural
332        // failure — degrading silently to an empty rewrite would
333        // loop the corrective recipe on the same retrieval.
334        let model = Arc::new(ScriptedModel::new(vec![assistant("   \n  ")]));
335        let rewriter = LlmQueryRewriter::builder(model).build();
336        let err = rewriter
337            .rewrite("alpha", &[], &ExecutionContext::new())
338            .await
339            .unwrap_err();
340        assert!(matches!(err, Error::InvalidRequest(_)));
341    }
342
343    #[tokio::test]
344    async fn model_error_propagates() {
345        struct FailingModel;
346        #[async_trait]
347        impl Runnable<Vec<Message>, Message> for FailingModel {
348            async fn invoke(
349                &self,
350                _input: Vec<Message>,
351                _ctx: &ExecutionContext,
352            ) -> Result<Message> {
353                Err(Error::provider_http(503, "transient"))
354            }
355        }
356        let rewriter = LlmQueryRewriter::builder(Arc::new(FailingModel)).build();
357        let err = rewriter
358            .rewrite("alpha", &[], &ExecutionContext::new())
359            .await
360            .unwrap_err();
361        assert!(matches!(
362            err,
363            Error::Provider {
364                kind: entelix_core::ProviderErrorKind::Http(503),
365                ..
366            }
367        ));
368    }
369}