oxirs_vec/
query_planning.rs

1//! Query planning and cost estimation for vector search operations
2//!
3//! This module provides intelligent query planning to select the optimal
4//! search strategy based on query characteristics and index statistics.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Query execution strategy
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum QueryStrategy {
13    /// Exhaustive linear scan (most accurate, slowest)
14    ExhaustiveScan,
15    /// HNSW approximate search
16    HnswApproximate,
17    /// NSG (Navigable Small World Graph) approximate search
18    NsgApproximate,
19    /// IVF with coarse quantization
20    IvfCoarse,
21    /// Product quantization with refinement
22    ProductQuantization,
23    /// Scalar quantization
24    ScalarQuantization,
25    /// LSH approximate search
26    LocalitySensitiveHashing,
27    /// GPU-accelerated search
28    GpuAccelerated,
29    /// Hybrid strategy (multiple indices)
30    Hybrid,
31}
32
33/// Cost model for query execution
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CostModel {
36    /// Cost per distance computation (CPU microseconds)
37    pub distance_computation_cost_us: f64,
38    /// Cost per index lookup (CPU microseconds)
39    pub index_lookup_cost_us: f64,
40    /// Cost per memory access (nanoseconds)
41    pub memory_access_cost_ns: f64,
42    /// GPU availability and cost multiplier
43    pub gpu_available: bool,
44    pub gpu_cost_multiplier: f64,
45}
46
47impl Default for CostModel {
48    fn default() -> Self {
49        Self {
50            distance_computation_cost_us: 0.5,
51            index_lookup_cost_us: 0.1,
52            memory_access_cost_ns: 50.0,
53            gpu_available: false,
54            gpu_cost_multiplier: 0.1, // GPU is 10x faster
55        }
56    }
57}
58
59/// Query characteristics
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct QueryCharacteristics {
62    /// Number of results requested (k)
63    pub k: usize,
64    /// Vector dimensionality
65    pub dimensions: usize,
66    /// Minimum acceptable recall
67    pub min_recall: f32,
68    /// Maximum acceptable latency
69    pub max_latency_ms: f64,
70    /// Query type
71    pub query_type: VectorQueryType,
72}
73
74/// Type of vector query
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
76pub enum VectorQueryType {
77    /// Single vector query
78    Single,
79    /// Batch of queries
80    Batch(usize),
81    /// Streaming queries
82    Streaming,
83}
84
85/// Index statistics for planning
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct IndexStatistics {
88    /// Total number of vectors
89    pub vector_count: usize,
90    /// Vector dimensionality
91    pub dimensions: usize,
92    /// Available index types
93    pub available_indices: Vec<QueryStrategy>,
94    /// Average query latencies by strategy (milliseconds)
95    pub avg_latencies: HashMap<QueryStrategy, f64>,
96    /// Average recalls by strategy
97    pub avg_recalls: HashMap<QueryStrategy, f32>,
98}
99
100/// Query execution plan
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct QueryPlan {
103    /// Selected strategy
104    pub strategy: QueryStrategy,
105    /// Estimated cost (microseconds)
106    pub estimated_cost_us: f64,
107    /// Estimated recall
108    pub estimated_recall: f32,
109    /// Confidence in plan (0.0 to 1.0)
110    pub confidence: f32,
111    /// Alternative strategies considered
112    pub alternatives: Vec<(QueryStrategy, f64, f32)>, // (strategy, cost, recall)
113    /// Recommended parameters
114    pub parameters: HashMap<String, String>,
115}
116
117/// Query planner
118pub struct QueryPlanner {
119    cost_model: CostModel,
120    index_stats: IndexStatistics,
121}
122
123impl QueryPlanner {
124    /// Create a new query planner
125    pub fn new(cost_model: CostModel, index_stats: IndexStatistics) -> Self {
126        Self {
127            cost_model,
128            index_stats,
129        }
130    }
131
132    /// Plan optimal query execution strategy
133    pub fn plan(&self, query: &QueryCharacteristics) -> Result<QueryPlan> {
134        let mut candidates = Vec::new();
135
136        // Evaluate each available strategy
137        for strategy in &self.index_stats.available_indices {
138            let (cost, recall) = self.estimate_strategy(*strategy, query);
139            candidates.push((*strategy, cost, recall));
140        }
141
142        // Sort by cost
143        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
144
145        // Find best strategy that meets recall requirement
146        let best = candidates
147            .iter()
148            .find(|(_, _, recall)| *recall >= query.min_recall)
149            .or_else(|| candidates.first())
150            .ok_or_else(|| anyhow::anyhow!("No suitable strategy found"))?;
151
152        let (strategy, cost, recall) = *best;
153
154        // Generate parameters for selected strategy
155        let parameters = self.generate_parameters(strategy, query);
156
157        // Calculate confidence based on historical data
158        let confidence = self.calculate_confidence(strategy);
159
160        Ok(QueryPlan {
161            strategy,
162            estimated_cost_us: cost,
163            estimated_recall: recall,
164            confidence,
165            alternatives: candidates
166                .iter()
167                .filter(|(s, _, _)| *s != strategy)
168                .take(3)
169                .copied()
170                .collect(),
171            parameters,
172        })
173    }
174
175    /// Estimate cost and recall for a strategy
176    fn estimate_strategy(
177        &self,
178        strategy: QueryStrategy,
179        query: &QueryCharacteristics,
180    ) -> (f64, f32) {
181        let base_cost = match strategy {
182            QueryStrategy::ExhaustiveScan => {
183                // Cost = number of vectors * distance computation cost
184                self.index_stats.vector_count as f64 * self.cost_model.distance_computation_cost_us
185            }
186            QueryStrategy::HnswApproximate => {
187                // Cost ≈ log(N) * M * distance computation
188                let hnsw_complexity = (self.index_stats.vector_count as f64).ln() * 16.0;
189                hnsw_complexity * self.cost_model.distance_computation_cost_us
190            }
191            QueryStrategy::NsgApproximate => {
192                // NSG is typically more efficient than HNSW due to monotonic search
193                // Cost ≈ log(N) * out_degree (typically 32) * distance computation
194                let nsg_complexity = (self.index_stats.vector_count as f64).ln() * 12.0;
195                nsg_complexity * self.cost_model.distance_computation_cost_us
196            }
197            QueryStrategy::IvfCoarse => {
198                // Cost ≈ sqrt(N) * distance computation
199                let ivf_probes = (self.index_stats.vector_count as f64).sqrt();
200                ivf_probes * self.cost_model.distance_computation_cost_us
201            }
202            QueryStrategy::ProductQuantization => {
203                // Lower cost due to compressed distance computations
204                let pq_cost = self.index_stats.vector_count as f64 * 0.1;
205                pq_cost * self.cost_model.distance_computation_cost_us
206            }
207            QueryStrategy::ScalarQuantization => {
208                // Similar to PQ but slightly faster
209                let sq_cost = self.index_stats.vector_count as f64 * 0.08;
210                sq_cost * self.cost_model.distance_computation_cost_us
211            }
212            QueryStrategy::LocalitySensitiveHashing => {
213                // Cost ≈ number of hash tables * bucket size
214                let lsh_cost = 10.0 * 100.0; // Example: 10 tables, 100 vectors per bucket
215                lsh_cost * self.cost_model.distance_computation_cost_us
216            }
217            QueryStrategy::GpuAccelerated => {
218                if self.cost_model.gpu_available {
219                    let cpu_cost = self.index_stats.vector_count as f64
220                        * self.cost_model.distance_computation_cost_us;
221                    cpu_cost * self.cost_model.gpu_cost_multiplier
222                } else {
223                    f64::INFINITY // Not available
224                }
225            }
226            QueryStrategy::Hybrid => {
227                // Combine HNSW + refinement
228                let hnsw_cost = (self.index_stats.vector_count as f64).ln() * 16.0;
229                let refinement_cost = query.k as f64 * 10.0;
230                (hnsw_cost + refinement_cost) * self.cost_model.distance_computation_cost_us
231            }
232        };
233
234        // Adjust for batch queries
235        let cost = match query.query_type {
236            VectorQueryType::Single => base_cost,
237            VectorQueryType::Batch(n) => base_cost * n as f64 * 0.8, // 20% batch efficiency
238            VectorQueryType::Streaming => base_cost * 1.2,           // 20% overhead for streaming
239        };
240
241        // Get historical recall or estimate
242        let recall = self
243            .index_stats
244            .avg_recalls
245            .get(&strategy)
246            .copied()
247            .unwrap_or_else(|| self.estimate_recall(strategy));
248
249        (cost, recall)
250    }
251
252    /// Estimate recall for a strategy
253    fn estimate_recall(&self, strategy: QueryStrategy) -> f32 {
254        match strategy {
255            QueryStrategy::ExhaustiveScan => 1.0,
256            QueryStrategy::HnswApproximate => 0.95,
257            QueryStrategy::NsgApproximate => 0.96, // NSG typically has slightly better recall than HNSW
258            QueryStrategy::IvfCoarse => 0.85,
259            QueryStrategy::ProductQuantization => 0.90,
260            QueryStrategy::ScalarQuantization => 0.92,
261            QueryStrategy::LocalitySensitiveHashing => 0.80,
262            QueryStrategy::GpuAccelerated => 0.95,
263            QueryStrategy::Hybrid => 0.98,
264        }
265    }
266
267    /// Generate recommended parameters for strategy
268    fn generate_parameters(
269        &self,
270        strategy: QueryStrategy,
271        query: &QueryCharacteristics,
272    ) -> HashMap<String, String> {
273        let mut params = HashMap::new();
274
275        match strategy {
276            QueryStrategy::HnswApproximate => {
277                // Adaptive ef_search based on k and recall requirement
278                let ef_search = if query.min_recall >= 0.95 {
279                    (query.k * 4).max(64)
280                } else {
281                    (query.k * 2).max(32)
282                };
283                params.insert("ef_search".to_string(), ef_search.to_string());
284            }
285            QueryStrategy::NsgApproximate => {
286                // NSG search length based on k and recall requirement
287                let search_length = if query.min_recall >= 0.95 {
288                    (query.k * 5).max(50)
289                } else {
290                    (query.k * 3).max(30)
291                };
292                params.insert("search_length".to_string(), search_length.to_string());
293                params.insert("out_degree".to_string(), "32".to_string());
294            }
295            QueryStrategy::IvfCoarse => {
296                let nprobe = if query.min_recall >= 0.90 { 16 } else { 8 };
297                params.insert("nprobe".to_string(), nprobe.to_string());
298            }
299            QueryStrategy::LocalitySensitiveHashing => {
300                params.insert("num_probes".to_string(), "3".to_string());
301            }
302            _ => {}
303        }
304
305        params
306    }
307
308    /// Calculate confidence in plan based on historical data
309    fn calculate_confidence(&self, strategy: QueryStrategy) -> f32 {
310        // Higher confidence if we have historical data
311        if self.index_stats.avg_latencies.contains_key(&strategy) {
312            0.9
313        } else {
314            0.5 // Lower confidence for estimated values
315        }
316    }
317
318    /// Update index statistics with observed performance
319    pub fn update_statistics(&mut self, strategy: QueryStrategy, latency_ms: f64, recall: f32) {
320        self.index_stats.avg_latencies.insert(strategy, latency_ms);
321        self.index_stats.avg_recalls.insert(strategy, recall);
322    }
323
324    /// Update index metadata (vector count, dimensions)
325    pub fn update_index_metadata(&mut self, vector_count: usize, dimensions: usize) {
326        self.index_stats.vector_count = vector_count;
327        self.index_stats.dimensions = dimensions;
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    fn create_test_stats() -> IndexStatistics {
336        IndexStatistics {
337            vector_count: 100_000,
338            dimensions: 128,
339            available_indices: vec![
340                QueryStrategy::ExhaustiveScan,
341                QueryStrategy::HnswApproximate,
342                QueryStrategy::IvfCoarse,
343            ],
344            avg_latencies: HashMap::new(),
345            avg_recalls: HashMap::new(),
346        }
347    }
348
349    #[test]
350    fn test_query_planner_creation() {
351        let cost_model = CostModel::default();
352        let stats = create_test_stats();
353        let _planner = QueryPlanner::new(cost_model, stats);
354    }
355
356    #[test]
357    fn test_query_planning() {
358        let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
359
360        let query = QueryCharacteristics {
361            k: 10,
362            dimensions: 128,
363            min_recall: 0.90,
364            max_latency_ms: 100.0,
365            query_type: VectorQueryType::Single,
366        };
367
368        let plan = planner.plan(&query);
369        assert!(plan.is_ok());
370
371        let plan = plan.unwrap();
372        assert!(plan.estimated_recall >= query.min_recall);
373        assert!(!plan.alternatives.is_empty());
374    }
375
376    #[test]
377    fn test_exhaustive_vs_approximate() {
378        let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
379
380        // High recall requirement should avoid exhaustive if approximate is available
381        let query = QueryCharacteristics {
382            k: 10,
383            dimensions: 128,
384            min_recall: 0.95,
385            max_latency_ms: 10.0,
386            query_type: VectorQueryType::Single,
387        };
388
389        let plan = planner.plan(&query).unwrap();
390        // Should prefer HNSW over exhaustive for speed
391        assert_ne!(plan.strategy, QueryStrategy::ExhaustiveScan);
392    }
393
394    #[test]
395    fn test_batch_query_planning() {
396        let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
397
398        let query = QueryCharacteristics {
399            k: 10,
400            dimensions: 128,
401            min_recall: 0.90,
402            max_latency_ms: 100.0,
403            query_type: VectorQueryType::Batch(100),
404        };
405
406        let plan = planner.plan(&query).unwrap();
407        assert!(plan.estimated_cost_us > 0.0);
408    }
409
410    #[test]
411    fn test_statistics_update() {
412        let mut planner = QueryPlanner::new(CostModel::default(), create_test_stats());
413
414        planner.update_statistics(QueryStrategy::HnswApproximate, 5.0, 0.96);
415
416        assert_eq!(
417            planner
418                .index_stats
419                .avg_latencies
420                .get(&QueryStrategy::HnswApproximate),
421            Some(&5.0)
422        );
423        assert_eq!(
424            planner
425                .index_stats
426                .avg_recalls
427                .get(&QueryStrategy::HnswApproximate),
428            Some(&0.96)
429        );
430    }
431}