1use serde::{Deserialize, Serialize};
30use std::collections::BTreeMap;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
37pub enum IndexFamily {
38 Hnsw,
40 Ivf,
42 Lsh,
44 Pq,
46}
47
48impl IndexFamily {
49 pub fn all() -> [IndexFamily; 4] {
51 [
52 IndexFamily::Hnsw,
53 IndexFamily::Ivf,
54 IndexFamily::Lsh,
55 IndexFamily::Pq,
56 ]
57 }
58
59 pub fn as_str(&self) -> &'static str {
61 match self {
62 IndexFamily::Hnsw => "hnsw",
63 IndexFamily::Ivf => "ivf",
64 IndexFamily::Lsh => "lsh",
65 IndexFamily::Pq => "pq",
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct WorkloadProfile {
78 pub data_size: usize,
80 pub dim: usize,
82 pub requested_recall: f32,
84 pub query_density: f32,
90 pub k: usize,
92}
93
94impl WorkloadProfile {
95 pub fn new(data_size: usize, dim: usize, requested_recall: f32) -> Self {
97 Self {
98 data_size,
99 dim,
100 requested_recall,
101 query_density: 1.0,
102 k: 10,
103 }
104 }
105
106 pub fn with_query_density(mut self, density: f32) -> Self {
108 self.query_density = density.clamp(0.0, 1.0);
109 self
110 }
111
112 pub fn with_k(mut self, k: usize) -> Self {
114 self.k = k.max(1);
115 self
116 }
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct IndexParameters {
128 pub hnsw_m: usize,
130 pub hnsw_ef: usize,
132 pub ivf_n_clusters: usize,
134 pub ivf_n_probes: usize,
136 pub lsh_tables: usize,
138 pub lsh_hash_functions: usize,
140 pub lsh_avg_bucket_size: usize,
142 pub pq_subquantizers: usize,
144 pub pq_centroids: usize,
146}
147
148impl Default for IndexParameters {
149 fn default() -> Self {
150 Self {
151 hnsw_m: 16,
152 hnsw_ef: 50,
153 ivf_n_clusters: 256,
154 ivf_n_probes: 8,
155 lsh_tables: 10,
156 lsh_hash_functions: 8,
157 lsh_avg_bucket_size: 64,
158 pq_subquantizers: 8,
159 pq_centroids: 256,
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
171pub struct CostWeights {
172 weights: BTreeMap<IndexFamily, f64>,
173}
174
175impl Default for CostWeights {
176 fn default() -> Self {
177 let mut weights = BTreeMap::new();
178 for fam in IndexFamily::all() {
179 weights.insert(fam, 1.0);
180 }
181 Self { weights }
182 }
183}
184
185impl CostWeights {
186 pub fn get(&self, family: IndexFamily) -> f64 {
188 self.weights.get(&family).copied().unwrap_or(1.0)
189 }
190
191 pub fn set(&mut self, family: IndexFamily, weight: f64) {
194 let clamped = weight.clamp(0.05, 20.0);
195 self.weights.insert(family, clamped);
196 }
197}
198
199fn expected_recall_floor(family: IndexFamily) -> f32 {
206 match family {
207 IndexFamily::Hnsw => 0.95,
208 IndexFamily::Ivf => 0.85,
209 IndexFamily::Lsh => 0.75,
210 IndexFamily::Pq => 0.88,
211 }
212}
213
214#[derive(Debug, Clone, Default)]
216pub struct CostModel {
217 parameters: IndexParameters,
218 weights: CostWeights,
219}
220
221impl CostModel {
222 pub fn new(parameters: IndexParameters, weights: CostWeights) -> Self {
224 Self {
225 parameters,
226 weights,
227 }
228 }
229
230 pub fn weights_mut(&mut self) -> &mut CostWeights {
232 &mut self.weights
233 }
234
235 pub fn weights(&self) -> &CostWeights {
237 &self.weights
238 }
239
240 pub fn parameters(&self) -> &IndexParameters {
242 &self.parameters
243 }
244
245 pub fn recall_floor(family: IndexFamily) -> f32 {
248 expected_recall_floor(family)
249 }
250
251 pub fn estimate(&self, family: IndexFamily, workload: &WorkloadProfile) -> CostEstimate {
258 let density_scale = (workload.query_density.clamp(0.01, 1.0)) as f64;
262 let n = workload.data_size.max(1) as f64;
263 let dim = workload.dim.max(1) as f64;
264 let k = workload.k.max(1) as f64;
265
266 let raw_cost = match family {
267 IndexFamily::Hnsw => self.estimate_hnsw(n, k),
268 IndexFamily::Ivf => self.estimate_ivf(n),
269 IndexFamily::Lsh => self.estimate_lsh(dim),
270 IndexFamily::Pq => self.estimate_pq(n),
271 };
272
273 let density_factor = match family {
276 IndexFamily::Hnsw => 1.0 / density_scale.max(0.1),
277 IndexFamily::Ivf => 1.0,
278 IndexFamily::Lsh => density_scale.max(0.5),
279 IndexFamily::Pq => density_scale.max(0.5),
280 };
281
282 let weight = self.weights.get(family);
283 let cost = raw_cost * weight * density_factor;
284
285 let recall = self.estimate_recall(family, workload);
288
289 CostEstimate {
290 family,
291 cost,
292 recall,
293 }
294 }
295
296 fn estimate_hnsw(&self, n: f64, k: f64) -> f64 {
298 let p = &self.parameters;
299 let log_n = n.ln().max(1.0);
300 (p.hnsw_ef as f64) * (p.hnsw_m as f64) * log_n + k
301 }
302
303 fn estimate_ivf(&self, n: f64) -> f64 {
307 let p = &self.parameters;
308 let n_clusters = p.ivf_n_clusters.max(1) as f64;
309 let n_probes = p.ivf_n_probes.max(1) as f64;
310 n_clusters + n * (n_probes / n_clusters)
311 }
312
313 fn estimate_lsh(&self, dim: f64) -> f64 {
315 let p = &self.parameters;
316 let l = p.lsh_tables.max(1) as f64;
317 let kk = p.lsh_hash_functions.max(1) as f64;
318 let bucket = p.lsh_avg_bucket_size.max(1) as f64;
319 kk * l * dim + l * bucket
320 }
321
322 fn estimate_pq(&self, n: f64) -> f64 {
325 let p = &self.parameters;
326 let cents = p.pq_centroids.max(1) as f64;
327 let subs = p.pq_subquantizers.max(1) as f64;
328 cents * subs + n * subs / 8.0
329 }
330
331 fn estimate_recall(&self, family: IndexFamily, workload: &WorkloadProfile) -> f32 {
333 let floor = expected_recall_floor(family);
334 let lift = match family {
336 IndexFamily::Hnsw => {
337 let ef = self.parameters.hnsw_ef as f32;
338 ((ef - 32.0) / 200.0).clamp(0.0, 0.04)
339 }
340 IndexFamily::Ivf => {
341 let probes = self.parameters.ivf_n_probes as f32;
342 ((probes - 4.0) / 64.0).clamp(0.0, 0.08)
343 }
344 IndexFamily::Lsh => {
345 let l = self.parameters.lsh_tables as f32;
346 ((l - 4.0) / 64.0).clamp(0.0, 0.10)
347 }
348 IndexFamily::Pq => {
349 let cents = self.parameters.pq_centroids as f32;
350 ((cents - 64.0) / 1024.0).clamp(0.0, 0.06)
351 }
352 };
353 let dim_penalty = if workload.dim > 512 {
355 ((workload.dim as f32 - 512.0) / 4096.0).min(0.05)
356 } else {
357 0.0
358 };
359 (floor + lift - dim_penalty).clamp(0.0, 1.0)
360 }
361}
362
363#[derive(Debug, Clone, PartialEq)]
365pub struct CostEstimate {
366 pub family: IndexFamily,
368 pub cost: f64,
370 pub recall: f32,
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 fn workload(n: usize, dim: usize, recall: f32) -> WorkloadProfile {
379 WorkloadProfile::new(n, dim, recall)
380 }
381
382 #[test]
383 fn index_family_all_returns_four_distinct() {
384 let all = IndexFamily::all();
385 assert_eq!(all.len(), 4);
386 let strs: Vec<_> = all.iter().map(|f| f.as_str()).collect();
387 assert_eq!(strs, vec!["hnsw", "ivf", "lsh", "pq"]);
388 }
389
390 #[test]
391 fn cost_weights_default_is_unit() {
392 let w = CostWeights::default();
393 for f in IndexFamily::all() {
394 assert!((w.get(f) - 1.0).abs() < 1e-12);
395 }
396 }
397
398 #[test]
399 fn cost_weights_set_clamps_outliers() {
400 let mut w = CostWeights::default();
401 w.set(IndexFamily::Hnsw, 1000.0);
402 assert!((w.get(IndexFamily::Hnsw) - 20.0).abs() < 1e-12);
403 w.set(IndexFamily::Pq, 0.0);
404 assert!((w.get(IndexFamily::Pq) - 0.05).abs() < 1e-12);
405 }
406
407 #[test]
408 fn hnsw_cost_grows_with_log_n() {
409 let cm = CostModel::default();
410 let small = cm.estimate(IndexFamily::Hnsw, &workload(1_000, 128, 0.9));
411 let large = cm.estimate(IndexFamily::Hnsw, &workload(1_000_000, 128, 0.9));
412 assert!(
413 large.cost > small.cost,
414 "HNSW cost must grow with data size"
415 );
416 assert!(large.cost < small.cost * 4.0);
418 }
419
420 #[test]
421 fn ivf_cost_grows_with_n() {
422 let cm = CostModel::default();
423 let small = cm.estimate(IndexFamily::Ivf, &workload(10_000, 128, 0.9));
424 let large = cm.estimate(IndexFamily::Ivf, &workload(1_000_000, 128, 0.9));
425 assert!(large.cost > small.cost);
427 assert!(large.cost > small.cost * 10.0);
428 }
429
430 #[test]
431 fn lsh_cost_independent_of_n() {
432 let cm = CostModel::default();
433 let small = cm.estimate(IndexFamily::Lsh, &workload(1_000, 128, 0.8));
434 let large = cm.estimate(IndexFamily::Lsh, &workload(1_000_000, 128, 0.8));
435 assert!((large.cost - small.cost).abs() < 1e-9);
438 }
439
440 #[test]
441 fn pq_cost_grows_with_n() {
442 let cm = CostModel::default();
443 let small = cm.estimate(IndexFamily::Pq, &workload(1_000, 128, 0.9));
444 let large = cm.estimate(IndexFamily::Pq, &workload(100_000, 128, 0.9));
445 assert!(large.cost > small.cost);
446 }
447
448 #[test]
449 fn weights_scale_cost_linearly() {
450 let mut cm = CostModel::default();
451 let baseline = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
452 cm.weights_mut().set(IndexFamily::Hnsw, 2.0);
453 let scaled = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
454 assert!((scaled.cost - 2.0 * baseline.cost).abs() < 1e-6);
455 }
456
457 #[test]
458 fn recall_floors_match_expectations() {
459 assert!((CostModel::recall_floor(IndexFamily::Hnsw) - 0.95).abs() < 1e-6);
460 assert!((CostModel::recall_floor(IndexFamily::Pq) - 0.88).abs() < 1e-6);
461 assert!(
462 CostModel::recall_floor(IndexFamily::Lsh) < CostModel::recall_floor(IndexFamily::Hnsw)
463 );
464 }
465
466 #[test]
467 fn high_dim_penalises_recall_estimate() {
468 let cm = CostModel::default();
469 let low_dim = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
470 let high_dim = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 4096, 0.9));
471 assert!(high_dim.recall < low_dim.recall);
472 }
473
474 #[test]
475 fn density_biases_toward_filterable_indices() {
476 let cm = CostModel::default();
477 let unfiltered = cm.estimate(
478 IndexFamily::Hnsw,
479 &workload(10_000, 128, 0.9).with_query_density(1.0),
480 );
481 let very_selective = cm.estimate(
482 IndexFamily::Hnsw,
483 &workload(10_000, 128, 0.9).with_query_density(0.05),
484 );
485 assert!(very_selective.cost > unfiltered.cost);
488 }
489
490 #[test]
491 fn density_helps_lsh_and_pq() {
492 let cm = CostModel::default();
493 let unfiltered = cm.estimate(
494 IndexFamily::Lsh,
495 &workload(10_000, 128, 0.8).with_query_density(1.0),
496 );
497 let selective = cm.estimate(
498 IndexFamily::Lsh,
499 &workload(10_000, 128, 0.8).with_query_density(0.5),
500 );
501 assert!(selective.cost <= unfiltered.cost);
502 }
503}