1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum QueryStrategy {
13 ExhaustiveScan,
15 HnswApproximate,
17 NsgApproximate,
19 IvfCoarse,
21 ProductQuantization,
23 ScalarQuantization,
25 LocalitySensitiveHashing,
27 GpuAccelerated,
29 Hybrid,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CostModel {
36 pub distance_computation_cost_us: f64,
38 pub index_lookup_cost_us: f64,
40 pub memory_access_cost_ns: f64,
42 pub gpu_available: bool,
44 pub gpu_cost_multiplier: f64,
45}
46
47impl Default for CostModel {
48 fn default() -> Self {
49 Self {
50 distance_computation_cost_us: 0.5,
51 index_lookup_cost_us: 0.1,
52 memory_access_cost_ns: 50.0,
53 gpu_available: false,
54 gpu_cost_multiplier: 0.1, }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct QueryCharacteristics {
62 pub k: usize,
64 pub dimensions: usize,
66 pub min_recall: f32,
68 pub max_latency_ms: f64,
70 pub query_type: VectorQueryType,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
76pub enum VectorQueryType {
77 Single,
79 Batch(usize),
81 Streaming,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct IndexStatistics {
88 pub vector_count: usize,
90 pub dimensions: usize,
92 pub available_indices: Vec<QueryStrategy>,
94 pub avg_latencies: HashMap<QueryStrategy, f64>,
96 pub avg_recalls: HashMap<QueryStrategy, f32>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct QueryPlan {
103 pub strategy: QueryStrategy,
105 pub estimated_cost_us: f64,
107 pub estimated_recall: f32,
109 pub confidence: f32,
111 pub alternatives: Vec<(QueryStrategy, f64, f32)>, pub parameters: HashMap<String, String>,
115}
116
117pub struct QueryPlanner {
119 cost_model: CostModel,
120 index_stats: IndexStatistics,
121}
122
123impl QueryPlanner {
124 pub fn new(cost_model: CostModel, index_stats: IndexStatistics) -> Self {
126 Self {
127 cost_model,
128 index_stats,
129 }
130 }
131
132 pub fn plan(&self, query: &QueryCharacteristics) -> Result<QueryPlan> {
134 let mut candidates = Vec::new();
135
136 for strategy in &self.index_stats.available_indices {
138 let (cost, recall) = self.estimate_strategy(*strategy, query);
139 candidates.push((*strategy, cost, recall));
140 }
141
142 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
144
145 let best = candidates
147 .iter()
148 .find(|(_, _, recall)| *recall >= query.min_recall)
149 .or_else(|| candidates.first())
150 .ok_or_else(|| anyhow::anyhow!("No suitable strategy found"))?;
151
152 let (strategy, cost, recall) = *best;
153
154 let parameters = self.generate_parameters(strategy, query);
156
157 let confidence = self.calculate_confidence(strategy);
159
160 Ok(QueryPlan {
161 strategy,
162 estimated_cost_us: cost,
163 estimated_recall: recall,
164 confidence,
165 alternatives: candidates
166 .iter()
167 .filter(|(s, _, _)| *s != strategy)
168 .take(3)
169 .copied()
170 .collect(),
171 parameters,
172 })
173 }
174
175 fn estimate_strategy(
177 &self,
178 strategy: QueryStrategy,
179 query: &QueryCharacteristics,
180 ) -> (f64, f32) {
181 let base_cost = match strategy {
182 QueryStrategy::ExhaustiveScan => {
183 self.index_stats.vector_count as f64 * self.cost_model.distance_computation_cost_us
185 }
186 QueryStrategy::HnswApproximate => {
187 let hnsw_complexity = (self.index_stats.vector_count as f64).ln() * 16.0;
189 hnsw_complexity * self.cost_model.distance_computation_cost_us
190 }
191 QueryStrategy::NsgApproximate => {
192 let nsg_complexity = (self.index_stats.vector_count as f64).ln() * 12.0;
195 nsg_complexity * self.cost_model.distance_computation_cost_us
196 }
197 QueryStrategy::IvfCoarse => {
198 let ivf_probes = (self.index_stats.vector_count as f64).sqrt();
200 ivf_probes * self.cost_model.distance_computation_cost_us
201 }
202 QueryStrategy::ProductQuantization => {
203 let pq_cost = self.index_stats.vector_count as f64 * 0.1;
205 pq_cost * self.cost_model.distance_computation_cost_us
206 }
207 QueryStrategy::ScalarQuantization => {
208 let sq_cost = self.index_stats.vector_count as f64 * 0.08;
210 sq_cost * self.cost_model.distance_computation_cost_us
211 }
212 QueryStrategy::LocalitySensitiveHashing => {
213 let lsh_cost = 10.0 * 100.0; lsh_cost * self.cost_model.distance_computation_cost_us
216 }
217 QueryStrategy::GpuAccelerated => {
218 if self.cost_model.gpu_available {
219 let cpu_cost = self.index_stats.vector_count as f64
220 * self.cost_model.distance_computation_cost_us;
221 cpu_cost * self.cost_model.gpu_cost_multiplier
222 } else {
223 f64::INFINITY }
225 }
226 QueryStrategy::Hybrid => {
227 let hnsw_cost = (self.index_stats.vector_count as f64).ln() * 16.0;
229 let refinement_cost = query.k as f64 * 10.0;
230 (hnsw_cost + refinement_cost) * self.cost_model.distance_computation_cost_us
231 }
232 };
233
234 let cost = match query.query_type {
236 VectorQueryType::Single => base_cost,
237 VectorQueryType::Batch(n) => base_cost * n as f64 * 0.8, VectorQueryType::Streaming => base_cost * 1.2, };
240
241 let recall = self
243 .index_stats
244 .avg_recalls
245 .get(&strategy)
246 .copied()
247 .unwrap_or_else(|| self.estimate_recall(strategy));
248
249 (cost, recall)
250 }
251
252 fn estimate_recall(&self, strategy: QueryStrategy) -> f32 {
254 match strategy {
255 QueryStrategy::ExhaustiveScan => 1.0,
256 QueryStrategy::HnswApproximate => 0.95,
257 QueryStrategy::NsgApproximate => 0.96, QueryStrategy::IvfCoarse => 0.85,
259 QueryStrategy::ProductQuantization => 0.90,
260 QueryStrategy::ScalarQuantization => 0.92,
261 QueryStrategy::LocalitySensitiveHashing => 0.80,
262 QueryStrategy::GpuAccelerated => 0.95,
263 QueryStrategy::Hybrid => 0.98,
264 }
265 }
266
267 fn generate_parameters(
269 &self,
270 strategy: QueryStrategy,
271 query: &QueryCharacteristics,
272 ) -> HashMap<String, String> {
273 let mut params = HashMap::new();
274
275 match strategy {
276 QueryStrategy::HnswApproximate => {
277 let ef_search = if query.min_recall >= 0.95 {
279 (query.k * 4).max(64)
280 } else {
281 (query.k * 2).max(32)
282 };
283 params.insert("ef_search".to_string(), ef_search.to_string());
284 }
285 QueryStrategy::NsgApproximate => {
286 let search_length = if query.min_recall >= 0.95 {
288 (query.k * 5).max(50)
289 } else {
290 (query.k * 3).max(30)
291 };
292 params.insert("search_length".to_string(), search_length.to_string());
293 params.insert("out_degree".to_string(), "32".to_string());
294 }
295 QueryStrategy::IvfCoarse => {
296 let nprobe = if query.min_recall >= 0.90 { 16 } else { 8 };
297 params.insert("nprobe".to_string(), nprobe.to_string());
298 }
299 QueryStrategy::LocalitySensitiveHashing => {
300 params.insert("num_probes".to_string(), "3".to_string());
301 }
302 _ => {}
303 }
304
305 params
306 }
307
308 fn calculate_confidence(&self, strategy: QueryStrategy) -> f32 {
310 if self.index_stats.avg_latencies.contains_key(&strategy) {
312 0.9
313 } else {
314 0.5 }
316 }
317
318 pub fn update_statistics(&mut self, strategy: QueryStrategy, latency_ms: f64, recall: f32) {
320 self.index_stats.avg_latencies.insert(strategy, latency_ms);
321 self.index_stats.avg_recalls.insert(strategy, recall);
322 }
323
324 pub fn update_index_metadata(&mut self, vector_count: usize, dimensions: usize) {
326 self.index_stats.vector_count = vector_count;
327 self.index_stats.dimensions = dimensions;
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 fn create_test_stats() -> IndexStatistics {
336 IndexStatistics {
337 vector_count: 100_000,
338 dimensions: 128,
339 available_indices: vec![
340 QueryStrategy::ExhaustiveScan,
341 QueryStrategy::HnswApproximate,
342 QueryStrategy::IvfCoarse,
343 ],
344 avg_latencies: HashMap::new(),
345 avg_recalls: HashMap::new(),
346 }
347 }
348
349 #[test]
350 fn test_query_planner_creation() {
351 let cost_model = CostModel::default();
352 let stats = create_test_stats();
353 let _planner = QueryPlanner::new(cost_model, stats);
354 }
355
356 #[test]
357 fn test_query_planning() {
358 let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
359
360 let query = QueryCharacteristics {
361 k: 10,
362 dimensions: 128,
363 min_recall: 0.90,
364 max_latency_ms: 100.0,
365 query_type: VectorQueryType::Single,
366 };
367
368 let plan = planner.plan(&query);
369 assert!(plan.is_ok());
370
371 let plan = plan.unwrap();
372 assert!(plan.estimated_recall >= query.min_recall);
373 assert!(!plan.alternatives.is_empty());
374 }
375
376 #[test]
377 fn test_exhaustive_vs_approximate() {
378 let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
379
380 let query = QueryCharacteristics {
382 k: 10,
383 dimensions: 128,
384 min_recall: 0.95,
385 max_latency_ms: 10.0,
386 query_type: VectorQueryType::Single,
387 };
388
389 let plan = planner.plan(&query).unwrap();
390 assert_ne!(plan.strategy, QueryStrategy::ExhaustiveScan);
392 }
393
394 #[test]
395 fn test_batch_query_planning() {
396 let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
397
398 let query = QueryCharacteristics {
399 k: 10,
400 dimensions: 128,
401 min_recall: 0.90,
402 max_latency_ms: 100.0,
403 query_type: VectorQueryType::Batch(100),
404 };
405
406 let plan = planner.plan(&query).unwrap();
407 assert!(plan.estimated_cost_us > 0.0);
408 }
409
410 #[test]
411 fn test_statistics_update() {
412 let mut planner = QueryPlanner::new(CostModel::default(), create_test_stats());
413
414 planner.update_statistics(QueryStrategy::HnswApproximate, 5.0, 0.96);
415
416 assert_eq!(
417 planner
418 .index_stats
419 .avg_latencies
420 .get(&QueryStrategy::HnswApproximate),
421 Some(&5.0)
422 );
423 assert_eq!(
424 planner
425 .index_stats
426 .avg_recalls
427 .get(&QueryStrategy::HnswApproximate),
428 Some(&0.96)
429 );
430 }
431}