1pub mod classifier;
72pub mod decomposer;
73pub mod expander;
74pub mod hyde;
75pub mod rewriter;
76
77pub use classifier::{ClassificationResult, QueryClassifier, QueryIntent, QueryType};
78pub use decomposer::{DecompositionStrategy, QueryDecomposer, SubQuery};
79pub use expander::{ExpansionConfig, ExpansionResult, ExpansionStrategy, QueryExpander};
80pub use hyde::{HyDEConfig, HyDEGenerator, HyDEResult};
81pub use rewriter::{QueryRewriteConfig, QueryRewriter, RewriteResult, RewriteStrategy};
82
83use crate::{EmbeddingProvider, RragResult};
84use std::sync::Arc;
85
86pub struct QueryProcessor {
88 rewriter: QueryRewriter,
90
91 expander: QueryExpander,
93
94 classifier: QueryClassifier,
96
97 decomposer: QueryDecomposer,
99
100 hyde: Option<HyDEGenerator>,
102
103 config: QueryProcessorConfig,
105}
106
107#[derive(Debug, Clone)]
109pub struct QueryProcessorConfig {
110 pub enable_rewriting: bool,
112
113 pub enable_expansion: bool,
115
116 pub enable_classification: bool,
118
119 pub enable_decomposition: bool,
121
122 pub enable_hyde: bool,
124
125 pub max_variants: usize,
127
128 pub confidence_threshold: f32,
130}
131
132impl Default for QueryProcessorConfig {
133 fn default() -> Self {
134 Self {
135 enable_rewriting: true,
136 enable_expansion: true,
137 enable_classification: true,
138 enable_decomposition: true,
139 enable_hyde: true,
140 max_variants: 5,
141 confidence_threshold: 0.7,
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct QueryProcessingResult {
149 pub original_query: String,
151
152 pub rewritten_queries: Vec<RewriteResult>,
154
155 pub expanded_queries: Vec<ExpansionResult>,
157
158 pub classification: Option<ClassificationResult>,
160
161 pub sub_queries: Vec<SubQuery>,
163
164 pub hyde_results: Vec<HyDEResult>,
166
167 pub final_queries: Vec<String>,
169
170 pub metadata: QueryProcessingMetadata,
172}
173
174#[derive(Debug, Clone)]
176pub struct QueryProcessingMetadata {
177 pub processing_time_ms: u64,
179
180 pub techniques_applied: Vec<String>,
182
183 pub confidence_scores: std::collections::HashMap<String, f32>,
185
186 pub warnings: Vec<String>,
188}
189
190impl QueryProcessor {
191 pub fn new(config: QueryProcessorConfig) -> Self {
193 let rewriter = QueryRewriter::new(QueryRewriteConfig::default());
194 let expander = QueryExpander::new(ExpansionConfig::default());
195 let classifier = QueryClassifier::new();
196 let decomposer = QueryDecomposer::new();
197
198 Self {
199 rewriter,
200 expander,
201 classifier,
202 decomposer,
203 hyde: None,
204 config,
205 }
206 }
207
208 pub fn with_embedding_provider(
210 mut self,
211 embedding_provider: Arc<dyn EmbeddingProvider>,
212 ) -> Self {
213 if self.config.enable_hyde {
214 self.hyde = Some(HyDEGenerator::new(
215 HyDEConfig::default(),
216 embedding_provider,
217 ));
218 }
219 self
220 }
221
222 pub async fn process_query(&self, query: &str) -> RragResult<QueryProcessingResult> {
224 let start_time = std::time::Instant::now();
225 let mut techniques_applied = Vec::new();
226 let mut confidence_scores = std::collections::HashMap::new();
227 let mut warnings = Vec::new();
228
229 let classification = if self.config.enable_classification {
231 techniques_applied.push("classification".to_string());
232 let result = self.classifier.classify(query).await?;
233 confidence_scores.insert("classification".to_string(), result.confidence);
234 Some(result)
235 } else {
236 None
237 };
238
239 let rewritten_queries = if self.config.enable_rewriting {
241 techniques_applied.push("rewriting".to_string());
242 let results = self.rewriter.rewrite(query).await?;
243 if results.is_empty() {
244 warnings.push("Query rewriting produced no results".to_string());
245 }
246 results
247 } else {
248 Vec::new()
249 };
250
251 let expanded_queries = if self.config.enable_expansion {
253 techniques_applied.push("expansion".to_string());
254 let results = self.expander.expand(query).await?;
255 confidence_scores.insert(
256 "expansion".to_string(),
257 results.iter().map(|r| r.confidence).fold(0.0, f32::max),
258 );
259 results
260 } else {
261 Vec::new()
262 };
263
264 let sub_queries = if self.config.enable_decomposition {
266 techniques_applied.push("decomposition".to_string());
267 self.decomposer.decompose(query).await?
268 } else {
269 Vec::new()
270 };
271
272 let hyde_results = if self.config.enable_hyde && self.hyde.is_some() {
274 techniques_applied.push("hyde".to_string());
275 let results = self.hyde.as_ref().unwrap().generate(query).await?;
276 confidence_scores.insert(
277 "hyde".to_string(),
278 results.iter().map(|r| r.confidence).fold(0.0, f32::max),
279 );
280 results
281 } else {
282 Vec::new()
283 };
284
285 let final_queries = self.generate_final_queries(
287 query,
288 &rewritten_queries,
289 &expanded_queries,
290 &sub_queries,
291 &hyde_results,
292 &classification,
293 );
294
295 let processing_time = start_time.elapsed().as_millis() as u64;
296
297 Ok(QueryProcessingResult {
298 original_query: query.to_string(),
299 rewritten_queries,
300 expanded_queries,
301 classification,
302 sub_queries,
303 hyde_results,
304 final_queries,
305 metadata: QueryProcessingMetadata {
306 processing_time_ms: processing_time,
307 techniques_applied,
308 confidence_scores,
309 warnings,
310 },
311 })
312 }
313
314 fn generate_final_queries(
316 &self,
317 original: &str,
318 rewritten: &[RewriteResult],
319 expanded: &[ExpansionResult],
320 sub_queries: &[SubQuery],
321 hyde: &[HyDEResult],
322 classification: &Option<ClassificationResult>,
323 ) -> Vec<String> {
324 let mut queries = Vec::new();
325
326 queries.push(original.to_string());
328
329 for rewrite in rewritten {
331 if rewrite.confidence >= self.config.confidence_threshold {
332 queries.push(rewrite.rewritten_query.clone());
333 }
334 }
335
336 if let Some(classification) = classification {
338 match classification.intent {
339 QueryIntent::Factual => {
340 queries.extend(
342 expanded
343 .iter()
344 .filter(|e| e.expansion_type == ExpansionStrategy::Synonyms)
345 .map(|e| e.expanded_query.clone()),
346 );
347 }
348 QueryIntent::Conceptual => {
349 queries.extend(
351 expanded
352 .iter()
353 .filter(|e| e.expansion_type == ExpansionStrategy::Semantic)
354 .map(|e| e.expanded_query.clone()),
355 );
356 }
357 _ => {
358 queries.extend(
360 expanded
361 .iter()
362 .filter(|e| e.confidence >= self.config.confidence_threshold)
363 .map(|e| e.expanded_query.clone()),
364 );
365 }
366 }
367 } else {
368 queries.extend(
369 expanded
370 .iter()
371 .filter(|e| e.confidence >= self.config.confidence_threshold)
372 .map(|e| e.expanded_query.clone()),
373 );
374 }
375
376 queries.extend(sub_queries.iter().map(|sq| sq.query.clone()));
378
379 queries.extend(
381 hyde.iter()
382 .filter(|h| h.confidence >= self.config.confidence_threshold)
383 .map(|h| h.hypothetical_answer.clone()),
384 );
385
386 queries.sort();
388 queries.dedup();
389 queries.truncate(self.config.max_variants);
390
391 queries
392 }
393}