use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub enum IndexFamily {
Hnsw,
Ivf,
Lsh,
Pq,
}
impl IndexFamily {
pub fn all() -> [IndexFamily; 4] {
[
IndexFamily::Hnsw,
IndexFamily::Ivf,
IndexFamily::Lsh,
IndexFamily::Pq,
]
}
pub fn as_str(&self) -> &'static str {
match self {
IndexFamily::Hnsw => "hnsw",
IndexFamily::Ivf => "ivf",
IndexFamily::Lsh => "lsh",
IndexFamily::Pq => "pq",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WorkloadProfile {
pub data_size: usize,
pub dim: usize,
pub requested_recall: f32,
pub query_density: f32,
pub k: usize,
}
impl WorkloadProfile {
pub fn new(data_size: usize, dim: usize, requested_recall: f32) -> Self {
Self {
data_size,
dim,
requested_recall,
query_density: 1.0,
k: 10,
}
}
pub fn with_query_density(mut self, density: f32) -> Self {
self.query_density = density.clamp(0.0, 1.0);
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k.max(1);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexParameters {
pub hnsw_m: usize,
pub hnsw_ef: usize,
pub ivf_n_clusters: usize,
pub ivf_n_probes: usize,
pub lsh_tables: usize,
pub lsh_hash_functions: usize,
pub lsh_avg_bucket_size: usize,
pub pq_subquantizers: usize,
pub pq_centroids: usize,
}
impl Default for IndexParameters {
fn default() -> Self {
Self {
hnsw_m: 16,
hnsw_ef: 50,
ivf_n_clusters: 256,
ivf_n_probes: 8,
lsh_tables: 10,
lsh_hash_functions: 8,
lsh_avg_bucket_size: 64,
pq_subquantizers: 8,
pq_centroids: 256,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CostWeights {
weights: BTreeMap<IndexFamily, f64>,
}
impl Default for CostWeights {
fn default() -> Self {
let mut weights = BTreeMap::new();
for fam in IndexFamily::all() {
weights.insert(fam, 1.0);
}
Self { weights }
}
}
impl CostWeights {
pub fn get(&self, family: IndexFamily) -> f64 {
self.weights.get(&family).copied().unwrap_or(1.0)
}
pub fn set(&mut self, family: IndexFamily, weight: f64) {
let clamped = weight.clamp(0.05, 20.0);
self.weights.insert(family, clamped);
}
}
fn expected_recall_floor(family: IndexFamily) -> f32 {
match family {
IndexFamily::Hnsw => 0.95,
IndexFamily::Ivf => 0.85,
IndexFamily::Lsh => 0.75,
IndexFamily::Pq => 0.88,
}
}
#[derive(Debug, Clone, Default)]
pub struct CostModel {
parameters: IndexParameters,
weights: CostWeights,
}
impl CostModel {
pub fn new(parameters: IndexParameters, weights: CostWeights) -> Self {
Self {
parameters,
weights,
}
}
pub fn weights_mut(&mut self) -> &mut CostWeights {
&mut self.weights
}
pub fn weights(&self) -> &CostWeights {
&self.weights
}
pub fn parameters(&self) -> &IndexParameters {
&self.parameters
}
pub fn recall_floor(family: IndexFamily) -> f32 {
expected_recall_floor(family)
}
pub fn estimate(&self, family: IndexFamily, workload: &WorkloadProfile) -> CostEstimate {
let density_scale = (workload.query_density.clamp(0.01, 1.0)) as f64;
let n = workload.data_size.max(1) as f64;
let dim = workload.dim.max(1) as f64;
let k = workload.k.max(1) as f64;
let raw_cost = match family {
IndexFamily::Hnsw => self.estimate_hnsw(n, k),
IndexFamily::Ivf => self.estimate_ivf(n),
IndexFamily::Lsh => self.estimate_lsh(dim),
IndexFamily::Pq => self.estimate_pq(n),
};
let density_factor = match family {
IndexFamily::Hnsw => 1.0 / density_scale.max(0.1),
IndexFamily::Ivf => 1.0,
IndexFamily::Lsh => density_scale.max(0.5),
IndexFamily::Pq => density_scale.max(0.5),
};
let weight = self.weights.get(family);
let cost = raw_cost * weight * density_factor;
let recall = self.estimate_recall(family, workload);
CostEstimate {
family,
cost,
recall,
}
}
fn estimate_hnsw(&self, n: f64, k: f64) -> f64 {
let p = &self.parameters;
let log_n = n.ln().max(1.0);
(p.hnsw_ef as f64) * (p.hnsw_m as f64) * log_n + k
}
fn estimate_ivf(&self, n: f64) -> f64 {
let p = &self.parameters;
let n_clusters = p.ivf_n_clusters.max(1) as f64;
let n_probes = p.ivf_n_probes.max(1) as f64;
n_clusters + n * (n_probes / n_clusters)
}
fn estimate_lsh(&self, dim: f64) -> f64 {
let p = &self.parameters;
let l = p.lsh_tables.max(1) as f64;
let kk = p.lsh_hash_functions.max(1) as f64;
let bucket = p.lsh_avg_bucket_size.max(1) as f64;
kk * l * dim + l * bucket
}
fn estimate_pq(&self, n: f64) -> f64 {
let p = &self.parameters;
let cents = p.pq_centroids.max(1) as f64;
let subs = p.pq_subquantizers.max(1) as f64;
cents * subs + n * subs / 8.0
}
fn estimate_recall(&self, family: IndexFamily, workload: &WorkloadProfile) -> f32 {
let floor = expected_recall_floor(family);
let lift = match family {
IndexFamily::Hnsw => {
let ef = self.parameters.hnsw_ef as f32;
((ef - 32.0) / 200.0).clamp(0.0, 0.04)
}
IndexFamily::Ivf => {
let probes = self.parameters.ivf_n_probes as f32;
((probes - 4.0) / 64.0).clamp(0.0, 0.08)
}
IndexFamily::Lsh => {
let l = self.parameters.lsh_tables as f32;
((l - 4.0) / 64.0).clamp(0.0, 0.10)
}
IndexFamily::Pq => {
let cents = self.parameters.pq_centroids as f32;
((cents - 64.0) / 1024.0).clamp(0.0, 0.06)
}
};
let dim_penalty = if workload.dim > 512 {
((workload.dim as f32 - 512.0) / 4096.0).min(0.05)
} else {
0.0
};
(floor + lift - dim_penalty).clamp(0.0, 1.0)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CostEstimate {
pub family: IndexFamily,
pub cost: f64,
pub recall: f32,
}
#[cfg(test)]
mod tests {
use super::*;
fn workload(n: usize, dim: usize, recall: f32) -> WorkloadProfile {
WorkloadProfile::new(n, dim, recall)
}
#[test]
fn index_family_all_returns_four_distinct() {
let all = IndexFamily::all();
assert_eq!(all.len(), 4);
let strs: Vec<_> = all.iter().map(|f| f.as_str()).collect();
assert_eq!(strs, vec!["hnsw", "ivf", "lsh", "pq"]);
}
#[test]
fn cost_weights_default_is_unit() {
let w = CostWeights::default();
for f in IndexFamily::all() {
assert!((w.get(f) - 1.0).abs() < 1e-12);
}
}
#[test]
fn cost_weights_set_clamps_outliers() {
let mut w = CostWeights::default();
w.set(IndexFamily::Hnsw, 1000.0);
assert!((w.get(IndexFamily::Hnsw) - 20.0).abs() < 1e-12);
w.set(IndexFamily::Pq, 0.0);
assert!((w.get(IndexFamily::Pq) - 0.05).abs() < 1e-12);
}
#[test]
fn hnsw_cost_grows_with_log_n() {
let cm = CostModel::default();
let small = cm.estimate(IndexFamily::Hnsw, &workload(1_000, 128, 0.9));
let large = cm.estimate(IndexFamily::Hnsw, &workload(1_000_000, 128, 0.9));
assert!(
large.cost > small.cost,
"HNSW cost must grow with data size"
);
assert!(large.cost < small.cost * 4.0);
}
#[test]
fn ivf_cost_grows_with_n() {
let cm = CostModel::default();
let small = cm.estimate(IndexFamily::Ivf, &workload(10_000, 128, 0.9));
let large = cm.estimate(IndexFamily::Ivf, &workload(1_000_000, 128, 0.9));
assert!(large.cost > small.cost);
assert!(large.cost > small.cost * 10.0);
}
#[test]
fn lsh_cost_independent_of_n() {
let cm = CostModel::default();
let small = cm.estimate(IndexFamily::Lsh, &workload(1_000, 128, 0.8));
let large = cm.estimate(IndexFamily::Lsh, &workload(1_000_000, 128, 0.8));
assert!((large.cost - small.cost).abs() < 1e-9);
}
#[test]
fn pq_cost_grows_with_n() {
let cm = CostModel::default();
let small = cm.estimate(IndexFamily::Pq, &workload(1_000, 128, 0.9));
let large = cm.estimate(IndexFamily::Pq, &workload(100_000, 128, 0.9));
assert!(large.cost > small.cost);
}
#[test]
fn weights_scale_cost_linearly() {
let mut cm = CostModel::default();
let baseline = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
cm.weights_mut().set(IndexFamily::Hnsw, 2.0);
let scaled = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
assert!((scaled.cost - 2.0 * baseline.cost).abs() < 1e-6);
}
#[test]
fn recall_floors_match_expectations() {
assert!((CostModel::recall_floor(IndexFamily::Hnsw) - 0.95).abs() < 1e-6);
assert!((CostModel::recall_floor(IndexFamily::Pq) - 0.88).abs() < 1e-6);
assert!(
CostModel::recall_floor(IndexFamily::Lsh) < CostModel::recall_floor(IndexFamily::Hnsw)
);
}
#[test]
fn high_dim_penalises_recall_estimate() {
let cm = CostModel::default();
let low_dim = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
let high_dim = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 4096, 0.9));
assert!(high_dim.recall < low_dim.recall);
}
#[test]
fn density_biases_toward_filterable_indices() {
let cm = CostModel::default();
let unfiltered = cm.estimate(
IndexFamily::Hnsw,
&workload(10_000, 128, 0.9).with_query_density(1.0),
);
let very_selective = cm.estimate(
IndexFamily::Hnsw,
&workload(10_000, 128, 0.9).with_query_density(0.05),
);
assert!(very_selective.cost > unfiltered.cost);
}
#[test]
fn density_helps_lsh_and_pq() {
let cm = CostModel::default();
let unfiltered = cm.estimate(
IndexFamily::Lsh,
&workload(10_000, 128, 0.8).with_query_density(1.0),
);
let selective = cm.estimate(
IndexFamily::Lsh,
&workload(10_000, 128, 0.8).with_query_density(0.5),
);
assert!(selective.cost <= unfiltered.cost);
}
}