1use crate::distance::{DistanceMetric, distance};
13use crate::error::VectorError;
14use crate::hnsw::SearchResult;
15
16use super::lifecycle::VectorCollection;
17use super::payload_index::FilterPredicate;
18use super::segment::SealedSegment;
19
20#[inline]
23fn sq8_score(
24 codec: &crate::quantize::sq8::Sq8Codec,
25 query: &[f32],
26 encoded: &[u8],
27 metric: DistanceMetric,
28) -> f32 {
29 match metric {
30 DistanceMetric::Cosine => codec.asymmetric_cosine(query, encoded),
31 DistanceMetric::InnerProduct => codec.asymmetric_ip(query, encoded),
32 _ => codec.asymmetric_l2(query, encoded),
37 }
38}
39
40fn quantized_search(
46 seg: &SealedSegment,
47 query: &[f32],
48 top_k: usize,
49 ef: usize,
50 metric: DistanceMetric,
51) -> Result<Vec<SearchResult>, VectorError> {
52 let rerank_k = top_k.saturating_mul(3).max(20);
53 let hnsw_candidates = seg.index.search(query, rerank_k, ef);
54
55 let mut scored: Vec<(u32, f32)> = if let Some((codec, codes)) = &seg.pq {
57 let table = codec.build_distance_table(query)?;
58 let m = codec.m;
59 hnsw_candidates
60 .into_iter()
61 .filter_map(|r| {
62 let start = (r.id as usize).checked_mul(m)?;
63 let end = start.checked_add(m)?;
64 let slice = codes.get(start..end)?;
65 Some((r.id, codec.asymmetric_distance(&table, slice)))
66 })
67 .collect()
68 } else if let Some((codec, data)) = &seg.sq8 {
69 let dim = codec.dim();
70 hnsw_candidates
71 .into_iter()
72 .filter_map(|r| {
73 let start = (r.id as usize).checked_mul(dim)?;
74 let end = start.checked_add(dim)?;
75 let slice = data.get(start..end)?;
76 Some((r.id, sq8_score(codec, query, slice, metric)))
77 })
78 .collect()
79 } else {
80 hnsw_candidates
81 .into_iter()
82 .map(|r| (r.id, r.distance))
83 .collect()
84 };
85 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
86
87 let keep = rerank_k.min(scored.len());
89 scored.truncate(keep);
90
91 if let Some(mmap) = &seg.mmap_vectors {
93 let ids: Vec<u32> = scored.iter().map(|&(id, _)| id).collect();
94 mmap.prefetch_batch(&ids);
95 }
96
97 let mut reranked: Vec<SearchResult> = scored
99 .into_iter()
100 .filter_map(|(id, _)| {
101 let v = if let Some(mmap) = &seg.mmap_vectors {
102 mmap.get_vector(id)?
103 } else {
104 seg.index.get_vector(id)?
105 };
106 Some(SearchResult {
107 id,
108 distance: distance(query, v, metric),
109 })
110 })
111 .collect();
112 reranked.sort_by(|a, b| {
113 a.distance
114 .partial_cmp(&b.distance)
115 .unwrap_or(std::cmp::Ordering::Equal)
116 });
117 reranked.truncate(top_k);
118 Ok(reranked)
119}
120
121impl VectorCollection {
122 pub fn search(&self, query: &[f32], top_k: usize, ef: usize) -> Vec<SearchResult> {
124 if let Some(ref dispatch) = self.codec_dispatch {
128 let mut all: Vec<SearchResult> = Vec::new();
129
130 let codec_results = dispatch.search(query, top_k, ef);
131 for r in codec_results {
132 all.push(SearchResult {
133 id: r.id,
134 distance: r.distance,
135 });
136 }
137
138 let growing_results = self.growing.search(query, top_k);
140 for mut r in growing_results {
141 r.id += self.growing_base_id;
142 all.push(r);
143 }
144
145 for seg in &self.building {
147 let results = seg.flat.search(query, top_k);
148 for mut r in results {
149 r.id += seg.base_id;
150 all.push(r);
151 }
152 }
153
154 all.sort_by(|a, b| {
155 a.distance
156 .partial_cmp(&b.distance)
157 .unwrap_or(std::cmp::Ordering::Equal)
158 });
159 all.truncate(top_k);
160 return all;
161 }
162
163 let mut all: Vec<SearchResult> = Vec::new();
164
165 let growing_results = self.growing.search(query, top_k);
167 for mut r in growing_results {
168 r.id += self.growing_base_id;
169 all.push(r);
170 }
171
172 for seg in &self.sealed {
174 let results = if seg.pq.is_some() || seg.sq8.is_some() {
175 match quantized_search(seg, query, top_k, ef, self.params.metric) {
176 Ok(r) => r,
177 Err(e) => {
178 tracing::warn!(error = %e, "quantized_search budget exhausted; skipping segment");
179 seg.index.search(query, top_k, ef)
180 }
181 }
182 } else {
183 seg.index.search(query, top_k, ef)
184 };
185 for mut r in results {
186 r.id += seg.base_id;
187 all.push(r);
188 }
189 }
190
191 for seg in &self.building {
193 let results = seg.flat.search(query, top_k);
194 for mut r in results {
195 r.id += seg.base_id;
196 all.push(r);
197 }
198 }
199
200 all.sort_by(|a, b| {
201 a.distance
202 .partial_cmp(&b.distance)
203 .unwrap_or(std::cmp::Ordering::Equal)
204 });
205 all.truncate(top_k);
206 all
207 }
208
209 pub fn search_with_metric(
216 &self,
217 query: &[f32],
218 top_k: usize,
219 ef: usize,
220 metric: DistanceMetric,
221 ) -> Vec<SearchResult> {
222 if let Some(ref dispatch) = self.codec_dispatch {
228 let mut all: Vec<SearchResult> = Vec::new();
229 let codec_results = dispatch.search(query, top_k, ef);
230 for r in codec_results {
231 all.push(SearchResult {
232 id: r.id,
233 distance: r.distance,
234 });
235 }
236 for mut r in self.growing.search_with_metric(query, top_k, metric) {
237 r.id += self.growing_base_id;
238 all.push(r);
239 }
240 for seg in &self.building {
241 for mut r in seg.flat.search_with_metric(query, top_k, metric) {
242 r.id += seg.base_id;
243 all.push(r);
244 }
245 }
246 all.sort_by(|a, b| {
247 a.distance
248 .partial_cmp(&b.distance)
249 .unwrap_or(std::cmp::Ordering::Equal)
250 });
251 all.truncate(top_k);
252 return all;
253 }
254
255 let mut all: Vec<SearchResult> = Vec::new();
256
257 for mut r in self.growing.search_with_metric(query, top_k, metric) {
258 r.id += self.growing_base_id;
259 all.push(r);
260 }
261
262 for seg in &self.sealed {
263 let results = if seg.pq.is_some() || seg.sq8.is_some() {
264 match quantized_search(seg, query, top_k, ef, metric) {
265 Ok(r) => r,
266 Err(e) => {
267 tracing::warn!(error = %e, "quantized_search budget exhausted; skipping segment");
268 seg.index.search(query, top_k, ef)
269 }
270 }
271 } else {
272 seg.index.search(query, top_k, ef)
273 };
274 for mut r in results {
275 r.id += seg.base_id;
276 all.push(r);
277 }
278 }
279
280 for seg in &self.building {
281 for mut r in seg.flat.search_with_metric(query, top_k, metric) {
282 r.id += seg.base_id;
283 all.push(r);
284 }
285 }
286
287 all.sort_by(|a, b| {
288 a.distance
289 .partial_cmp(&b.distance)
290 .unwrap_or(std::cmp::Ordering::Equal)
291 });
292 all.truncate(top_k);
293 all
294 }
295
296 pub fn search_with_bitmap_bytes_and_metric(
298 &self,
299 query: &[f32],
300 top_k: usize,
301 ef: usize,
302 bitmap: &[u8],
303 metric: DistanceMetric,
304 ) -> Vec<SearchResult> {
305 let mut all: Vec<SearchResult> = Vec::new();
306
307 let growing_results = self.growing.search_filtered_offset_with_metric(
308 query,
309 top_k,
310 bitmap,
311 self.growing_base_id,
312 metric,
313 );
314 for mut r in growing_results {
315 r.id += self.growing_base_id;
316 all.push(r);
317 }
318
319 for seg in &self.sealed {
320 let results =
321 seg.index
322 .search_with_bitmap_bytes_offset(query, top_k, ef, bitmap, seg.base_id);
323 for mut r in results {
324 if let Some(v) = seg.index.get_vector(r.id.wrapping_sub(seg.base_id)) {
326 r.distance = crate::distance::distance(query, v, metric);
327 }
328 r.id += seg.base_id;
329 all.push(r);
330 }
331 }
332
333 for seg in &self.building {
334 let results = seg.flat.search_filtered_offset_with_metric(
335 query,
336 top_k,
337 bitmap,
338 seg.base_id,
339 metric,
340 );
341 for mut r in results {
342 r.id += seg.base_id;
343 all.push(r);
344 }
345 }
346
347 all.sort_by(|a, b| {
348 a.distance
349 .partial_cmp(&b.distance)
350 .unwrap_or(std::cmp::Ordering::Equal)
351 });
352 all.truncate(top_k);
353 all
354 }
355
356 pub fn search_with_bitmap_bytes(
358 &self,
359 query: &[f32],
360 top_k: usize,
361 ef: usize,
362 bitmap: &[u8],
363 ) -> Vec<SearchResult> {
364 let mut all: Vec<SearchResult> = Vec::new();
365
366 let growing_results =
367 self.growing
368 .search_filtered_offset(query, top_k, bitmap, self.growing_base_id);
369 for mut r in growing_results {
370 r.id += self.growing_base_id;
371 all.push(r);
372 }
373
374 for seg in &self.sealed {
375 let results =
376 seg.index
377 .search_with_bitmap_bytes_offset(query, top_k, ef, bitmap, seg.base_id);
378 for mut r in results {
379 r.id += seg.base_id;
380 all.push(r);
381 }
382 }
383
384 for seg in &self.building {
385 let results = seg
386 .flat
387 .search_filtered_offset(query, top_k, bitmap, seg.base_id);
388 for mut r in results {
389 r.id += seg.base_id;
390 all.push(r);
391 }
392 }
393
394 all.sort_by(|a, b| {
395 a.distance
396 .partial_cmp(&b.distance)
397 .unwrap_or(std::cmp::Ordering::Equal)
398 });
399 all.truncate(top_k);
400 all
401 }
402
403 pub fn search_with_payload_filter(
416 &self,
417 query: &[f32],
418 top_k: usize,
419 ef: usize,
420 predicate: &FilterPredicate,
421 ) -> (Vec<SearchResult>, bool) {
422 match self.payload.pre_filter(predicate) {
423 Some(bm) => {
424 let mut bm_bytes = Vec::new();
427 if bm.serialize_into(&mut bm_bytes).is_ok() {
428 let results = self.search_with_bitmap_bytes(query, top_k, ef, &bm_bytes);
429 (results, true)
430 } else {
431 (self.search(query, top_k, ef), false)
433 }
434 }
435 None => {
436 (self.search(query, top_k, ef), false)
438 }
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use crate::collection::lifecycle::VectorCollection;
446 use crate::collection::segment::DEFAULT_SEAL_THRESHOLD;
447 use crate::distance::DistanceMetric;
448 use crate::hnsw::{HnswIndex, HnswParams};
449
450 fn make_collection() -> VectorCollection {
451 VectorCollection::new(
452 3,
453 HnswParams {
454 metric: DistanceMetric::L2,
455 ..HnswParams::default()
456 },
457 )
458 }
459
460 #[test]
461 fn insert_and_search() {
462 let mut coll = make_collection();
463 for i in 0..100u32 {
464 coll.insert(vec![i as f32, 0.0, 0.0]);
465 }
466 assert_eq!(coll.len(), 100);
467 let results = coll.search(&[50.0, 0.0, 0.0], 3, 64);
468 assert_eq!(results.len(), 3);
469 assert_eq!(results[0].id, 50);
470 }
471
472 #[test]
473 fn seal_moves_to_building() {
474 let mut coll = VectorCollection::new(2, HnswParams::default());
475 for i in 0..DEFAULT_SEAL_THRESHOLD {
476 coll.insert(vec![i as f32, 0.0]);
477 }
478 assert!(coll.needs_seal());
479
480 let req = coll.seal("test_key").unwrap();
481 assert_eq!(req.vectors.len(), DEFAULT_SEAL_THRESHOLD);
482 assert_eq!(coll.building.len(), 1);
483 assert_eq!(coll.growing.len(), 0);
484
485 let results = coll.search(&[100.0, 0.0], 1, 64);
486 assert!(!results.is_empty());
487 }
488
489 #[test]
490 fn complete_build_promotes_to_sealed() {
491 let mut coll = VectorCollection::new(2, HnswParams::default());
492 for i in 0..100 {
493 coll.insert(vec![i as f32, 0.0]);
494 }
495 let req = coll.seal("test").unwrap();
496
497 let mut index = HnswIndex::new(req.dim, req.params);
498 for v in &req.vectors {
499 index.insert(v.clone()).unwrap();
500 }
501 coll.complete_build(req.segment_id, index);
502
503 assert_eq!(coll.building.len(), 0);
504 assert_eq!(coll.sealed.len(), 1);
505
506 let results = coll.search(&[50.0, 0.0], 3, 64);
507 assert!(!results.is_empty());
508 }
509
510 #[test]
511 fn multi_segment_search_merges() {
512 let mut coll = VectorCollection::new(
513 2,
514 HnswParams {
515 metric: DistanceMetric::L2,
516 ..HnswParams::default()
517 },
518 );
519
520 for i in 0..100 {
521 coll.insert(vec![i as f32, 0.0]);
522 }
523 let req = coll.seal("test").unwrap();
524 let mut idx = HnswIndex::new(2, req.params);
525 for v in &req.vectors {
526 idx.insert(v.clone()).unwrap();
527 }
528 coll.complete_build(req.segment_id, idx);
529
530 for i in 100..200 {
531 coll.insert(vec![i as f32, 0.0]);
532 }
533
534 let results = coll.search(&[150.0, 0.0], 3, 64);
535 assert_eq!(results.len(), 3);
536 assert_eq!(results[0].id, 150);
537 }
538
539 #[test]
540 fn delete_across_segments() {
541 let mut coll = VectorCollection::new(2, HnswParams::default());
542 for i in 0..10 {
543 coll.insert(vec![i as f32, 0.0]);
544 }
545 assert!(coll.delete(5));
546 assert_eq!(coll.live_count(), 9);
547
548 let results = coll.search(&[5.0, 0.0], 10, 64);
549 assert!(results.iter().all(|r| r.id != 5));
550 }
551
552 fn make_sealed_collection(n: usize) -> VectorCollection {
555 let mut coll = VectorCollection::new(
556 2,
557 HnswParams {
558 metric: DistanceMetric::L2,
559 ..HnswParams::default()
560 },
561 );
562 for i in 0..n {
563 coll.insert(vec![i as f32, 0.0]);
564 }
565 let req = coll.seal("seg").unwrap();
566 let mut idx = HnswIndex::new(req.dim, req.params);
567 for v in &req.vectors {
568 idx.insert(v.clone()).unwrap();
569 }
570 coll.complete_build(req.segment_id, idx);
571 coll
572 }
573
574 fn attach_sq8(coll: &mut VectorCollection) {
576 use crate::quantize::sq8::Sq8Codec;
577
578 let sealed = &mut coll.sealed[0];
579 let dim = sealed.index.dim();
580 let n = sealed.index.len();
581 let vecs: Vec<Vec<f32>> = (0..n)
582 .filter_map(|i| sealed.index.get_vector(i as u32).map(|v| v.to_vec()))
583 .collect();
584 let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect();
585 let codec = Sq8Codec::calibrate(&refs, dim);
586 let sq8_data: Vec<u8> = vecs.iter().flat_map(|v| codec.quantize(v)).collect();
587 sealed.sq8 = Some((codec, sq8_data));
588 }
589
590 #[test]
591 fn sq8_search_returns_correct_nearest_neighbor() {
592 let mut coll = make_sealed_collection(200);
593 attach_sq8(&mut coll);
594
595 let results = coll.search(&[100.0, 0.0], 5, 64);
596 assert!(!results.is_empty(), "expected non-empty results");
597 assert_eq!(
598 results[0].id, 100,
599 "nearest neighbor of [100,0] should be id=100, got id={}",
600 results[0].id
601 );
602 }
603
604 #[test]
605 fn sq8_search_recall_matches_hnsw() {
606 let coll_plain = make_sealed_collection(500);
608 let mut coll_sq8 = make_sealed_collection(500);
609 attach_sq8(&mut coll_sq8);
610
611 let query = [250.0f32, 0.0];
612 let top_k = 5;
613
614 let plain_results = coll_plain.search(&query, top_k, 64);
615 let sq8_results = coll_sq8.search(&query, top_k, 64);
616
617 let plain_ids: std::collections::HashSet<u32> =
618 plain_results.iter().map(|r| r.id).collect();
619 let sq8_ids: std::collections::HashSet<u32> = sq8_results.iter().map(|r| r.id).collect();
620
621 let overlap = plain_ids.intersection(&sq8_ids).count();
622 assert!(
623 overlap >= 4,
624 "SQ8 recall too low: {overlap}/5 results matched plain HNSW (need >=4)"
625 );
626 }
627
628 #[test]
629 fn codec_dispatch_bbq_search_returns_results_and_stats_report_bbq() {
630 let dim = 4;
631 let mut coll = VectorCollection::new(
632 dim,
633 HnswParams {
634 metric: DistanceMetric::L2,
635 m: 8,
636 ef_construction: 50,
637 ..HnswParams::default()
638 },
639 );
640
641 for i in 0u32..50 {
643 coll.insert(vec![i as f32, 0.0, 0.0, 0.0]);
644 }
645
646 let dispatch = coll.build_codec_dispatch("bbq");
648 assert!(
649 dispatch.is_some(),
650 "build_codec_dispatch(bbq) should return Some"
651 );
652
653 let query = [25.0f32, 0.0, 0.0, 0.0];
655 let results = coll.search(&query, 5, 32);
656 assert!(
657 !results.is_empty(),
658 "BBQ codec-dispatch search should return results"
659 );
660
661 let stats = coll.stats();
663 assert_eq!(
664 stats.quantization,
665 nodedb_types::VectorIndexQuantization::Bbq,
666 "stats quantization should be Bbq after build_codec_dispatch(bbq)"
667 );
668 }
669
670 #[test]
671 fn codec_dispatch_rabitq_search_non_empty() {
672 let dim = 4;
673 let mut coll = VectorCollection::new(
674 dim,
675 HnswParams {
676 metric: DistanceMetric::L2,
677 m: 8,
678 ef_construction: 50,
679 ..HnswParams::default()
680 },
681 );
682 for i in 0u32..50 {
683 coll.insert(vec![i as f32, 0.0, 0.0, 0.0]);
684 }
685 coll.build_codec_dispatch("rabitq").unwrap();
686
687 let results = coll.search(&[10.0, 0.0, 0.0, 0.0], 3, 32);
688 assert!(
689 !results.is_empty(),
690 "RaBitQ dispatch search should return results"
691 );
692
693 let stats = coll.stats();
694 assert_eq!(
695 stats.quantization,
696 nodedb_types::VectorIndexQuantization::RaBitQ
697 );
698 }
699
700 #[test]
701 fn sq8_search_does_not_scan_all_vectors() {
702 let mut coll = make_sealed_collection(2000);
708 attach_sq8(&mut coll);
709
710 let results = coll.search(&[1000.0, 0.0], 5, 64);
711 assert!(!results.is_empty(), "expected non-empty results");
712 assert_eq!(
713 results[0].id, 1000,
714 "nearest neighbor of [1000,0] should be id=1000, got id={}",
715 results[0].id
716 );
717 }
718}