Skip to main content

nodedb_vector/rerank/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use nodedb_types::vector_ann::VectorAnnOptions;
4use nodedb_types::vector_distance::DistanceMetric;
5
6use super::gating::codec_name_for_quant;
7use super::sidecar::CodecSidecar;
8use super::types::{Candidate, Ranked, RerankError};
9
10/// Shared rerank pipeline. Both Origin and Lite call this after their index-level coarse search.
11///
12/// Callers use `opts.oversample` to compute `fetch_k` before pre-fetching from HNSW;
13/// this function receives whatever candidates were fetched and reranks by exact distance.
14///
15/// When `opts.quantization` is `None` (or `VectorQuantization::None`), the FP32 path is used:
16/// `fetch_vector` is called once per candidate and must return the stored full-precision vector.
17/// Returning `None` for any id is a hard inconsistency error.
18///
19/// When `opts.quantization` is `Some(_)`, a `CodecSidecar` must be provided. The sidecar
20/// encodes query and stored vectors; `fetch_vector` is not called in this path.
21///
22/// When `opts.query_dim = Some(d)`, the FP32 path applies Matryoshka truncated-distance
23/// reranking using only the first `d` components. `d` must satisfy `0 < d <= query.len()`.
24/// `query_dim` combined with `quantization` is not supported — return `BadInput` if both set.
25///
26/// `target_recall`, `oversample`, and `meta_token_budget` are accepted via `opts` but not
27/// honored here — callers handle those before calling this function.
28pub fn rerank<'v, F>(
29    candidates: Vec<Candidate>,
30    query: &[f32],
31    metric: DistanceMetric,
32    k: usize,
33    opts: &VectorAnnOptions,
34    sidecar: Option<&CodecSidecar>,
35    mut fetch_vector: F,
36) -> Result<Vec<Ranked>, RerankError>
37where
38    F: FnMut(u32) -> Option<&'v [f32]>,
39{
40    if k == 0 {
41        return Err(RerankError::BadInput("k must be > 0".into()));
42    }
43    if query.is_empty() {
44        return Err(RerankError::BadInput("query is empty".into()));
45    }
46
47    // Determine requested codec (if any) from opts.
48    let requested_codec = opts.quantization.and_then(codec_name_for_quant);
49
50    // Part C: query_dim + quantization combination is not supported.
51    if opts.query_dim.is_some() && requested_codec.is_some() {
52        return Err(RerankError::BadInput(
53            "rerank: query_dim (Matryoshka truncation) is not yet supported in combination \
54             with quantization codecs — use one or the other"
55                .into(),
56        ));
57    }
58
59    if candidates.is_empty() {
60        return Ok(Vec::new());
61    }
62
63    // Codec path.
64    if let Some(requested) = requested_codec {
65        let sc = sidecar.ok_or_else(|| {
66            RerankError::BadInput(
67                "rerank: opts.quantization requested but no codec sidecar provided".into(),
68            )
69        })?;
70
71        let actual = sc.codec_name();
72        if actual != requested {
73            return Err(RerankError::BadInput(format!(
74                "rerank: requested codec {requested:?} does not match sidecar codec {actual:?}"
75            )));
76        }
77
78        let prepared = sc.prepare_query(query)?;
79
80        let mut scored: Vec<Ranked> = Vec::with_capacity(candidates.len());
81        for c in candidates {
82            match sc.distance_prepared(&prepared, c.id)? {
83                None => {
84                    return Err(RerankError::BadInput(format!(
85                        "rerank: candidate id {} not present in sidecar (index/sidecar drift)",
86                        c.id
87                    )));
88                }
89                Some(d) => {
90                    scored.push(Ranked {
91                        id: c.id,
92                        distance: d,
93                    });
94                }
95            }
96        }
97
98        scored.sort_unstable_by(|a, b| {
99            a.distance
100                .partial_cmp(&b.distance)
101                .unwrap_or(std::cmp::Ordering::Equal)
102        });
103        scored.truncate(k);
104
105        // Suppress unused-closure warning — fetch_vector is not used in codec path.
106        let _ = &mut fetch_vector;
107        return Ok(scored);
108    }
109
110    // FP32 path: validate query_dim before touching candidates.
111    let effective_dim: usize = match opts.query_dim {
112        Some(d) => {
113            let d = d as usize;
114            if d == 0 || d > query.len() {
115                return Err(RerankError::BadInput(format!(
116                    "query_dim={d} is out of range; query has {} dimensions \
117                     (must be 0 < query_dim <= query.len())",
118                    query.len(),
119                )));
120            }
121            d
122        }
123        None => query.len(),
124    };
125
126    // Truncate query once; candidates are sliced inline using the same length.
127    let query_slice = crate::matryoshka::truncate(query, effective_dim);
128
129    let mut scored: Vec<Ranked> = Vec::with_capacity(candidates.len());
130    let query_dim = query.len();
131
132    for c in candidates {
133        let vec = fetch_vector(c.id).ok_or_else(|| {
134            RerankError::BadInput(format!(
135                "rerank: fetch_vector returned None for id {}",
136                c.id
137            ))
138        })?;
139        if vec.len() != query_dim {
140            return Err(RerankError::BadInput(format!(
141                "candidate id={} has dim {} but query has dim {}",
142                c.id,
143                vec.len(),
144                query_dim,
145            )));
146        }
147        let vec_slice = crate::matryoshka::truncate(vec, effective_dim);
148        let d = crate::distance::distance(query_slice, vec_slice, metric);
149        scored.push(Ranked {
150            id: c.id,
151            distance: d,
152        });
153    }
154
155    scored.sort_unstable_by(|a, b| {
156        a.distance
157            .partial_cmp(&b.distance)
158            .unwrap_or(std::cmp::Ordering::Equal)
159    });
160    scored.truncate(k);
161
162    Ok(scored)
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use std::collections::HashMap;
169    use std::sync::Arc;
170
171    use nodedb_types::vector_ann::{VectorAnnOptions, VectorQuantization};
172
173    use crate::rerank::codec::{CodecName, PreparedQuery, RerankCodec};
174    use crate::rerank::sidecar::CodecSidecar;
175    use crate::rerank::types::RerankError;
176
177    // ── helpers ──────────────────────────────────────────────────────────────
178
179    fn opts() -> VectorAnnOptions {
180        VectorAnnOptions::default()
181    }
182
183    fn opts_with_dim(d: u32) -> VectorAnnOptions {
184        VectorAnnOptions {
185            query_dim: Some(d),
186            ..Default::default()
187        }
188    }
189
190    fn opts_with_quant(q: VectorQuantization) -> VectorAnnOptions {
191        VectorAnnOptions {
192            quantization: Some(q),
193            ..Default::default()
194        }
195    }
196
197    fn make(id: u32) -> Candidate {
198        Candidate {
199            id,
200            index_distance: 0.0,
201        }
202    }
203
204    fn store(pairs: &[(u32, Vec<f32>)]) -> HashMap<u32, Vec<f32>> {
205        pairs.iter().cloned().collect()
206    }
207
208    fn fetch<'a>(store: &'a HashMap<u32, Vec<f32>>) -> impl FnMut(u32) -> Option<&'a [f32]> {
209        move |id| store.get(&id).map(|v| v.as_slice())
210    }
211
212    /// Stub codec that encodes as raw LE f32 bytes and computes L2 distance.
213    /// Reports `CodecName::Binary` so tests can request it via `VectorQuantization::Binary`.
214    struct StubCodec {
215        name: CodecName,
216    }
217
218    impl RerankCodec for StubCodec {
219        fn encode(&self, v: &[f32]) -> Result<Vec<u8>, RerankError> {
220            Ok(v.iter().flat_map(|x| x.to_le_bytes()).collect())
221        }
222
223        fn prepare_query(&self, q: &[f32]) -> Result<PreparedQuery, RerankError> {
224            Ok(PreparedQuery::Raw(q.to_vec()))
225        }
226
227        fn distance_prepared(
228            &self,
229            prepared: &PreparedQuery,
230            encoded: &[u8],
231        ) -> Result<f32, RerankError> {
232            let query = match prepared {
233                PreparedQuery::Raw(v) => v,
234                _ => {
235                    return Err(RerankError::BadInput(
236                        "StubCodec expects Raw prepared query".into(),
237                    ));
238                }
239            };
240            let floats: Vec<f32> = encoded
241                .chunks_exact(4)
242                .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
243                .collect();
244            if query.len() != floats.len() {
245                return Err(RerankError::BadInput("dimension mismatch".into()));
246            }
247            let d = query
248                .iter()
249                .zip(floats.iter())
250                .map(|(a, b)| (a - b) * (a - b))
251                .sum::<f32>()
252                .sqrt();
253            Ok(d)
254        }
255
256        fn name(&self) -> CodecName {
257            self.name
258        }
259
260        fn to_bytes(&self) -> Result<Vec<u8>, RerankError> {
261            Err(RerankError::BadInput(
262                "StubCodec does not support serialization".into(),
263            ))
264        }
265    }
266
267    fn make_sidecar(name: CodecName) -> CodecSidecar {
268        CodecSidecar::new(Arc::new(StubCodec { name }))
269    }
270
271    // ── existing FP32 tests (updated to pass None sidecar) ───────────────────
272
273    #[test]
274    fn happy_path_top2() {
275        let s = store(&[
276            (1, vec![1.0, 0.0]),
277            (2, vec![0.1, 0.0]),
278            (3, vec![0.5, 0.0]),
279            (4, vec![2.0, 0.0]),
280            (5, vec![0.3, 0.0]),
281        ]);
282        let candidates = vec![make(1), make(2), make(3), make(4), make(5)];
283        let query = [0.0, 0.0];
284        let result = rerank(
285            candidates,
286            &query,
287            DistanceMetric::L2,
288            2,
289            &opts(),
290            None,
291            fetch(&s),
292        )
293        .unwrap();
294        assert_eq!(result.len(), 2);
295        // closest: id=2 (0.01), then id=5 (0.09)
296        assert_eq!(result[0].id, 2);
297        assert_eq!(result[1].id, 5);
298    }
299
300    #[test]
301    fn empty_candidates_returns_empty() {
302        let s: HashMap<u32, Vec<f32>> = HashMap::new();
303        let result = rerank(
304            vec![],
305            &[1.0, 2.0],
306            DistanceMetric::L2,
307            3,
308            &opts(),
309            None,
310            fetch(&s),
311        )
312        .unwrap();
313        assert!(result.is_empty());
314    }
315
316    #[test]
317    fn dim_mismatch_returns_bad_input() {
318        let s = store(&[(7, vec![1.0, 2.0, 3.0])]);
319        let err = rerank(
320            vec![make(7)],
321            &[1.0, 2.0],
322            DistanceMetric::L2,
323            1,
324            &opts(),
325            None,
326            fetch(&s),
327        )
328        .unwrap_err();
329        let msg = err.to_string();
330        assert!(msg.contains("7"), "expected id in message: {msg}");
331        assert!(
332            msg.contains("3"),
333            "expected candidate dim in message: {msg}"
334        );
335        assert!(msg.contains("2"), "expected query dim in message: {msg}");
336    }
337
338    #[test]
339    fn k_zero_returns_bad_input() {
340        let s = store(&[(1, vec![1.0])]);
341        let err = rerank(
342            vec![make(1)],
343            &[0.0],
344            DistanceMetric::L2,
345            0,
346            &opts(),
347            None,
348            fetch(&s),
349        )
350        .unwrap_err();
351        assert!(err.to_string().contains("k must be > 0"));
352    }
353
354    #[test]
355    fn k_exceeds_candidates_returns_all() {
356        let s = store(&[
357            (1, vec![1.0, 0.0]),
358            (2, vec![2.0, 0.0]),
359            (3, vec![3.0, 0.0]),
360        ]);
361        let candidates = vec![make(1), make(2), make(3)];
362        let result = rerank(
363            candidates,
364            &[0.0, 0.0],
365            DistanceMetric::L2,
366            10,
367            &opts(),
368            None,
369            fetch(&s),
370        )
371        .unwrap();
372        assert_eq!(result.len(), 3);
373    }
374
375    // ── query_dim (Matryoshka truncated-distance reranking) ───────────────────
376
377    #[test]
378    fn query_dim_truncated_ranking_differs_from_full() {
379        let s = store(&[(1, vec![0.1, 0.1]), (2, vec![0.0, 9.0])]);
380        let query = [0.0_f32, 1.0];
381
382        let full = rerank(
383            vec![make(1), make(2)],
384            &query,
385            DistanceMetric::L2,
386            1,
387            &opts(),
388            None,
389            fetch(&s),
390        )
391        .unwrap();
392        assert_eq!(full[0].id, 1, "full-dim should rank id=1 first");
393
394        let trunc = rerank(
395            vec![make(1), make(2)],
396            &query,
397            DistanceMetric::L2,
398            1,
399            &opts_with_dim(1),
400            None,
401            fetch(&s),
402        )
403        .unwrap();
404        assert_eq!(trunc[0].id, 2, "truncated-dim=1 should rank id=2 first");
405    }
406
407    #[test]
408    fn query_dim_zero_returns_bad_input() {
409        let s = store(&[(1, vec![1.0, 2.0])]);
410        let err = rerank(
411            vec![make(1)],
412            &[0.0, 0.0],
413            DistanceMetric::L2,
414            1,
415            &opts_with_dim(0),
416            None,
417            fetch(&s),
418        )
419        .unwrap_err();
420        let msg = err.to_string();
421        assert!(
422            msg.contains("query_dim=0"),
423            "error should name query_dim=0: {msg}"
424        );
425    }
426
427    #[test]
428    fn query_dim_exceeds_query_len_returns_bad_input() {
429        let s = store(&[(1, vec![1.0, 2.0])]);
430        let err = rerank(
431            vec![make(1)],
432            &[0.0, 0.0],
433            DistanceMetric::L2,
434            1,
435            &opts_with_dim(5),
436            None,
437            fetch(&s),
438        )
439        .unwrap_err();
440        let msg = err.to_string();
441        assert!(
442            msg.contains("query_dim=5"),
443            "error should name query_dim=5: {msg}"
444        );
445        assert!(
446            msg.contains('2'),
447            "error should mention query len (2): {msg}"
448        );
449    }
450
451    #[test]
452    fn query_dim_equal_to_query_len_matches_full_dim() {
453        let s = store(&[
454            (1, vec![1.0, 0.0, 0.0]),
455            (2, vec![0.5, 0.0, 0.0]),
456            (3, vec![3.0, 0.0, 0.0]),
457        ]);
458        let query = [0.0_f32, 0.0, 0.0];
459
460        let full = rerank(
461            vec![make(1), make(2), make(3)],
462            &query,
463            DistanceMetric::L2,
464            3,
465            &opts(),
466            None,
467            fetch(&s),
468        )
469        .unwrap();
470        let trunc = rerank(
471            vec![make(1), make(2), make(3)],
472            &query,
473            DistanceMetric::L2,
474            3,
475            &opts_with_dim(3),
476            None,
477            fetch(&s),
478        )
479        .unwrap();
480
481        let full_ids: Vec<u32> = full.iter().map(|r| r.id).collect();
482        let trunc_ids: Vec<u32> = trunc.iter().map(|r| r.id).collect();
483        assert_eq!(
484            full_ids, trunc_ids,
485            "query_dim == query.len() should produce identical ranking"
486        );
487    }
488
489    #[test]
490    fn fetch_returns_none_is_bad_input() {
491        let s: HashMap<u32, Vec<f32>> = HashMap::new();
492        let err = rerank(
493            vec![make(99)],
494            &[0.0, 0.0],
495            DistanceMetric::L2,
496            1,
497            &opts(),
498            None,
499            fetch(&s),
500        )
501        .unwrap_err();
502        let msg = err.to_string();
503        assert!(
504            msg.contains("99"),
505            "error should name the missing id (99): {msg}"
506        );
507        assert!(
508            matches!(err, RerankError::BadInput(_)),
509            "expected BadInput, got: {err}"
510        );
511    }
512
513    // ── Part F: codec path tests ──────────────────────────────────────────────
514
515    /// Part F.1: codec path uses sidecar distances rather than FP32 fetch_vector.
516    #[test]
517    fn codec_path_uses_sidecar() {
518        // StubCodec uses Binary name, so request VectorQuantization::Binary.
519        let mut sc = make_sidecar(CodecName::Binary);
520        // Insert 3 vectors: distances from [0,0] are 1.0, 2.0, 3.0.
521        sc.encode_and_insert(1, &[1.0, 0.0]).unwrap();
522        sc.encode_and_insert(2, &[0.0, 2.0]).unwrap();
523        sc.encode_and_insert(3, &[3.0, 0.0]).unwrap();
524
525        let candidates = vec![make(1), make(2), make(3)];
526        let opts = opts_with_quant(VectorQuantization::Binary);
527
528        // fetch_vector should never be called in codec path — pass a closure
529        // that panics to confirm.
530        let result = rerank(
531            candidates,
532            &[0.0, 0.0],
533            DistanceMetric::L2,
534            3,
535            &opts,
536            Some(&sc),
537            |_id| panic!("fetch_vector must not be called in codec path"),
538        )
539        .unwrap();
540
541        assert_eq!(result.len(), 3);
542        // Distances: id=1 → 1.0, id=2 → 2.0, id=3 → 3.0
543        assert_eq!(result[0].id, 1);
544        assert_eq!(result[1].id, 2);
545        assert_eq!(result[2].id, 3);
546        assert!((result[0].distance - 1.0).abs() < 1e-5);
547    }
548
549    /// Part F.2: opts requests codec but sidecar is None → BadInput.
550    #[test]
551    fn codec_requested_but_no_sidecar_returns_bad_input() {
552        let s: HashMap<u32, Vec<f32>> = HashMap::new();
553        let opts = opts_with_quant(VectorQuantization::Binary);
554        let err = rerank(
555            vec![make(1)],
556            &[0.0, 0.0],
557            DistanceMetric::L2,
558            1,
559            &opts,
560            None,
561            fetch(&s),
562        )
563        .unwrap_err();
564        let msg = err.to_string();
565        assert!(
566            msg.contains("no codec sidecar provided"),
567            "expected sidecar-missing message: {msg}"
568        );
569        assert!(matches!(err, RerankError::BadInput(_)));
570    }
571
572    /// Part F.3: sidecar codec name (Binary) does not match requested (Sq8) → BadInput.
573    #[test]
574    fn codec_name_mismatch_returns_bad_input() {
575        let mut sc = make_sidecar(CodecName::Binary); // sidecar is Binary
576        sc.encode_and_insert(1, &[1.0, 0.0]).unwrap();
577
578        let opts = opts_with_quant(VectorQuantization::Sq8); // request Sq8
579        let s: HashMap<u32, Vec<f32>> = HashMap::new();
580        let err = rerank(
581            vec![make(1)],
582            &[0.0, 0.0],
583            DistanceMetric::L2,
584            1,
585            &opts,
586            Some(&sc),
587            fetch(&s),
588        )
589        .unwrap_err();
590        let msg = err.to_string();
591        assert!(
592            msg.contains("Sq8") || msg.contains("sq8"),
593            "expected requested codec in message: {msg}"
594        );
595        assert!(
596            msg.contains("Binary") || msg.contains("binary"),
597            "expected actual codec in message: {msg}"
598        );
599        assert!(matches!(err, RerankError::BadInput(_)));
600    }
601
602    /// Part F.4: both query_dim and quantization set → BadInput.
603    #[test]
604    fn codec_with_query_dim_returns_bad_input() {
605        let mut sc = make_sidecar(CodecName::Binary);
606        sc.encode_and_insert(1, &[1.0, 0.0]).unwrap();
607
608        let opts = VectorAnnOptions {
609            query_dim: Some(1),
610            quantization: Some(VectorQuantization::Binary),
611            ..Default::default()
612        };
613        let s: HashMap<u32, Vec<f32>> = HashMap::new();
614        let err = rerank(
615            vec![make(1)],
616            &[0.0, 0.0],
617            DistanceMetric::L2,
618            1,
619            &opts,
620            Some(&sc),
621            fetch(&s),
622        )
623        .unwrap_err();
624        let msg = err.to_string();
625        assert!(
626            msg.contains("query_dim") && msg.contains("quantization"),
627            "expected both terms in message: {msg}"
628        );
629        assert!(matches!(err, RerankError::BadInput(_)));
630    }
631
632    /// Part F.5: quantization=None with sidecar Some → FP32 path, sidecar ignored.
633    #[test]
634    fn fp32_path_with_some_sidecar_argument() {
635        let sc = make_sidecar(CodecName::Binary);
636        // FP32 store has real vectors.
637        let s = store(&[(1, vec![1.0, 0.0]), (2, vec![0.1, 0.0])]);
638        // opts has no quantization — FP32 path should run.
639        let result = rerank(
640            vec![make(1), make(2)],
641            &[0.0, 0.0],
642            DistanceMetric::L2,
643            2,
644            &opts(), // no quantization
645            Some(&sc),
646            fetch(&s),
647        )
648        .unwrap();
649        // FP32: id=2 is closer (dist=0.1) than id=1 (dist=1.0).
650        assert_eq!(result[0].id, 2);
651        assert_eq!(result[1].id, 1);
652    }
653
654    /// Part F.6: candidate id not in sidecar → BadInput (index/sidecar drift).
655    #[test]
656    fn codec_path_missing_id_in_sidecar_returns_bad_input() {
657        let sc = make_sidecar(CodecName::Binary);
658        // Sidecar is empty — id 99 is not present.
659        let opts = opts_with_quant(VectorQuantization::Binary);
660        let s: HashMap<u32, Vec<f32>> = HashMap::new();
661        let err = rerank(
662            vec![make(99)],
663            &[0.0, 0.0],
664            DistanceMetric::L2,
665            1,
666            &opts,
667            Some(&sc),
668            fetch(&s),
669        )
670        .unwrap_err();
671        let msg = err.to_string();
672        assert!(msg.contains("99"), "expected id 99 in message: {msg}");
673        assert!(
674            msg.contains("sidecar drift") || msg.contains("not present in sidecar"),
675            "expected drift message: {msg}"
676        );
677        assert!(matches!(err, RerankError::BadInput(_)));
678    }
679}