rexis_rag/query/
rewriter.rs

1//! # Query Rewriter
2//!
3//! Intelligent query rewriting for improved search quality.
4//! Implements multiple rewriting strategies including grammar correction,
5//! clarification, and style normalization.
6
7use crate::RragResult;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Query rewriter for transforming user queries
13pub struct QueryRewriter {
14    /// Configuration
15    config: QueryRewriteConfig,
16
17    /// Grammar patterns for correction
18    grammar_patterns: Vec<GrammarPattern>,
19
20    /// Query templates for different domains
21    templates: HashMap<String, Vec<String>>,
22
23    /// Common query transformations
24    transformations: Vec<QueryTransformation>,
25}
26
27/// Configuration for query rewriting
28#[derive(Debug, Clone)]
29pub struct QueryRewriteConfig {
30    /// Enable grammar correction
31    pub enable_grammar_correction: bool,
32
33    /// Enable query clarification
34    pub enable_clarification: bool,
35
36    /// Enable style normalization
37    pub enable_style_normalization: bool,
38
39    /// Enable domain-specific rewriting
40    pub enable_domain_rewriting: bool,
41
42    /// Maximum number of rewrites per query
43    pub max_rewrites: usize,
44
45    /// Minimum confidence for accepting rewrites
46    pub min_confidence: f32,
47}
48
49impl Default for QueryRewriteConfig {
50    fn default() -> Self {
51        Self {
52            enable_grammar_correction: true,
53            enable_clarification: true,
54            enable_style_normalization: true,
55            enable_domain_rewriting: true,
56            max_rewrites: 3,
57            min_confidence: 0.6,
58        }
59    }
60}
61
62/// Rewriting strategies
63#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
64pub enum RewriteStrategy {
65    /// Grammar and spelling correction
66    GrammarCorrection,
67    /// Add clarifying information
68    Clarification,
69    /// Normalize writing style
70    StyleNormalization,
71    /// Domain-specific transformations
72    DomainSpecific,
73    /// Template-based rewriting
74    TemplateBasedRewriting,
75}
76
77/// Grammar pattern for correction
78struct GrammarPattern {
79    /// Pattern to match
80    pattern: Regex,
81    /// Replacement template
82    replacement: String,
83    /// Confidence score
84    confidence: f32,
85}
86
87/// Query transformation rule
88struct QueryTransformation {
89    /// Name of the transformation
90    name: String,
91    /// Function to apply transformation
92    transform: fn(&str) -> Option<String>,
93    /// Confidence score
94    confidence: f32,
95    /// Strategy type
96    strategy: RewriteStrategy,
97}
98
99/// Result of query rewriting
100#[derive(Debug, Clone)]
101pub struct RewriteResult {
102    /// Original query
103    pub original_query: String,
104
105    /// Rewritten query
106    pub rewritten_query: String,
107
108    /// Strategy used for rewriting
109    pub strategy: RewriteStrategy,
110
111    /// Confidence score (0.0 to 1.0)
112    pub confidence: f32,
113
114    /// Explanation of the rewrite
115    pub explanation: String,
116}
117
118impl QueryRewriter {
119    /// Create a new query rewriter
120    pub fn new(config: QueryRewriteConfig) -> Self {
121        let grammar_patterns = Self::init_grammar_patterns();
122        let templates = Self::init_templates();
123        let transformations = Self::init_transformations();
124
125        Self {
126            config,
127            grammar_patterns,
128            templates,
129            transformations,
130        }
131    }
132
133    /// Rewrite a query using all enabled strategies
134    pub async fn rewrite(&self, query: &str) -> RragResult<Vec<RewriteResult>> {
135        let mut results = Vec::new();
136
137        // Apply grammar correction
138        if self.config.enable_grammar_correction {
139            if let Some(result) = self.apply_grammar_correction(query) {
140                if result.confidence >= self.config.min_confidence {
141                    results.push(result);
142                }
143            }
144        }
145
146        // Apply clarification
147        if self.config.enable_clarification {
148            if let Some(result) = self.apply_clarification(query) {
149                if result.confidence >= self.config.min_confidence {
150                    results.push(result);
151                }
152            }
153        }
154
155        // Apply style normalization
156        if self.config.enable_style_normalization {
157            if let Some(result) = self.apply_style_normalization(query) {
158                if result.confidence >= self.config.min_confidence {
159                    results.push(result);
160                }
161            }
162        }
163
164        // Apply domain-specific rewriting
165        if self.config.enable_domain_rewriting {
166            let domain_results = self.apply_domain_rewriting(query);
167            results.extend(
168                domain_results
169                    .into_iter()
170                    .filter(|r| r.confidence >= self.config.min_confidence),
171            );
172        }
173
174        // Limit results
175        results.truncate(self.config.max_rewrites);
176
177        Ok(results)
178    }
179
180    /// Apply grammar correction patterns
181    fn apply_grammar_correction(&self, query: &str) -> Option<RewriteResult> {
182        for pattern in &self.grammar_patterns {
183            if let Some(rewritten) = pattern.apply(query) {
184                if rewritten != query {
185                    return Some(RewriteResult {
186                        original_query: query.to_string(),
187                        rewritten_query: rewritten,
188                        strategy: RewriteStrategy::GrammarCorrection,
189                        confidence: pattern.confidence,
190                        explanation: "Applied grammar correction".to_string(),
191                    });
192                }
193            }
194        }
195        None
196    }
197
198    /// Apply query clarification
199    fn apply_clarification(&self, query: &str) -> Option<RewriteResult> {
200        // Check if query is too vague or ambiguous
201        if self.is_vague_query(query) {
202            let clarified = self.clarify_query(query);
203            if let Some(clarified_query) = clarified {
204                return Some(RewriteResult {
205                    original_query: query.to_string(),
206                    rewritten_query: clarified_query,
207                    strategy: RewriteStrategy::Clarification,
208                    confidence: 0.7,
209                    explanation: "Added clarifying information to vague query".to_string(),
210                });
211            }
212        }
213        None
214    }
215
216    /// Apply style normalization
217    fn apply_style_normalization(&self, query: &str) -> Option<RewriteResult> {
218        let normalized = self.normalize_style(query);
219        if normalized != query {
220            Some(RewriteResult {
221                original_query: query.to_string(),
222                rewritten_query: normalized,
223                strategy: RewriteStrategy::StyleNormalization,
224                confidence: 0.8,
225                explanation: "Normalized query style".to_string(),
226            })
227        } else {
228            None
229        }
230    }
231
232    /// Apply domain-specific rewriting
233    fn apply_domain_rewriting(&self, query: &str) -> Vec<RewriteResult> {
234        let mut results = Vec::new();
235
236        // Apply transformations
237        for transformation in &self.transformations {
238            if let Some(transformed) = (transformation.transform)(query) {
239                if transformed != query {
240                    results.push(RewriteResult {
241                        original_query: query.to_string(),
242                        rewritten_query: transformed,
243                        strategy: transformation.strategy.clone(),
244                        confidence: transformation.confidence,
245                        explanation: format!("Applied {}", transformation.name),
246                    });
247                }
248            }
249        }
250
251        results
252    }
253
254    /// Check if a query is too vague
255    fn is_vague_query(&self, query: &str) -> bool {
256        let vague_patterns = [
257            r"^(what|how|why|when|where)\s+is\s+\w+\?*$",
258            r"^(tell me about|about|info on)\s+\w+\?*$",
259            r"^\w{1,3}\?*$", // Very short queries
260        ];
261
262        let query_lower = query.to_lowercase();
263        for pattern in &vague_patterns {
264            if Regex::new(pattern).unwrap().is_match(&query_lower) {
265                return true;
266            }
267        }
268
269        false
270    }
271
272    /// Clarify a vague query
273    fn clarify_query(&self, query: &str) -> Option<String> {
274        let query_lower = query.to_lowercase();
275
276        // Add context based on common patterns
277        if query_lower.starts_with("what is") {
278            return Some(format!(
279                "{} and how does it work?",
280                query.trim_end_matches('?')
281            ));
282        }
283
284        if query_lower.starts_with("how") {
285            return Some(format!("{} step by step", query.trim_end_matches('?')));
286        }
287
288        if query_lower.starts_with("tell me about") {
289            return Some(query_lower.replace("tell me about", "explain the concept of"));
290        }
291
292        None
293    }
294
295    /// Normalize query style
296    fn normalize_style(&self, query: &str) -> String {
297        let mut normalized = query.to_string();
298
299        // Remove excessive punctuation
300        normalized = Regex::new(r"[!]{2,}")
301            .unwrap()
302            .replace_all(&normalized, "!")
303            .to_string();
304        normalized = Regex::new(r"[?]{2,}")
305            .unwrap()
306            .replace_all(&normalized, "?")
307            .to_string();
308
309        // Fix spacing
310        normalized = Regex::new(r"\s+")
311            .unwrap()
312            .replace_all(&normalized, " ")
313            .to_string();
314
315        // Capitalize first letter
316        if let Some(first_char) = normalized.chars().next() {
317            normalized = first_char.to_uppercase().collect::<String>() + &normalized[1..];
318        }
319
320        // Ensure questions end with question mark
321        if self.is_question(&normalized) && !normalized.ends_with('?') {
322            normalized.push('?');
323        }
324
325        normalized.trim().to_string()
326    }
327
328    /// Check if query is a question
329    fn is_question(&self, query: &str) -> bool {
330        let question_words = [
331            "what", "how", "why", "when", "where", "who", "which", "can", "is", "are", "do", "does",
332        ];
333        let query_lower = query.to_lowercase();
334        question_words
335            .iter()
336            .any(|&word| query_lower.starts_with(word))
337    }
338
339    /// Initialize grammar patterns
340    fn init_grammar_patterns() -> Vec<GrammarPattern> {
341        vec![
342            GrammarPattern {
343                pattern: Regex::new(r"\bteh\b").unwrap(),
344                replacement: "the".to_string(),
345                confidence: 0.9,
346            },
347            GrammarPattern {
348                pattern: Regex::new(r"\badn\b").unwrap(),
349                replacement: "and".to_string(),
350                confidence: 0.9,
351            },
352            GrammarPattern {
353                pattern: Regex::new(r"\bwat\b").unwrap(),
354                replacement: "what".to_string(),
355                confidence: 0.8,
356            },
357            // Add more patterns as needed
358        ]
359    }
360
361    /// Initialize query templates
362    fn init_templates() -> HashMap<String, Vec<String>> {
363        let mut templates = HashMap::new();
364
365        templates.insert(
366            "technical".to_string(),
367            vec![
368                "How does {concept} work?".to_string(),
369                "What are the key features of {concept}?".to_string(),
370                "Explain {concept} in detail".to_string(),
371            ],
372        );
373
374        templates.insert(
375            "comparison".to_string(),
376            vec![
377                "Compare {item1} and {item2}".to_string(),
378                "What are the differences between {item1} and {item2}?".to_string(),
379                "{item1} vs {item2} pros and cons".to_string(),
380            ],
381        );
382
383        templates
384    }
385
386    /// Initialize transformations
387    fn init_transformations() -> Vec<QueryTransformation> {
388        vec![
389            QueryTransformation {
390                name: "Convert abbreviations".to_string(),
391                transform: |query| {
392                    let mut result = query.to_string();
393                    let abbreviations = [
394                        ("ML", "machine learning"),
395                        ("AI", "artificial intelligence"),
396                        ("NLP", "natural language processing"),
397                        ("API", "application programming interface"),
398                        ("UI", "user interface"),
399                        ("UX", "user experience"),
400                    ];
401
402                    for (abbr, full) in &abbreviations {
403                        result = result.replace(abbr, full);
404                    }
405
406                    if result != query {
407                        Some(result)
408                    } else {
409                        None
410                    }
411                },
412                confidence: 0.8,
413                strategy: RewriteStrategy::DomainSpecific,
414            },
415            QueryTransformation {
416                name: "Add technical context".to_string(),
417                transform: |query| {
418                    let tech_terms = ["algorithm", "framework", "library", "system"];
419                    if tech_terms
420                        .iter()
421                        .any(|term| query.to_lowercase().contains(term))
422                    {
423                        Some(format!("{} implementation and usage", query))
424                    } else {
425                        None
426                    }
427                },
428                confidence: 0.6,
429                strategy: RewriteStrategy::DomainSpecific,
430            },
431        ]
432    }
433}
434
435impl GrammarPattern {
436    /// Apply the pattern to a query
437    fn apply(&self, query: &str) -> Option<String> {
438        if self.pattern.is_match(query) {
439            Some(
440                self.pattern
441                    .replace_all(query, &self.replacement)
442                    .to_string(),
443            )
444        } else {
445            None
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[tokio::test]
455    async fn test_query_rewriter() {
456        let rewriter = QueryRewriter::new(QueryRewriteConfig::default());
457
458        let results = rewriter.rewrite("wat is ML?").await.unwrap();
459        assert!(!results.is_empty());
460
461        // Should correct "wat" to "what" and expand "ML"
462        let grammar_corrected = results
463            .iter()
464            .find(|r| r.strategy == RewriteStrategy::GrammarCorrection);
465        assert!(grammar_corrected.is_some());
466    }
467
468    #[tokio::test]
469    async fn test_style_normalization() {
470        let rewriter = QueryRewriter::new(QueryRewriteConfig::default());
471
472        let results = rewriter.rewrite("how   does  this work???").await.unwrap();
473        let normalized = results
474            .iter()
475            .find(|r| r.strategy == RewriteStrategy::StyleNormalization);
476
477        assert!(normalized.is_some());
478        assert_eq!(normalized.unwrap().rewritten_query, "How does this work?");
479    }
480}