use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum QueryStrategy {
ExhaustiveScan,
HnswApproximate,
NsgApproximate,
IvfCoarse,
ProductQuantization,
ScalarQuantization,
LocalitySensitiveHashing,
GpuAccelerated,
Hybrid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostModel {
pub distance_computation_cost_us: f64,
pub index_lookup_cost_us: f64,
pub memory_access_cost_ns: f64,
pub gpu_available: bool,
pub gpu_cost_multiplier: f64,
}
impl Default for CostModel {
fn default() -> Self {
Self {
distance_computation_cost_us: 0.5,
index_lookup_cost_us: 0.1,
memory_access_cost_ns: 50.0,
gpu_available: false,
gpu_cost_multiplier: 0.1, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryCharacteristics {
pub k: usize,
pub dimensions: usize,
pub min_recall: f32,
pub max_latency_ms: f64,
pub query_type: VectorQueryType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum VectorQueryType {
Single,
Batch(usize),
Streaming,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStatistics {
pub vector_count: usize,
pub dimensions: usize,
pub available_indices: Vec<QueryStrategy>,
pub avg_latencies: HashMap<QueryStrategy, f64>,
pub avg_recalls: HashMap<QueryStrategy, f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryPlan {
pub strategy: QueryStrategy,
pub estimated_cost_us: f64,
pub estimated_recall: f32,
pub confidence: f32,
pub alternatives: Vec<(QueryStrategy, f64, f32)>, pub parameters: HashMap<String, String>,
}
pub struct QueryPlanner {
cost_model: CostModel,
index_stats: IndexStatistics,
}
impl QueryPlanner {
pub fn new(cost_model: CostModel, index_stats: IndexStatistics) -> Self {
Self {
cost_model,
index_stats,
}
}
pub fn plan(&self, query: &QueryCharacteristics) -> Result<QueryPlan> {
let mut candidates = Vec::new();
for strategy in &self.index_stats.available_indices {
let (cost, recall) = self.estimate_strategy(*strategy, query);
candidates.push((*strategy, cost, recall));
}
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let best = candidates
.iter()
.find(|(_, _, recall)| *recall >= query.min_recall)
.or_else(|| candidates.first())
.ok_or_else(|| anyhow::anyhow!("No suitable strategy found"))?;
let (strategy, cost, recall) = *best;
let parameters = self.generate_parameters(strategy, query);
let confidence = self.calculate_confidence(strategy);
Ok(QueryPlan {
strategy,
estimated_cost_us: cost,
estimated_recall: recall,
confidence,
alternatives: candidates
.iter()
.filter(|(s, _, _)| *s != strategy)
.take(3)
.copied()
.collect(),
parameters,
})
}
fn estimate_strategy(
&self,
strategy: QueryStrategy,
query: &QueryCharacteristics,
) -> (f64, f32) {
let base_cost = match strategy {
QueryStrategy::ExhaustiveScan => {
self.index_stats.vector_count as f64 * self.cost_model.distance_computation_cost_us
}
QueryStrategy::HnswApproximate => {
let hnsw_complexity = (self.index_stats.vector_count as f64).ln() * 16.0;
hnsw_complexity * self.cost_model.distance_computation_cost_us
}
QueryStrategy::NsgApproximate => {
let nsg_complexity = (self.index_stats.vector_count as f64).ln() * 12.0;
nsg_complexity * self.cost_model.distance_computation_cost_us
}
QueryStrategy::IvfCoarse => {
let ivf_probes = (self.index_stats.vector_count as f64).sqrt();
ivf_probes * self.cost_model.distance_computation_cost_us
}
QueryStrategy::ProductQuantization => {
let pq_cost = self.index_stats.vector_count as f64 * 0.1;
pq_cost * self.cost_model.distance_computation_cost_us
}
QueryStrategy::ScalarQuantization => {
let sq_cost = self.index_stats.vector_count as f64 * 0.08;
sq_cost * self.cost_model.distance_computation_cost_us
}
QueryStrategy::LocalitySensitiveHashing => {
let lsh_cost = 10.0 * 100.0; lsh_cost * self.cost_model.distance_computation_cost_us
}
QueryStrategy::GpuAccelerated => {
if self.cost_model.gpu_available {
let cpu_cost = self.index_stats.vector_count as f64
* self.cost_model.distance_computation_cost_us;
cpu_cost * self.cost_model.gpu_cost_multiplier
} else {
f64::INFINITY }
}
QueryStrategy::Hybrid => {
let hnsw_cost = (self.index_stats.vector_count as f64).ln() * 16.0;
let refinement_cost = query.k as f64 * 10.0;
(hnsw_cost + refinement_cost) * self.cost_model.distance_computation_cost_us
}
};
let cost = match query.query_type {
VectorQueryType::Single => base_cost,
VectorQueryType::Batch(n) => base_cost * n as f64 * 0.8, VectorQueryType::Streaming => base_cost * 1.2, };
let recall = self
.index_stats
.avg_recalls
.get(&strategy)
.copied()
.unwrap_or_else(|| self.estimate_recall(strategy));
(cost, recall)
}
fn estimate_recall(&self, strategy: QueryStrategy) -> f32 {
match strategy {
QueryStrategy::ExhaustiveScan => 1.0,
QueryStrategy::HnswApproximate => 0.95,
QueryStrategy::NsgApproximate => 0.96, QueryStrategy::IvfCoarse => 0.85,
QueryStrategy::ProductQuantization => 0.90,
QueryStrategy::ScalarQuantization => 0.92,
QueryStrategy::LocalitySensitiveHashing => 0.80,
QueryStrategy::GpuAccelerated => 0.95,
QueryStrategy::Hybrid => 0.98,
}
}
fn generate_parameters(
&self,
strategy: QueryStrategy,
query: &QueryCharacteristics,
) -> HashMap<String, String> {
let mut params = HashMap::new();
match strategy {
QueryStrategy::HnswApproximate => {
let ef_search = if query.min_recall >= 0.95 {
(query.k * 4).max(64)
} else {
(query.k * 2).max(32)
};
params.insert("ef_search".to_string(), ef_search.to_string());
}
QueryStrategy::NsgApproximate => {
let search_length = if query.min_recall >= 0.95 {
(query.k * 5).max(50)
} else {
(query.k * 3).max(30)
};
params.insert("search_length".to_string(), search_length.to_string());
params.insert("out_degree".to_string(), "32".to_string());
}
QueryStrategy::IvfCoarse => {
let nprobe = if query.min_recall >= 0.90 { 16 } else { 8 };
params.insert("nprobe".to_string(), nprobe.to_string());
}
QueryStrategy::LocalitySensitiveHashing => {
params.insert("num_probes".to_string(), "3".to_string());
}
_ => {}
}
params
}
fn calculate_confidence(&self, strategy: QueryStrategy) -> f32 {
if self.index_stats.avg_latencies.contains_key(&strategy) {
0.9
} else {
0.5 }
}
pub fn update_statistics(&mut self, strategy: QueryStrategy, latency_ms: f64, recall: f32) {
self.index_stats.avg_latencies.insert(strategy, latency_ms);
self.index_stats.avg_recalls.insert(strategy, recall);
}
pub fn update_index_metadata(&mut self, vector_count: usize, dimensions: usize) {
self.index_stats.vector_count = vector_count;
self.index_stats.dimensions = dimensions;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_stats() -> IndexStatistics {
IndexStatistics {
vector_count: 100_000,
dimensions: 128,
available_indices: vec![
QueryStrategy::ExhaustiveScan,
QueryStrategy::HnswApproximate,
QueryStrategy::IvfCoarse,
],
avg_latencies: HashMap::new(),
avg_recalls: HashMap::new(),
}
}
#[test]
fn test_query_planner_creation() {
let cost_model = CostModel::default();
let stats = create_test_stats();
let _planner = QueryPlanner::new(cost_model, stats);
}
#[test]
fn test_query_planning() -> Result<()> {
let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
let query = QueryCharacteristics {
k: 10,
dimensions: 128,
min_recall: 0.90,
max_latency_ms: 100.0,
query_type: VectorQueryType::Single,
};
let plan = planner.plan(&query);
assert!(plan.is_ok());
let plan = plan?;
assert!(plan.estimated_recall >= query.min_recall);
assert!(!plan.alternatives.is_empty());
Ok(())
}
#[test]
fn test_exhaustive_vs_approximate() -> Result<()> {
let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
let query = QueryCharacteristics {
k: 10,
dimensions: 128,
min_recall: 0.95,
max_latency_ms: 10.0,
query_type: VectorQueryType::Single,
};
let plan = planner.plan(&query)?;
assert_ne!(plan.strategy, QueryStrategy::ExhaustiveScan);
Ok(())
}
#[test]
fn test_batch_query_planning() -> Result<()> {
let planner = QueryPlanner::new(CostModel::default(), create_test_stats());
let query = QueryCharacteristics {
k: 10,
dimensions: 128,
min_recall: 0.90,
max_latency_ms: 100.0,
query_type: VectorQueryType::Batch(100),
};
let plan = planner.plan(&query)?;
assert!(plan.estimated_cost_us > 0.0);
Ok(())
}
#[test]
fn test_statistics_update() {
let mut planner = QueryPlanner::new(CostModel::default(), create_test_stats());
planner.update_statistics(QueryStrategy::HnswApproximate, 5.0, 0.96);
assert_eq!(
planner
.index_stats
.avg_latencies
.get(&QueryStrategy::HnswApproximate),
Some(&5.0)
);
assert_eq!(
planner
.index_stats
.avg_recalls
.get(&QueryStrategy::HnswApproximate),
Some(&0.96)
);
}
}