1use crate::query_planning::QueryStrategy;
35use crate::Vector;
36use anyhow::Result;
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use tracing::debug;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct QueryRewriterConfig {
44 pub enable_expansion: bool,
46 pub enable_reduction: bool,
48 pub enable_parameter_tuning: bool,
50 pub enable_caching: bool,
52 pub max_expansion_factor: f32,
54 pub min_confidence: f32,
56 pub enable_learning: bool,
58}
59
60impl Default for QueryRewriterConfig {
61 fn default() -> Self {
62 Self {
63 enable_expansion: true,
64 enable_reduction: true,
65 enable_parameter_tuning: true,
66 enable_caching: true,
67 max_expansion_factor: 2.0,
68 min_confidence: 0.7,
69 enable_learning: true,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
76pub enum RewriteRule {
77 ExpandK,
79 ReduceK,
81 TuneParameters,
83 SuggestIndex,
85 NormalizeQuery,
87 RemoveOutliers,
89 BoostDimensions,
91 ApplyFilters,
93}
94
95#[derive(Debug, Clone)]
97pub struct RewrittenQuery {
98 pub original_vector: Vector,
100 pub rewritten_vector: Vector,
102 pub original_k: usize,
104 pub optimized_k: usize,
106 pub applied_rules: Vec<RewriteRule>,
108 pub suggested_strategy: Option<QueryStrategy>,
110 pub parameters: HashMap<String, String>,
112 pub confidence: f32,
114 pub estimated_improvement: f32,
116}
117
118#[derive(Debug, Clone, Default)]
120pub struct QueryVectorStatistics {
121 pub dimensions: usize,
123 pub norm: f32,
125 pub sparsity: f32,
127 pub std_dev: f32,
129 pub mean: f32,
131 pub max_value: f32,
133 pub min_value: f32,
135}
136
137impl QueryVectorStatistics {
138 pub fn from_vector(vector: &Vector) -> Self {
140 let values = vector.as_f32();
141 let n = values.len() as f32;
142
143 if values.is_empty() {
144 return Self::default();
145 }
146
147 let sum: f32 = values.iter().sum();
148 let mean = sum / n;
149
150 let variance: f32 = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
151 let std_dev = variance.sqrt();
152
153 let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
154
155 let max_value = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
156 let min_value = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
157
158 let threshold = 1e-6;
160 let near_zero_count = values.iter().filter(|&&v| v.abs() < threshold).count();
161 let sparsity = near_zero_count as f32 / n;
162
163 Self {
164 dimensions: values.len(),
165 norm,
166 sparsity,
167 std_dev,
168 mean,
169 max_value,
170 min_value,
171 }
172 }
173}
174
175pub struct QueryRewriter {
177 config: QueryRewriterConfig,
178 rule_stats: HashMap<RewriteRule, RuleStatistics>,
179 query_cache: HashMap<String, RewrittenQuery>,
180}
181
182#[derive(Debug, Clone, Default)]
184pub struct RuleStatistics {
185 pub times_applied: usize,
186 pub times_successful: usize,
187 pub avg_improvement: f64,
188}
189
190impl QueryRewriter {
191 pub fn new(config: QueryRewriterConfig) -> Self {
193 Self {
194 config,
195 rule_stats: HashMap::new(),
196 query_cache: HashMap::new(),
197 }
198 }
199
200 pub fn rewrite(&mut self, query: &Vector, k: usize) -> Result<RewrittenQuery> {
202 let cache_key = self.cache_key(query, k);
204 if self.config.enable_caching {
205 if let Some(cached) = self.query_cache.get(&cache_key) {
206 debug!("Query cache hit");
207 return Ok(cached.clone());
208 }
209 }
210
211 let stats = QueryVectorStatistics::from_vector(query);
213 debug!(
214 "Query stats: dim={}, norm={:.2}, sparsity={:.2}, std_dev={:.2}",
215 stats.dimensions, stats.norm, stats.sparsity, stats.std_dev
216 );
217
218 let mut rewritten = RewrittenQuery {
220 original_vector: query.clone(),
221 rewritten_vector: query.clone(),
222 original_k: k,
223 optimized_k: k,
224 applied_rules: Vec::new(),
225 suggested_strategy: None,
226 parameters: HashMap::new(),
227 confidence: 1.0,
228 estimated_improvement: 0.0,
229 };
230
231 if self.config.enable_parameter_tuning {
233 self.tune_k(&mut rewritten, &stats)?;
234 }
235
236 if self.config.enable_expansion {
237 self.apply_expansion(&mut rewritten, &stats)?;
238 }
239
240 if self.config.enable_reduction {
241 self.apply_reduction(&mut rewritten, &stats)?;
242 }
243
244 self.suggest_strategy(&mut rewritten, &stats)?;
246
247 if self.should_normalize(&stats) {
249 self.normalize_query(&mut rewritten)?;
250 }
251
252 rewritten.confidence = self.calculate_confidence(&rewritten);
254
255 if rewritten.confidence < self.config.min_confidence {
257 debug!(
258 "Rewrite confidence too low ({:.2}), keeping original query",
259 rewritten.confidence
260 );
261 rewritten.rewritten_vector = query.clone();
262 rewritten.optimized_k = k;
263 rewritten.applied_rules.clear();
264 }
265
266 if self.config.enable_caching {
268 self.query_cache.insert(cache_key, rewritten.clone());
269 }
270
271 Ok(rewritten)
272 }
273
274 fn tune_k(
276 &mut self,
277 rewritten: &mut RewrittenQuery,
278 stats: &QueryVectorStatistics,
279 ) -> Result<()> {
280 let original_k = rewritten.optimized_k;
281 let mut new_k = original_k;
282
283 if stats.sparsity < 0.1 && stats.norm > 1.0 {
285 new_k = (original_k as f32 * 0.8) as usize;
286 new_k = new_k.max(1);
287 debug!(
288 "Reducing k from {} to {} (high-selectivity query)",
289 original_k, new_k
290 );
291 rewritten.applied_rules.push(RewriteRule::ReduceK);
292 }
293
294 if stats.sparsity > 0.5 && stats.std_dev < 0.1 {
296 new_k = (original_k as f32 * self.config.max_expansion_factor) as usize;
297 new_k = new_k.min(1000); debug!(
299 "Expanding k from {} to {} (low-selectivity query)",
300 original_k, new_k
301 );
302 rewritten.applied_rules.push(RewriteRule::ExpandK);
303 }
304
305 rewritten.optimized_k = new_k;
306 rewritten.estimated_improvement +=
307 (new_k as f32 - original_k as f32).abs() / original_k as f32 * 10.0;
308
309 self.record_rule_application(RewriteRule::TuneParameters);
310
311 Ok(())
312 }
313
314 fn apply_expansion(
316 &mut self,
317 _rewritten: &mut RewrittenQuery,
318 stats: &QueryVectorStatistics,
319 ) -> Result<()> {
320 if stats.sparsity > 0.6 {
323 debug!("Query is sparse, expansion could be beneficial");
324 }
325 Ok(())
326 }
327
328 fn apply_reduction(
330 &mut self,
331 rewritten: &mut RewrittenQuery,
332 stats: &QueryVectorStatistics,
333 ) -> Result<()> {
334 if stats.std_dev > 2.0 {
336 debug!("High variance detected, considering outlier removal");
337 rewritten.applied_rules.push(RewriteRule::RemoveOutliers);
338 self.record_rule_application(RewriteRule::RemoveOutliers);
339 }
340 Ok(())
341 }
342
343 fn suggest_strategy(
345 &self,
346 rewritten: &mut RewrittenQuery,
347 stats: &QueryVectorStatistics,
348 ) -> Result<()> {
349 let strategy = if stats.sparsity > 0.7 {
351 QueryStrategy::LocalitySensitiveHashing
353 } else if stats.dimensions > 512 {
354 QueryStrategy::ProductQuantization
356 } else if stats.norm > 10.0 {
357 QueryStrategy::NsgApproximate
359 } else {
360 QueryStrategy::HnswApproximate
362 };
363
364 rewritten.suggested_strategy = Some(strategy);
365 rewritten.applied_rules.push(RewriteRule::SuggestIndex);
366
367 Ok(())
368 }
369
370 fn should_normalize(&self, stats: &QueryVectorStatistics) -> bool {
372 (stats.norm - 1.0).abs() > 0.1
374 }
375
376 fn normalize_query(&mut self, rewritten: &mut RewrittenQuery) -> Result<()> {
378 let values = rewritten.rewritten_vector.as_f32();
379 let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
380
381 if norm > 1e-6 {
382 let normalized: Vec<f32> = values.iter().map(|v| v / norm).collect();
383 rewritten.rewritten_vector = Vector::new(normalized);
384 rewritten.applied_rules.push(RewriteRule::NormalizeQuery);
385 debug!("Query normalized (original norm: {:.2})", norm);
386 self.record_rule_application(RewriteRule::NormalizeQuery);
387 }
388
389 Ok(())
390 }
391
392 fn calculate_confidence(&self, rewritten: &RewrittenQuery) -> f32 {
394 let mut confidence = 1.0;
396
397 confidence -= rewritten.applied_rules.len() as f32 * 0.05;
399
400 let k_change_ratio =
402 (rewritten.optimized_k as f32 / rewritten.original_k as f32 - 1.0).abs();
403 confidence -= k_change_ratio * 0.2;
404
405 for rule in &rewritten.applied_rules {
407 if let Some(stats) = self.rule_stats.get(rule) {
408 if stats.times_applied > 0 {
409 let success_rate = stats.times_successful as f32 / stats.times_applied as f32;
410 confidence *= success_rate;
411 }
412 }
413 }
414
415 confidence.clamp(0.0, 1.0)
416 }
417
418 fn record_rule_application(&mut self, rule: RewriteRule) {
420 if !self.config.enable_learning {
421 return;
422 }
423
424 self.rule_stats.entry(rule).or_default().times_applied += 1;
425 }
426
427 pub fn record_rule_success(&mut self, rule: RewriteRule, improvement: f64) {
429 if !self.config.enable_learning {
430 return;
431 }
432
433 let stats = self.rule_stats.entry(rule).or_default();
434
435 stats.times_successful += 1;
436 stats.avg_improvement = (stats.avg_improvement * (stats.times_successful - 1) as f64
437 + improvement)
438 / stats.times_successful as f64;
439 }
440
441 fn cache_key(&self, query: &Vector, k: usize) -> String {
443 let values = query.as_f32();
445 let hash: u64 = values
446 .iter()
447 .map(|v| (v * 1000.0) as i32)
448 .fold(0u64, |acc, v| acc.wrapping_mul(31).wrapping_add(v as u64));
449
450 format!("{:x}_{}", hash, k)
451 }
452
453 pub fn clear_cache(&mut self) {
455 self.query_cache.clear();
456 }
457
458 pub fn rule_statistics(&self) -> &HashMap<RewriteRule, RuleStatistics> {
460 &self.rule_stats
461 }
462
463 pub fn cache_size(&self) -> usize {
465 self.query_cache.len()
466 }
467}
468
469impl Default for QueryRewriter {
470 fn default() -> Self {
471 Self::new(QueryRewriterConfig::default())
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_query_statistics() {
481 let vector = Vector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
482 let stats = QueryVectorStatistics::from_vector(&vector);
483
484 assert_eq!(stats.dimensions, 5);
485 assert!(stats.norm > 0.0);
486 assert!(stats.std_dev > 0.0);
487 }
488
489 #[test]
490 fn test_query_rewriter_creation() {
491 let config = QueryRewriterConfig::default();
492 let _rewriter = QueryRewriter::new(config);
493 }
494
495 #[test]
496 fn test_query_rewrite() {
497 let config = QueryRewriterConfig {
498 min_confidence: 0.5, ..Default::default()
500 };
501 let mut rewriter = QueryRewriter::new(config);
502
503 let query = Vector::new(vec![1.0, 2.0, 3.0, 4.0]);
504 let result = rewriter.rewrite(&query, 10).unwrap();
505
506 assert_eq!(result.original_k, 10);
507 assert!(result.confidence >= 0.0);
509 }
510
511 #[test]
512 fn test_normalize_query() {
513 let config = QueryRewriterConfig {
514 min_confidence: 0.5, ..Default::default()
516 };
517 let mut rewriter = QueryRewriter::new(config);
518
519 let query = Vector::new(vec![3.0, 4.0]); let result = rewriter.rewrite(&query, 10).unwrap();
522
523 let normalized_values = result.rewritten_vector.as_f32();
525 let norm: f32 = normalized_values.iter().map(|v| v * v).sum::<f32>().sqrt();
526
527 if result.applied_rules.contains(&RewriteRule::NormalizeQuery) {
529 assert!(
530 (norm - 1.0).abs() < 0.01,
531 "Expected norm close to 1.0, got {}",
532 norm
533 );
534 } else {
535 assert!(
537 (norm - 5.0).abs() < 0.01,
538 "Expected original norm ~5.0, got {}",
539 norm
540 );
541 }
542 }
543
544 #[test]
545 fn test_k_tuning_sparse_query() {
546 let config = QueryRewriterConfig {
547 enable_parameter_tuning: true,
548 ..Default::default()
549 };
550 let mut rewriter = QueryRewriter::new(config);
551
552 let query = Vector::new(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
554 let result = rewriter.rewrite(&query, 10).unwrap();
555
556 assert!(result.optimized_k >= result.original_k);
558 }
559
560 #[test]
561 fn test_caching() {
562 let config = QueryRewriterConfig {
563 enable_caching: true,
564 ..Default::default()
565 };
566 let mut rewriter = QueryRewriter::new(config);
567
568 let query = Vector::new(vec![1.0, 2.0, 3.0]);
569
570 let _result1 = rewriter.rewrite(&query, 10).unwrap();
572 assert_eq!(rewriter.cache_size(), 1);
573
574 let _result2 = rewriter.rewrite(&query, 10).unwrap();
576 assert_eq!(rewriter.cache_size(), 1);
577
578 let _result3 = rewriter.rewrite(&query, 20).unwrap();
580 assert_eq!(rewriter.cache_size(), 2);
581 }
582
583 #[test]
584 fn test_rule_learning() {
585 let config = QueryRewriterConfig {
586 enable_learning: true,
587 ..Default::default()
588 };
589 let mut rewriter = QueryRewriter::new(config);
590
591 let query = Vector::new(vec![1.0, 2.0, 3.0]);
592 rewriter.rewrite(&query, 10).unwrap();
593
594 rewriter.record_rule_success(RewriteRule::NormalizeQuery, 0.15);
596
597 let stats = rewriter.rule_statistics();
598 assert!(stats.contains_key(&RewriteRule::NormalizeQuery));
599 }
600}