oxirs_vec/
query_rewriter.rs

1//! Query rewriting and optimization for vector search
2//!
3//! This module provides automatic query rewriting and optimization to improve
4//! search performance and accuracy. It transforms queries before execution using:
5//!
6//! - **Query expansion**: Add related terms/vectors
7//! - **Query reduction**: Remove redundant components
8//! - **Parameter tuning**: Optimize search parameters
9//! - **Index selection hints**: Suggest best indices to use
10//! - **Semantic optimization**: Improve semantic relevance
11//!
12//! # Features
13//!
14//! - Rule-based rewriting
15//! - Statistics-driven optimization
16//! - Query plan caching
17//! - Performance prediction
18//! - Automatic parameter tuning
19//!
20//! # Example
21//!
22//! ```rust,ignore
23//! use oxirs_vec::query_rewriter::{QueryRewriter, RewriteRule};
24//! use oxirs_vec::Vector;
25//!
26//! let rewriter = QueryRewriter::new();
27//!
28//! let query = Vector::new(vec![1.0, 2.0, 3.0]);
29//! let rewritten = rewriter.rewrite(&query, 10).unwrap();
30//!
31//! println!("Original k: 10, Optimized k: {}", rewritten.optimized_k);
32//! ```
33
34use crate::query_planning::QueryStrategy;
35use crate::Vector;
36use anyhow::Result;
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use tracing::debug;
40
41/// Query rewriting configuration
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct QueryRewriterConfig {
44    /// Enable query expansion
45    pub enable_expansion: bool,
46    /// Enable query reduction
47    pub enable_reduction: bool,
48    /// Enable parameter tuning
49    pub enable_parameter_tuning: bool,
50    /// Enable query caching
51    pub enable_caching: bool,
52    /// Maximum expansion factor
53    pub max_expansion_factor: f32,
54    /// Minimum confidence for rewriting
55    pub min_confidence: f32,
56    /// Enable learning from query performance
57    pub enable_learning: bool,
58}
59
60impl Default for QueryRewriterConfig {
61    fn default() -> Self {
62        Self {
63            enable_expansion: true,
64            enable_reduction: true,
65            enable_parameter_tuning: true,
66            enable_caching: true,
67            max_expansion_factor: 2.0,
68            min_confidence: 0.7,
69            enable_learning: true,
70        }
71    }
72}
73
74/// Rewrite rule
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
76pub enum RewriteRule {
77    /// Expand k for low-selectivity queries
78    ExpandK,
79    /// Reduce k for high-selectivity queries
80    ReduceK,
81    /// Adjust search parameters based on query characteristics
82    TuneParameters,
83    /// Suggest better index for query
84    SuggestIndex,
85    /// Normalize query vector
86    NormalizeQuery,
87    /// Remove outliers from query vector
88    RemoveOutliers,
89    /// Boost important dimensions
90    BoostDimensions,
91    /// Apply query-specific filters
92    ApplyFilters,
93}
94
95/// Rewritten query
96#[derive(Debug, Clone)]
97pub struct RewrittenQuery {
98    /// Original query vector
99    pub original_vector: Vector,
100    /// Rewritten query vector (may be same as original)
101    pub rewritten_vector: Vector,
102    /// Original k value
103    pub original_k: usize,
104    /// Optimized k value
105    pub optimized_k: usize,
106    /// Applied rewrite rules
107    pub applied_rules: Vec<RewriteRule>,
108    /// Suggested query strategy
109    pub suggested_strategy: Option<QueryStrategy>,
110    /// Optimized parameters
111    pub parameters: HashMap<String, String>,
112    /// Confidence in rewrite (0.0 to 1.0)
113    pub confidence: f32,
114    /// Estimated performance improvement (%)
115    pub estimated_improvement: f32,
116}
117
118/// Query statistics for optimization
119#[derive(Debug, Clone, Default)]
120pub struct QueryVectorStatistics {
121    /// Query vector dimensionality
122    pub dimensions: usize,
123    /// Query vector norm
124    pub norm: f32,
125    /// Query vector sparsity (ratio of near-zero values)
126    pub sparsity: f32,
127    /// Standard deviation of components
128    pub std_dev: f32,
129    /// Mean of components
130    pub mean: f32,
131    /// Max component value
132    pub max_value: f32,
133    /// Min component value
134    pub min_value: f32,
135}
136
137impl QueryVectorStatistics {
138    /// Compute statistics from a query vector
139    pub fn from_vector(vector: &Vector) -> Self {
140        let values = vector.as_f32();
141        let n = values.len() as f32;
142
143        if values.is_empty() {
144            return Self::default();
145        }
146
147        let sum: f32 = values.iter().sum();
148        let mean = sum / n;
149
150        let variance: f32 = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n;
151        let std_dev = variance.sqrt();
152
153        let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
154
155        let max_value = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
156        let min_value = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
157
158        // Sparsity: count near-zero values
159        let threshold = 1e-6;
160        let near_zero_count = values.iter().filter(|&&v| v.abs() < threshold).count();
161        let sparsity = near_zero_count as f32 / n;
162
163        Self {
164            dimensions: values.len(),
165            norm,
166            sparsity,
167            std_dev,
168            mean,
169            max_value,
170            min_value,
171        }
172    }
173}
174
175/// Query rewriter
176pub struct QueryRewriter {
177    config: QueryRewriterConfig,
178    rule_stats: HashMap<RewriteRule, RuleStatistics>,
179    query_cache: HashMap<String, RewrittenQuery>,
180}
181
182/// Statistics for a rewrite rule
183#[derive(Debug, Clone, Default)]
184pub struct RuleStatistics {
185    pub times_applied: usize,
186    pub times_successful: usize,
187    pub avg_improvement: f64,
188}
189
190impl QueryRewriter {
191    /// Create a new query rewriter
192    pub fn new(config: QueryRewriterConfig) -> Self {
193        Self {
194            config,
195            rule_stats: HashMap::new(),
196            query_cache: HashMap::new(),
197        }
198    }
199
200    /// Rewrite a query for optimal performance
201    pub fn rewrite(&mut self, query: &Vector, k: usize) -> Result<RewrittenQuery> {
202        // Check cache first
203        let cache_key = self.cache_key(query, k);
204        if self.config.enable_caching {
205            if let Some(cached) = self.query_cache.get(&cache_key) {
206                debug!("Query cache hit");
207                return Ok(cached.clone());
208            }
209        }
210
211        // Compute query statistics
212        let stats = QueryVectorStatistics::from_vector(query);
213        debug!(
214            "Query stats: dim={}, norm={:.2}, sparsity={:.2}, std_dev={:.2}",
215            stats.dimensions, stats.norm, stats.sparsity, stats.std_dev
216        );
217
218        // Initialize rewritten query
219        let mut rewritten = RewrittenQuery {
220            original_vector: query.clone(),
221            rewritten_vector: query.clone(),
222            original_k: k,
223            optimized_k: k,
224            applied_rules: Vec::new(),
225            suggested_strategy: None,
226            parameters: HashMap::new(),
227            confidence: 1.0,
228            estimated_improvement: 0.0,
229        };
230
231        // Apply rewrite rules
232        if self.config.enable_parameter_tuning {
233            self.tune_k(&mut rewritten, &stats)?;
234        }
235
236        if self.config.enable_expansion {
237            self.apply_expansion(&mut rewritten, &stats)?;
238        }
239
240        if self.config.enable_reduction {
241            self.apply_reduction(&mut rewritten, &stats)?;
242        }
243
244        // Suggest optimal strategy
245        self.suggest_strategy(&mut rewritten, &stats)?;
246
247        // Normalize query if beneficial
248        if self.should_normalize(&stats) {
249            self.normalize_query(&mut rewritten)?;
250        }
251
252        // Calculate confidence
253        rewritten.confidence = self.calculate_confidence(&rewritten);
254
255        // Only apply rewrite if confidence is high enough
256        if rewritten.confidence < self.config.min_confidence {
257            debug!(
258                "Rewrite confidence too low ({:.2}), keeping original query",
259                rewritten.confidence
260            );
261            rewritten.rewritten_vector = query.clone();
262            rewritten.optimized_k = k;
263            rewritten.applied_rules.clear();
264        }
265
266        // Cache result
267        if self.config.enable_caching {
268            self.query_cache.insert(cache_key, rewritten.clone());
269        }
270
271        Ok(rewritten)
272    }
273
274    /// Tune k parameter based on query characteristics
275    fn tune_k(
276        &mut self,
277        rewritten: &mut RewrittenQuery,
278        stats: &QueryVectorStatistics,
279    ) -> Result<()> {
280        let original_k = rewritten.optimized_k;
281        let mut new_k = original_k;
282
283        // If query is very specific (low sparsity, high norm), reduce k
284        if stats.sparsity < 0.1 && stats.norm > 1.0 {
285            new_k = (original_k as f32 * 0.8) as usize;
286            new_k = new_k.max(1);
287            debug!(
288                "Reducing k from {} to {} (high-selectivity query)",
289                original_k, new_k
290            );
291            rewritten.applied_rules.push(RewriteRule::ReduceK);
292        }
293
294        // If query is very general (high sparsity, low variance), increase k
295        if stats.sparsity > 0.5 && stats.std_dev < 0.1 {
296            new_k = (original_k as f32 * self.config.max_expansion_factor) as usize;
297            new_k = new_k.min(1000); // Cap at reasonable value
298            debug!(
299                "Expanding k from {} to {} (low-selectivity query)",
300                original_k, new_k
301            );
302            rewritten.applied_rules.push(RewriteRule::ExpandK);
303        }
304
305        rewritten.optimized_k = new_k;
306        rewritten.estimated_improvement +=
307            (new_k as f32 - original_k as f32).abs() / original_k as f32 * 10.0;
308
309        self.record_rule_application(RewriteRule::TuneParameters);
310
311        Ok(())
312    }
313
314    /// Apply query expansion
315    fn apply_expansion(
316        &mut self,
317        _rewritten: &mut RewrittenQuery,
318        stats: &QueryVectorStatistics,
319    ) -> Result<()> {
320        // Query expansion would add related vectors or boost dimensions
321        // For now, this is a placeholder for future implementation
322        if stats.sparsity > 0.6 {
323            debug!("Query is sparse, expansion could be beneficial");
324        }
325        Ok(())
326    }
327
328    /// Apply query reduction
329    fn apply_reduction(
330        &mut self,
331        rewritten: &mut RewrittenQuery,
332        stats: &QueryVectorStatistics,
333    ) -> Result<()> {
334        // Remove outlier dimensions
335        if stats.std_dev > 2.0 {
336            debug!("High variance detected, considering outlier removal");
337            rewritten.applied_rules.push(RewriteRule::RemoveOutliers);
338            self.record_rule_application(RewriteRule::RemoveOutliers);
339        }
340        Ok(())
341    }
342
343    /// Suggest optimal query strategy
344    fn suggest_strategy(
345        &self,
346        rewritten: &mut RewrittenQuery,
347        stats: &QueryVectorStatistics,
348    ) -> Result<()> {
349        // Suggest strategy based on query characteristics
350        let strategy = if stats.sparsity > 0.7 {
351            // Sparse queries work well with LSH
352            QueryStrategy::LocalitySensitiveHashing
353        } else if stats.dimensions > 512 {
354            // High-dimensional queries benefit from PQ
355            QueryStrategy::ProductQuantization
356        } else if stats.norm > 10.0 {
357            // High-norm queries work well with NSG
358            QueryStrategy::NsgApproximate
359        } else {
360            // Default to HNSW
361            QueryStrategy::HnswApproximate
362        };
363
364        rewritten.suggested_strategy = Some(strategy);
365        rewritten.applied_rules.push(RewriteRule::SuggestIndex);
366
367        Ok(())
368    }
369
370    /// Check if query should be normalized
371    fn should_normalize(&self, stats: &QueryVectorStatistics) -> bool {
372        // Normalize if norm is far from 1.0
373        (stats.norm - 1.0).abs() > 0.1
374    }
375
376    /// Normalize query vector
377    fn normalize_query(&mut self, rewritten: &mut RewrittenQuery) -> Result<()> {
378        let values = rewritten.rewritten_vector.as_f32();
379        let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
380
381        if norm > 1e-6 {
382            let normalized: Vec<f32> = values.iter().map(|v| v / norm).collect();
383            rewritten.rewritten_vector = Vector::new(normalized);
384            rewritten.applied_rules.push(RewriteRule::NormalizeQuery);
385            debug!("Query normalized (original norm: {:.2})", norm);
386            self.record_rule_application(RewriteRule::NormalizeQuery);
387        }
388
389        Ok(())
390    }
391
392    /// Calculate confidence in rewrite
393    fn calculate_confidence(&self, rewritten: &RewrittenQuery) -> f32 {
394        // Base confidence
395        let mut confidence = 1.0;
396
397        // Reduce confidence if many rules were applied
398        confidence -= rewritten.applied_rules.len() as f32 * 0.05;
399
400        // Reduce confidence if k changed dramatically
401        let k_change_ratio =
402            (rewritten.optimized_k as f32 / rewritten.original_k as f32 - 1.0).abs();
403        confidence -= k_change_ratio * 0.2;
404
405        // Confidence from historical rule performance
406        for rule in &rewritten.applied_rules {
407            if let Some(stats) = self.rule_stats.get(rule) {
408                if stats.times_applied > 0 {
409                    let success_rate = stats.times_successful as f32 / stats.times_applied as f32;
410                    confidence *= success_rate;
411                }
412            }
413        }
414
415        confidence.clamp(0.0, 1.0)
416    }
417
418    /// Record rule application for learning
419    fn record_rule_application(&mut self, rule: RewriteRule) {
420        if !self.config.enable_learning {
421            return;
422        }
423
424        self.rule_stats.entry(rule).or_default().times_applied += 1;
425    }
426
427    /// Record rule success for learning
428    pub fn record_rule_success(&mut self, rule: RewriteRule, improvement: f64) {
429        if !self.config.enable_learning {
430            return;
431        }
432
433        let stats = self.rule_stats.entry(rule).or_default();
434
435        stats.times_successful += 1;
436        stats.avg_improvement = (stats.avg_improvement * (stats.times_successful - 1) as f64
437            + improvement)
438            / stats.times_successful as f64;
439    }
440
441    /// Generate cache key
442    fn cache_key(&self, query: &Vector, k: usize) -> String {
443        // Simple hash of query vector + k
444        let values = query.as_f32();
445        let hash: u64 = values
446            .iter()
447            .map(|v| (v * 1000.0) as i32)
448            .fold(0u64, |acc, v| acc.wrapping_mul(31).wrapping_add(v as u64));
449
450        format!("{:x}_{}", hash, k)
451    }
452
453    /// Clear query cache
454    pub fn clear_cache(&mut self) {
455        self.query_cache.clear();
456    }
457
458    /// Get rule statistics
459    pub fn rule_statistics(&self) -> &HashMap<RewriteRule, RuleStatistics> {
460        &self.rule_stats
461    }
462
463    /// Get cache size
464    pub fn cache_size(&self) -> usize {
465        self.query_cache.len()
466    }
467}
468
469impl Default for QueryRewriter {
470    fn default() -> Self {
471        Self::new(QueryRewriterConfig::default())
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478
479    #[test]
480    fn test_query_statistics() {
481        let vector = Vector::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
482        let stats = QueryVectorStatistics::from_vector(&vector);
483
484        assert_eq!(stats.dimensions, 5);
485        assert!(stats.norm > 0.0);
486        assert!(stats.std_dev > 0.0);
487    }
488
489    #[test]
490    fn test_query_rewriter_creation() {
491        let config = QueryRewriterConfig::default();
492        let _rewriter = QueryRewriter::new(config);
493    }
494
495    #[test]
496    fn test_query_rewrite() {
497        let config = QueryRewriterConfig {
498            min_confidence: 0.5, // Lower threshold for test
499            ..Default::default()
500        };
501        let mut rewriter = QueryRewriter::new(config);
502
503        let query = Vector::new(vec![1.0, 2.0, 3.0, 4.0]);
504        let result = rewriter.rewrite(&query, 10).unwrap();
505
506        assert_eq!(result.original_k, 10);
507        // Confidence should be positive (may be low if no rules applied)
508        assert!(result.confidence >= 0.0);
509    }
510
511    #[test]
512    fn test_normalize_query() {
513        let config = QueryRewriterConfig {
514            min_confidence: 0.5, // Lower threshold to allow normalization
515            ..Default::default()
516        };
517        let mut rewriter = QueryRewriter::new(config);
518
519        // Create a query with non-unit norm
520        let query = Vector::new(vec![3.0, 4.0]); // norm = 5.0
521        let result = rewriter.rewrite(&query, 10).unwrap();
522
523        // Check if normalized
524        let normalized_values = result.rewritten_vector.as_f32();
525        let norm: f32 = normalized_values.iter().map(|v| v * v).sum::<f32>().sqrt();
526
527        // If normalization was applied, norm should be ~1.0
528        if result.applied_rules.contains(&RewriteRule::NormalizeQuery) {
529            assert!(
530                (norm - 1.0).abs() < 0.01,
531                "Expected norm close to 1.0, got {}",
532                norm
533            );
534        } else {
535            // Otherwise it should be the original norm
536            assert!(
537                (norm - 5.0).abs() < 0.01,
538                "Expected original norm ~5.0, got {}",
539                norm
540            );
541        }
542    }
543
544    #[test]
545    fn test_k_tuning_sparse_query() {
546        let config = QueryRewriterConfig {
547            enable_parameter_tuning: true,
548            ..Default::default()
549        };
550        let mut rewriter = QueryRewriter::new(config);
551
552        // Create a sparse query (many zeros)
553        let query = Vector::new(vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]);
554        let result = rewriter.rewrite(&query, 10).unwrap();
555
556        // Should expand k for sparse queries
557        assert!(result.optimized_k >= result.original_k);
558    }
559
560    #[test]
561    fn test_caching() {
562        let config = QueryRewriterConfig {
563            enable_caching: true,
564            ..Default::default()
565        };
566        let mut rewriter = QueryRewriter::new(config);
567
568        let query = Vector::new(vec![1.0, 2.0, 3.0]);
569
570        // First call
571        let _result1 = rewriter.rewrite(&query, 10).unwrap();
572        assert_eq!(rewriter.cache_size(), 1);
573
574        // Second call (should hit cache)
575        let _result2 = rewriter.rewrite(&query, 10).unwrap();
576        assert_eq!(rewriter.cache_size(), 1);
577
578        // Different k (should miss cache)
579        let _result3 = rewriter.rewrite(&query, 20).unwrap();
580        assert_eq!(rewriter.cache_size(), 2);
581    }
582
583    #[test]
584    fn test_rule_learning() {
585        let config = QueryRewriterConfig {
586            enable_learning: true,
587            ..Default::default()
588        };
589        let mut rewriter = QueryRewriter::new(config);
590
591        let query = Vector::new(vec![1.0, 2.0, 3.0]);
592        rewriter.rewrite(&query, 10).unwrap();
593
594        // Record success
595        rewriter.record_rule_success(RewriteRule::NormalizeQuery, 0.15);
596
597        let stats = rewriter.rule_statistics();
598        assert!(stats.contains_key(&RewriteRule::NormalizeQuery));
599    }
600}