1use super::query_options::{IndexType, QuantizationKind};
12
13#[derive(Debug, Clone, Copy)]
15pub struct VectorCost {
16 pub graph_traversal_us: f32,
18
19 pub codec_decode_ns: f32,
21
22 pub rerank_us: f32,
27
28 pub filter_selectivity: f32,
30
31 pub candidate_set_size: usize,
33
34 pub predicted_recall: f32,
37}
38
39pub struct CostModelInputs {
41 pub n_vectors: usize,
43
44 pub dim: usize,
46
47 pub index_type: IndexType,
49
50 pub quantization: QuantizationKind,
52
53 pub ef_search: usize,
55
56 pub global_selectivity: f32,
59
60 pub candidate_set_size: usize,
63}
64
65fn codec_decode_ns_for(quantization: QuantizationKind) -> f32 {
69 match quantization {
70 QuantizationKind::None => 16.0,
71 QuantizationKind::Sq8 => 4.0,
72 QuantizationKind::Pq | QuantizationKind::Opq => 0.5,
73 QuantizationKind::RaBitQ | QuantizationKind::Bbq | QuantizationKind::Binary => 0.3,
74 QuantizationKind::Ternary => 0.4,
75 }
76}
77
78fn predicted_recall_for(index_type: IndexType, quantization: QuantizationKind) -> f32 {
80 match (index_type, quantization) {
81 (IndexType::Hnsw, QuantizationKind::None) => 0.99,
82 (IndexType::Hnsw, QuantizationKind::Sq8) => 0.97,
83 (IndexType::Hnsw, QuantizationKind::Pq) | (IndexType::Hnsw, QuantizationKind::Opq) => 0.92,
84 (IndexType::Hnsw, QuantizationKind::RaBitQ) | (IndexType::Hnsw, QuantizationKind::Bbq) => {
85 0.95
86 }
87 (IndexType::Hnsw, QuantizationKind::Binary) => 0.85,
88 (IndexType::Hnsw, QuantizationKind::Ternary) => 0.90,
89 (IndexType::Vamana, QuantizationKind::RaBitQ) => 0.96,
91 (IndexType::Vamana, QuantizationKind::Bbq) => 0.96,
92 (IndexType::Vamana, QuantizationKind::None) => 0.99,
93 (IndexType::Vamana, QuantizationKind::Sq8) => 0.97,
94 (IndexType::Vamana, _) => 0.92,
95 }
96}
97
98pub fn estimate_cost(inputs: &CostModelInputs) -> VectorCost {
117 let n = inputs.n_vectors.max(1) as f32;
118
119 let graph_traversal_us = match inputs.index_type {
121 IndexType::Hnsw => n.log2() * inputs.ef_search as f32 * 0.001,
122 IndexType::Vamana => inputs.ef_search as f32 * 0.002,
123 };
124
125 let codec_decode_ns = codec_decode_ns_for(inputs.quantization);
127
128 const DEFAULT_OVERSAMPLE: u8 = 3;
130 let rerank_us = if inputs.quantization == QuantizationKind::None {
131 0.0
132 } else {
133 DEFAULT_OVERSAMPLE as f32 * inputs.ef_search as f32 * 0.01
134 };
135
136 let predicted_recall = predicted_recall_for(inputs.index_type, inputs.quantization);
137
138 VectorCost {
139 graph_traversal_us,
140 codec_decode_ns,
141 rerank_us,
142 filter_selectivity: inputs.global_selectivity,
143 candidate_set_size: inputs.candidate_set_size,
144 predicted_recall,
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 fn base_inputs() -> CostModelInputs {
153 CostModelInputs {
154 n_vectors: 1_000_000,
155 dim: 768,
156 index_type: IndexType::Hnsw,
157 quantization: QuantizationKind::Sq8,
158 ef_search: 64,
159 global_selectivity: 1.0,
160 candidate_set_size: 1_000_000,
161 }
162 }
163
164 #[test]
165 fn estimate_returns_finite_positive_values() {
166 let cost = estimate_cost(&base_inputs());
167 assert!(cost.graph_traversal_us.is_finite() && cost.graph_traversal_us > 0.0);
168 assert!(cost.codec_decode_ns.is_finite() && cost.codec_decode_ns > 0.0);
169 assert!(cost.rerank_us.is_finite() && cost.rerank_us >= 0.0);
170 assert!(cost.predicted_recall > 0.0 && cost.predicted_recall <= 1.0);
171 assert!(cost.filter_selectivity >= 0.0 && cost.filter_selectivity <= 1.0);
172 }
173
174 #[test]
175 fn hnsw_graph_traversal_scales_with_log_n() {
176 let small = estimate_cost(&CostModelInputs {
178 n_vectors: 1_000,
179 ..base_inputs()
180 });
181 let large = estimate_cost(&CostModelInputs {
182 n_vectors: 1_000_000,
183 ..base_inputs()
184 });
185 assert!(
186 large.graph_traversal_us > small.graph_traversal_us,
187 "HNSW traversal should grow with n_vectors"
188 );
189 }
190
191 #[test]
192 fn vamana_traversal_is_flat_across_n() {
193 let small = estimate_cost(&CostModelInputs {
195 n_vectors: 1_000,
196 index_type: IndexType::Vamana,
197 quantization: QuantizationKind::RaBitQ,
198 ..base_inputs()
199 });
200 let large = estimate_cost(&CostModelInputs {
201 n_vectors: 1_000_000_000,
202 index_type: IndexType::Vamana,
203 quantization: QuantizationKind::RaBitQ,
204 ..base_inputs()
205 });
206 assert!(
207 (small.graph_traversal_us - large.graph_traversal_us).abs() < f32::EPSILON,
208 "Vamana traversal cost must not depend on n_vectors"
209 );
210 }
211
212 #[test]
213 fn sq8_decode_more_expensive_than_rabitq() {
214 let sq8 = estimate_cost(&base_inputs());
215 let rabitq = estimate_cost(&CostModelInputs {
216 quantization: QuantizationKind::RaBitQ,
217 ..base_inputs()
218 });
219 assert!(
220 sq8.codec_decode_ns > rabitq.codec_decode_ns,
221 "SQ8 decode ({}) should cost more ns/vector than RaBitQ ({})",
222 sq8.codec_decode_ns,
223 rabitq.codec_decode_ns
224 );
225 }
226
227 #[test]
228 fn none_quantization_has_zero_rerank() {
229 let cost = estimate_cost(&CostModelInputs {
230 quantization: QuantizationKind::None,
231 ..base_inputs()
232 });
233 assert_eq!(cost.rerank_us, 0.0, "FP32 traversal needs no rerank");
234 }
235
236 #[test]
237 fn compressed_quantization_has_nonzero_rerank() {
238 let cost = estimate_cost(&CostModelInputs {
239 quantization: QuantizationKind::Bbq,
240 ..base_inputs()
241 });
242 assert!(
243 cost.rerank_us > 0.0,
244 "BBQ traversal should incur rerank cost"
245 );
246 }
247
248 #[test]
249 fn predicted_recall_within_bounds() {
250 for quant in [
251 QuantizationKind::None,
252 QuantizationKind::Sq8,
253 QuantizationKind::Pq,
254 QuantizationKind::RaBitQ,
255 QuantizationKind::Bbq,
256 QuantizationKind::Binary,
257 QuantizationKind::Ternary,
258 ] {
259 for idx in [IndexType::Hnsw, IndexType::Vamana] {
260 let cost = estimate_cost(&CostModelInputs {
261 index_type: idx,
262 quantization: quant,
263 ..base_inputs()
264 });
265 assert!(
266 cost.predicted_recall >= 0.0 && cost.predicted_recall <= 1.0,
267 "recall out of bounds for {idx:?} + {quant:?}: {}",
268 cost.predicted_recall
269 );
270 }
271 }
272 }
273
274 #[test]
275 fn candidate_set_size_and_selectivity_passthrough() {
276 let inputs = CostModelInputs {
277 global_selectivity: 0.05,
278 candidate_set_size: 50_000,
279 ..base_inputs()
280 };
281 let cost = estimate_cost(&inputs);
282 assert!((cost.filter_selectivity - 0.05).abs() < f32::EPSILON);
283 assert_eq!(cost.candidate_set_size, 50_000);
284 }
285}