1use crate::RragResult;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12pub struct QueryRewriter {
14 config: QueryRewriteConfig,
16
17 grammar_patterns: Vec<GrammarPattern>,
19
20 templates: HashMap<String, Vec<String>>,
22
23 transformations: Vec<QueryTransformation>,
25}
26
27#[derive(Debug, Clone)]
29pub struct QueryRewriteConfig {
30 pub enable_grammar_correction: bool,
32
33 pub enable_clarification: bool,
35
36 pub enable_style_normalization: bool,
38
39 pub enable_domain_rewriting: bool,
41
42 pub max_rewrites: usize,
44
45 pub min_confidence: f32,
47}
48
49impl Default for QueryRewriteConfig {
50 fn default() -> Self {
51 Self {
52 enable_grammar_correction: true,
53 enable_clarification: true,
54 enable_style_normalization: true,
55 enable_domain_rewriting: true,
56 max_rewrites: 3,
57 min_confidence: 0.6,
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub enum RewriteStrategy {
65 GrammarCorrection,
67 Clarification,
69 StyleNormalization,
71 DomainSpecific,
73 TemplateBasedRewriting,
75}
76
77struct GrammarPattern {
79 pattern: Regex,
81 replacement: String,
83 confidence: f32,
85}
86
87struct QueryTransformation {
89 name: String,
91 transform: fn(&str) -> Option<String>,
93 confidence: f32,
95 strategy: RewriteStrategy,
97}
98
99#[derive(Debug, Clone)]
101pub struct RewriteResult {
102 pub original_query: String,
104
105 pub rewritten_query: String,
107
108 pub strategy: RewriteStrategy,
110
111 pub confidence: f32,
113
114 pub explanation: String,
116}
117
118impl QueryRewriter {
119 pub fn new(config: QueryRewriteConfig) -> Self {
121 let grammar_patterns = Self::init_grammar_patterns();
122 let templates = Self::init_templates();
123 let transformations = Self::init_transformations();
124
125 Self {
126 config,
127 grammar_patterns,
128 templates,
129 transformations,
130 }
131 }
132
133 pub async fn rewrite(&self, query: &str) -> RragResult<Vec<RewriteResult>> {
135 let mut results = Vec::new();
136
137 if self.config.enable_grammar_correction {
139 if let Some(result) = self.apply_grammar_correction(query) {
140 if result.confidence >= self.config.min_confidence {
141 results.push(result);
142 }
143 }
144 }
145
146 if self.config.enable_clarification {
148 if let Some(result) = self.apply_clarification(query) {
149 if result.confidence >= self.config.min_confidence {
150 results.push(result);
151 }
152 }
153 }
154
155 if self.config.enable_style_normalization {
157 if let Some(result) = self.apply_style_normalization(query) {
158 if result.confidence >= self.config.min_confidence {
159 results.push(result);
160 }
161 }
162 }
163
164 if self.config.enable_domain_rewriting {
166 let domain_results = self.apply_domain_rewriting(query);
167 results.extend(
168 domain_results
169 .into_iter()
170 .filter(|r| r.confidence >= self.config.min_confidence),
171 );
172 }
173
174 results.truncate(self.config.max_rewrites);
176
177 Ok(results)
178 }
179
180 fn apply_grammar_correction(&self, query: &str) -> Option<RewriteResult> {
182 for pattern in &self.grammar_patterns {
183 if let Some(rewritten) = pattern.apply(query) {
184 if rewritten != query {
185 return Some(RewriteResult {
186 original_query: query.to_string(),
187 rewritten_query: rewritten,
188 strategy: RewriteStrategy::GrammarCorrection,
189 confidence: pattern.confidence,
190 explanation: "Applied grammar correction".to_string(),
191 });
192 }
193 }
194 }
195 None
196 }
197
198 fn apply_clarification(&self, query: &str) -> Option<RewriteResult> {
200 if self.is_vague_query(query) {
202 let clarified = self.clarify_query(query);
203 if let Some(clarified_query) = clarified {
204 return Some(RewriteResult {
205 original_query: query.to_string(),
206 rewritten_query: clarified_query,
207 strategy: RewriteStrategy::Clarification,
208 confidence: 0.7,
209 explanation: "Added clarifying information to vague query".to_string(),
210 });
211 }
212 }
213 None
214 }
215
216 fn apply_style_normalization(&self, query: &str) -> Option<RewriteResult> {
218 let normalized = self.normalize_style(query);
219 if normalized != query {
220 Some(RewriteResult {
221 original_query: query.to_string(),
222 rewritten_query: normalized,
223 strategy: RewriteStrategy::StyleNormalization,
224 confidence: 0.8,
225 explanation: "Normalized query style".to_string(),
226 })
227 } else {
228 None
229 }
230 }
231
232 fn apply_domain_rewriting(&self, query: &str) -> Vec<RewriteResult> {
234 let mut results = Vec::new();
235
236 for transformation in &self.transformations {
238 if let Some(transformed) = (transformation.transform)(query) {
239 if transformed != query {
240 results.push(RewriteResult {
241 original_query: query.to_string(),
242 rewritten_query: transformed,
243 strategy: transformation.strategy.clone(),
244 confidence: transformation.confidence,
245 explanation: format!("Applied {}", transformation.name),
246 });
247 }
248 }
249 }
250
251 results
252 }
253
254 fn is_vague_query(&self, query: &str) -> bool {
256 let vague_patterns = [
257 r"^(what|how|why|when|where)\s+is\s+\w+\?*$",
258 r"^(tell me about|about|info on)\s+\w+\?*$",
259 r"^\w{1,3}\?*$", ];
261
262 let query_lower = query.to_lowercase();
263 for pattern in &vague_patterns {
264 if Regex::new(pattern).unwrap().is_match(&query_lower) {
265 return true;
266 }
267 }
268
269 false
270 }
271
272 fn clarify_query(&self, query: &str) -> Option<String> {
274 let query_lower = query.to_lowercase();
275
276 if query_lower.starts_with("what is") {
278 return Some(format!(
279 "{} and how does it work?",
280 query.trim_end_matches('?')
281 ));
282 }
283
284 if query_lower.starts_with("how") {
285 return Some(format!("{} step by step", query.trim_end_matches('?')));
286 }
287
288 if query_lower.starts_with("tell me about") {
289 return Some(query_lower.replace("tell me about", "explain the concept of"));
290 }
291
292 None
293 }
294
295 fn normalize_style(&self, query: &str) -> String {
297 let mut normalized = query.to_string();
298
299 normalized = Regex::new(r"[!]{2,}")
301 .unwrap()
302 .replace_all(&normalized, "!")
303 .to_string();
304 normalized = Regex::new(r"[?]{2,}")
305 .unwrap()
306 .replace_all(&normalized, "?")
307 .to_string();
308
309 normalized = Regex::new(r"\s+")
311 .unwrap()
312 .replace_all(&normalized, " ")
313 .to_string();
314
315 if let Some(first_char) = normalized.chars().next() {
317 normalized = first_char.to_uppercase().collect::<String>() + &normalized[1..];
318 }
319
320 if self.is_question(&normalized) && !normalized.ends_with('?') {
322 normalized.push('?');
323 }
324
325 normalized.trim().to_string()
326 }
327
328 fn is_question(&self, query: &str) -> bool {
330 let question_words = [
331 "what", "how", "why", "when", "where", "who", "which", "can", "is", "are", "do", "does",
332 ];
333 let query_lower = query.to_lowercase();
334 question_words
335 .iter()
336 .any(|&word| query_lower.starts_with(word))
337 }
338
339 fn init_grammar_patterns() -> Vec<GrammarPattern> {
341 vec![
342 GrammarPattern {
343 pattern: Regex::new(r"\bteh\b").unwrap(),
344 replacement: "the".to_string(),
345 confidence: 0.9,
346 },
347 GrammarPattern {
348 pattern: Regex::new(r"\badn\b").unwrap(),
349 replacement: "and".to_string(),
350 confidence: 0.9,
351 },
352 GrammarPattern {
353 pattern: Regex::new(r"\bwat\b").unwrap(),
354 replacement: "what".to_string(),
355 confidence: 0.8,
356 },
357 ]
359 }
360
361 fn init_templates() -> HashMap<String, Vec<String>> {
363 let mut templates = HashMap::new();
364
365 templates.insert(
366 "technical".to_string(),
367 vec![
368 "How does {concept} work?".to_string(),
369 "What are the key features of {concept}?".to_string(),
370 "Explain {concept} in detail".to_string(),
371 ],
372 );
373
374 templates.insert(
375 "comparison".to_string(),
376 vec![
377 "Compare {item1} and {item2}".to_string(),
378 "What are the differences between {item1} and {item2}?".to_string(),
379 "{item1} vs {item2} pros and cons".to_string(),
380 ],
381 );
382
383 templates
384 }
385
386 fn init_transformations() -> Vec<QueryTransformation> {
388 vec![
389 QueryTransformation {
390 name: "Convert abbreviations".to_string(),
391 transform: |query| {
392 let mut result = query.to_string();
393 let abbreviations = [
394 ("ML", "machine learning"),
395 ("AI", "artificial intelligence"),
396 ("NLP", "natural language processing"),
397 ("API", "application programming interface"),
398 ("UI", "user interface"),
399 ("UX", "user experience"),
400 ];
401
402 for (abbr, full) in &abbreviations {
403 result = result.replace(abbr, full);
404 }
405
406 if result != query {
407 Some(result)
408 } else {
409 None
410 }
411 },
412 confidence: 0.8,
413 strategy: RewriteStrategy::DomainSpecific,
414 },
415 QueryTransformation {
416 name: "Add technical context".to_string(),
417 transform: |query| {
418 let tech_terms = ["algorithm", "framework", "library", "system"];
419 if tech_terms
420 .iter()
421 .any(|term| query.to_lowercase().contains(term))
422 {
423 Some(format!("{} implementation and usage", query))
424 } else {
425 None
426 }
427 },
428 confidence: 0.6,
429 strategy: RewriteStrategy::DomainSpecific,
430 },
431 ]
432 }
433}
434
435impl GrammarPattern {
436 fn apply(&self, query: &str) -> Option<String> {
438 if self.pattern.is_match(query) {
439 Some(
440 self.pattern
441 .replace_all(query, &self.replacement)
442 .to_string(),
443 )
444 } else {
445 None
446 }
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[tokio::test]
455 async fn test_query_rewriter() {
456 let rewriter = QueryRewriter::new(QueryRewriteConfig::default());
457
458 let results = rewriter.rewrite("wat is ML?").await.unwrap();
459 assert!(!results.is_empty());
460
461 let grammar_corrected = results
463 .iter()
464 .find(|r| r.strategy == RewriteStrategy::GrammarCorrection);
465 assert!(grammar_corrected.is_some());
466 }
467
468 #[tokio::test]
469 async fn test_style_normalization() {
470 let rewriter = QueryRewriter::new(QueryRewriteConfig::default());
471
472 let results = rewriter.rewrite("how does this work???").await.unwrap();
473 let normalized = results
474 .iter()
475 .find(|r| r.strategy == RewriteStrategy::StyleNormalization);
476
477 assert!(normalized.is_some());
478 assert_eq!(normalized.unwrap().rewritten_query, "How does this work?");
479 }
480}