1use std::sync::Arc;
16
17use async_trait::async_trait;
18use entelix_core::ir::{ContentPart, Message, Role};
19use entelix_core::{Error, ExecutionContext, Result};
20use entelix_runnable::Runnable;
21
22#[async_trait]
29pub trait QueryRewriter: Send + Sync {
30 fn name(&self) -> &'static str;
33
34 async fn rewrite(
42 &self,
43 original: &str,
44 previous_attempts: &[String],
45 ctx: &ExecutionContext,
46 ) -> Result<String>;
47}
48
49pub const DEFAULT_REWRITER_INSTRUCTION: &str = "\
53You are a query rewriter. Given the user's original query and any prior failed attempts \
54(retrieval did not return useful results), produce a single corrected query that captures \
55the user's intent in different words. Reply with only the corrected query string — no \
56quotes, no explanation, no surrounding text.";
57
58const LLM_REWRITER_NAME: &str = "llm-query-rewriter";
60
61pub struct LlmQueryRewriterBuilder<M> {
63 model: Arc<M>,
64 instruction: String,
65}
66
67impl<M> LlmQueryRewriterBuilder<M>
68where
69 M: Runnable<Vec<Message>, Message> + 'static,
70{
71 #[must_use]
74 pub fn with_instruction(mut self, instruction: impl Into<String>) -> Self {
75 self.instruction = instruction.into();
76 self
77 }
78
79 #[must_use]
81 pub fn build(self) -> LlmQueryRewriter<M> {
82 LlmQueryRewriter {
83 model: self.model,
84 instruction: Arc::from(self.instruction),
85 }
86 }
87}
88
89pub struct LlmQueryRewriter<M> {
93 model: Arc<M>,
94 instruction: Arc<str>,
95}
96
97impl<M> LlmQueryRewriter<M>
98where
99 M: Runnable<Vec<Message>, Message> + 'static,
100{
101 #[must_use]
103 pub fn builder(model: Arc<M>) -> LlmQueryRewriterBuilder<M> {
104 LlmQueryRewriterBuilder {
105 model,
106 instruction: DEFAULT_REWRITER_INSTRUCTION.to_owned(),
107 }
108 }
109
110 fn build_prompt(&self, original: &str, previous_attempts: &[String]) -> Vec<Message> {
116 let prior_block = if previous_attempts.is_empty() {
117 "(none)".to_owned()
118 } else {
119 previous_attempts
120 .iter()
121 .enumerate()
122 .map(|(idx, attempt)| format!("attempt {}: {attempt}", idx + 1))
123 .collect::<Vec<_>>()
124 .join("\n")
125 };
126 vec![Message::new(
127 Role::User,
128 vec![
129 ContentPart::text(self.instruction.to_string()),
130 ContentPart::text(format!("<original>\n{original}\n</original>")),
131 ContentPart::text(format!(
132 "<failed_attempts>\n{prior_block}\n</failed_attempts>"
133 )),
134 ],
135 )]
136 }
137}
138
139impl<M> Clone for LlmQueryRewriter<M> {
140 fn clone(&self) -> Self {
141 Self {
142 model: Arc::clone(&self.model),
143 instruction: Arc::clone(&self.instruction),
144 }
145 }
146}
147
148impl<M> std::fmt::Debug for LlmQueryRewriter<M> {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("LlmQueryRewriter").finish_non_exhaustive()
151 }
152}
153
154#[async_trait]
155impl<M> QueryRewriter for LlmQueryRewriter<M>
156where
157 M: Runnable<Vec<Message>, Message> + 'static,
158{
159 fn name(&self) -> &'static str {
160 LLM_REWRITER_NAME
161 }
162
163 async fn rewrite(
164 &self,
165 original: &str,
166 previous_attempts: &[String],
167 ctx: &ExecutionContext,
168 ) -> Result<String> {
169 let prompt = self.build_prompt(original, previous_attempts);
170 let reply = self.model.invoke(prompt, ctx).await?;
171 let cleaned = clean_reply(&reply);
172 if cleaned.is_empty() {
173 return Err(Error::invalid_request(
174 "LlmQueryRewriter: model returned no text — rewrite failed",
175 ));
176 }
177 Ok(cleaned)
178 }
179}
180
181fn clean_reply(message: &Message) -> String {
188 let mut buf = String::new();
189 for part in &message.content {
190 if let ContentPart::Text { text, .. } = part {
191 if !buf.is_empty() {
192 buf.push('\n');
193 }
194 buf.push_str(text);
195 }
196 }
197 let trimmed = buf.trim();
198 let stripped = trimmed
202 .strip_prefix('"')
203 .and_then(|s| s.strip_suffix('"'))
204 .or_else(|| {
205 trimmed
206 .strip_prefix('\'')
207 .and_then(|s| s.strip_suffix('\''))
208 })
209 .unwrap_or(trimmed);
210 stripped.to_owned()
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use std::sync::Mutex;
217
218 fn assistant(text: &str) -> Message {
219 Message::new(Role::Assistant, vec![ContentPart::text(text)])
220 }
221
222 struct ScriptedModel {
225 script: Mutex<Vec<Result<Message>>>,
226 observed: Mutex<Vec<Vec<Message>>>,
227 }
228
229 impl ScriptedModel {
230 fn new(replies: Vec<Message>) -> Self {
231 Self {
232 script: Mutex::new(replies.into_iter().map(Ok).rev().collect()),
233 observed: Mutex::new(Vec::new()),
234 }
235 }
236 fn observed(&self) -> Vec<Vec<Message>> {
237 self.observed.lock().unwrap().clone()
238 }
239 }
240
241 #[async_trait]
242 impl Runnable<Vec<Message>, Message> for ScriptedModel {
243 async fn invoke(&self, input: Vec<Message>, _ctx: &ExecutionContext) -> Result<Message> {
244 self.observed.lock().unwrap().push(input);
245 self.script.lock().unwrap().pop().expect("script exhausted")
246 }
247 }
248
249 #[tokio::test]
250 async fn first_attempt_sees_only_original_query() {
251 let model = Arc::new(ScriptedModel::new(vec![assistant(
252 "alpha letter explanation",
253 )]));
254 let rewriter = LlmQueryRewriter::builder(Arc::clone(&model)).build();
255 let out = rewriter
256 .rewrite("what is alpha?", &[], &ExecutionContext::new())
257 .await
258 .unwrap();
259 assert_eq!(out, "alpha letter explanation");
260
261 let prompts = model.observed();
266 let parts = &prompts[0][0].content;
267 let prior_text = match &parts[2] {
268 ContentPart::Text { text, .. } => text.clone(),
269 _ => panic!("third part must be Text"),
270 };
271 assert!(prior_text.contains("(none)"));
272 }
273
274 #[tokio::test]
275 async fn subsequent_attempts_carry_prior_history() {
276 let model = Arc::new(ScriptedModel::new(vec![assistant(
277 "what does alpha denote in linear algebra?",
278 )]));
279 let rewriter = LlmQueryRewriter::builder(Arc::clone(&model)).build();
280 let prior = vec!["alpha?".to_owned(), "alpha letter".to_owned()];
281 rewriter
282 .rewrite("alpha", &prior, &ExecutionContext::new())
283 .await
284 .unwrap();
285 let prompts = model.observed();
286 let prior_text = match &prompts[0][0].content[2] {
287 ContentPart::Text { text, .. } => text.clone(),
288 _ => panic!("third part must be Text"),
289 };
290 assert!(prior_text.contains("attempt 1: alpha?"));
291 assert!(prior_text.contains("attempt 2: alpha letter"));
292 }
293
294 #[tokio::test]
295 async fn double_quotes_stripped_from_reply() {
296 let model = Arc::new(ScriptedModel::new(vec![assistant(
297 "\"alpha definition with examples\"",
298 )]));
299 let rewriter = LlmQueryRewriter::builder(model).build();
300 let out = rewriter
301 .rewrite("alpha", &[], &ExecutionContext::new())
302 .await
303 .unwrap();
304 assert_eq!(out, "alpha definition with examples");
305 }
306
307 #[tokio::test]
308 async fn single_quotes_stripped_from_reply() {
309 let model = Arc::new(ScriptedModel::new(vec![assistant("'alpha primer'")]));
310 let rewriter = LlmQueryRewriter::builder(model).build();
311 let out = rewriter
312 .rewrite("alpha", &[], &ExecutionContext::new())
313 .await
314 .unwrap();
315 assert_eq!(out, "alpha primer");
316 }
317
318 #[tokio::test]
319 async fn whitespace_around_reply_trimmed() {
320 let model = Arc::new(ScriptedModel::new(vec![assistant(" alpha primer\n")]));
321 let rewriter = LlmQueryRewriter::builder(model).build();
322 let out = rewriter
323 .rewrite("alpha", &[], &ExecutionContext::new())
324 .await
325 .unwrap();
326 assert_eq!(out, "alpha primer");
327 }
328
329 #[tokio::test]
330 async fn empty_reply_surfaces_invalid_request_error() {
331 let model = Arc::new(ScriptedModel::new(vec![assistant(" \n ")]));
335 let rewriter = LlmQueryRewriter::builder(model).build();
336 let err = rewriter
337 .rewrite("alpha", &[], &ExecutionContext::new())
338 .await
339 .unwrap_err();
340 assert!(matches!(err, Error::InvalidRequest(_)));
341 }
342
343 #[tokio::test]
344 async fn model_error_propagates() {
345 struct FailingModel;
346 #[async_trait]
347 impl Runnable<Vec<Message>, Message> for FailingModel {
348 async fn invoke(
349 &self,
350 _input: Vec<Message>,
351 _ctx: &ExecutionContext,
352 ) -> Result<Message> {
353 Err(Error::provider_http(503, "transient"))
354 }
355 }
356 let rewriter = LlmQueryRewriter::builder(Arc::new(FailingModel)).build();
357 let err = rewriter
358 .rewrite("alpha", &[], &ExecutionContext::new())
359 .await
360 .unwrap_err();
361 assert!(matches!(
362 err,
363 Error::Provider {
364 kind: entelix_core::ProviderErrorKind::Http(503),
365 ..
366 }
367 ));
368 }
369}