Skip to main content

nodedb_vector/collection/
search.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! VectorCollection search: multi-segment merging with SQ8 reranking.
4//!
5//! The `search_with_payload_filter` method wires payload bitmap pre-filtering
6//! into the search path. When all referenced fields in the predicate are
7//! indexed, the bitmap is built and passed to `search_with_bitmap_bytes`.
8//! When any field is un-indexed, the search falls back to the full unfiltered
9//! path and lets the caller apply post-filtering — the un-indexed predicate is
10//! never silently dropped.
11
12use crate::distance::{DistanceMetric, distance};
13use crate::error::VectorError;
14use crate::hnsw::SearchResult;
15
16use super::lifecycle::VectorCollection;
17use super::payload_index::FilterPredicate;
18use super::segment::SealedSegment;
19
20/// Score a single candidate via the SQ8 codec, using the metric-appropriate
21/// asymmetric distance.
22#[inline]
23fn sq8_score(
24    codec: &crate::quantize::sq8::Sq8Codec,
25    query: &[f32],
26    encoded: &[u8],
27    metric: DistanceMetric,
28) -> f32 {
29    match metric {
30        DistanceMetric::Cosine => codec.asymmetric_cosine(query, encoded),
31        DistanceMetric::InnerProduct => codec.asymmetric_ip(query, encoded),
32        // L2 (and all other metrics that don't have a specialized asymmetric
33        // form yet) fall back to squared L2 — correct for ordering when the
34        // metric is L2 and a reasonable proxy otherwise since we rerank with
35        // exact FP32 below.
36        _ => codec.asymmetric_l2(query, encoded),
37    }
38}
39
40/// Candidate-generation + rerank for a sealed segment that has a quantized
41/// codec attached. Generates a widened candidate pool via HNSW, re-scores
42/// candidates using the quantized codec (this is where SQ8/PQ actually pay
43/// off — the FP32 vectors need not be resident), and reranks the top
44/// `top_k` via exact FP32 distance from mmap or index storage.
45fn quantized_search(
46    seg: &SealedSegment,
47    query: &[f32],
48    top_k: usize,
49    ef: usize,
50    metric: DistanceMetric,
51) -> Result<Vec<SearchResult>, VectorError> {
52    let rerank_k = top_k.saturating_mul(3).max(20);
53    let hnsw_candidates = seg.index.search(query, rerank_k, ef);
54
55    // Phase 1: rank candidates by quantized distance.
56    let mut scored: Vec<(u32, f32)> = if let Some((codec, codes)) = &seg.pq {
57        let table = codec.build_distance_table(query)?;
58        let m = codec.m;
59        hnsw_candidates
60            .into_iter()
61            .filter_map(|r| {
62                let start = (r.id as usize).checked_mul(m)?;
63                let end = start.checked_add(m)?;
64                let slice = codes.get(start..end)?;
65                Some((r.id, codec.asymmetric_distance(&table, slice)))
66            })
67            .collect()
68    } else if let Some((codec, data)) = &seg.sq8 {
69        let dim = codec.dim();
70        hnsw_candidates
71            .into_iter()
72            .filter_map(|r| {
73                let start = (r.id as usize).checked_mul(dim)?;
74                let end = start.checked_add(dim)?;
75                let slice = data.get(start..end)?;
76                Some((r.id, sq8_score(codec, query, slice, metric)))
77            })
78            .collect()
79    } else {
80        hnsw_candidates
81            .into_iter()
82            .map(|r| (r.id, r.distance))
83            .collect()
84    };
85    scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
86
87    // Keep only the most promising candidates for FP32 rerank.
88    let keep = rerank_k.min(scored.len());
89    scored.truncate(keep);
90
91    // Prefetch FP32 vectors for reranking.
92    if let Some(mmap) = &seg.mmap_vectors {
93        let ids: Vec<u32> = scored.iter().map(|&(id, _)| id).collect();
94        mmap.prefetch_batch(&ids);
95    }
96
97    // Phase 2: rerank with exact FP32.
98    let mut reranked: Vec<SearchResult> = scored
99        .into_iter()
100        .filter_map(|(id, _)| {
101            let v = if let Some(mmap) = &seg.mmap_vectors {
102                mmap.get_vector(id)?
103            } else {
104                seg.index.get_vector(id)?
105            };
106            Some(SearchResult {
107                id,
108                distance: distance(query, v, metric),
109            })
110        })
111        .collect();
112    reranked.sort_by(|a, b| {
113        a.distance
114            .partial_cmp(&b.distance)
115            .unwrap_or(std::cmp::Ordering::Equal)
116    });
117    reranked.truncate(top_k);
118    Ok(reranked)
119}
120
121impl VectorCollection {
122    /// Search across all segments, merging results by distance.
123    pub fn search(&self, query: &[f32], top_k: usize, ef: usize) -> Vec<SearchResult> {
124        // Codec-dispatch fast path: if a collection-level HnswCodecIndex has
125        // been built (RaBitQ or BBQ), use it exclusively for sealed-segment
126        // results and fall back to the growing/building flat segments only.
127        if let Some(ref dispatch) = self.codec_dispatch {
128            let mut all: Vec<SearchResult> = Vec::new();
129
130            let codec_results = dispatch.search(query, top_k, ef);
131            for r in codec_results {
132                all.push(SearchResult {
133                    id: r.id,
134                    distance: r.distance,
135                });
136            }
137
138            // Growing segment (brute-force, not yet in codec index).
139            let growing_results = self.growing.search(query, top_k);
140            for mut r in growing_results {
141                r.id += self.growing_base_id;
142                all.push(r);
143            }
144
145            // Building segments (brute-force while codec index rebuilds).
146            for seg in &self.building {
147                let results = seg.flat.search(query, top_k);
148                for mut r in results {
149                    r.id += seg.base_id;
150                    all.push(r);
151                }
152            }
153
154            all.sort_by(|a, b| {
155                a.distance
156                    .partial_cmp(&b.distance)
157                    .unwrap_or(std::cmp::Ordering::Equal)
158            });
159            all.truncate(top_k);
160            return all;
161        }
162
163        let mut all: Vec<SearchResult> = Vec::new();
164
165        // Search growing segment (brute-force).
166        let growing_results = self.growing.search(query, top_k);
167        for mut r in growing_results {
168            r.id += self.growing_base_id;
169            all.push(r);
170        }
171
172        // Search sealed segments.
173        for seg in &self.sealed {
174            let results = if seg.pq.is_some() || seg.sq8.is_some() {
175                match quantized_search(seg, query, top_k, ef, self.params.metric) {
176                    Ok(r) => r,
177                    Err(e) => {
178                        tracing::warn!(error = %e, "quantized_search budget exhausted; skipping segment");
179                        seg.index.search(query, top_k, ef)
180                    }
181                }
182            } else {
183                seg.index.search(query, top_k, ef)
184            };
185            for mut r in results {
186                r.id += seg.base_id;
187                all.push(r);
188            }
189        }
190
191        // Search building segments (brute-force while HNSW builds).
192        for seg in &self.building {
193            let results = seg.flat.search(query, top_k);
194            for mut r in results {
195                r.id += seg.base_id;
196                all.push(r);
197            }
198        }
199
200        all.sort_by(|a, b| {
201            a.distance
202                .partial_cmp(&b.distance)
203                .unwrap_or(std::cmp::Ordering::Equal)
204        });
205        all.truncate(top_k);
206        all
207    }
208
209    /// Search across all segments using an explicit metric override.
210    ///
211    /// For sealed segments with quantized codecs, the metric override is applied
212    /// during candidate reranking. Growing and building segments apply it exactly
213    /// via brute-force. The HNSW graph structure was built with the collection
214    /// metric; using a different metric affects the scoring but not graph traversal.
215    pub fn search_with_metric(
216        &self,
217        query: &[f32],
218        top_k: usize,
219        ef: usize,
220        metric: DistanceMetric,
221    ) -> Vec<SearchResult> {
222        // Codec-dispatch fast path: codec dispatch does not yet support per-query
223        // metric override — fall through to the non-codec path which does.
224        // When a codec index is active, we search only the growing/building
225        // segments with the override and add codec results with collection metric
226        // (approximate cross-metric search for the codec-indexed segments).
227        if let Some(ref dispatch) = self.codec_dispatch {
228            let mut all: Vec<SearchResult> = Vec::new();
229            let codec_results = dispatch.search(query, top_k, ef);
230            for r in codec_results {
231                all.push(SearchResult {
232                    id: r.id,
233                    distance: r.distance,
234                });
235            }
236            for mut r in self.growing.search_with_metric(query, top_k, metric) {
237                r.id += self.growing_base_id;
238                all.push(r);
239            }
240            for seg in &self.building {
241                for mut r in seg.flat.search_with_metric(query, top_k, metric) {
242                    r.id += seg.base_id;
243                    all.push(r);
244                }
245            }
246            all.sort_by(|a, b| {
247                a.distance
248                    .partial_cmp(&b.distance)
249                    .unwrap_or(std::cmp::Ordering::Equal)
250            });
251            all.truncate(top_k);
252            return all;
253        }
254
255        let mut all: Vec<SearchResult> = Vec::new();
256
257        for mut r in self.growing.search_with_metric(query, top_k, metric) {
258            r.id += self.growing_base_id;
259            all.push(r);
260        }
261
262        for seg in &self.sealed {
263            let results = if seg.pq.is_some() || seg.sq8.is_some() {
264                match quantized_search(seg, query, top_k, ef, metric) {
265                    Ok(r) => r,
266                    Err(e) => {
267                        tracing::warn!(error = %e, "quantized_search budget exhausted; skipping segment");
268                        seg.index.search(query, top_k, ef)
269                    }
270                }
271            } else {
272                seg.index.search(query, top_k, ef)
273            };
274            for mut r in results {
275                r.id += seg.base_id;
276                all.push(r);
277            }
278        }
279
280        for seg in &self.building {
281            for mut r in seg.flat.search_with_metric(query, top_k, metric) {
282                r.id += seg.base_id;
283                all.push(r);
284            }
285        }
286
287        all.sort_by(|a, b| {
288            a.distance
289                .partial_cmp(&b.distance)
290                .unwrap_or(std::cmp::Ordering::Equal)
291        });
292        all.truncate(top_k);
293        all
294    }
295
296    /// Search with a pre-filter bitmap (byte-array format) and explicit metric override.
297    pub fn search_with_bitmap_bytes_and_metric(
298        &self,
299        query: &[f32],
300        top_k: usize,
301        ef: usize,
302        bitmap: &[u8],
303        metric: DistanceMetric,
304    ) -> Vec<SearchResult> {
305        let mut all: Vec<SearchResult> = Vec::new();
306
307        let growing_results = self.growing.search_filtered_offset_with_metric(
308            query,
309            top_k,
310            bitmap,
311            self.growing_base_id,
312            metric,
313        );
314        for mut r in growing_results {
315            r.id += self.growing_base_id;
316            all.push(r);
317        }
318
319        for seg in &self.sealed {
320            let results =
321                seg.index
322                    .search_with_bitmap_bytes_offset(query, top_k, ef, bitmap, seg.base_id);
323            for mut r in results {
324                // Rerank with the requested metric using the stored FP32 vector.
325                if let Some(v) = seg.index.get_vector(r.id.wrapping_sub(seg.base_id)) {
326                    r.distance = crate::distance::distance(query, v, metric);
327                }
328                r.id += seg.base_id;
329                all.push(r);
330            }
331        }
332
333        for seg in &self.building {
334            let results = seg.flat.search_filtered_offset_with_metric(
335                query,
336                top_k,
337                bitmap,
338                seg.base_id,
339                metric,
340            );
341            for mut r in results {
342                r.id += seg.base_id;
343                all.push(r);
344            }
345        }
346
347        all.sort_by(|a, b| {
348            a.distance
349                .partial_cmp(&b.distance)
350                .unwrap_or(std::cmp::Ordering::Equal)
351        });
352        all.truncate(top_k);
353        all
354    }
355
356    /// Search with a pre-filter bitmap (byte-array format).
357    pub fn search_with_bitmap_bytes(
358        &self,
359        query: &[f32],
360        top_k: usize,
361        ef: usize,
362        bitmap: &[u8],
363    ) -> Vec<SearchResult> {
364        let mut all: Vec<SearchResult> = Vec::new();
365
366        let growing_results =
367            self.growing
368                .search_filtered_offset(query, top_k, bitmap, self.growing_base_id);
369        for mut r in growing_results {
370            r.id += self.growing_base_id;
371            all.push(r);
372        }
373
374        for seg in &self.sealed {
375            let results =
376                seg.index
377                    .search_with_bitmap_bytes_offset(query, top_k, ef, bitmap, seg.base_id);
378            for mut r in results {
379                r.id += seg.base_id;
380                all.push(r);
381            }
382        }
383
384        for seg in &self.building {
385            let results = seg
386                .flat
387                .search_filtered_offset(query, top_k, bitmap, seg.base_id);
388            for mut r in results {
389                r.id += seg.base_id;
390                all.push(r);
391            }
392        }
393
394        all.sort_by(|a, b| {
395            a.distance
396                .partial_cmp(&b.distance)
397                .unwrap_or(std::cmp::Ordering::Equal)
398        });
399        all.truncate(top_k);
400        all
401    }
402
403    /// Search with a structured payload predicate.
404    ///
405    /// If `predicate` is fully covered by indexed fields (all leaf fields have
406    /// a bitmap index), the bitmap is built and HNSW traversal uses it as a
407    /// pre-filter.
408    ///
409    /// If any field in `predicate` is un-indexed, the method returns
410    /// `(results, false)` where `false` signals that the predicate was NOT
411    /// applied and the caller must apply it as a post-filter. This guarantees
412    /// the un-indexed predicate is never silently dropped.
413    ///
414    /// Returns `(results, filter_was_applied)`.
415    pub fn search_with_payload_filter(
416        &self,
417        query: &[f32],
418        top_k: usize,
419        ef: usize,
420        predicate: &FilterPredicate,
421    ) -> (Vec<SearchResult>, bool) {
422        match self.payload.pre_filter(predicate) {
423            Some(bm) => {
424                // Serialize the bitmap to the byte format expected by
425                // `search_with_bitmap_bytes`.
426                let mut bm_bytes = Vec::new();
427                if bm.serialize_into(&mut bm_bytes).is_ok() {
428                    let results = self.search_with_bitmap_bytes(query, top_k, ef, &bm_bytes);
429                    (results, true)
430                } else {
431                    // Serialization failure: fall back to unfiltered search.
432                    (self.search(query, top_k, ef), false)
433                }
434            }
435            None => {
436                // Un-indexed field present: full scan, caller must post-filter.
437                (self.search(query, top_k, ef), false)
438            }
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use crate::collection::lifecycle::VectorCollection;
446    use crate::collection::segment::DEFAULT_SEAL_THRESHOLD;
447    use crate::distance::DistanceMetric;
448    use crate::hnsw::{HnswIndex, HnswParams};
449
450    fn make_collection() -> VectorCollection {
451        VectorCollection::new(
452            3,
453            HnswParams {
454                metric: DistanceMetric::L2,
455                ..HnswParams::default()
456            },
457        )
458    }
459
460    #[test]
461    fn insert_and_search() {
462        let mut coll = make_collection();
463        for i in 0..100u32 {
464            coll.insert(vec![i as f32, 0.0, 0.0]);
465        }
466        assert_eq!(coll.len(), 100);
467        let results = coll.search(&[50.0, 0.0, 0.0], 3, 64);
468        assert_eq!(results.len(), 3);
469        assert_eq!(results[0].id, 50);
470    }
471
472    #[test]
473    fn seal_moves_to_building() {
474        let mut coll = VectorCollection::new(2, HnswParams::default());
475        for i in 0..DEFAULT_SEAL_THRESHOLD {
476            coll.insert(vec![i as f32, 0.0]);
477        }
478        assert!(coll.needs_seal());
479
480        let req = coll.seal("test_key").unwrap();
481        assert_eq!(req.vectors.len(), DEFAULT_SEAL_THRESHOLD);
482        assert_eq!(coll.building.len(), 1);
483        assert_eq!(coll.growing.len(), 0);
484
485        let results = coll.search(&[100.0, 0.0], 1, 64);
486        assert!(!results.is_empty());
487    }
488
489    #[test]
490    fn complete_build_promotes_to_sealed() {
491        let mut coll = VectorCollection::new(2, HnswParams::default());
492        for i in 0..100 {
493            coll.insert(vec![i as f32, 0.0]);
494        }
495        let req = coll.seal("test").unwrap();
496
497        let mut index = HnswIndex::new(req.dim, req.params);
498        for v in &req.vectors {
499            index.insert(v.clone()).unwrap();
500        }
501        coll.complete_build(req.segment_id, index);
502
503        assert_eq!(coll.building.len(), 0);
504        assert_eq!(coll.sealed.len(), 1);
505
506        let results = coll.search(&[50.0, 0.0], 3, 64);
507        assert!(!results.is_empty());
508    }
509
510    #[test]
511    fn multi_segment_search_merges() {
512        let mut coll = VectorCollection::new(
513            2,
514            HnswParams {
515                metric: DistanceMetric::L2,
516                ..HnswParams::default()
517            },
518        );
519
520        for i in 0..100 {
521            coll.insert(vec![i as f32, 0.0]);
522        }
523        let req = coll.seal("test").unwrap();
524        let mut idx = HnswIndex::new(2, req.params);
525        for v in &req.vectors {
526            idx.insert(v.clone()).unwrap();
527        }
528        coll.complete_build(req.segment_id, idx);
529
530        for i in 100..200 {
531            coll.insert(vec![i as f32, 0.0]);
532        }
533
534        let results = coll.search(&[150.0, 0.0], 3, 64);
535        assert_eq!(results.len(), 3);
536        assert_eq!(results[0].id, 150);
537    }
538
539    #[test]
540    fn delete_across_segments() {
541        let mut coll = VectorCollection::new(2, HnswParams::default());
542        for i in 0..10 {
543            coll.insert(vec![i as f32, 0.0]);
544        }
545        assert!(coll.delete(5));
546        assert_eq!(coll.live_count(), 9);
547
548        let results = coll.search(&[5.0, 0.0], 10, 64);
549        assert!(results.iter().all(|r| r.id != 5));
550    }
551
552    /// Build a sealed HNSW segment from `n` vectors of `dim=2`, where vector `i`
553    /// is `[i as f32, 0.0]`. Returns the collection with one sealed segment.
554    fn make_sealed_collection(n: usize) -> VectorCollection {
555        let mut coll = VectorCollection::new(
556            2,
557            HnswParams {
558                metric: DistanceMetric::L2,
559                ..HnswParams::default()
560            },
561        );
562        for i in 0..n {
563            coll.insert(vec![i as f32, 0.0]);
564        }
565        let req = coll.seal("seg").unwrap();
566        let mut idx = HnswIndex::new(req.dim, req.params);
567        for v in &req.vectors {
568            idx.insert(v.clone()).unwrap();
569        }
570        coll.complete_build(req.segment_id, idx);
571        coll
572    }
573
574    /// Attach SQ8 quantization to the first sealed segment of `coll`.
575    fn attach_sq8(coll: &mut VectorCollection) {
576        use crate::quantize::sq8::Sq8Codec;
577
578        let sealed = &mut coll.sealed[0];
579        let dim = sealed.index.dim();
580        let n = sealed.index.len();
581        let vecs: Vec<Vec<f32>> = (0..n)
582            .filter_map(|i| sealed.index.get_vector(i as u32).map(|v| v.to_vec()))
583            .collect();
584        let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
585        let codec = Sq8Codec::calibrate(&refs, dim);
586        let sq8_data: Vec<u8> = vecs.iter().flat_map(|v| codec.quantize(v)).collect();
587        sealed.sq8 = Some((codec, sq8_data));
588    }
589
590    #[test]
591    fn sq8_search_returns_correct_nearest_neighbor() {
592        let mut coll = make_sealed_collection(200);
593        attach_sq8(&mut coll);
594
595        let results = coll.search(&[100.0, 0.0], 5, 64);
596        assert!(!results.is_empty(), "expected non-empty results");
597        assert_eq!(
598            results[0].id, 100,
599            "nearest neighbor of [100,0] should be id=100, got id={}",
600            results[0].id
601        );
602    }
603
604    #[test]
605    fn sq8_search_recall_matches_hnsw() {
606        // Build two identical collections — one without SQ8, one with.
607        let coll_plain = make_sealed_collection(500);
608        let mut coll_sq8 = make_sealed_collection(500);
609        attach_sq8(&mut coll_sq8);
610
611        let query = [250.0f32, 0.0];
612        let top_k = 5;
613
614        let plain_results = coll_plain.search(&query, top_k, 64);
615        let sq8_results = coll_sq8.search(&query, top_k, 64);
616
617        let plain_ids: std::collections::HashSet<u32> =
618            plain_results.iter().map(|r| r.id).collect();
619        let sq8_ids: std::collections::HashSet<u32> = sq8_results.iter().map(|r| r.id).collect();
620
621        let overlap = plain_ids.intersection(&sq8_ids).count();
622        assert!(
623            overlap >= 4,
624            "SQ8 recall too low: {overlap}/5 results matched plain HNSW (need >=4)"
625        );
626    }
627
628    #[test]
629    fn codec_dispatch_bbq_search_returns_results_and_stats_report_bbq() {
630        let dim = 4;
631        let mut coll = VectorCollection::new(
632            dim,
633            HnswParams {
634                metric: DistanceMetric::L2,
635                m: 8,
636                ef_construction: 50,
637                ..HnswParams::default()
638            },
639        );
640
641        // Insert 50 vectors: vector i = [i as f32, 0, 0, 0].
642        for i in 0u32..50 {
643            coll.insert(vec![i as f32, 0.0, 0.0, 0.0]);
644        }
645
646        // Build the collection-level BBQ dispatch index over current vectors.
647        let dispatch = coll.build_codec_dispatch("bbq");
648        assert!(
649            dispatch.is_some(),
650            "build_codec_dispatch(bbq) should return Some"
651        );
652
653        // Query near id=25.
654        let query = [25.0f32, 0.0, 0.0, 0.0];
655        let results = coll.search(&query, 5, 32);
656        assert!(
657            !results.is_empty(),
658            "BBQ codec-dispatch search should return results"
659        );
660
661        // Stats should report Bbq quantization.
662        let stats = coll.stats();
663        assert_eq!(
664            stats.quantization,
665            nodedb_types::VectorIndexQuantization::Bbq,
666            "stats quantization should be Bbq after build_codec_dispatch(bbq)"
667        );
668    }
669
670    #[test]
671    fn codec_dispatch_rabitq_search_non_empty() {
672        let dim = 4;
673        let mut coll = VectorCollection::new(
674            dim,
675            HnswParams {
676                metric: DistanceMetric::L2,
677                m: 8,
678                ef_construction: 50,
679                ..HnswParams::default()
680            },
681        );
682        for i in 0u32..50 {
683            coll.insert(vec![i as f32, 0.0, 0.0, 0.0]);
684        }
685        coll.build_codec_dispatch("rabitq").unwrap();
686
687        let results = coll.search(&[10.0, 0.0, 0.0, 0.0], 3, 32);
688        assert!(
689            !results.is_empty(),
690            "RaBitQ dispatch search should return results"
691        );
692
693        let stats = coll.stats();
694        assert_eq!(
695            stats.quantization,
696            nodedb_types::VectorIndexQuantization::RaBitQ
697        );
698    }
699
700    #[test]
701    fn sq8_search_does_not_scan_all_vectors() {
702        // This test validates correctness of the SQ8 search path for a large
703        // segment. The bug being guarded against is an O(N) linear scan instead
704        // of graph-guided traversal: the fix must use HNSW with SQ8 as the
705        // distance function. Correctness (correct nearest neighbor) is the
706        // invariant that must be preserved when the implementation changes.
707        let mut coll = make_sealed_collection(2000);
708        attach_sq8(&mut coll);
709
710        let results = coll.search(&[1000.0, 0.0], 5, 64);
711        assert!(!results.is_empty(), "expected non-empty results");
712        assert_eq!(
713            results[0].id, 1000,
714            "nearest neighbor of [1000,0] should be id=1000, got id={}",
715            results[0].id
716        );
717    }
718}