use lancedb::Table;
use lancedb::{query::VectorQuery, DistanceType};
#[derive(Debug, Clone)]
pub struct VectorIndexParams {
pub should_create_index: bool,
pub num_partitions: u32,
pub num_sub_vectors: u32,
pub num_bits: u8,
pub distance_type: DistanceType,
}
#[derive(Debug, Clone)]
pub struct SearchParams {
pub nprobes: usize,
pub refine_factor: Option<u32>,
}
pub struct VectorOptimizer;
impl VectorOptimizer {
pub async fn optimize_query(
mut query: VectorQuery,
table: &Table,
table_name: &str,
) -> Result<VectorQuery, lancedb::Error> {
let row_count = table.count_rows(None).await?;
let indices = table.list_indices().await?;
let has_index = indices.iter().any(|idx| idx.columns == vec!["embedding"]);
if has_index {
let estimated_partitions = if row_count < 1000 {
2
} else {
(row_count as f64).sqrt() as u32
};
let search_params = Self::calculate_search_params(estimated_partitions, row_count);
query = query.nprobes(search_params.nprobes);
if let Some(refine_factor) = search_params.refine_factor {
query = query.refine_factor(refine_factor);
}
tracing::debug!(
"Applied search optimization to {}: nprobes={}, refine_factor={:?}, rows={}, has_index={}",
table_name,
search_params.nprobes,
search_params.refine_factor,
row_count,
has_index
);
} else {
tracing::debug!(
"No index found for {}, using default search (rows={}, has_index={})",
table_name,
row_count,
has_index
);
}
Ok(query)
}
pub fn calculate_index_params(row_count: usize, vector_dimension: usize) -> VectorIndexParams {
if row_count < 1000 {
tracing::debug!(
"Dataset size {} is small, skipping index creation (brute force will be faster)",
row_count
);
return VectorIndexParams {
should_create_index: false,
num_partitions: 0,
num_sub_vectors: 0,
num_bits: 8,
distance_type: DistanceType::Cosine,
};
}
let sqrt_rows = (row_count as f64).sqrt() as u32;
let optimal_partition_size = if row_count < 10_000 {
std::cmp::max(sqrt_rows / 2, 2)
} else if row_count < 100_000 {
sqrt_rows
} else {
let max_partition_size = 8000;
std::cmp::max(row_count as u32 / max_partition_size, sqrt_rows)
};
let num_partitions = optimal_partition_size.clamp(2, 1024);
let base_sub_vectors = std::cmp::max(1, vector_dimension / 16);
let num_sub_vectors = Self::find_optimal_sub_vectors(vector_dimension, base_sub_vectors);
let num_bits = if row_count > 50_000 {
8
} else {
8
};
tracing::debug!(
"Calculated index params for {} rows, {} dimensions: partitions={}, sub_vectors={}, bits={}",
row_count, vector_dimension, num_partitions, num_sub_vectors, num_bits
);
VectorIndexParams {
should_create_index: true,
num_partitions,
num_sub_vectors,
num_bits,
distance_type: DistanceType::Cosine, }
}
pub fn calculate_search_params(num_partitions: u32, row_count: usize) -> SearchParams {
let min_nprobes = std::cmp::max(1, num_partitions / 20); let max_nprobes = std::cmp::max(min_nprobes, num_partitions / 7);
let nprobes = if row_count < 10_000 {
max_nprobes as usize
} else {
std::cmp::max(min_nprobes, num_partitions / 10) as usize
};
let refine_factor = if row_count > 100_000 {
Some(20)
} else if row_count > 10_000 {
Some(10)
} else {
None
};
tracing::debug!(
"Calculated search params for {} partitions, {} rows: nprobes={}, refine_factor={:?}",
num_partitions,
row_count,
nprobes,
refine_factor
);
SearchParams {
nprobes,
refine_factor,
}
}
fn find_optimal_sub_vectors(dimension: usize, target: usize) -> u32 {
let mut best = 1;
for candidate in (1..=target).rev() {
if dimension % candidate == 0 {
let sub_vector_size = dimension / candidate;
if sub_vector_size % 4 == 0 || sub_vector_size <= 8 {
best = candidate;
break; }
}
}
if best == 1 {
for candidate in (target..=dimension).rev() {
if dimension % candidate == 0 {
best = candidate;
break;
}
}
}
std::cmp::max(1, std::cmp::min(best, dimension)) as u32
}
pub fn should_recreate_index(
current_partitions: u32,
current_sub_vectors: u32,
optimal: &VectorIndexParams,
) -> bool {
if !optimal.should_create_index {
return false;
}
let partition_diff = (current_partitions as f32 - optimal.num_partitions as f32).abs()
/ optimal.num_partitions as f32;
let sub_vector_diff = (current_sub_vectors as f32 - optimal.num_sub_vectors as f32).abs()
/ optimal.num_sub_vectors as f32;
partition_diff > 0.5 || sub_vector_diff > 0.25
}
pub fn should_optimize_for_growth(
current_rows: usize,
vector_dimension: usize,
has_embedding_index: bool,
) -> bool {
if !has_embedding_index {
return false;
}
let optimal_params = Self::calculate_index_params(current_rows, vector_dimension);
if !optimal_params.should_create_index {
return false;
}
let growth_milestones = [
1000, 5000, 10000, 25000, 50000, 100000, 250000, 500000, 1000000,
];
for &milestone in &growth_milestones {
if current_rows >= milestone && current_rows <= milestone + 50 {
tracing::info!(
"Dataset reached {} rows milestone, considering index optimization",
milestone
);
return true;
}
}
if current_rows > 1000000 && current_rows % 100000 <= 50 {
return true;
}
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_small_dataset_no_index() {
let params = VectorOptimizer::calculate_index_params(500, 768);
assert!(!params.should_create_index);
}
#[test]
fn test_medium_dataset_creates_index() {
let params = VectorOptimizer::calculate_index_params(5000, 768);
assert!(params.should_create_index);
assert!(params.num_partitions >= 2);
assert!(params.num_sub_vectors >= 1);
assert_eq!(params.num_bits, 8);
}
#[test]
fn test_large_dataset_optimized() {
let params = VectorOptimizer::calculate_index_params(200_000, 1536);
assert!(params.should_create_index);
assert!(params.num_partitions > 100); assert!(params.num_sub_vectors > 50); }
#[test]
fn test_sub_vector_calculation() {
assert_eq!(VectorOptimizer::find_optimal_sub_vectors(768, 48), 48);
assert_eq!(VectorOptimizer::find_optimal_sub_vectors(1000, 62), 50); }
#[test]
fn test_search_params() {
let search_params = VectorOptimizer::calculate_search_params(100, 50_000);
assert!(search_params.nprobes >= 5); assert!(search_params.nprobes <= 15); assert!(search_params.refine_factor.is_some());
}
#[test]
fn test_growth_optimization() {
assert!(VectorOptimizer::should_optimize_for_growth(1000, 768, true));
assert!(VectorOptimizer::should_optimize_for_growth(5000, 768, true));
assert!(VectorOptimizer::should_optimize_for_growth(
10000, 768, true
));
assert!(!VectorOptimizer::should_optimize_for_growth(
1500, 768, true
));
assert!(!VectorOptimizer::should_optimize_for_growth(
7500, 768, true
));
assert!(!VectorOptimizer::should_optimize_for_growth(
5000, 768, false
));
}
}