pub mod classifier;
pub mod decomposer;
pub mod expander;
pub mod hyde;
pub mod rewriter;
pub use classifier::{ClassificationResult, QueryClassifier, QueryIntent, QueryType};
pub use decomposer::{DecompositionStrategy, QueryDecomposer, SubQuery};
pub use expander::{ExpansionConfig, ExpansionResult, ExpansionStrategy, QueryExpander};
pub use hyde::{HyDEConfig, HyDEGenerator, HyDEResult};
pub use rewriter::{QueryRewriteConfig, QueryRewriter, RewriteResult, RewriteStrategy};
use crate::{EmbeddingProvider, RragResult};
use std::sync::Arc;
pub struct QueryProcessor {
rewriter: QueryRewriter,
expander: QueryExpander,
classifier: QueryClassifier,
decomposer: QueryDecomposer,
hyde: Option<HyDEGenerator>,
config: QueryProcessorConfig,
}
#[derive(Debug, Clone)]
pub struct QueryProcessorConfig {
pub enable_rewriting: bool,
pub enable_expansion: bool,
pub enable_classification: bool,
pub enable_decomposition: bool,
pub enable_hyde: bool,
pub max_variants: usize,
pub confidence_threshold: f32,
}
impl Default for QueryProcessorConfig {
fn default() -> Self {
Self {
enable_rewriting: true,
enable_expansion: true,
enable_classification: true,
enable_decomposition: true,
enable_hyde: true,
max_variants: 5,
confidence_threshold: 0.7,
}
}
}
#[derive(Debug, Clone)]
pub struct QueryProcessingResult {
pub original_query: String,
pub rewritten_queries: Vec<RewriteResult>,
pub expanded_queries: Vec<ExpansionResult>,
pub classification: Option<ClassificationResult>,
pub sub_queries: Vec<SubQuery>,
pub hyde_results: Vec<HyDEResult>,
pub final_queries: Vec<String>,
pub metadata: QueryProcessingMetadata,
}
#[derive(Debug, Clone)]
pub struct QueryProcessingMetadata {
pub processing_time_ms: u64,
pub techniques_applied: Vec<String>,
pub confidence_scores: std::collections::HashMap<String, f32>,
pub warnings: Vec<String>,
}
impl QueryProcessor {
pub fn new(config: QueryProcessorConfig) -> Self {
let rewriter = QueryRewriter::new(QueryRewriteConfig::default());
let expander = QueryExpander::new(ExpansionConfig::default());
let classifier = QueryClassifier::new();
let decomposer = QueryDecomposer::new();
Self {
rewriter,
expander,
classifier,
decomposer,
hyde: None,
config,
}
}
pub fn with_embedding_provider(
mut self,
embedding_provider: Arc<dyn EmbeddingProvider>,
) -> Self {
if self.config.enable_hyde {
self.hyde = Some(HyDEGenerator::new(
HyDEConfig::default(),
embedding_provider,
));
}
self
}
pub async fn process_query(&self, query: &str) -> RragResult<QueryProcessingResult> {
let start_time = std::time::Instant::now();
let mut techniques_applied = Vec::new();
let mut confidence_scores = std::collections::HashMap::new();
let mut warnings = Vec::new();
let classification = if self.config.enable_classification {
techniques_applied.push("classification".to_string());
let result = self.classifier.classify(query).await?;
confidence_scores.insert("classification".to_string(), result.confidence);
Some(result)
} else {
None
};
let rewritten_queries = if self.config.enable_rewriting {
techniques_applied.push("rewriting".to_string());
let results = self.rewriter.rewrite(query).await?;
if results.is_empty() {
warnings.push("Query rewriting produced no results".to_string());
}
results
} else {
Vec::new()
};
let expanded_queries = if self.config.enable_expansion {
techniques_applied.push("expansion".to_string());
let results = self.expander.expand(query).await?;
confidence_scores.insert(
"expansion".to_string(),
results.iter().map(|r| r.confidence).fold(0.0, f32::max),
);
results
} else {
Vec::new()
};
let sub_queries = if self.config.enable_decomposition {
techniques_applied.push("decomposition".to_string());
self.decomposer.decompose(query).await?
} else {
Vec::new()
};
let hyde_results = if self.config.enable_hyde && self.hyde.is_some() {
techniques_applied.push("hyde".to_string());
let results = self.hyde.as_ref().unwrap().generate(query).await?;
confidence_scores.insert(
"hyde".to_string(),
results.iter().map(|r| r.confidence).fold(0.0, f32::max),
);
results
} else {
Vec::new()
};
let final_queries = self.generate_final_queries(
query,
&rewritten_queries,
&expanded_queries,
&sub_queries,
&hyde_results,
&classification,
);
let processing_time = start_time.elapsed().as_millis() as u64;
Ok(QueryProcessingResult {
original_query: query.to_string(),
rewritten_queries,
expanded_queries,
classification,
sub_queries,
hyde_results,
final_queries,
metadata: QueryProcessingMetadata {
processing_time_ms: processing_time,
techniques_applied,
confidence_scores,
warnings,
},
})
}
fn generate_final_queries(
&self,
original: &str,
rewritten: &[RewriteResult],
expanded: &[ExpansionResult],
sub_queries: &[SubQuery],
hyde: &[HyDEResult],
classification: &Option<ClassificationResult>,
) -> Vec<String> {
let mut queries = Vec::new();
queries.push(original.to_string());
for rewrite in rewritten {
if rewrite.confidence >= self.config.confidence_threshold {
queries.push(rewrite.rewritten_query.clone());
}
}
if let Some(classification) = classification {
match classification.intent {
QueryIntent::Factual => {
queries.extend(
expanded
.iter()
.filter(|e| e.expansion_type == ExpansionStrategy::Synonyms)
.map(|e| e.expanded_query.clone()),
);
}
QueryIntent::Conceptual => {
queries.extend(
expanded
.iter()
.filter(|e| e.expansion_type == ExpansionStrategy::Semantic)
.map(|e| e.expanded_query.clone()),
);
}
_ => {
queries.extend(
expanded
.iter()
.filter(|e| e.confidence >= self.config.confidence_threshold)
.map(|e| e.expanded_query.clone()),
);
}
}
} else {
queries.extend(
expanded
.iter()
.filter(|e| e.confidence >= self.config.confidence_threshold)
.map(|e| e.expanded_query.clone()),
);
}
queries.extend(sub_queries.iter().map(|sq| sq.query.clone()));
queries.extend(
hyde.iter()
.filter(|h| h.confidence >= self.config.confidence_threshold)
.map(|h| h.hypothetical_answer.clone()),
);
queries.sort();
queries.dedup();
queries.truncate(self.config.max_variants);
queries
}
}