1use std::{
7 collections::HashMap,
8 future::Future,
9 num::NonZeroUsize,
10 ops::Range,
11 sync::{
12 atomic::{AtomicU64, AtomicUsize},
13 Arc,
14 },
15 time::Instant,
16};
17
18use diskann::{
19 graph::{
20 self,
21 glue::{
22 self, DefaultPostProcessor, ExpandBeam, IdIterator, SearchExt, SearchPostProcess,
23 SearchStrategy,
24 },
25 search::Knn,
26 search_output_buffer, AdjacencyList, DiskANNIndex,
27 },
28 neighbor::Neighbor,
29 provider::{
30 Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
31 NeighborAccessor, NoopGuard,
32 },
33 utils::{
34 object_pool::{ObjectPool, PoolOption, TryAsPooled},
35 IntoUsize, VectorRepr,
36 },
37 ANNError, ANNResult,
38};
39use diskann_providers::storage::StorageReadProvider;
40use diskann_providers::{
41 model::{
42 compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
43 pq::quantizer_preprocess, PQData, PQScratch,
44 },
45 storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
46};
47use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
48use futures_util::future;
49use tokio::runtime::Runtime;
50use tracing::debug;
51
52use crate::{
53 data_model::{CachingStrategy, GraphHeader},
54 filter_parameter::{default_vector_filter, VectorFilter},
55 search::{
56 provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
57 traits::{VertexProvider, VertexProviderFactory},
58 },
59 storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
60 utils::AlignedFileReaderFactory,
61 utils::QueryStatistics,
62};
63
64pub struct DiskProvider<Data>
74where
75 Data: GraphDataType<VectorIdType = u32>,
76{
77 graph_header: GraphHeader,
79
80 distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
82
83 pq_data: Arc<PQData>,
85
86 num_points: usize,
88
89 metric: Metric,
91
92 search_io_limit: usize,
94}
95
96impl<Data> DataProvider for DiskProvider<Data>
97where
98 Data: GraphDataType<VectorIdType = u32>,
99{
100 type Context = DefaultContext;
101
102 type InternalId = u32;
103
104 type ExternalId = u32;
105
106 type Guard = NoopGuard<u32>;
107
108 type Error = ANNError;
109
110 fn to_internal_id(
112 &self,
113 _context: &DefaultContext,
114 gid: &Self::ExternalId,
115 ) -> Result<Self::InternalId, Self::Error> {
116 Ok(*gid)
117 }
118
119 fn to_external_id(
121 &self,
122 _context: &DefaultContext,
123 id: Self::InternalId,
124 ) -> Result<Self::ExternalId, Self::Error> {
125 Ok(id)
126 }
127}
128
129impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
130where
131 Data: GraphDataType<VectorIdType = u32>,
132{
133 type Error = ANNError;
134
135 async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
136 where
137 P: StorageReadProvider,
138 {
139 debug!(
140 "DiskProvider::load_with() called with file: {:?}",
141 get_disk_index_file(ctx.quant_load_context.metadata.prefix())
142 );
143
144 let graph_header = {
145 let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
146 ctx.quant_load_context.metadata.prefix(),
147 ));
148
149 let caching_strategy = if ctx.num_nodes_to_cache > 0 {
150 CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
151 } else {
152 CachingStrategy::None
153 };
154
155 let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
156 aligned_reader_factory,
157 caching_strategy,
158 )?;
159 VertexProviderFactory::get_header(&vertex_provider_factory)?
160 };
161
162 let metric = ctx.quant_load_context.metric;
163 let num_points = ctx.num_points;
164
165 let index_path_prefix = ctx.quant_load_context.metadata.prefix();
166 let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
167 get_pq_pivot_file(index_path_prefix),
168 get_compressed_pq_file(index_path_prefix),
169 provider,
170 )?;
171
172 Self::new(
173 &index_reader,
174 graph_header,
175 metric,
176 num_points,
177 ctx.search_io_limit,
178 )
179 }
180}
181
182impl<Data> DiskProvider<Data>
183where
184 Data: GraphDataType<VectorIdType = u32>,
185{
186 fn new(
187 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
188 graph_header: GraphHeader,
189 metric: Metric,
190 num_points: usize,
191 search_io_limit: usize,
192 ) -> ANNResult<Self> {
193 let distance_comparer =
194 Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
195
196 let pq_data = disk_index_reader.get_pq_data();
197
198 Ok(Self {
199 graph_header,
200 distance_comparer,
201 pq_data,
202 num_points,
203 metric,
204 search_io_limit,
205 })
206 }
207}
208
209pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
220where
221 Data: GraphDataType<VectorIdType = u32>,
222 ProviderFactory: VertexProviderFactory<Data>,
223{
224 io_tracker: IOTracker,
226 vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), query: &'a [Data::VectorDataType],
228
229 vertex_provider_factory: &'a ProviderFactory,
231
232 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
234}
235
236struct IOTracker {
239 io_time_us: AtomicU64,
240 preprocess_time_us: AtomicU64,
241 io_count: AtomicUsize,
242}
243
244impl Default for IOTracker {
245 fn default() -> Self {
246 Self {
247 io_time_us: AtomicU64::new(0),
248 preprocess_time_us: AtomicU64::new(0),
249 io_count: AtomicUsize::new(0),
250 }
251 }
252}
253
254impl IOTracker {
255 fn add_time(category: &AtomicU64, time: u64) {
256 category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
257 }
258
259 fn time(category: &AtomicU64) -> u64 {
260 category.load(std::sync::atomic::Ordering::Relaxed)
261 }
262
263 fn add_io_count(&self, count: usize) {
264 self.io_count
265 .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
266 }
267
268 fn io_count(&self) -> usize {
269 self.io_count.load(std::sync::atomic::Ordering::Relaxed)
270 }
271}
272
273#[derive(Clone, Copy)]
274pub struct RerankAndFilter<'a> {
275 filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
276}
277
278impl<'a> RerankAndFilter<'a> {
279 fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
280 Self { filter }
281 }
282}
283
284impl<Data, VP>
285 SearchPostProcess<
286 DiskAccessor<'_, Data, VP>,
287 &[Data::VectorDataType],
288 (
289 <DiskProvider<Data> as DataProvider>::InternalId,
290 Data::AssociatedDataType,
291 ),
292 > for RerankAndFilter<'_>
293where
294 Data: GraphDataType<VectorIdType = u32>,
295 VP: VertexProvider<Data>,
296{
297 type Error = ANNError;
298 async fn post_process<I, B>(
299 &self,
300 accessor: &mut DiskAccessor<'_, Data, VP>,
301 query: &[Data::VectorDataType],
302 _computer: &DiskQueryComputer,
303 candidates: I,
304 output: &mut B,
305 ) -> Result<usize, Self::Error>
306 where
307 I: Iterator<Item = Neighbor<u32>> + Send,
308 B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)>
309 + Send
310 + ?Sized,
311 {
312 let provider = accessor.provider;
313
314 let mut uncached_ids = Vec::new();
315 let mut reranked = candidates
316 .map(|n| n.id)
317 .filter(|id| (self.filter)(id))
318 .filter_map(|n| {
319 if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
320 Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
321 } else {
322 uncached_ids.push(n);
323 None
324 }
325 })
326 .collect::<Result<Vec<_>, _>>()?;
327 if !uncached_ids.is_empty() {
328 ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
329 for n in &uncached_ids {
330 let v = accessor.scratch.vertex_provider.get_vector(n)?;
331 let d = provider.distance_comparer.evaluate_similarity(query, v);
332 let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
333 reranked.push(((*n, *a), d));
334 }
335 }
336
337 reranked
339 .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
340 Ok(output.extend(reranked))
342 }
343}
344
345impl<'this, Data, ProviderFactory> SearchStrategy<DiskProvider<Data>, &[Data::VectorDataType]>
346 for DiskSearchStrategy<'this, Data, ProviderFactory>
347where
348 Data: GraphDataType<VectorIdType = u32>,
349 ProviderFactory: VertexProviderFactory<Data>,
350{
351 type QueryComputer = DiskQueryComputer;
352 type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
353 type SearchAccessorError = ANNError;
354
355 fn search_accessor<'a>(
356 &'a self,
357 provider: &'a DiskProvider<Data>,
358 _context: &DefaultContext,
359 ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
360 DiskAccessor::new(
361 provider,
362 &self.io_tracker,
363 self.query,
364 self.vertex_provider_factory,
365 self.scratch_pool,
366 )
367 }
368}
369
370impl<'this, Data, ProviderFactory>
371 DefaultPostProcessor<
372 DiskProvider<Data>,
373 &[Data::VectorDataType],
374 (
375 <DiskProvider<Data> as DataProvider>::InternalId,
376 Data::AssociatedDataType,
377 ),
378 > for DiskSearchStrategy<'this, Data, ProviderFactory>
379where
380 Data: GraphDataType<VectorIdType = u32>,
381 ProviderFactory: VertexProviderFactory<Data>,
382{
383 type Processor = RerankAndFilter<'this>;
384
385 fn default_post_processor(&self) -> Self::Processor {
386 RerankAndFilter::new(self.vector_filter)
387 }
388}
389
390pub struct DiskQueryComputer {
392 num_pq_chunks: usize,
393 query_centroid_l2_distance: Vec<f32>,
394}
395
396impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
397 fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
398 let mut dist = 0.0f32;
399 #[allow(clippy::expect_used)]
400 compute_pq_distance_for_pq_coordinates(
401 changing,
402 self.num_pq_chunks,
403 &self.query_centroid_l2_distance,
404 std::slice::from_mut(&mut dist),
405 )
406 .expect("PQ distance compute for PQ coordinates is expected to succeed");
407 dist
408 }
409}
410
411impl<Data, VP> BuildQueryComputer<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
412where
413 Data: GraphDataType<VectorIdType = u32>,
414 VP: VertexProvider<Data>,
415{
416 type QueryComputerError = ANNError;
417 type QueryComputer = DiskQueryComputer;
418
419 fn build_query_computer(
420 &self,
421 _from: &[Data::VectorDataType],
422 ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
423 Ok(DiskQueryComputer {
424 num_pq_chunks: self.provider.pq_data.get_num_chunks(),
425 query_centroid_l2_distance: self
426 .scratch
427 .pq_scratch
428 .aligned_pqtable_dist_scratch
429 .as_slice()
430 .to_vec(),
431 })
432 }
433
434 async fn distances_unordered<Itr, F>(
435 &mut self,
436 vec_id_itr: Itr,
437 _computer: &Self::QueryComputer,
438 f: F,
439 ) -> Result<(), Self::GetError>
440 where
441 F: Send + FnMut(f32, Self::Id),
442 Itr: Iterator<Item = Self::Id>,
443 {
444 self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
445 }
446}
447
448impl<Data, VP> ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
449where
450 Data: GraphDataType<VectorIdType = u32>,
451 VP: VertexProvider<Data>,
452{
453 fn expand_beam<Itr, P, F>(
454 &mut self,
455 ids: Itr,
456 _computer: &Self::QueryComputer,
457 mut pred: P,
458 mut f: F,
459 ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
460 where
461 Itr: Iterator<Item = Self::Id> + Send,
462 P: glue::HybridPredicate<Self::Id> + Send + Sync,
463 F: FnMut(f32, Self::Id) + Send,
464 {
465 let result = (|| {
466 let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
467 let load_ids: Box<[_]> = ids.take(io_limit).collect();
468
469 self.ensure_loaded(&load_ids)?;
470 let mut ids = Vec::new();
471 for i in load_ids {
472 ids.clear();
473 ids.extend(
474 self.scratch
475 .vertex_provider
476 .get_adjacency_list(&i)?
477 .iter()
478 .copied()
479 .filter(|id| pred.eval_mut(id)),
480 );
481
482 self.pq_distances(&ids, &mut f)?;
483 }
484
485 Ok(())
486 })();
487
488 std::future::ready(result)
489 }
490}
491
492struct DiskSearchScratch<Data, VP>
495where
496 Data: GraphDataType<VectorIdType = u32>,
497 VP: VertexProvider<Data>,
498{
499 distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
500 pq_scratch: PQScratch,
501 vertex_provider: VP,
502}
503
504#[derive(Clone)]
505struct DiskSearchScratchArgs<'a, ProviderFactory> {
506 graph_degree: usize,
507 dim: usize,
508 num_pq_chunks: usize,
509 num_pq_centers: usize,
510 vertex_factory: &'a ProviderFactory,
511 graph_header: &'a GraphHeader,
512}
513
514impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
515 for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
516where
517 Data: GraphDataType<VectorIdType = u32>,
518 ProviderFactory: VertexProviderFactory<Data>,
519{
520 type Error = ANNError;
521
522 fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
523 let pq_scratch = PQScratch::new(
524 args.graph_degree,
525 args.dim,
526 args.num_pq_chunks,
527 args.num_pq_centers,
528 )?;
529
530 const DEFAULT_BEAM_WIDTH: usize = 0; let vertex_provider = args
532 .vertex_factory
533 .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
534
535 Ok(Self {
536 distance_cache: HashMap::new(),
537 pq_scratch,
538 vertex_provider,
539 })
540 }
541
542 fn try_modify(
543 &mut self,
544 _args: &DiskSearchScratchArgs<ProviderFactory>,
545 ) -> Result<(), Self::Error> {
546 self.distance_cache.clear();
547 self.vertex_provider.clear();
548 Ok(())
549 }
550}
551
552pub struct DiskAccessor<'a, Data, VP>
553where
554 Data: GraphDataType<VectorIdType = u32>,
555 VP: VertexProvider<Data>,
556{
557 provider: &'a DiskProvider<Data>,
558 io_tracker: &'a IOTracker,
559 scratch: PoolOption<DiskSearchScratch<Data, VP>>,
560 query: &'a [Data::VectorDataType],
561}
562
563impl<Data, VP> DiskAccessor<'_, Data, VP>
564where
565 Data: GraphDataType<VectorIdType = u32>,
566 VP: VertexProvider<Data>,
567{
568 fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
571 where
572 F: FnMut(f32, u32),
573 {
574 let pq_scratch = &mut self.scratch.pq_scratch;
575 compute_pq_distance(
576 ids,
577 self.provider.pq_data.get_num_chunks(),
578 &pq_scratch.aligned_pqtable_dist_scratch,
579 self.provider.pq_data.pq_compressed_data().get_data(),
580 &mut pq_scratch.aligned_pq_coord_scratch,
581 &mut pq_scratch.aligned_dist_scratch,
582 )?;
583
584 for (i, id) in ids.iter().enumerate() {
585 let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
586 f(distance, *id);
587 }
588
589 Ok(())
590 }
591}
592
593impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
594where
595 Data: GraphDataType<VectorIdType = u32>,
596 VP: VertexProvider<Data>,
597{
598 async fn starting_points(&self) -> ANNResult<Vec<u32>> {
599 let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
600 Ok(vec![start_vertex_id])
601 }
602
603 fn terminate_early(&mut self) -> bool {
604 self.io_tracker.io_count() > self.provider.search_io_limit
605 }
606}
607
608impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
609where
610 Data: GraphDataType<VectorIdType = u32>,
611 VP: VertexProvider<Data>,
612{
613 fn new<VPF>(
614 provider: &'a DiskProvider<Data>,
615 io_tracker: &'a IOTracker,
616 query: &'a [Data::VectorDataType],
617 vertex_provider_factory: &'a VPF,
618 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
619 ) -> ANNResult<Self>
620 where
621 VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
622 {
623 let mut scratch = PoolOption::try_pooled(
624 scratch_pool,
625 &DiskSearchScratchArgs {
626 graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
627 dim: provider.graph_header.metadata().dims,
628 num_pq_chunks: provider.pq_data.get_num_chunks(),
629 num_pq_centers: provider.pq_data.get_num_centers(),
630 vertex_factory: vertex_provider_factory,
631 graph_header: &provider.graph_header,
632 },
633 )?;
634
635 scratch.pq_scratch.set(
636 provider.graph_header.metadata().dims,
637 query,
638 1.0_f32, )?;
640 let start_vertex_id = provider.graph_header.metadata().medoid as u32;
641
642 let timer = Instant::now();
643 quantizer_preprocess(
644 &mut scratch.pq_scratch,
645 &provider.pq_data,
646 provider.metric,
647 &[start_vertex_id],
648 )?;
649 IOTracker::add_time(
650 &io_tracker.preprocess_time_us,
651 timer.elapsed().as_micros() as u64,
652 );
653
654 Ok(Self {
655 provider,
656 io_tracker,
657 scratch,
658 query,
659 })
660 }
661 fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
662 if ids.is_empty() {
663 return Ok(());
664 }
665 let scratch = &mut self.scratch;
666 let timer = Instant::now();
667 ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
668 IOTracker::add_time(
669 &self.io_tracker.io_time_us,
670 timer.elapsed().as_micros() as u64,
671 );
672 self.io_tracker.add_io_count(ids.len());
673 for id in ids {
674 let distance = self
675 .provider
676 .distance_comparer
677 .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
678 let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
679 scratch
680 .distance_cache
681 .insert(*id, (distance, associated_data));
682 }
683 Ok(())
684 }
685}
686
687impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
688where
689 Data: GraphDataType<VectorIdType = u32>,
690 VP: VertexProvider<Data>,
691{
692 type Id = u32;
693}
694
695impl<Data, VP> Accessor for DiskAccessor<'_, Data, VP>
696where
697 Data: GraphDataType<VectorIdType = u32>,
698 VP: VertexProvider<Data>,
699{
700 type Element<'a>
703 = &'a [u8]
704 where
705 Self: 'a;
706
707 type ElementRef<'a> = &'a [u8];
709
710 type GetError = ANNError;
712
713 fn get_element(
714 &mut self,
715 id: Self::Id,
716 ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
717 std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
718 }
719}
720
721impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
722where
723 Data: GraphDataType<VectorIdType = u32>,
724 VP: VertexProvider<Data>,
725{
726 async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
727 Ok(0..self.provider.num_points as u32)
728 }
729}
730
731impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
732where
733 Data: GraphDataType<VectorIdType = u32>,
734 VP: VertexProvider<Data>,
735{
736 type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
737 fn delegate_neighbor(&'a mut self) -> Self::Delegate {
738 AsNeighborAccessor(self)
739 }
740}
741
742pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
748where
749 Data: GraphDataType<VectorIdType = u32>,
750 VP: VertexProvider<Data>;
751
752impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
753where
754 Data: GraphDataType<VectorIdType = u32>,
755 VP: VertexProvider<Data>,
756{
757 type Id = u32;
758}
759
760impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
761where
762 Data: GraphDataType<VectorIdType = u32>,
763 VP: VertexProvider<Data>,
764{
765 fn get_neighbors(
766 self,
767 id: Self::Id,
768 neighbors: &mut AdjacencyList<Self::Id>,
769 ) -> impl Future<Output = ANNResult<Self>> + Send {
770 if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
771 return future::ok(self); }
773
774 if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
775 return future::err(e);
776 }
777 let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
778 Ok(list) => list,
779 Err(e) => return future::err(e),
780 };
781 neighbors.overwrite_trusted(list);
782 future::ok(self)
783 }
784}
785
786pub struct DiskIndexSearcher<
790 Data,
791 ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
792> where
793 Data: GraphDataType<VectorIdType = u32>,
794 ProviderFactory: VertexProviderFactory<Data>,
795{
796 index: DiskANNIndex<DiskProvider<Data>>,
797 runtime: Runtime,
798
799 vertex_provider_factory: ProviderFactory,
801
802 scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
804}
805
806#[derive(Debug)]
807pub struct SearchResultStats {
808 pub cmps: u32,
809 pub result_count: u32,
810 pub query_statistics: QueryStatistics,
811}
812
813pub struct SearchResult<AssociatedData> {
818 pub results: Vec<SearchResultItem<AssociatedData>>,
820 pub stats: SearchResultStats,
821}
822
823pub struct SearchResultItem<AssociatedData> {
828 pub vertex_id: u32,
830 pub data: AssociatedData,
833 pub distance: f32,
835}
836
837impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
838where
839 Data: GraphDataType<VectorIdType = u32>,
840 ProviderFactory: VertexProviderFactory<Data>,
841{
842 pub fn new(
852 num_threads: usize,
853 search_io_limit: usize,
854 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
855 vertex_provider_factory: ProviderFactory,
856 metric: Metric,
857 runtime: Option<Runtime>,
858 ) -> ANNResult<Self> {
859 let runtime = match runtime {
860 Some(rt) => rt,
861 None => tokio::runtime::Builder::new_current_thread().build()?,
862 };
863
864 let graph_header = vertex_provider_factory.get_header()?;
865 let metadata = graph_header.metadata();
866 let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
867
868 let config = graph::config::Builder::new(
869 max_degree.into_usize(),
870 graph::config::MaxDegree::default_slack(),
871 1, metric.into(),
873 )
874 .build()?;
875
876 debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
877
878 let graph_header = vertex_provider_factory.get_header()?;
879 let pq_data = disk_index_reader.get_pq_data();
880 let scratch_pool_args = DiskSearchScratchArgs {
881 graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
882 dim: graph_header.metadata().dims,
883 num_pq_chunks: pq_data.get_num_chunks(),
884 num_pq_centers: pq_data.get_num_centers(),
885 vertex_factory: &vertex_provider_factory,
886 graph_header: &graph_header,
887 };
888 let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
889
890 let disk_provider = DiskProvider::new(
891 disk_index_reader,
892 graph_header,
893 metric,
894 metadata.num_pts.into_usize(),
895 search_io_limit,
896 )?;
897
898 let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
899 Ok(Self {
900 index,
901 runtime,
902 vertex_provider_factory,
903 scratch_pool,
904 })
905 }
906
907 fn search_strategy<'a>(
909 &'a self,
910 query: &'a [Data::VectorDataType],
911 vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
912 ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
913 DiskSearchStrategy {
914 io_tracker: IOTracker::default(),
915 vector_filter,
916 query,
917 vertex_provider_factory: &self.vertex_provider_factory,
918 scratch_pool: &self.scratch_pool,
919 }
920 }
921
922 pub fn search(
925 &self,
926 query: &[Data::VectorDataType],
927 return_list_size: u32,
928 search_list_size: u32,
929 beam_width: Option<usize>,
930 vector_filter: Option<VectorFilter<Data>>,
931 is_flat_search: bool,
932 ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
933 let mut query_stats = QueryStatistics::default();
934 let mut indices = vec![0u32; return_list_size as usize];
935 let mut distances = vec![0f32; return_list_size as usize];
936 let mut associated_data =
937 vec![Data::AssociatedDataType::default(); return_list_size as usize];
938
939 let stats = self.search_internal(
940 query,
941 return_list_size as usize,
942 search_list_size,
943 beam_width,
944 &mut query_stats,
945 &mut indices,
946 &mut distances,
947 &mut associated_data,
948 &vector_filter.unwrap_or(default_vector_filter::<Data>()),
949 is_flat_search,
950 )?;
951
952 let mut search_result = SearchResult {
953 results: Vec::with_capacity(return_list_size as usize),
954 stats,
955 };
956
957 for ((vertex_id, distance), associated_data) in indices
958 .into_iter()
959 .zip(distances.into_iter())
960 .zip(associated_data.into_iter())
961 {
962 search_result.results.push(SearchResultItem {
963 vertex_id,
964 distance,
965 data: associated_data,
966 });
967 }
968
969 Ok(search_result)
970 }
971
972 #[allow(clippy::too_many_arguments)]
975 pub(crate) fn search_internal(
976 &self,
977 query: &[Data::VectorDataType],
978 k_value: usize,
979 search_list_size: u32,
980 beam_width: Option<usize>,
981 query_stats: &mut QueryStatistics,
982 indices: &mut [u32],
983 distances: &mut [f32],
984 associated_data: &mut [Data::AssociatedDataType],
985 vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
986 is_flat_search: bool,
987 ) -> ANNResult<SearchResultStats> {
988 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
989 &mut indices[..k_value],
990 &mut distances[..k_value],
991 &mut associated_data[..k_value],
992 );
993
994 let strategy = self.search_strategy(query, vector_filter);
995 let timer = Instant::now();
996 let k = k_value;
997 let l = search_list_size as usize;
998 let stats = if is_flat_search {
999 self.runtime.block_on(self.index.flat_search(
1000 &strategy,
1001 &DefaultContext,
1002 strategy.query,
1003 vector_filter,
1004 &Knn::new(k, l, beam_width)?,
1005 &mut result_output_buffer,
1006 ))?
1007 } else {
1008 let knn_search = Knn::new(k, l, beam_width)?;
1009 self.runtime.block_on(self.index.search(
1010 knn_search,
1011 &strategy,
1012 &DefaultContext,
1013 strategy.query,
1014 &mut result_output_buffer,
1015 ))?
1016 };
1017 query_stats.total_comparisons = stats.cmps;
1018 query_stats.search_hops = stats.hops;
1019
1020 query_stats.total_execution_time_us = timer.elapsed().as_micros();
1021 query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1022 query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1023 query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1024 query_stats.query_pq_preprocess_time_us =
1025 IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1026 query_stats.cpu_time_us = query_stats.total_execution_time_us
1027 - query_stats.io_time_us
1028 - query_stats.query_pq_preprocess_time_us;
1029 Ok(SearchResultStats {
1030 cmps: query_stats.total_comparisons,
1031 result_count: stats.result_count,
1032 query_statistics: query_stats.clone(),
1033 })
1034 }
1035}
1036
1037fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1043 vertex_provider: &mut V,
1044 ids: &[Data::VectorIdType],
1045) -> ANNResult<()> {
1046 vertex_provider.load_vertices(ids)?;
1047 for (idx, id) in ids.iter().enumerate() {
1048 vertex_provider.process_loaded_node(id, idx)?;
1049 }
1050 Ok(())
1051}
1052
1053#[cfg(test)]
1054mod disk_provider_tests {
1055 use diskann::{
1056 graph::{
1057 search::{record::VisitedSearchRecord, Knn},
1058 KnnSearchError,
1059 },
1060 utils::IntoUsize,
1061 ANNErrorKind,
1062 };
1063 use diskann_providers::storage::{
1064 DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1065 };
1066 use diskann_providers::{
1067 common::AlignedBoxWithSlice,
1068 test_utils::graph_data_type_utils::{
1069 GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1070 },
1071 utils::{create_thread_pool, load_aligned_bin, PQPathNames, ParallelIteratorInPool},
1072 };
1073 use diskann_utils::{io::read_bin, test_data_root};
1074 use diskann_vector::distance::Metric;
1075 use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1076 use rstest::rstest;
1077 use vfs::OverlayFS;
1078
1079 use super::*;
1080 use crate::{
1081 build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1082 utils::{QueryStatistics, VirtualAlignedReaderFactory},
1083 };
1084
1085 const TEST_INDEX_PREFIX_128DIM: &str =
1086 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1087 const TEST_INDEX_128DIM: &str =
1088 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1089 const TEST_PQ_PIVOT_128DIM: &str =
1090 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1091 const TEST_PQ_COMPRESSED_128DIM: &str =
1092 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1093 const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1094 "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1095 const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1096
1097 const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1098 const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1099 const TEST_PQ_PIVOT_100DIM: &str =
1100 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1101 const TEST_PQ_COMPRESSED_100DIM: &str =
1102 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1103 const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1104 "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1105 const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1106 const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1107 const TEST_INDEX: &str =
1108 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1109 const TEST_INDEX_PREFIX: &str =
1110 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1111 const TEST_PQ_PIVOT: &str =
1112 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1113 const TEST_PQ_COMPRESSED: &str =
1114 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1115
1116 #[test]
1117 fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1118 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1119
1120 let search_engine = create_disk_index_searcher(
1121 CreateDiskIndexSearcherParams {
1122 max_thread_num: 5,
1123 pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1124 pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1125 index_path: TEST_INDEX_100DIM,
1126 index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1127 ..Default::default()
1128 },
1129 &storage_provider,
1130 );
1131 test_disk_search(TestDiskSearchParams {
1133 storage_provider: storage_provider.as_ref(),
1134 index_search_engine: &search_engine,
1135 thread_num: 1,
1136 query_file_path: TEST_QUERY_10PTS_100DIM,
1137 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1138 k: 10,
1139 l: 20,
1140 dim: 104,
1141 });
1142 test_disk_search(TestDiskSearchParams {
1144 storage_provider: storage_provider.as_ref(),
1145 index_search_engine: &search_engine,
1146 thread_num: 5,
1147 query_file_path: TEST_QUERY_10PTS_100DIM,
1148 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1149 k: 10,
1150 l: 20,
1151 dim: 104,
1152 });
1153 }
1154
1155 #[test]
1156 fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1157 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1158
1159 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1160 CreateDiskIndexSearcherParams {
1161 max_thread_num: 5,
1162 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1163 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1164 index_path: TEST_INDEX_128DIM,
1165 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1166 ..Default::default()
1167 },
1168 &storage_provider,
1169 );
1170 test_disk_search(TestDiskSearchParams {
1172 storage_provider: storage_provider.as_ref(),
1173 index_search_engine: &search_engine,
1174 thread_num: 1,
1175 query_file_path: TEST_QUERY_10PTS_128DIM,
1176 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1177 k: 10,
1178 l: 20,
1179 dim: 128,
1180 });
1181 test_disk_search(TestDiskSearchParams {
1183 storage_provider: storage_provider.as_ref(),
1184 index_search_engine: &search_engine,
1185 thread_num: 5,
1186 query_file_path: TEST_QUERY_10PTS_128DIM,
1187 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1188 k: 10,
1189 l: 20,
1190 dim: 128,
1191 });
1192 }
1193
1194 fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1195 storage_provider: &StorageReader,
1196 ) -> Vec<u32> {
1197 const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1198
1199 let data =
1200 read_bin::<u32>(&mut storage_provider.open_reader(ASSOCIATED_DATA_FILE).unwrap())
1201 .unwrap();
1202 data.into_inner().into_vec()
1203 }
1204
1205 #[test]
1206 fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1207 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1208 let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1209 let params = TestParams {
1210 data_path: TEST_DATA_FILE.to_string(),
1211 index_path_prefix: index_path_prefix.to_string(),
1212 associated_data_path: Some(
1213 "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1214 ),
1215 ..TestParams::default()
1216 };
1217 let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1218 fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1220 {
1221 let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1222 CreateDiskIndexSearcherParams {
1223 max_thread_num: 5,
1224 pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1225 pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1226 .as_str(),
1227 index_path: format!("{}_disk.index", index_path_prefix).as_str(), index_path_prefix,
1229 ..Default::default()
1230 },
1231 &fixture.storage_provider,
1232 );
1233
1234 test_disk_search_with_associated(
1236 TestDiskSearchAssociateParams {
1237 storage_provider: fixture.storage_provider.as_ref(),
1238 index_search_engine: &search_engine,
1239 thread_num: 1,
1240 query_file_path: TEST_QUERY_10PTS_128DIM,
1241 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1242 k: 10,
1243 l: 20,
1244 dim: 128,
1245 },
1246 None,
1247 );
1248
1249 test_disk_search_with_associated(
1251 TestDiskSearchAssociateParams {
1252 storage_provider: fixture.storage_provider.as_ref(),
1253 index_search_engine: &search_engine,
1254 thread_num: 5,
1255 query_file_path: TEST_QUERY_10PTS_128DIM,
1256 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1257 k: 10,
1258 l: 20,
1259 dim: 128,
1260 },
1261 None,
1262 );
1263 }
1264
1265 fixture
1266 .storage_provider
1267 .delete(&format!("{}_disk.index", index_path_prefix))
1268 .expect("Failed to delete file");
1269 fixture
1270 .storage_provider
1271 .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1272 .expect("Failed to delete file");
1273 fixture
1274 .storage_provider
1275 .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1276 .expect("Failed to delete file");
1277 }
1278
1279 struct CreateDiskIndexSearcherParams<'a> {
1280 max_thread_num: usize,
1281 pq_pivot_file_path: &'a str,
1282 pq_compressed_file_path: &'a str,
1283 index_path: &'a str,
1284 index_path_prefix: &'a str,
1285 io_limit: usize,
1286 }
1287
1288 impl Default for CreateDiskIndexSearcherParams<'_> {
1289 fn default() -> Self {
1290 Self {
1291 max_thread_num: 1,
1292 pq_pivot_file_path: "",
1293 pq_compressed_file_path: "",
1294 index_path: "",
1295 index_path_prefix: "",
1296 io_limit: usize::MAX,
1297 }
1298 }
1299 }
1300
1301 fn create_disk_index_searcher<Data>(
1302 params: CreateDiskIndexSearcherParams,
1303 storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1304 ) -> DiskIndexSearcher<
1305 Data,
1306 DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1307 >
1308 where
1309 Data: GraphDataType<VectorIdType = u32>,
1310 {
1311 assert!(params.io_limit > 0);
1312
1313 let runtime = tokio::runtime::Builder::new_multi_thread()
1314 .worker_threads(params.max_thread_num)
1315 .build()
1316 .unwrap();
1317
1318 let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1319 params.pq_pivot_file_path.to_string(),
1320 params.pq_compressed_file_path.to_string(),
1321 storage_provider.as_ref(),
1322 )
1323 .unwrap();
1324
1325 let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1326 get_disk_index_file(params.index_path_prefix),
1327 Arc::clone(storage_provider),
1328 );
1329 let caching_strategy = CachingStrategy::None;
1330 let vertex_provider_factory =
1331 DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1332 .unwrap();
1333
1334 DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1335 params.max_thread_num,
1336 params.io_limit,
1337 &disk_index_reader,
1338 vertex_provider_factory,
1339 Metric::L2,
1340 Some(runtime),
1341 )
1342 .unwrap()
1343 }
1344
1345 fn load_query_result<StorageReader: StorageReadProvider>(
1346 storage_provider: &StorageReader,
1347 query_result_path: &str,
1348 ) -> Vec<u32> {
1349 let result =
1350 read_bin::<u32>(&mut storage_provider.open_reader(query_result_path).unwrap()).unwrap();
1351 result.into_inner().into_vec()
1352 }
1353
1354 struct TestDiskSearchParams<'a, StorageType> {
1355 storage_provider: &'a StorageType,
1356 index_search_engine: &'a DiskIndexSearcher<
1357 GraphDataF32VectorUnitData,
1358 DiskVertexProviderFactory<
1359 GraphDataF32VectorUnitData,
1360 VirtualAlignedReaderFactory<OverlayFS>,
1361 >,
1362 >,
1363 thread_num: u64,
1364 query_file_path: &'a str,
1365 truth_result_file_path: &'a str,
1366 k: usize,
1367 l: usize,
1368 dim: usize,
1369 }
1370
1371 struct TestDiskSearchAssociateParams<'a, StorageType> {
1372 storage_provider: &'a StorageType,
1373 index_search_engine: &'a DiskIndexSearcher<
1374 GraphDataF32VectorU32Data,
1375 DiskVertexProviderFactory<
1376 GraphDataF32VectorU32Data,
1377 VirtualAlignedReaderFactory<OverlayFS>,
1378 >,
1379 >,
1380 thread_num: u64,
1381 query_file_path: &'a str,
1382 truth_result_file_path: &'a str,
1383 k: usize,
1384 l: usize,
1385 dim: usize,
1386 }
1387
1388 fn test_disk_search<StorageType: StorageReadProvider>(
1389 params: TestDiskSearchParams<StorageType>,
1390 ) {
1391 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1392 .unwrap()
1393 .0;
1394 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1395 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1396
1397 let queries = aligned_query
1398 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1399 .unwrap();
1400
1401 let truth_result =
1402 load_query_result(params.storage_provider, params.truth_result_file_path);
1403
1404 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1405 queries
1407 .par_iter()
1408 .enumerate()
1409 .for_each_in_pool(&pool, |(i, query)| {
1410 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1412 let mut temp = Vec::with_capacity(query.len() + 1);
1413 temp.push(0.0);
1414 temp.extend_from_slice(query);
1415 aligned_box.memcpy(temp.as_slice()).unwrap();
1416 let query = &aligned_box.as_slice()[1..];
1417
1418 let mut query_stats = QueryStatistics::default();
1419 let mut indices = vec![0u32; 10];
1420 let mut distances = vec![0f32; 10];
1421 let mut associated_data = vec![(); 10];
1422
1423 let result = params
1424 .index_search_engine
1425 .search_internal(
1427 query,
1428 params.k,
1429 params.l as u32,
1430 None, &mut query_stats,
1432 &mut indices,
1433 &mut distances,
1434 &mut associated_data,
1435 &(|_| true),
1436 false,
1437 );
1438
1439 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1441
1442 assert!(result.is_ok(), "Expected search to succeed");
1443
1444 let result_unwrapped = result.unwrap();
1445 assert!(
1446 result_unwrapped.query_statistics.total_io_operations > 0,
1447 "Expected IO operations to be greater than 0"
1448 );
1449 assert!(
1450 result_unwrapped.query_statistics.total_vertices_loaded > 0,
1451 "Expected vertices loaded to be greater than 0"
1452 );
1453
1454 assert_eq!(
1456 indices, truth_slice,
1457 "Results DO NOT match with the truth result for query {}",
1458 i
1459 );
1460 });
1461 }
1462
1463 fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1464 params: TestDiskSearchAssociateParams<StorageType>,
1465 beam_width: Option<usize>,
1466 ) {
1467 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1468 .unwrap()
1469 .0;
1470 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1471 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1472 let queries = aligned_query
1473 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1474 .unwrap();
1475 let truth_result =
1476 load_query_result(params.storage_provider, params.truth_result_file_path);
1477 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1478 queries
1480 .par_iter()
1481 .enumerate()
1482 .for_each_in_pool(&pool, |(i, query)| {
1483 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1485 let mut temp = Vec::with_capacity(query.len() + 1);
1486 temp.push(0.0);
1487 temp.extend_from_slice(query);
1488 aligned_box.memcpy(temp.as_slice()).unwrap();
1489 let query = &aligned_box.as_slice()[1..];
1490 let result = params
1491 .index_search_engine
1492 .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1493 .unwrap();
1494 let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1495 let associated_data: Vec<u32> =
1496 result.results.iter().map(|item| item.data).collect();
1497 let truth_data = get_truth_associated_data(params.storage_provider);
1498 let associated_data_truth: Vec<u32> = indices
1499 .iter()
1500 .map(|&vid| truth_data[vid as usize])
1501 .collect();
1502 assert_eq!(
1503 associated_data, associated_data_truth,
1504 "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1505 i,associated_data, associated_data_truth
1506 );
1507 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1508 assert_eq!(
1509 indices, truth_slice,
1510 "Results DO NOT match with the truth result for query {}",
1511 i
1512 );
1513 });
1514 }
1515
1516 #[test]
1517 fn test_disk_search_invalid_input() {
1518 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1519 let ctx = &DefaultContext;
1520
1521 let params = CreateDiskIndexSearcherParams {
1522 max_thread_num: 5,
1523 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1524 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1525 index_path: TEST_INDEX_128DIM,
1526 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1527 ..Default::default()
1528 };
1529
1530 let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1531 assert_eq!(
1532 paths.pivots, params.pq_pivot_file_path,
1533 "pq_pivot_file_path is not correct"
1534 );
1535 assert_eq!(
1536 paths.compressed_data, params.pq_compressed_file_path,
1537 "pq_compressed_file_path is not correct"
1538 );
1539 assert_eq!(
1540 params.index_path,
1541 format!("{}_disk.index", params.index_path_prefix),
1542 "index_path is not correct"
1543 );
1544
1545 let res = Knn::new_default(20, 10);
1547 assert!(res.is_err());
1548 assert_eq!(
1549 <KnnSearchError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1550 ANNErrorKind::IndexError
1551 );
1552 let res = Knn::new(10, 10, Some(0));
1554 assert!(res.is_err());
1555
1556 let search_engine =
1557 create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1558
1559 assert_eq!(
1561 search_engine
1562 .index
1563 .data_provider
1564 .to_external_id(ctx, 0)
1565 .unwrap(),
1566 0
1567 );
1568 assert_eq!(
1569 search_engine
1570 .index
1571 .data_provider
1572 .to_internal_id(ctx, &0)
1573 .unwrap(),
1574 0
1575 );
1576
1577 let provider_max_degree = search_engine
1578 .index
1579 .data_provider
1580 .graph_header
1581 .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1582 .unwrap();
1583 let index_max_degree = search_engine.index.config.pruned_degree().get();
1584 assert_eq!(provider_max_degree, index_max_degree);
1585
1586 let query = vec![0f32; 128];
1587 let mut query_stats = QueryStatistics::default();
1588 let mut indices = vec![0u32; 10];
1589 let mut distances = vec![0f32; 10];
1590 let mut associated_data = vec![0u32; 10];
1591
1592 let result = search_engine.search_internal(
1594 &query,
1595 10,
1596 10 - 1,
1597 None,
1598 &mut query_stats,
1599 &mut indices,
1600 &mut distances,
1601 &mut associated_data,
1602 &|_| true,
1603 false,
1604 );
1605
1606 assert!(result.is_err());
1607 assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1608 }
1609
1610 #[test]
1611 fn test_disk_search_beam_search() {
1612 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1613
1614 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1615 CreateDiskIndexSearcherParams {
1616 max_thread_num: 1,
1617 pq_pivot_file_path: TEST_PQ_PIVOT,
1618 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1619 index_path: TEST_INDEX,
1620 index_path_prefix: TEST_INDEX_PREFIX,
1621 ..Default::default()
1622 },
1623 &storage_provider,
1624 );
1625
1626 let query_vector: [f32; 128] = [1f32; 128];
1627 let mut indices = vec![0u32; 10];
1628 let mut distances = vec![0f32; 10];
1629 let mut associated_data = vec![(); 10];
1630
1631 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1632 &mut indices,
1633 &mut distances,
1634 &mut associated_data,
1635 );
1636 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1637 let mut search_record = VisitedSearchRecord::new(0);
1638 let search_params = Knn::new(10, 10, Some(4)).unwrap();
1639 let recorded_search =
1640 diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
1641 search_engine
1642 .runtime
1643 .block_on(search_engine.index.search(
1644 recorded_search,
1645 &strategy,
1646 &DefaultContext,
1647 query_vector.as_slice(),
1648 &mut result_output_buffer,
1649 ))
1650 .unwrap();
1651
1652 let ids = search_record
1653 .visited
1654 .iter()
1655 .map(|n| n.id)
1656 .collect::<Vec<_>>();
1657
1658 const EXPECTED_NODES: [u32; 18] = [
1659 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1660 ]; assert_eq!(ids, &EXPECTED_NODES);
1663
1664 let return_list_size = 10;
1665 let search_list_size = 10;
1666 let result = search_engine.search(
1667 &query_vector,
1668 return_list_size,
1669 search_list_size,
1670 Some(4),
1671 None,
1672 false,
1673 );
1674 assert!(result.is_ok(), "Expected search to succeed");
1675 let search_result = result.unwrap();
1676 assert_eq!(
1677 search_result.results.len() as u32,
1678 return_list_size,
1679 "Expected result count to match"
1680 );
1681 assert_eq!(
1682 indices,
1683 vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1684 "Expected indices to match"
1685 );
1686 }
1687
1688 #[cfg(feature = "experimental_diversity_search")]
1689 #[test]
1690 fn test_disk_search_diversity_search() {
1691 use diskann::graph::DiverseSearchParams;
1692 use diskann::neighbor::AttributeValueProvider;
1693 use std::collections::HashMap;
1694
1695 #[derive(Debug, Clone)]
1697 struct TestAttributeProvider {
1698 attributes: HashMap<u32, u32>,
1699 }
1700 impl TestAttributeProvider {
1701 fn new() -> Self {
1702 Self {
1703 attributes: HashMap::new(),
1704 }
1705 }
1706 fn insert(&mut self, id: u32, attribute: u32) {
1707 self.attributes.insert(id, attribute);
1708 }
1709 }
1710 impl diskann::provider::HasId for TestAttributeProvider {
1711 type Id = u32;
1712 }
1713
1714 impl AttributeValueProvider for TestAttributeProvider {
1715 type Value = u32;
1716
1717 fn get(&self, id: Self::Id) -> Option<Self::Value> {
1718 self.attributes.get(&id).copied()
1719 }
1720 }
1721
1722 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1723
1724 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1725 CreateDiskIndexSearcherParams {
1726 max_thread_num: 1,
1727 pq_pivot_file_path: TEST_PQ_PIVOT,
1728 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1729 index_path: TEST_INDEX,
1730 index_path_prefix: TEST_INDEX_PREFIX,
1731 ..Default::default()
1732 },
1733 &storage_provider,
1734 );
1735
1736 let query_vector: [f32; 128] = [1f32; 128];
1737
1738 let mut attribute_provider = TestAttributeProvider::new();
1740 let num_vectors = 256; for i in 0..num_vectors {
1742 let label = (i % 15) + 1;
1744 attribute_provider.insert(i, label);
1745 }
1746 let attribute_provider = std::sync::Arc::new(attribute_provider);
1748
1749 let mut indices = vec![0u32; 10];
1750 let mut distances = vec![0f32; 10];
1751 let mut associated_data = vec![(); 10];
1752
1753 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1754 &mut indices,
1755 &mut distances,
1756 &mut associated_data,
1757 );
1758 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1759
1760 let diverse_params = DiverseSearchParams::new(
1762 0, 3, attribute_provider.clone(),
1765 );
1766
1767 let search_params = Knn::new(10, 20, None).unwrap();
1768
1769 let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params);
1770 let stats = search_engine
1771 .runtime
1772 .block_on(search_engine.index.search(
1773 diverse_search,
1774 &strategy,
1775 &DefaultContext,
1776 query_vector.as_slice(),
1777 &mut result_output_buffer,
1778 ))
1779 .unwrap();
1780
1781 assert!(
1783 stats.result_count > 0,
1784 "Expected to get some results during diversity search"
1785 );
1786
1787 let return_list_size = 10;
1788 let search_list_size = 20;
1789 let diverse_results_k = 1;
1790 let diverse_params = DiverseSearchParams::new(
1791 0, diverse_results_k,
1793 attribute_provider.clone(),
1794 );
1795
1796 let mut indices2 = vec![0u32; return_list_size as usize];
1798 let mut distances2 = vec![0f32; return_list_size as usize];
1799 let mut associated_data2 = vec![(); return_list_size as usize];
1800 let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1801 &mut indices2,
1802 &mut distances2,
1803 &mut associated_data2,
1804 );
1805 let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1806 let search_params2 =
1807 Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1808
1809 let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params);
1810 let stats = search_engine
1811 .runtime
1812 .block_on(search_engine.index.search(
1813 diverse_search2,
1814 &strategy2,
1815 &DefaultContext,
1816 query_vector.as_slice(),
1817 &mut result_output_buffer2,
1818 ))
1819 .unwrap();
1820
1821 assert!(
1823 stats.result_count > 0,
1824 "Expected diversity search to return results"
1825 );
1826 assert!(
1827 stats.result_count <= return_list_size,
1828 "Expected result count to be <= {}",
1829 return_list_size
1830 );
1831
1832 assert!(
1834 stats.result_count > 0,
1835 "Expected to get some search results"
1836 );
1837
1838 println!("\n=== Diversity Search Results ===");
1840 println!("Query: [1f32; 128]");
1841 println!("diverse_results_k: {}", diverse_results_k);
1842 println!("Total results: {}\n", stats.result_count);
1843 println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1844 println!("{}", "-".repeat(35));
1845 for i in 0..stats.result_count as usize {
1846 let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1847 println!(
1848 "{:<10} {:<15.2} {:<10}",
1849 indices2[i], distances2[i], attribute_value
1850 );
1851 }
1852
1853 for i in 0..(stats.result_count as usize).saturating_sub(1) {
1855 assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1856 assert!(
1857 distances2[i] <= distances2[i + 1],
1858 "Expected distances to be sorted in ascending order"
1859 );
1860 }
1861
1862 let mut attribute_counts = HashMap::new();
1864 for item in indices2.iter().take(stats.result_count as usize) {
1865 if let Some(attribute_value) = attribute_provider.get(*item) {
1866 *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1867 }
1868 }
1869
1870 println!("\n=== Attribute Distribution ===");
1872 let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1873 sorted_attrs.sort_by_key(|(k, _)| *k);
1874 for (attribute_value, count) in &sorted_attrs {
1875 println!(
1876 "Label {}: {} occurrences (max allowed: {})",
1877 attribute_value, count, diverse_results_k
1878 );
1879 }
1880 println!("Total unique labels: {}", attribute_counts.len());
1881 println!("================================\n");
1882
1883 for (attribute_value, count) in &attribute_counts {
1885 println!(
1886 "Assert: Label {} has {} occurrences (max: {})",
1887 attribute_value, count, diverse_results_k
1888 );
1889 assert!(
1890 *count <= diverse_results_k,
1891 "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1892 attribute_value,
1893 count,
1894 diverse_results_k
1895 );
1896 }
1897
1898 println!(
1901 "Assert: Found {} unique labels (expected at least 2)",
1902 attribute_counts.len()
1903 );
1904 assert!(
1905 attribute_counts.len() >= 2,
1906 "Expected at least 2 different attribute values for diversity, got {}",
1907 attribute_counts.len()
1908 );
1909 }
1910
1911 #[rstest]
1912 #[case(
1914 |_id: &u32| true,
1915 false,
1916 10,
1917 vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1918 vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1919 )]
1920 #[case(
1923 |id: &u32| *id == 0 || *id == 1,
1924 false,
1925 0,
1926 vec![0; 10],
1927 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1928 )]
1929 #[case(
1932 |id: &u32| *id == 0 || *id == 1,
1933 true,
1934 2,
1935 vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1936 vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1937 )]
1938 #[case(
1941 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1942 false,
1943 3,
1944 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1945 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1946 )]
1947 #[case(
1950 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1951 true,
1952 3,
1953 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1954 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1955 )]
1956 fn test_search_with_vector_filter(
1957 #[case] vector_filter: fn(&u32) -> bool,
1958 #[case] is_flat_search: bool,
1959 #[case] expected_result_count: u32,
1960 #[case] expected_indices: Vec<u32>,
1961 #[case] expected_distances: Vec<f32>,
1962 ) {
1963 let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1968 const ABS_TOLERANCE: f32 = 0.02;
1969 assert_eq!(got.len(), expected.len());
1970 for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1971 if (g - e).abs() > ABS_TOLERANCE {
1972 panic!(
1973 "distances differ at position {} by more than {}\n\n\
1974 got: {:?}\nexpected: {:?}",
1975 i, ABS_TOLERANCE, got, expected,
1976 );
1977 }
1978 }
1979 true
1980 };
1981
1982 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1983
1984 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1985 CreateDiskIndexSearcherParams {
1986 max_thread_num: 5,
1987 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1988 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1989 index_path: TEST_INDEX_128DIM,
1990 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1991 ..Default::default()
1992 },
1993 &storage_provider,
1994 );
1995 let query = vec![0.1f32; 128];
1996 let mut query_stats = QueryStatistics::default();
1997 let mut indices = vec![0u32; 10];
1998 let mut distances = vec![0f32; 10];
1999 let mut associated_data = vec![(); 10];
2000
2001 let result = search_engine.search_internal(
2002 &query,
2003 10,
2004 10,
2005 None, &mut query_stats,
2007 &mut indices,
2008 &mut distances,
2009 &mut associated_data,
2010 &vector_filter,
2011 is_flat_search,
2012 );
2013
2014 assert!(result.is_ok(), "Expected search to succeed");
2015 assert_eq!(
2016 result.unwrap().result_count,
2017 expected_result_count,
2018 "Expected result count to match"
2019 );
2020 assert_eq!(indices, expected_indices, "Expected indices to match");
2021 assert!(
2022 check_distances(&distances, &expected_distances),
2023 "Expected distances to match"
2024 );
2025
2026 let result_with_filter = search_engine.search(
2027 &query,
2028 10,
2029 10,
2030 None, Some(Box::new(vector_filter)),
2032 is_flat_search,
2033 );
2034
2035 assert!(result_with_filter.is_ok(), "Expected search to succeed");
2036 let result_with_filter_unwrapped = result_with_filter.unwrap();
2037 assert_eq!(
2038 result_with_filter_unwrapped.stats.result_count, expected_result_count,
2039 "Expected result count to match"
2040 );
2041 let actual_indices = result_with_filter_unwrapped
2042 .results
2043 .iter()
2044 .map(|x| x.vertex_id)
2045 .collect::<Vec<_>>();
2046 assert_eq!(
2047 actual_indices, expected_indices,
2048 "Expected indices to match"
2049 );
2050 let actual_distances = result_with_filter_unwrapped
2051 .results
2052 .iter()
2053 .map(|x| x.distance)
2054 .collect::<Vec<_>>();
2055 assert!(
2056 check_distances(&actual_distances, &expected_distances),
2057 "Expected distances to match"
2058 );
2059 }
2060
2061 #[test]
2062 fn test_beam_search_respects_io_limit() {
2063 let io_limit = 11; let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2065
2066 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2067 CreateDiskIndexSearcherParams {
2068 max_thread_num: 1,
2069 pq_pivot_file_path: TEST_PQ_PIVOT,
2070 pq_compressed_file_path: TEST_PQ_COMPRESSED,
2071 index_path: TEST_INDEX,
2072 index_path_prefix: TEST_INDEX_PREFIX,
2073 io_limit,
2074 },
2075 &storage_provider,
2076 );
2077 let query_vector: [f32; 128] = [1f32; 128];
2078
2079 let mut indices = vec![0u32; 10];
2080 let mut distances = vec![0f32; 10];
2081 let mut associated_data = vec![(); 10];
2082
2083 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2084 &mut indices,
2085 &mut distances,
2086 &mut associated_data,
2087 );
2088
2089 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2090
2091 let mut search_record = VisitedSearchRecord::new(0);
2092 let search_params = Knn::new(10, 10, Some(4)).unwrap();
2093 let recorded_search =
2094 diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
2095 search_engine
2096 .runtime
2097 .block_on(search_engine.index.search(
2098 recorded_search,
2099 &strategy,
2100 &DefaultContext,
2101 query_vector.as_slice(),
2102 &mut result_output_buffer,
2103 ))
2104 .unwrap();
2105 let visited_ids = search_record
2106 .visited
2107 .iter()
2108 .map(|n| n.id)
2109 .collect::<Vec<_>>();
2110
2111 let query_stats = strategy.io_tracker;
2112 assert!(
2114 query_stats.io_count() <= io_limit,
2115 "Expected IO operations to be <= {}, but got {}",
2116 io_limit,
2117 query_stats.io_count()
2118 );
2119
2120 const EXPECTED_NODES: [u32; 17] = [
2121 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2122 ]; let mut matching_count = 0;
2126 for expected_node in EXPECTED_NODES.iter() {
2127 if visited_ids.contains(expected_node) {
2128 matching_count += 1;
2129 }
2130 }
2131
2132 let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2134
2135 assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2138 }
2139}