use crate::query_planning::QueryStrategy;
use crate::Vector;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRewriterConfig {
pub enable_expansion: bool,
pub enable_reduction: bool,
pub enable_parameter_tuning: bool,
pub enable_caching: bool,
pub max_expansion_factor: f32,
pub min_confidence: f32,
pub enable_learning: bool,
}
impl Default for QueryRewriterConfig {
fn default() -> Self {
Self {
enable_expansion: true,
enable_reduction: true,
enable_parameter_tuning: true,
enable_caching: true,
max_expansion_factor: 2.0,
min_confidence: 0.7,
enable_learning: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RewriteRule {
ExpandK,
ReduceK,
TuneParameters,
SuggestIndex,
NormalizeQuery,
RemoveOutliers,
BoostDimensions,
ApplyFilters,
}
#[derive(Debug, Clone)]
pub struct RewrittenQuery {
pub original_vector: Vector,
pub rewritten_vector: Vector,
pub original_k: usize,
pub optimized_k: usize,
pub applied_rules: Vec<RewriteRule>,
pub suggested_strategy: Option<QueryStrategy>,
pub parameters: HashMap<String, String>,
pub confidence: f32,
pub estimated_improvement: f32,
}
#[derive(Debug, Clone, Default)]
pub struct QueryVectorStatistics {
pub dimensions: usize,
pub norm: f32,
pub sparsity: f32,
pub std_dev: f32,
pub mean: f32,
pub max_value: f32,
pub min_value: f32,
}
impl QueryVectorStatistics {
pub fn from_vector(vector: &Vector) -> Self {
let values = vector.as_f32();
let n = values.len() as f32;
if values.is_empty() {
return Self::default();
}
let sum: f32 = values.iter().sum();
let mean = sum / n;
let variance: f32 = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
let std_dev = variance.sqrt();
let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
let max_value = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let min_value = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let threshold = 1e-6;
let near_zero_count = values.iter().filter(|&&v| v.abs() < threshold).count();
let sparsity = near_zero_count as f32 / n;
Self {
dimensions: values.len(),
norm,
sparsity,
std_dev,
mean,
max_value,
min_value,
}
}
}
pub struct QueryRewriter {
config: QueryRewriterConfig,
rule_stats: HashMap<RewriteRule, RuleStatistics>,
query_cache: HashMap<String, RewrittenQuery>,
}
#[derive(Debug, Clone, Default)]
pub struct RuleStatistics {
pub times_applied: usize,
pub times_successful: usize,
pub avg_improvement: f64,
}
impl QueryRewriter {
pub fn new(config: QueryRewriterConfig) -> Self {
Self {
config,
rule_stats: HashMap::new(),
query_cache: HashMap::new(),
}
}
pub fn rewrite(&mut self, query: &Vector, k: usize) -> Result<RewrittenQuery> {
let cache_key = self.cache_key(query, k);
if self.config.enable_caching {
if let Some(cached) = self.query_cache.get(&cache_key) {
debug!("Query cache hit");
return Ok(cached.clone());
}
}
let stats = QueryVectorStatistics::from_vector(query);
debug!(
"Query stats: dim={}, norm={:.2}, sparsity={:.2}, std_dev={:.2}",
stats.dimensions, stats.norm, stats.sparsity, stats.std_dev
);
let mut rewritten = RewrittenQuery {
original_vector: query.clone(),
rewritten_vector: query.clone(),
original_k: k,
optimized_k: k,
applied_rules: Vec::new(),
suggested_strategy: None,
parameters: HashMap::new(),
confidence: 1.0,
estimated_improvement: 0.0,
};
if self.config.enable_parameter_tuning {
self.tune_k(&mut rewritten, &stats)?;
}
if self.config.enable_expansion {
self.apply_expansion(&mut rewritten, &stats)?;
}
if self.config.enable_reduction {
self.apply_reduction(&mut rewritten, &stats)?;
}
self.suggest_strategy(&mut rewritten, &stats)?;
if self.should_normalize(&stats) {
self.normalize_query(&mut rewritten)?;
}
rewritten.confidence = self.calculate_confidence(&rewritten);
if rewritten.confidence < self.config.min_confidence {
debug!(
"Rewrite confidence too low ({:.2}), keeping original query",
rewritten.confidence
);
rewritten.rewritten_vector = query.clone();
rewritten.optimized_k = k;
rewritten.applied_rules.clear();
}
if self.config.enable_caching {
self.query_cache.insert(cache_key, rewritten.clone());
}
Ok(rewritten)
}
fn tune_k(
&mut self,
rewritten: &mut RewrittenQuery,
stats: &QueryVectorStatistics,
) -> Result<()> {
let original_k = rewritten.optimized_k;
let mut new_k = original_k;
if stats.sparsity < 0.1 && stats.norm > 1.0 {
new_k = (original_k as f32 * 0.8) as usize;
new_k = new_k.max(1);
debug!(
"Reducing k from {} to {} (high-selectivity query)",
original_k, new_k
);
rewritten.applied_rules.push(RewriteRule::ReduceK);
}
if stats.sparsity > 0.5 && stats.std_dev < 0.1 {
new_k = (original_k as f32 * self.config.max_expansion_factor) as usize;
new_k = new_k.min(1000); debug!(
"Expanding k from {} to {} (low-selectivity query)",
original_k, new_k
);
rewritten.applied_rules.push(RewriteRule::ExpandK);
}
rewritten.optimized_k = new_k;
rewritten.estimated_improvement +=
(new_k as f32 - original_k as f32).abs() / original_k as f32 * 10.0;
self.record_rule_application(RewriteRule::TuneParameters);
Ok(())
}
fn apply_expansion(
&mut self,
_rewritten: &mut RewrittenQuery,
stats: &QueryVectorStatistics,
) -> Result<()> {
if stats.sparsity > 0.6 {
debug!("Query is sparse, expansion could be beneficial");
}
Ok(())
}
fn apply_reduction(
&mut self,
rewritten: &mut RewrittenQuery,
stats: &QueryVectorStatistics,
) -> Result<()> {
if stats.std_dev > 2.0 {
debug!("High variance detected, considering outlier removal");
rewritten.applied_rules.push(RewriteRule::RemoveOutliers);
self.record_rule_application(RewriteRule::RemoveOutliers);
}
Ok(())
}
fn suggest_strategy(
&self,
rewritten: &mut RewrittenQuery,
stats: &QueryVectorStatistics,
) -> Result<()> {
let strategy = if stats.sparsity > 0.7 {
QueryStrategy::LocalitySensitiveHashing
} else if stats.dimensions > 512 {
QueryStrategy::ProductQuantization
} else if stats.norm > 10.0 {
QueryStrategy::NsgApproximate
} else {
QueryStrategy::HnswApproximate
};
rewritten.suggested_strategy = Some(strategy);
rewritten.applied_rules.push(RewriteRule::SuggestIndex);
Ok(())
}
fn should_normalize(&self, stats: &QueryVectorStatistics) -> bool {
(stats.norm - 1.0).abs() > 0.1
}
fn normalize_query(&mut self, rewritten: &mut RewrittenQuery) -> Result<()> {
let values = rewritten.rewritten_vector.as_f32();
let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 1e-6 {
let normalized: Vec<f32> = values.iter().map(|v| v / norm).collect();
rewritten.rewritten_vector = Vector::new(normalized);
rewritten.applied_rules.push(RewriteRule::NormalizeQuery);
debug!("Query normalized (original norm: {:.2})", norm);
self.record_rule_application(RewriteRule::NormalizeQuery);
}
Ok(())
}
fn calculate_confidence(&self, rewritten: &RewrittenQuery) -> f32 {
let mut confidence = 1.0;
confidence -= rewritten.applied_rules.len() as f32 * 0.05;
let k_change_ratio =
(rewritten.optimized_k as f32 / rewritten.original_k as f32 - 1.0).abs();
confidence -= k_change_ratio * 0.2;
for rule in &rewritten.applied_rules {
if let Some(stats) = self.rule_stats.get(rule) {
if stats.times_applied > 0 {
let success_rate = stats.times_successful as f32 / stats.times_applied as f32;
confidence *= success_rate;
}
}
}
confidence.clamp(0.0, 1.0)
}
fn record_rule_application(&mut self, rule: RewriteRule) {
if !self.config.enable_learning {
return;
}
self.rule_stats.entry(rule).or_default().times_applied += 1;
}
pub fn record_rule_success(&mut self, rule: RewriteRule, improvement: f64) {
if !self.config.enable_learning {
return;
}
let stats = self.rule_stats.entry(rule).or_default();
stats.times_successful += 1;
stats.avg_improvement = (stats.avg_improvement * (stats.times_successful - 1) as f64
+ improvement)
/ stats.times_successful as f64;
}
fn cache_key(&self, query: &Vector, k: usize) -> String {
let values = query.as_f32();
let hash: u64 = values
.iter()
.map(|v| (v * 1000.0) as i32)
.fold(0u64, |acc, v| acc.wrapping_mul(31).wrapping_add(v as u64));
format!("{:x}_{}", hash, k)
}
pub fn clear_cache(&mut self) {
self.query_cache.clear();
}
pub fn rule_statistics(&self) -> &HashMap<RewriteRule, RuleStatistics> {
&self.rule_stats
}
pub fn cache_size(&self) -> usize {
self.query_cache.len()
}
}
impl Default for QueryRewriter {
fn default() -> Self {
Self::new(QueryRewriterConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_statistics() {
let vector = Vector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let stats = QueryVectorStatistics::from_vector(&vector);
assert_eq!(stats.dimensions, 5);
assert!(stats.norm > 0.0);
assert!(stats.std_dev > 0.0);
}
#[test]
fn test_query_rewriter_creation() {
let config = QueryRewriterConfig::default();
let _rewriter = QueryRewriter::new(config);
}
#[test]
fn test_query_rewrite() -> Result<()> {
let config = QueryRewriterConfig {
min_confidence: 0.5, ..Default::default()
};
let mut rewriter = QueryRewriter::new(config);
let query = Vector::new(vec![1.0, 2.0, 3.0, 4.0]);
let result = rewriter.rewrite(&query, 10)?;
assert_eq!(result.original_k, 10);
assert!(result.confidence >= 0.0);
Ok(())
}
#[test]
fn test_normalize_query() -> Result<()> {
let config = QueryRewriterConfig {
min_confidence: 0.5, ..Default::default()
};
let mut rewriter = QueryRewriter::new(config);
let query = Vector::new(vec![3.0, 4.0]); let result = rewriter.rewrite(&query, 10)?;
let normalized_values = result.rewritten_vector.as_f32();
let norm: f32 = normalized_values.iter().map(|v| v * v).sum::<f32>().sqrt();
if result.applied_rules.contains(&RewriteRule::NormalizeQuery) {
assert!(
(norm - 1.0).abs() < 0.01,
"Expected norm close to 1.0, got {}",
norm
);
} else {
assert!(
(norm - 5.0).abs() < 0.01,
"Expected original norm ~5.0, got {}",
norm
);
}
Ok(())
}
#[test]
fn test_k_tuning_sparse_query() -> Result<()> {
let config = QueryRewriterConfig {
enable_parameter_tuning: true,
..Default::default()
};
let mut rewriter = QueryRewriter::new(config);
let query = Vector::new(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
let result = rewriter.rewrite(&query, 10)?;
assert!(result.optimized_k >= result.original_k);
Ok(())
}
#[test]
fn test_caching() -> Result<()> {
let config = QueryRewriterConfig {
enable_caching: true,
..Default::default()
};
let mut rewriter = QueryRewriter::new(config);
let query = Vector::new(vec![1.0, 2.0, 3.0]);
let _result1 = rewriter.rewrite(&query, 10)?;
assert_eq!(rewriter.cache_size(), 1);
let _result2 = rewriter.rewrite(&query, 10)?;
assert_eq!(rewriter.cache_size(), 1);
let _result3 = rewriter.rewrite(&query, 20)?;
assert_eq!(rewriter.cache_size(), 2);
Ok(())
}
#[test]
fn test_rule_learning() -> Result<()> {
let config = QueryRewriterConfig {
enable_learning: true,
..Default::default()
};
let mut rewriter = QueryRewriter::new(config);
let query = Vector::new(vec![1.0, 2.0, 3.0]);
rewriter.rewrite(&query, 10)?;
rewriter.record_rule_success(RewriteRule::NormalizeQuery, 0.15);
let stats = rewriter.rule_statistics();
assert!(stats.contains_key(&RewriteRule::NormalizeQuery));
Ok(())
}
}