rexis_rag/query/
mod.rs

1//! # Query Processing Module
2//!
3//! Advanced query understanding, transformation, and optimization for RAG systems.
4//!
5//! This module provides state-of-the-art query processing techniques to improve
6//! retrieval accuracy and relevance. It implements multiple strategies for query
7//! enhancement including rewriting, expansion, decomposition, and hypothetical
8//! document generation.
9//!
10//! ## Features
11//!
12//! - **Query Rewriting**: Transform queries for better retrieval
13//! - **Query Expansion**: Add synonyms and related terms
14//! - **Query Classification**: Understand query intent and type
15//! - **Query Decomposition**: Break complex queries into sub-queries
16//! - **HyDE**: Generate hypothetical documents for improved retrieval
17//!
18//! ## Examples
19//!
20//! ### Query Rewriting
21//! ```rust
22//! use rrag::query::{QueryRewriter, RewriteStrategy};
23//!
24//! # async fn example() -> rrag::RragResult<()> {
25//! let rewriter = QueryRewriter::new()
26//!     .with_strategy(RewriteStrategy::Semantic)
27//!     .build();
28//!
29//! let rewritten = rewriter.rewrite("What's RAG?").await?;
30//! assert!(rewritten.alternatives.contains(&"What is Retrieval Augmented Generation?".to_string()));
31//! # Ok(())
32//! # }
33//! ```
34//!
35//! ### Query Decomposition
36//! ```rust
37//! use rrag::query::{QueryDecomposer, DecompositionStrategy};
38//!
39//! # async fn example() -> rrag::RragResult<()> {
40//! let decomposer = QueryDecomposer::new();
41//!
42//! let query = "Compare the performance of BERT and GPT-3 on sentiment analysis";
43//! let sub_queries = decomposer.decompose(query).await?;
44//!
45//! // Results in sub-queries like:
46//! // - "BERT performance on sentiment analysis"
47//! // - "GPT-3 performance on sentiment analysis"
48//! // - "Comparison between BERT and GPT-3"
49//! # Ok(())
50//! # }
51//! ```
52//!
53//! ### HyDE (Hypothetical Document Embeddings)
54//! ```rust
55//! use rrag::query::{HyDEGenerator, HyDEConfig};
56//!
57//! # async fn example() -> rrag::RragResult<()> {
58//! let hyde = HyDEGenerator::new(HyDEConfig::default());
59//!
60//! let query = "How does photosynthesis work?";
61//! let hypothetical_docs = hyde.generate(query).await?;
62//!
63//! // Use hypothetical documents for retrieval
64//! for doc in hypothetical_docs.documents {
65//!     tracing::debug!("Hypothetical answer: {}", doc.content);
66//! }
67//! # Ok(())
68//! # }
69//! ```
70
71pub mod classifier;
72pub mod decomposer;
73pub mod expander;
74pub mod hyde;
75pub mod rewriter;
76
77pub use classifier::{ClassificationResult, QueryClassifier, QueryIntent, QueryType};
78pub use decomposer::{DecompositionStrategy, QueryDecomposer, SubQuery};
79pub use expander::{ExpansionConfig, ExpansionResult, ExpansionStrategy, QueryExpander};
80pub use hyde::{HyDEConfig, HyDEGenerator, HyDEResult};
81pub use rewriter::{QueryRewriteConfig, QueryRewriter, RewriteResult, RewriteStrategy};
82
83use crate::{EmbeddingProvider, RragResult};
84use std::sync::Arc;
85
86/// Main query processor that orchestrates all query enhancement techniques
87pub struct QueryProcessor {
88    /// Query rewriter for transforming queries
89    rewriter: QueryRewriter,
90
91    /// Query expander for adding related terms
92    expander: QueryExpander,
93
94    /// Query classifier for intent detection
95    classifier: QueryClassifier,
96
97    /// Query decomposer for breaking down complex queries
98    decomposer: QueryDecomposer,
99
100    /// HyDE generator for hypothetical document embeddings
101    hyde: Option<HyDEGenerator>,
102
103    /// Configuration
104    config: QueryProcessorConfig,
105}
106
107/// Configuration for the query processor
108#[derive(Debug, Clone)]
109pub struct QueryProcessorConfig {
110    /// Whether to enable query rewriting
111    pub enable_rewriting: bool,
112
113    /// Whether to enable query expansion
114    pub enable_expansion: bool,
115
116    /// Whether to enable intent classification
117    pub enable_classification: bool,
118
119    /// Whether to enable query decomposition
120    pub enable_decomposition: bool,
121
122    /// Whether to enable HyDE
123    pub enable_hyde: bool,
124
125    /// Maximum number of query variants to generate
126    pub max_variants: usize,
127
128    /// Confidence threshold for classifications
129    pub confidence_threshold: f32,
130}
131
132impl Default for QueryProcessorConfig {
133    fn default() -> Self {
134        Self {
135            enable_rewriting: true,
136            enable_expansion: true,
137            enable_classification: true,
138            enable_decomposition: true,
139            enable_hyde: true,
140            max_variants: 5,
141            confidence_threshold: 0.7,
142        }
143    }
144}
145
146/// Complete query processing result
147#[derive(Debug, Clone)]
148pub struct QueryProcessingResult {
149    /// Original query
150    pub original_query: String,
151
152    /// Rewritten queries
153    pub rewritten_queries: Vec<RewriteResult>,
154
155    /// Expanded queries with additional terms
156    pub expanded_queries: Vec<ExpansionResult>,
157
158    /// Query classification results
159    pub classification: Option<ClassificationResult>,
160
161    /// Decomposed sub-queries
162    pub sub_queries: Vec<SubQuery>,
163
164    /// HyDE generated hypothetical documents
165    pub hyde_results: Vec<HyDEResult>,
166
167    /// Final optimized queries for retrieval
168    pub final_queries: Vec<String>,
169
170    /// Processing metadata
171    pub metadata: QueryProcessingMetadata,
172}
173
174/// Metadata about query processing
175#[derive(Debug, Clone)]
176pub struct QueryProcessingMetadata {
177    /// Processing time in milliseconds
178    pub processing_time_ms: u64,
179
180    /// Number of techniques applied
181    pub techniques_applied: Vec<String>,
182
183    /// Confidence scores
184    pub confidence_scores: std::collections::HashMap<String, f32>,
185
186    /// Any warnings or notes
187    pub warnings: Vec<String>,
188}
189
190impl QueryProcessor {
191    /// Create a new query processor
192    pub fn new(config: QueryProcessorConfig) -> Self {
193        let rewriter = QueryRewriter::new(QueryRewriteConfig::default());
194        let expander = QueryExpander::new(ExpansionConfig::default());
195        let classifier = QueryClassifier::new();
196        let decomposer = QueryDecomposer::new();
197
198        Self {
199            rewriter,
200            expander,
201            classifier,
202            decomposer,
203            hyde: None,
204            config,
205        }
206    }
207
208    /// Create with embedding provider for HyDE support
209    pub fn with_embedding_provider(
210        mut self,
211        embedding_provider: Arc<dyn EmbeddingProvider>,
212    ) -> Self {
213        if self.config.enable_hyde {
214            self.hyde = Some(HyDEGenerator::new(
215                HyDEConfig::default(),
216                embedding_provider,
217            ));
218        }
219        self
220    }
221
222    /// Process a query through all enabled techniques
223    pub async fn process_query(&self, query: &str) -> RragResult<QueryProcessingResult> {
224        let start_time = std::time::Instant::now();
225        let mut techniques_applied = Vec::new();
226        let mut confidence_scores = std::collections::HashMap::new();
227        let mut warnings = Vec::new();
228
229        // 1. Classify the query intent
230        let classification = if self.config.enable_classification {
231            techniques_applied.push("classification".to_string());
232            let result = self.classifier.classify(query).await?;
233            confidence_scores.insert("classification".to_string(), result.confidence);
234            Some(result)
235        } else {
236            None
237        };
238
239        // 2. Rewrite the query
240        let rewritten_queries = if self.config.enable_rewriting {
241            techniques_applied.push("rewriting".to_string());
242            let results = self.rewriter.rewrite(query).await?;
243            if results.is_empty() {
244                warnings.push("Query rewriting produced no results".to_string());
245            }
246            results
247        } else {
248            Vec::new()
249        };
250
251        // 3. Expand the query with synonyms and related terms
252        let expanded_queries = if self.config.enable_expansion {
253            techniques_applied.push("expansion".to_string());
254            let results = self.expander.expand(query).await?;
255            confidence_scores.insert(
256                "expansion".to_string(),
257                results.iter().map(|r| r.confidence).fold(0.0, f32::max),
258            );
259            results
260        } else {
261            Vec::new()
262        };
263
264        // 4. Decompose complex queries
265        let sub_queries = if self.config.enable_decomposition {
266            techniques_applied.push("decomposition".to_string());
267            self.decomposer.decompose(query).await?
268        } else {
269            Vec::new()
270        };
271
272        // 5. Generate HyDE hypothetical documents
273        let hyde_results = if self.config.enable_hyde && self.hyde.is_some() {
274            techniques_applied.push("hyde".to_string());
275            let results = self.hyde.as_ref().unwrap().generate(query).await?;
276            confidence_scores.insert(
277                "hyde".to_string(),
278                results.iter().map(|r| r.confidence).fold(0.0, f32::max),
279            );
280            results
281        } else {
282            Vec::new()
283        };
284
285        // 6. Generate final optimized queries
286        let final_queries = self.generate_final_queries(
287            query,
288            &rewritten_queries,
289            &expanded_queries,
290            &sub_queries,
291            &hyde_results,
292            &classification,
293        );
294
295        let processing_time = start_time.elapsed().as_millis() as u64;
296
297        Ok(QueryProcessingResult {
298            original_query: query.to_string(),
299            rewritten_queries,
300            expanded_queries,
301            classification,
302            sub_queries,
303            hyde_results,
304            final_queries,
305            metadata: QueryProcessingMetadata {
306                processing_time_ms: processing_time,
307                techniques_applied,
308                confidence_scores,
309                warnings,
310            },
311        })
312    }
313
314    /// Generate final optimized queries from all processing results
315    fn generate_final_queries(
316        &self,
317        original: &str,
318        rewritten: &[RewriteResult],
319        expanded: &[ExpansionResult],
320        sub_queries: &[SubQuery],
321        hyde: &[HyDEResult],
322        classification: &Option<ClassificationResult>,
323    ) -> Vec<String> {
324        let mut queries = Vec::new();
325
326        // Always include the original query
327        queries.push(original.to_string());
328
329        // Add high-confidence rewritten queries
330        for rewrite in rewritten {
331            if rewrite.confidence >= self.config.confidence_threshold {
332                queries.push(rewrite.rewritten_query.clone());
333            }
334        }
335
336        // Add expanded queries based on intent
337        if let Some(classification) = classification {
338            match classification.intent {
339                QueryIntent::Factual => {
340                    // For factual queries, prefer exact matches
341                    queries.extend(
342                        expanded
343                            .iter()
344                            .filter(|e| e.expansion_type == ExpansionStrategy::Synonyms)
345                            .map(|e| e.expanded_query.clone()),
346                    );
347                }
348                QueryIntent::Conceptual => {
349                    // For conceptual queries, prefer broader expansions
350                    queries.extend(
351                        expanded
352                            .iter()
353                            .filter(|e| e.expansion_type == ExpansionStrategy::Semantic)
354                            .map(|e| e.expanded_query.clone()),
355                    );
356                }
357                _ => {
358                    // Default: include all high-confidence expansions
359                    queries.extend(
360                        expanded
361                            .iter()
362                            .filter(|e| e.confidence >= self.config.confidence_threshold)
363                            .map(|e| e.expanded_query.clone()),
364                    );
365                }
366            }
367        } else {
368            queries.extend(
369                expanded
370                    .iter()
371                    .filter(|e| e.confidence >= self.config.confidence_threshold)
372                    .map(|e| e.expanded_query.clone()),
373            );
374        }
375
376        // Add sub-queries for complex queries
377        queries.extend(sub_queries.iter().map(|sq| sq.query.clone()));
378
379        // Add HyDE queries for semantic search
380        queries.extend(
381            hyde.iter()
382                .filter(|h| h.confidence >= self.config.confidence_threshold)
383                .map(|h| h.hypothetical_answer.clone()),
384        );
385
386        // Deduplicate and limit
387        queries.sort();
388        queries.dedup();
389        queries.truncate(self.config.max_variants);
390
391        queries
392    }
393}