Skip to main content

nodedb_vector/rerank/
gating.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Option-combination gating for `rerank`.
4//!
5//! Validates a [`VectorAnnOptions`] request against the index shape and returns
6//! the [`CodecName`] the search should use (or `None` for FP32-only), surfacing
7//! unsupported combinations as [`RerankError::BadInput`] with precise messages.
8
9use nodedb_types::vector_ann::{VectorAnnOptions, VectorQuantization};
10
11use super::codec::CodecName;
12use super::types::RerankError;
13
14/// Shape of the underlying vector index, used by [`validate_options`] to decide
15/// which options are coherent.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum IndexShape {
18    SingleVector,
19    MultiVector,
20}
21
22/// Validate the option combo against the index shape and the collection's
23/// configured quantization. Returns the [`CodecName`] the search should use
24/// for rerank (`None` when the request is FP32-only), or
25/// [`RerankError::BadInput`] when the combination is invalid.
26///
27/// # Quantization contract
28///
29/// The codec is fixed at collection-creation time. `collection_quant` is the
30/// quantization that was declared via DDL; `opts.quantization` is the optional
31/// search-time override.
32///
33/// - If `opts.quantization` is `None`: honor `collection_quant` — map it to a
34///   `CodecName` via `codec_name_for_quant`. Ternary / OPQ collection configs
35///   still surface as `BadInput` because they have no HNSW-integration path.
36/// - If `opts.quantization` is `Some(q)` and `q == collection_quant`: proceed
37///   (same as old behavior — the caller is being explicit about what the index
38///   already uses).
39/// - If `opts.quantization` is `Some(q)` and `q != collection_quant`: return
40///   `RerankError::BadInput` naming both the requested codec and the collection's
41///   configured codec. Silent fallback is never allowed.
42/// - `Some(VectorQuantization::None)` against any non-`None` `collection_quant`
43///   is also a contradiction and returns `BadInput`.
44pub fn validate_options(
45    opts: &VectorAnnOptions,
46    index_shape: IndexShape,
47    collection_quant: VectorQuantization,
48) -> Result<Option<CodecName>, RerankError> {
49    validate_meta_token_budget(opts, index_shape)?;
50    validate_quantization_with_collection(opts, collection_quant)
51}
52
53fn validate_meta_token_budget(
54    opts: &VectorAnnOptions,
55    index_shape: IndexShape,
56) -> Result<(), RerankError> {
57    if opts.meta_token_budget.is_none() {
58        return Ok(());
59    }
60    match index_shape {
61        IndexShape::SingleVector => Err(RerankError::BadInput(
62            "meta_token_budget requires a multi-vector (MetaEmbed) index; \
63             the target collection is single-vector. \
64             Multi-vector indexes are not yet available in this deployment."
65                .to_owned(),
66        )),
67        IndexShape::MultiVector => Err(RerankError::BadInput(
68            "meta_token_budget routing not yet implemented; \
69             multi-vector indexes exist but PLAID/MaxSim dispatch is not wired."
70                .to_owned(),
71        )),
72    }
73}
74
75/// Map a `VectorQuantization` variant to its `CodecName`, returning `None`
76/// for variants that have no codec path (i.e. `None` / `VectorQuantization::None`).
77/// Variants that are not yet routable (Ternary, Opq, unknown) return `None`
78/// because their error is surfaced by `validate_quantization` — callers that
79/// need error surfacing should use `validate_options` instead.
80pub(crate) fn codec_name_for_quant(q: VectorQuantization) -> Option<CodecName> {
81    match q {
82        VectorQuantization::None => None,
83        VectorQuantization::Sq8 => Some(CodecName::Sq8),
84        VectorQuantization::Pq => Some(CodecName::Pq),
85        VectorQuantization::Binary => Some(CodecName::Binary),
86        VectorQuantization::RaBitQ => Some(CodecName::RaBitQ),
87        VectorQuantization::Bbq => Some(CodecName::Bbq),
88        // Not yet routable — validate_options surfaces a precise error.
89        _ => None,
90    }
91}
92
93/// Validate the search-time quantization against the collection's configured
94/// codec, and surface a precise `BadInput` on any mismatch. No silent fallback.
95fn validate_quantization_with_collection(
96    opts: &VectorAnnOptions,
97    collection_quant: VectorQuantization,
98) -> Result<Option<CodecName>, RerankError> {
99    match opts.quantization {
100        // Caller did not specify a codec: honor whatever the collection was
101        // built with. Ternary / OPQ are still unroutable even at this level.
102        None => map_collection_quant(collection_quant),
103
104        // Caller explicitly requested "no quantization" (FP32 path).
105        Some(VectorQuantization::None) => {
106            if collection_quant != VectorQuantization::None {
107                return Err(RerankError::BadInput(format!(
108                    "search-time quantization 'None' does not match collection's configured \
109                     quantization '{collection_quant:?}'; the codec is fixed at \
110                     collection-creation time"
111                )));
112            }
113            Ok(None)
114        }
115
116        // Caller specified a concrete codec.
117        Some(requested) => {
118            if requested != collection_quant {
119                return Err(RerankError::BadInput(format!(
120                    "search-time quantization '{requested:?}' does not match collection's \
121                     configured quantization '{collection_quant:?}'; the codec is fixed at \
122                     collection-creation time"
123                )));
124            }
125            // The requested codec matches the collection config — validate it is routable.
126            map_collection_quant(collection_quant)
127        }
128    }
129}
130
131/// Map a `VectorQuantization` that matches (or defaults from) the collection's
132/// config to a `CodecName`. Returns `BadInput` for unroutable variants.
133fn map_collection_quant(q: VectorQuantization) -> Result<Option<CodecName>, RerankError> {
134    match q {
135        VectorQuantization::None => Ok(None),
136        VectorQuantization::Sq8 => Ok(Some(CodecName::Sq8)),
137        VectorQuantization::Pq => Ok(Some(CodecName::Pq)),
138        VectorQuantization::Binary => Ok(Some(CodecName::Binary)),
139        VectorQuantization::RaBitQ => Ok(Some(CodecName::RaBitQ)),
140        VectorQuantization::Bbq => Ok(Some(CodecName::Bbq)),
141        VectorQuantization::Ternary => Err(RerankError::BadInput(
142            "quantization=ternary: codec exists in nodedb-codec but has no HNSW-integration \
143             path in nodedb-vector; cannot serve a search request with ternary quantization \
144             until the index-side wiring lands."
145                .to_owned(),
146        )),
147        VectorQuantization::Opq => Err(RerankError::BadInput(
148            "quantization=opq: codec exists in nodedb-codec but has no HNSW-integration \
149             path in nodedb-vector; cannot serve a search request with opq quantization \
150             until the index-side wiring lands."
151                .to_owned(),
152        )),
153        // Safety net for future non_exhaustive variants added to VectorQuantization
154        // before nodedb-vector is updated. Treat as unroutable until wired.
155        _ => Err(RerankError::BadInput(
156            "quantization variant is not yet routable in nodedb-vector; \
157             update gating.rs when the HNSW-integration path lands."
158                .to_owned(),
159        )),
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn opts_with_quant(q: Option<VectorQuantization>) -> VectorAnnOptions {
168        VectorAnnOptions {
169            quantization: q,
170            ..Default::default()
171        }
172    }
173
174    fn opts_with_budget(budget: u8, q: Option<VectorQuantization>) -> VectorAnnOptions {
175        VectorAnnOptions {
176            meta_token_budget: Some(budget),
177            quantization: q,
178            ..Default::default()
179        }
180    }
181
182    // ── Existing tests updated to pass collection_quant ───────────────────────
183
184    #[test]
185    fn none_quantization_returns_none() {
186        let result = validate_options(
187            &opts_with_quant(None),
188            IndexShape::SingleVector,
189            VectorQuantization::None,
190        );
191        assert_eq!(result.unwrap(), None);
192    }
193
194    #[test]
195    fn explicit_none_quantization_returns_none() {
196        let result = validate_options(
197            &opts_with_quant(Some(VectorQuantization::None)),
198            IndexShape::SingleVector,
199            VectorQuantization::None,
200        );
201        assert_eq!(result.unwrap(), None);
202    }
203
204    #[test]
205    fn sq8_returns_codec() {
206        let result = validate_options(
207            &opts_with_quant(Some(VectorQuantization::Sq8)),
208            IndexShape::SingleVector,
209            VectorQuantization::Sq8,
210        );
211        assert_eq!(result.unwrap(), Some(CodecName::Sq8));
212    }
213
214    #[test]
215    fn pq_returns_codec() {
216        let result = validate_options(
217            &opts_with_quant(Some(VectorQuantization::Pq)),
218            IndexShape::SingleVector,
219            VectorQuantization::Pq,
220        );
221        assert_eq!(result.unwrap(), Some(CodecName::Pq));
222    }
223
224    #[test]
225    fn binary_returns_codec() {
226        let result = validate_options(
227            &opts_with_quant(Some(VectorQuantization::Binary)),
228            IndexShape::SingleVector,
229            VectorQuantization::Binary,
230        );
231        assert_eq!(result.unwrap(), Some(CodecName::Binary));
232    }
233
234    #[test]
235    fn rabitq_returns_codec() {
236        let result = validate_options(
237            &opts_with_quant(Some(VectorQuantization::RaBitQ)),
238            IndexShape::SingleVector,
239            VectorQuantization::RaBitQ,
240        );
241        assert_eq!(result.unwrap(), Some(CodecName::RaBitQ));
242    }
243
244    #[test]
245    fn bbq_returns_codec() {
246        let result = validate_options(
247            &opts_with_quant(Some(VectorQuantization::Bbq)),
248            IndexShape::SingleVector,
249            VectorQuantization::Bbq,
250        );
251        assert_eq!(result.unwrap(), Some(CodecName::Bbq));
252    }
253
254    #[test]
255    fn ternary_returns_bad_input() {
256        let err = validate_options(
257            &opts_with_quant(Some(VectorQuantization::Ternary)),
258            IndexShape::SingleVector,
259            VectorQuantization::Ternary,
260        )
261        .unwrap_err();
262        let msg = err.to_string();
263        assert!(msg.contains("ternary"), "expected 'ternary' in: {msg}");
264        assert!(matches!(err, RerankError::BadInput(_)));
265    }
266
267    #[test]
268    fn opq_returns_bad_input() {
269        let err = validate_options(
270            &opts_with_quant(Some(VectorQuantization::Opq)),
271            IndexShape::SingleVector,
272            VectorQuantization::Opq,
273        )
274        .unwrap_err();
275        let msg = err.to_string();
276        assert!(msg.contains("opq"), "expected 'opq' in: {msg}");
277        assert!(matches!(err, RerankError::BadInput(_)));
278    }
279
280    #[test]
281    fn meta_token_budget_single_vec_returns_bad_input() {
282        let err = validate_options(
283            &opts_with_budget(8, None),
284            IndexShape::SingleVector,
285            VectorQuantization::None,
286        )
287        .unwrap_err();
288        let msg = err.to_string();
289        assert!(
290            msg.contains("single-vector"),
291            "expected 'single-vector' in: {msg}"
292        );
293        assert!(matches!(err, RerankError::BadInput(_)));
294    }
295
296    #[test]
297    fn meta_token_budget_multi_vec_returns_bad_input() {
298        let err = validate_options(
299            &opts_with_budget(8, None),
300            IndexShape::MultiVector,
301            VectorQuantization::None,
302        )
303        .unwrap_err();
304        let msg = err.to_string();
305        assert!(
306            msg.contains("PLAID") || msg.contains("MaxSim"),
307            "expected 'PLAID' or 'MaxSim' in: {msg}"
308        );
309        assert!(matches!(err, RerankError::BadInput(_)));
310    }
311
312    #[test]
313    fn meta_token_budget_none_passes_with_sq8() {
314        let result = validate_options(
315            &opts_with_quant(Some(VectorQuantization::Sq8)),
316            IndexShape::SingleVector,
317            VectorQuantization::Sq8,
318        );
319        assert_eq!(result.unwrap(), Some(CodecName::Sq8));
320    }
321
322    // ── New mismatch / collection-default tests ───────────────────────────────
323
324    #[test]
325    fn quantization_mismatch_returns_bad_input() {
326        let opts = opts_with_quant(Some(VectorQuantization::Sq8));
327        let err =
328            validate_options(&opts, IndexShape::SingleVector, VectorQuantization::Pq).unwrap_err();
329        let msg = err.to_string();
330        assert!(
331            msg.contains("Sq8") && msg.contains("Pq"),
332            "message must name both requested and configured codec: {msg}"
333        );
334        assert!(matches!(err, RerankError::BadInput(_)));
335    }
336
337    #[test]
338    fn quantization_matches_collection_passes() {
339        let opts = opts_with_quant(Some(VectorQuantization::RaBitQ));
340        let result = validate_options(&opts, IndexShape::SingleVector, VectorQuantization::RaBitQ);
341        assert_eq!(result.unwrap(), Some(CodecName::RaBitQ));
342    }
343
344    #[test]
345    fn quantization_none_with_collection_codec_uses_collection_codec() {
346        // Caller didn't specify; collection was built with Sq8 → return Sq8.
347        let opts = opts_with_quant(None);
348        let result = validate_options(&opts, IndexShape::SingleVector, VectorQuantization::Sq8);
349        assert_eq!(result.unwrap(), Some(CodecName::Sq8));
350    }
351
352    #[test]
353    fn quantization_none_with_collection_none_returns_none() {
354        // Both unset → FP32-only path.
355        let opts = opts_with_quant(None);
356        let result = validate_options(&opts, IndexShape::SingleVector, VectorQuantization::None);
357        assert_eq!(result.unwrap(), None);
358    }
359
360    #[test]
361    fn explicit_none_against_sq8_collection_returns_bad_input() {
362        // Requesting "no codec" against a collection configured with Sq8 is contradictory.
363        let opts = opts_with_quant(Some(VectorQuantization::None));
364        let err =
365            validate_options(&opts, IndexShape::SingleVector, VectorQuantization::Sq8).unwrap_err();
366        let msg = err.to_string();
367        assert!(
368            msg.contains("None") && msg.contains("Sq8"),
369            "message must name both requested 'None' and collection's 'Sq8': {msg}"
370        );
371        assert!(matches!(err, RerankError::BadInput(_)));
372    }
373}