Skip to main content

oxirs_vec/optimizer/
cost_model.rs

1//! Per-index cost formulas for vector search query optimization.
2//!
3//! Each formula reflects the asymptotic search cost of a specific approximate
4//! nearest-neighbour index, expressed as **expected number of distance
5//! computations** for a single k-NN query.  Distance computations dominate
6//! query latency for vector search, so distance-count is a robust proxy for
7//! wall-clock time.
8//!
9//! The formulas implemented here are:
10//!
11//! - **HNSW**: `O(log n × M × ef)` — `ef * (M * log n)` expected probes
12//!   through the hierarchy.
13//! - **IVF**:  `O(n / nprobe + n_clusters)` — `n_clusters` for coarse
14//!   centroid selection, then `n / nprobe` for fine search inside the
15//!   probed cells.
16//! - **LSH**:  `O(L × bucket_size + K × L × dim)` — `K * L * dim` for hash
17//!   evaluation and `L * bucket_size` for candidate scoring.
18//! - **PQ**:   `O(centroids × subquantizers + n × subquantizers)` — codebook
19//!   lookup table build then asymmetric distance against every encoded vector.
20//!
21//! Each cost is multiplied by a tunable per-index weight obtained from
22//! historical statistics, enabling online cost-model adaptation when actual
23//! latency measurements drift from the static formula.
24//!
25//! All costs are expressed in **abstract distance-equivalent units**.  Higher
26//! is more expensive; the dispatcher picks the lowest-cost index that meets
27//! the requested recall.
28
29use serde::{Deserialize, Serialize};
30use std::collections::BTreeMap;
31
32/// Index families recognised by the optimizer cost model.
33///
34/// Each variant maps to a concrete cost formula encoded in
35/// [`CostModel::estimate`].
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
37pub enum IndexFamily {
38    /// Hierarchical Navigable Small World graph.
39    Hnsw,
40    /// Inverted File index (Voronoi-cell partitioning).
41    Ivf,
42    /// Locality-Sensitive Hashing.
43    Lsh,
44    /// Product Quantization (asymmetric distance against coded vectors).
45    Pq,
46}
47
48impl IndexFamily {
49    /// All families known to the cost model.
50    pub fn all() -> [IndexFamily; 4] {
51        [
52            IndexFamily::Hnsw,
53            IndexFamily::Ivf,
54            IndexFamily::Lsh,
55            IndexFamily::Pq,
56        ]
57    }
58
59    /// Stable string identifier for serialization and metrics.
60    pub fn as_str(&self) -> &'static str {
61        match self {
62            IndexFamily::Hnsw => "hnsw",
63            IndexFamily::Ivf => "ivf",
64            IndexFamily::Lsh => "lsh",
65            IndexFamily::Pq => "pq",
66        }
67    }
68}
69
70/// Workload characteristics the cost model needs to estimate per-index cost.
71///
72/// `data_size` is the total number of indexed vectors; `dim` is the
73/// dimensionality; `requested_recall` is in `[0.0, 1.0]`; `query_density`
74/// expresses the expected fraction of the dataset that satisfies the query
75/// predicate (e.g. 1.0 = no filtering, 0.1 = 10% of data is candidate).
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct WorkloadProfile {
78    /// Number of vectors currently indexed.
79    pub data_size: usize,
80    /// Vector dimensionality.
81    pub dim: usize,
82    /// Minimum acceptable recall (0.0 to 1.0).
83    pub requested_recall: f32,
84    /// Expected fraction of data passing pre-filter (0.0 to 1.0).
85    ///
86    /// Use `1.0` for unfiltered queries.  Lower values indicate selective
87    /// filters that should bias the optimizer toward indices that benefit
88    /// from filtering (LSH and PQ degrade gracefully; HNSW does not).
89    pub query_density: f32,
90    /// Number of nearest neighbours requested.
91    pub k: usize,
92}
93
94impl WorkloadProfile {
95    /// Create a profile with sensible defaults: density=1.0 (unfiltered), k=10.
96    pub fn new(data_size: usize, dim: usize, requested_recall: f32) -> Self {
97        Self {
98            data_size,
99            dim,
100            requested_recall,
101            query_density: 1.0,
102            k: 10,
103        }
104    }
105
106    /// Set the query density (filter selectivity).
107    pub fn with_query_density(mut self, density: f32) -> Self {
108        self.query_density = density.clamp(0.0, 1.0);
109        self
110    }
111
112    /// Set the number of nearest neighbours.
113    pub fn with_k(mut self, k: usize) -> Self {
114        self.k = k.max(1);
115        self
116    }
117}
118
119/// Tunable per-index parameters that affect the cost formula.
120///
121/// Defaults match the canonical defaults used elsewhere in the crate:
122/// `HnswConfig::m = 16`, `HnswConfig::ef = 50`, `IvfConfig::n_clusters = 256`,
123/// `IvfConfig::n_probes = 8`, `LshConfig::num_tables = 10`,
124/// `LshConfig::num_hash_functions = 8`, `PQConfig::n_subquantizers = 8`,
125/// `PQConfig::n_centroids = 256`.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct IndexParameters {
128    /// HNSW max degree per layer.
129    pub hnsw_m: usize,
130    /// HNSW search beam width.
131    pub hnsw_ef: usize,
132    /// IVF coarse centroids.
133    pub ivf_n_clusters: usize,
134    /// IVF cells probed per query.
135    pub ivf_n_probes: usize,
136    /// LSH hash tables.
137    pub lsh_tables: usize,
138    /// LSH hash functions per table.
139    pub lsh_hash_functions: usize,
140    /// LSH expected bucket size (data_size / (tables * 2^hash_functions) clamped).
141    pub lsh_avg_bucket_size: usize,
142    /// PQ subquantizers.
143    pub pq_subquantizers: usize,
144    /// PQ centroids per subquantizer.
145    pub pq_centroids: usize,
146}
147
148impl Default for IndexParameters {
149    fn default() -> Self {
150        Self {
151            hnsw_m: 16,
152            hnsw_ef: 50,
153            ivf_n_clusters: 256,
154            ivf_n_probes: 8,
155            lsh_tables: 10,
156            lsh_hash_functions: 8,
157            lsh_avg_bucket_size: 64,
158            pq_subquantizers: 8,
159            pq_centroids: 256,
160        }
161    }
162}
163
164/// Online-learnable weights applied to each per-index formula output.
165///
166/// `weight = 1.0` means "trust the formula"; higher values indicate the
167/// formula systematically underestimates real latency for that family;
168/// lower values indicate it overestimates.  Weights are updated by
169/// [`crate::optimizer::query_stats::QueryStats::recommended_weights`].
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
171pub struct CostWeights {
172    weights: BTreeMap<IndexFamily, f64>,
173}
174
175impl Default for CostWeights {
176    fn default() -> Self {
177        let mut weights = BTreeMap::new();
178        for fam in IndexFamily::all() {
179            weights.insert(fam, 1.0);
180        }
181        Self { weights }
182    }
183}
184
185impl CostWeights {
186    /// Get weight for a family (defaults to 1.0 if not yet set).
187    pub fn get(&self, family: IndexFamily) -> f64 {
188        self.weights.get(&family).copied().unwrap_or(1.0)
189    }
190
191    /// Set weight for a family.  Values are clamped to `[0.05, 20.0]` to
192    /// prevent runaway feedback loops from outlier observations.
193    pub fn set(&mut self, family: IndexFamily, weight: f64) {
194        let clamped = weight.clamp(0.05, 20.0);
195        self.weights.insert(family, clamped);
196    }
197}
198
199/// Expected recall floor per family at default parameters.
200///
201/// These are conservative empirical lower bounds — the dispatcher uses them
202/// to filter out indices that *cannot meet* the requested recall before any
203/// cost comparison is performed.  Indices with adaptive parameters (HNSW
204/// `ef`, IVF `nprobe`) can usually exceed these floors when tuned.
205fn expected_recall_floor(family: IndexFamily) -> f32 {
206    match family {
207        IndexFamily::Hnsw => 0.95,
208        IndexFamily::Ivf => 0.85,
209        IndexFamily::Lsh => 0.75,
210        IndexFamily::Pq => 0.88,
211    }
212}
213
214/// Cost-model entrypoint used by the optimizer dispatcher.
215#[derive(Debug, Clone, Default)]
216pub struct CostModel {
217    parameters: IndexParameters,
218    weights: CostWeights,
219}
220
221impl CostModel {
222    /// Construct a cost model with explicit parameters and weights.
223    pub fn new(parameters: IndexParameters, weights: CostWeights) -> Self {
224        Self {
225            parameters,
226            weights,
227        }
228    }
229
230    /// Return mutable access to the weights for online adaptation.
231    pub fn weights_mut(&mut self) -> &mut CostWeights {
232        &mut self.weights
233    }
234
235    /// Borrow the current weights.
236    pub fn weights(&self) -> &CostWeights {
237        &self.weights
238    }
239
240    /// Borrow the current parameters.
241    pub fn parameters(&self) -> &IndexParameters {
242        &self.parameters
243    }
244
245    /// Recall floor that an index family is expected to deliver at its
246    /// default parameter configuration.
247    pub fn recall_floor(family: IndexFamily) -> f32 {
248        expected_recall_floor(family)
249    }
250
251    /// Estimate the (cost, recall) pair for executing `workload` against
252    /// `family`.
253    ///
254    /// `cost` is in abstract distance-equivalent units (higher = slower);
255    /// `recall` is a 0..1 estimate.  Cost is multiplied by the
256    /// learned per-family weight.
257    pub fn estimate(&self, family: IndexFamily, workload: &WorkloadProfile) -> CostEstimate {
258        // Scale by query density: highly selective queries benefit from
259        // filtered scans; less selective queries amortize over more probes.
260        // Density of 1.0 means no scaling (the default unfiltered query).
261        let density_scale = (workload.query_density.clamp(0.01, 1.0)) as f64;
262        let n = workload.data_size.max(1) as f64;
263        let dim = workload.dim.max(1) as f64;
264        let k = workload.k.max(1) as f64;
265
266        let raw_cost = match family {
267            IndexFamily::Hnsw => self.estimate_hnsw(n, k),
268            IndexFamily::Ivf => self.estimate_ivf(n),
269            IndexFamily::Lsh => self.estimate_lsh(dim),
270            IndexFamily::Pq => self.estimate_pq(n),
271        };
272
273        // Density boosts indices that benefit from filtering (LSH, PQ) and
274        // penalises HNSW which has no native filtering primitive.
275        let density_factor = match family {
276            IndexFamily::Hnsw => 1.0 / density_scale.max(0.1),
277            IndexFamily::Ivf => 1.0,
278            IndexFamily::Lsh => density_scale.max(0.5),
279            IndexFamily::Pq => density_scale.max(0.5),
280        };
281
282        let weight = self.weights.get(family);
283        let cost = raw_cost * weight * density_factor;
284
285        // Estimate recall: at default parameters use the family floor;
286        // adapt slightly based on dim and requested_recall.
287        let recall = self.estimate_recall(family, workload);
288
289        CostEstimate {
290            family,
291            cost,
292            recall,
293        }
294    }
295
296    /// HNSW: `ef * M * log n + k`
297    fn estimate_hnsw(&self, n: f64, k: f64) -> f64 {
298        let p = &self.parameters;
299        let log_n = n.ln().max(1.0);
300        (p.hnsw_ef as f64) * (p.hnsw_m as f64) * log_n + k
301    }
302
303    /// IVF: `n_clusters + (n / max(nprobe, 1)) * (nprobe / n_clusters)`
304    /// = `n_clusters + n / n_clusters * nprobe` …simplifies to
305    /// `n_clusters + n * (n_probes / n_clusters)`.
306    fn estimate_ivf(&self, n: f64) -> f64 {
307        let p = &self.parameters;
308        let n_clusters = p.ivf_n_clusters.max(1) as f64;
309        let n_probes = p.ivf_n_probes.max(1) as f64;
310        n_clusters + n * (n_probes / n_clusters)
311    }
312
313    /// LSH: `K * L * dim + L * avg_bucket_size`
314    fn estimate_lsh(&self, dim: f64) -> f64 {
315        let p = &self.parameters;
316        let l = p.lsh_tables.max(1) as f64;
317        let kk = p.lsh_hash_functions.max(1) as f64;
318        let bucket = p.lsh_avg_bucket_size.max(1) as f64;
319        kk * l * dim + l * bucket
320    }
321
322    /// PQ: `centroids * subquantizers + n * subquantizers / 8`
323    /// (codebook precompute, then table lookup per coded vector).
324    fn estimate_pq(&self, n: f64) -> f64 {
325        let p = &self.parameters;
326        let cents = p.pq_centroids.max(1) as f64;
327        let subs = p.pq_subquantizers.max(1) as f64;
328        cents * subs + n * subs / 8.0
329    }
330
331    /// Estimate recall for a family at current parameters.
332    fn estimate_recall(&self, family: IndexFamily, workload: &WorkloadProfile) -> f32 {
333        let floor = expected_recall_floor(family);
334        // Tighter beam widths/probes lift recall closer to 1.0.
335        let lift = match family {
336            IndexFamily::Hnsw => {
337                let ef = self.parameters.hnsw_ef as f32;
338                ((ef - 32.0) / 200.0).clamp(0.0, 0.04)
339            }
340            IndexFamily::Ivf => {
341                let probes = self.parameters.ivf_n_probes as f32;
342                ((probes - 4.0) / 64.0).clamp(0.0, 0.08)
343            }
344            IndexFamily::Lsh => {
345                let l = self.parameters.lsh_tables as f32;
346                ((l - 4.0) / 64.0).clamp(0.0, 0.10)
347            }
348            IndexFamily::Pq => {
349                let cents = self.parameters.pq_centroids as f32;
350                ((cents - 64.0) / 1024.0).clamp(0.0, 0.06)
351            }
352        };
353        // Higher dimensionality slightly degrades approximate recall.
354        let dim_penalty = if workload.dim > 512 {
355            ((workload.dim as f32 - 512.0) / 4096.0).min(0.05)
356        } else {
357            0.0
358        };
359        (floor + lift - dim_penalty).clamp(0.0, 1.0)
360    }
361}
362
363/// One row of cost-model output: family + estimated cost + estimated recall.
364#[derive(Debug, Clone, PartialEq)]
365pub struct CostEstimate {
366    /// Index family being estimated.
367    pub family: IndexFamily,
368    /// Abstract cost units (lower is better).
369    pub cost: f64,
370    /// Estimated recall in `[0.0, 1.0]`.
371    pub recall: f32,
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    fn workload(n: usize, dim: usize, recall: f32) -> WorkloadProfile {
379        WorkloadProfile::new(n, dim, recall)
380    }
381
382    #[test]
383    fn index_family_all_returns_four_distinct() {
384        let all = IndexFamily::all();
385        assert_eq!(all.len(), 4);
386        let strs: Vec<_> = all.iter().map(|f| f.as_str()).collect();
387        assert_eq!(strs, vec!["hnsw", "ivf", "lsh", "pq"]);
388    }
389
390    #[test]
391    fn cost_weights_default_is_unit() {
392        let w = CostWeights::default();
393        for f in IndexFamily::all() {
394            assert!((w.get(f) - 1.0).abs() < 1e-12);
395        }
396    }
397
398    #[test]
399    fn cost_weights_set_clamps_outliers() {
400        let mut w = CostWeights::default();
401        w.set(IndexFamily::Hnsw, 1000.0);
402        assert!((w.get(IndexFamily::Hnsw) - 20.0).abs() < 1e-12);
403        w.set(IndexFamily::Pq, 0.0);
404        assert!((w.get(IndexFamily::Pq) - 0.05).abs() < 1e-12);
405    }
406
407    #[test]
408    fn hnsw_cost_grows_with_log_n() {
409        let cm = CostModel::default();
410        let small = cm.estimate(IndexFamily::Hnsw, &workload(1_000, 128, 0.9));
411        let large = cm.estimate(IndexFamily::Hnsw, &workload(1_000_000, 128, 0.9));
412        assert!(
413            large.cost > small.cost,
414            "HNSW cost must grow with data size"
415        );
416        // Log scaling: 1000x more data should be only ~2x more cost.
417        assert!(large.cost < small.cost * 4.0);
418    }
419
420    #[test]
421    fn ivf_cost_grows_with_n() {
422        let cm = CostModel::default();
423        let small = cm.estimate(IndexFamily::Ivf, &workload(10_000, 128, 0.9));
424        let large = cm.estimate(IndexFamily::Ivf, &workload(1_000_000, 128, 0.9));
425        // IVF is roughly linear (n_probes / n_clusters fraction of n).
426        assert!(large.cost > small.cost);
427        assert!(large.cost > small.cost * 10.0);
428    }
429
430    #[test]
431    fn lsh_cost_independent_of_n() {
432        let cm = CostModel::default();
433        let small = cm.estimate(IndexFamily::Lsh, &workload(1_000, 128, 0.8));
434        let large = cm.estimate(IndexFamily::Lsh, &workload(1_000_000, 128, 0.8));
435        // LSH has bucket scan, but average bucket size is parameterised
436        // independently from n in our cost model.
437        assert!((large.cost - small.cost).abs() < 1e-9);
438    }
439
440    #[test]
441    fn pq_cost_grows_with_n() {
442        let cm = CostModel::default();
443        let small = cm.estimate(IndexFamily::Pq, &workload(1_000, 128, 0.9));
444        let large = cm.estimate(IndexFamily::Pq, &workload(100_000, 128, 0.9));
445        assert!(large.cost > small.cost);
446    }
447
448    #[test]
449    fn weights_scale_cost_linearly() {
450        let mut cm = CostModel::default();
451        let baseline = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
452        cm.weights_mut().set(IndexFamily::Hnsw, 2.0);
453        let scaled = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
454        assert!((scaled.cost - 2.0 * baseline.cost).abs() < 1e-6);
455    }
456
457    #[test]
458    fn recall_floors_match_expectations() {
459        assert!((CostModel::recall_floor(IndexFamily::Hnsw) - 0.95).abs() < 1e-6);
460        assert!((CostModel::recall_floor(IndexFamily::Pq) - 0.88).abs() < 1e-6);
461        assert!(
462            CostModel::recall_floor(IndexFamily::Lsh) < CostModel::recall_floor(IndexFamily::Hnsw)
463        );
464    }
465
466    #[test]
467    fn high_dim_penalises_recall_estimate() {
468        let cm = CostModel::default();
469        let low_dim = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 128, 0.9));
470        let high_dim = cm.estimate(IndexFamily::Hnsw, &workload(10_000, 4096, 0.9));
471        assert!(high_dim.recall < low_dim.recall);
472    }
473
474    #[test]
475    fn density_biases_toward_filterable_indices() {
476        let cm = CostModel::default();
477        let unfiltered = cm.estimate(
478            IndexFamily::Hnsw,
479            &workload(10_000, 128, 0.9).with_query_density(1.0),
480        );
481        let very_selective = cm.estimate(
482            IndexFamily::Hnsw,
483            &workload(10_000, 128, 0.9).with_query_density(0.05),
484        );
485        // HNSW gets more expensive when the query is very selective
486        // because it cannot exploit the filter.
487        assert!(very_selective.cost > unfiltered.cost);
488    }
489
490    #[test]
491    fn density_helps_lsh_and_pq() {
492        let cm = CostModel::default();
493        let unfiltered = cm.estimate(
494            IndexFamily::Lsh,
495            &workload(10_000, 128, 0.8).with_query_density(1.0),
496        );
497        let selective = cm.estimate(
498            IndexFamily::Lsh,
499            &workload(10_000, 128, 0.8).with_query_density(0.5),
500        );
501        assert!(selective.cost <= unfiltered.cost);
502    }
503}