use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use crate::store::{Query, VecStore};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationHint {
pub category: HintCategory,
pub suggestion: String,
pub impact: Impact,
pub estimated_improvement: f32, }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HintCategory {
Index,
QueryParam,
Filter,
Dimension,
Batching,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Impact {
High, Medium, Low, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostBreakdown {
pub similarity_cost: f32,
pub filter_cost: f32,
pub index_cost: f32,
pub sorting_cost: f32,
pub total_cost: f32,
}
impl CostBreakdown {
fn total(&self) -> f32 {
self.similarity_cost + self.filter_cost + self.index_cost + self.sorting_cost
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionPlan {
pub steps: Vec<PlanStep>,
pub estimated_rows: Vec<usize>,
pub uses_index: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanStep {
pub name: String,
pub description: String,
pub cost: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryAnalysis {
pub estimated_cost: f32,
pub cost_breakdown: CostBreakdown,
pub hints: Vec<OptimizationHint>,
pub execution_plan: ExecutionPlan,
pub complexity: QueryComplexity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QueryComplexity {
Simple, Moderate, Complex, }
pub struct QueryOptimizer<'a> {
store: &'a VecStore,
}
impl<'a> QueryOptimizer<'a> {
pub fn new(store: &'a VecStore) -> Self {
Self { store }
}
pub fn analyze_query(&self, query: &Query) -> Result<QueryAnalysis> {
let store_size = self.store.len();
let vector_dim = if store_size > 0 {
128 } else {
128
};
let cost_breakdown = self.estimate_costs(query, store_size, vector_dim);
let total_cost = cost_breakdown.total();
let execution_plan = self.generate_execution_plan(query, store_size);
let hints = self.generate_hints(query, store_size, vector_dim, &cost_breakdown);
let complexity = if total_cost < 10.0 {
QueryComplexity::Simple
} else if total_cost < 100.0 {
QueryComplexity::Moderate
} else {
QueryComplexity::Complex
};
Ok(QueryAnalysis {
estimated_cost: total_cost,
cost_breakdown,
hints,
execution_plan,
complexity,
})
}
fn estimate_costs(&self, query: &Query, store_size: usize, vector_dim: usize) -> CostBreakdown {
let base_comparison_cost = 0.001 * vector_dim as f32;
let vectors_to_compare = if query.filter.is_some() {
store_size / 2
} else {
store_size
};
let similarity_cost = vectors_to_compare as f32 * base_comparison_cost;
let filter_cost = if query.filter.is_some() {
store_size as f32 * 0.0005 } else {
0.0
};
let index_cost = if query.filter.is_some() {
(store_size as f32).log2() * 0.001
} else {
0.0
};
let k = query.k;
let sorting_cost = if vectors_to_compare > k {
(vectors_to_compare as f32 * k as f32).log2() * 0.002
} else {
0.0
};
CostBreakdown {
similarity_cost,
filter_cost,
index_cost,
sorting_cost,
total_cost: similarity_cost + filter_cost + index_cost + sorting_cost,
}
}
fn generate_hints(
&self,
query: &Query,
store_size: usize,
vector_dim: usize,
costs: &CostBreakdown,
) -> Vec<OptimizationHint> {
let mut hints = Vec::new();
if query.k > 100 {
hints.push(OptimizationHint {
category: HintCategory::QueryParam,
suggestion: format!(
"Consider reducing k from {} to 100 or less. Large K values increase memory and sorting overhead.",
query.k
),
impact: Impact::Medium,
estimated_improvement: 20.0,
});
}
if query.filter.is_some() && store_size > 1000 {
hints.push(OptimizationHint {
category: HintCategory::Index,
suggestion: "Add metadata index for filtered fields to speed up filtering. Use MetadataIndexManager to create indexes.".to_string(),
impact: Impact::High,
estimated_improvement: 70.0,
});
}
if vector_dim > 512 {
hints.push(OptimizationHint {
category: HintCategory::Dimension,
suggestion: format!(
"Consider dimensionality reduction from {} to 128-256 dimensions using PCA. This can speed up similarity computation by 2-4x.",
vector_dim
),
impact: Impact::High,
estimated_improvement: 60.0,
});
}
if costs.similarity_cost > costs.total_cost * 0.8 && store_size > 10000 {
hints.push(OptimizationHint {
category: HintCategory::Index,
suggestion: "Similarity computation dominates cost. Consider using IVF-PQ or LSH indexing for approximate search on large datasets.".to_string(),
impact: Impact::High,
estimated_improvement: 90.0,
});
}
if store_size > 5000 {
hints.push(OptimizationHint {
category: HintCategory::Batching,
suggestion: "For multiple queries, use batch operations to amortize index lookup costs across queries.".to_string(),
impact: Impact::Medium,
estimated_improvement: 30.0,
});
}
let vector = &query.vector;
let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if (magnitude - 1.0).abs() > 0.1 {
hints.push(OptimizationHint {
category: HintCategory::QueryParam,
suggestion: format!(
"Query vector is not normalized (magnitude: {:.3}). Normalize vectors for cosine similarity to improve accuracy.",
magnitude
),
impact: Impact::Low,
estimated_improvement: 5.0,
});
}
hints
}
fn generate_execution_plan(&self, query: &Query, store_size: usize) -> ExecutionPlan {
let mut steps = Vec::new();
let mut estimated_rows = vec![store_size];
let mut uses_index = false;
if query.filter.is_some() {
steps.push(PlanStep {
name: "Filter".to_string(),
description: "Apply metadata filter to reduce candidate set".to_string(),
cost: 0.5,
});
let filtered_rows = store_size / 2; estimated_rows.push(filtered_rows);
uses_index = true;
}
let candidates = *estimated_rows.last().unwrap();
steps.push(PlanStep {
name: "Similarity".to_string(),
description: format!("Compute similarity for {} vectors", candidates),
cost: candidates as f32 * 0.001,
});
let k = query.k;
steps.push(PlanStep {
name: "Top-K".to_string(),
description: format!("Select top {} results", k),
cost: 0.1,
});
estimated_rows.push(k);
ExecutionPlan {
steps,
estimated_rows,
uses_index,
}
}
pub fn compare_queries(&self, query1: &Query, query2: &Query) -> Result<QueryComparison> {
let analysis1 = self.analyze_query(query1)?;
let analysis2 = self.analyze_query(query2)?;
let faster_query = if analysis1.estimated_cost < analysis2.estimated_cost {
1
} else {
2
};
let cost_difference = (analysis1.estimated_cost - analysis2.estimated_cost).abs();
let relative_difference =
cost_difference / analysis1.estimated_cost.min(analysis2.estimated_cost);
Ok(QueryComparison {
query1_cost: analysis1.estimated_cost,
query2_cost: analysis2.estimated_cost,
faster_query,
cost_difference,
relative_difference,
recommendation: if relative_difference > 0.3 {
format!(
"Query {} is significantly faster ({:.1}% improvement)",
faster_query,
relative_difference * 100.0
)
} else {
"Both queries have similar performance".to_string()
},
})
}
pub fn store_optimization_summary(&self) -> StoreOptimizationSummary {
let store_size = self.store.len();
let mut recommendations = Vec::new();
if store_size > 100000 {
recommendations.push(
"Consider partitioning large dataset by metadata for faster queries".to_string(),
);
}
if store_size > 50000 {
recommendations
.push("Use approximate indexes (IVF-PQ, LSH) for better scaling".to_string());
}
if store_size > 10000 {
recommendations.push("Add metadata indexes for frequently filtered fields".to_string());
}
StoreOptimizationSummary {
store_size,
estimated_query_time: self.estimate_avg_query_time(store_size),
recommendations,
}
}
fn estimate_avg_query_time(&self, store_size: usize) -> Duration {
let ms = (store_size as f32 * 0.001).max(0.1);
Duration::from_millis(ms as u64)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryComparison {
pub query1_cost: f32,
pub query2_cost: f32,
pub faster_query: u8,
pub cost_difference: f32,
pub relative_difference: f32,
pub recommendation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoreOptimizationSummary {
pub store_size: usize,
pub estimated_query_time: Duration,
pub recommendations: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Metadata;
use std::collections::HashMap;
use tempfile::TempDir;
fn create_test_store() -> Result<VecStore> {
let temp_dir = TempDir::new()?;
let mut store = VecStore::open(temp_dir.path().join("test.db"))?;
for i in 0..100 {
let mut metadata = Metadata {
fields: HashMap::new(),
};
metadata
.fields
.insert("category".to_string(), serde_json::json!("test"));
store.upsert(format!("doc{}", i), vec![i as f32 * 0.01; 128], metadata)?;
}
Ok(store)
}
#[test]
fn test_basic_analysis() -> Result<()> {
let store = create_test_store()?;
let optimizer = QueryOptimizer::new(&store);
let query = Query::new(vec![0.5; 128]).with_limit(10);
let analysis = optimizer.analyze_query(&query)?;
assert!(analysis.estimated_cost > 0.0);
assert!(matches!(
analysis.complexity,
QueryComplexity::Simple | QueryComplexity::Moderate
));
Ok(())
}
#[test]
fn test_filter_hint() -> Result<()> {
let store = create_test_store()?;
let optimizer = QueryOptimizer::new(&store);
let query = Query::new(vec![0.5; 128])
.with_limit(10)
.with_filter("category = 'test'");
let analysis = optimizer.analyze_query(&query)?;
assert!(!analysis.hints.is_empty());
Ok(())
}
#[test]
fn test_large_k_hint() -> Result<()> {
let store = create_test_store()?;
let optimizer = QueryOptimizer::new(&store);
let query = Query::new(vec![0.5; 128]).with_limit(200);
let analysis = optimizer.analyze_query(&query)?;
let has_k_hint = analysis
.hints
.iter()
.any(|h| matches!(h.category, HintCategory::QueryParam));
assert!(has_k_hint);
Ok(())
}
#[test]
fn test_execution_plan() -> Result<()> {
let store = create_test_store()?;
let optimizer = QueryOptimizer::new(&store);
let query = Query::new(vec![0.5; 128])
.with_limit(10)
.with_filter("category = 'test'");
let analysis = optimizer.analyze_query(&query)?;
assert!(!analysis.execution_plan.steps.is_empty());
assert!(analysis.execution_plan.uses_index);
Ok(())
}
#[test]
fn test_query_comparison() -> Result<()> {
let store = create_test_store()?;
let optimizer = QueryOptimizer::new(&store);
let query1 = Query::new(vec![0.5; 128]).with_limit(10);
let query2 = Query::new(vec![0.5; 128]).with_limit(100);
let comparison = optimizer.compare_queries(&query1, &query2)?;
assert!(comparison.faster_query == 1 || comparison.faster_query == 2);
assert!(comparison.cost_difference >= 0.0);
Ok(())
}
}