1use std::{
7 collections::HashMap,
8 future::Future,
9 num::NonZeroUsize,
10 ops::Range,
11 sync::{
12 atomic::{AtomicU64, AtomicUsize},
13 Arc,
14 },
15 time::Instant,
16};
17
18use diskann::{
19 graph::{
20 self,
21 glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy},
22 search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, SearchParams,
23 },
24 neighbor::Neighbor,
25 provider::{
26 Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
27 NeighborAccessor,
28 },
29 utils::{
30 object_pool::{ObjectPool, PoolOption, TryAsPooled},
31 IntoUsize, VectorRepr,
32 },
33 ANNError, ANNResult,
34};
35use diskann_providers::storage::StorageReadProvider;
36use diskann_providers::{
37 model::{
38 compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
39 pq::quantizer_preprocess, PQData, PQScratch,
40 },
41 storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
42};
43use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
44use futures_util::future;
45use tokio::runtime::Runtime;
46use tracing::debug;
47
48use crate::{
49 data_model::{CachingStrategy, GraphHeader},
50 filter_parameter::{default_vector_filter, VectorFilter},
51 search::{
52 provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
53 traits::{VertexProvider, VertexProviderFactory},
54 },
55 storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
56 utils::AlignedFileReaderFactory,
57 utils::QueryStatistics,
58};
59
60pub struct DiskProvider<Data>
70where
71 Data: GraphDataType<VectorIdType = u32>,
72{
73 graph_header: GraphHeader,
75
76 distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
78
79 pq_data: Arc<PQData>,
81
82 num_points: usize,
84
85 metric: Metric,
87
88 search_io_limit: usize,
90}
91
92impl<Data> DataProvider for DiskProvider<Data>
93where
94 Data: GraphDataType<VectorIdType = u32>,
95{
96 type Context = DefaultContext;
97
98 type InternalId = u32;
99
100 type ExternalId = u32;
101
102 type Error = ANNError;
103
104 fn to_internal_id(
106 &self,
107 _context: &DefaultContext,
108 gid: &Self::ExternalId,
109 ) -> Result<Self::InternalId, Self::Error> {
110 Ok(*gid)
111 }
112
113 fn to_external_id(
115 &self,
116 _context: &DefaultContext,
117 id: Self::InternalId,
118 ) -> Result<Self::ExternalId, Self::Error> {
119 Ok(id)
120 }
121}
122
123impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
124where
125 Data: GraphDataType<VectorIdType = u32>,
126{
127 type Error = ANNError;
128
129 async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
130 where
131 P: StorageReadProvider,
132 {
133 debug!(
134 "DiskProvider::load_with() called with file: {:?}",
135 get_disk_index_file(ctx.quant_load_context.metadata.prefix())
136 );
137
138 let graph_header = {
139 let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
140 ctx.quant_load_context.metadata.prefix(),
141 ));
142
143 let caching_strategy = if ctx.num_nodes_to_cache > 0 {
144 CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
145 } else {
146 CachingStrategy::None
147 };
148
149 let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
150 aligned_reader_factory,
151 caching_strategy,
152 )?;
153 VertexProviderFactory::get_header(&vertex_provider_factory)?
154 };
155
156 let metric = ctx.quant_load_context.metric;
157 let num_points = ctx.num_points;
158
159 let index_path_prefix = ctx.quant_load_context.metadata.prefix();
160 let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
161 get_pq_pivot_file(index_path_prefix),
162 get_compressed_pq_file(index_path_prefix),
163 provider,
164 )?;
165
166 Self::new(
167 &index_reader,
168 graph_header,
169 metric,
170 num_points,
171 ctx.search_io_limit,
172 )
173 }
174}
175
176impl<Data> DiskProvider<Data>
177where
178 Data: GraphDataType<VectorIdType = u32>,
179{
180 fn new(
181 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
182 graph_header: GraphHeader,
183 metric: Metric,
184 num_points: usize,
185 search_io_limit: usize,
186 ) -> ANNResult<Self> {
187 let distance_comparer =
188 Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
189
190 let pq_data = disk_index_reader.get_pq_data();
191
192 Ok(Self {
193 graph_header,
194 distance_comparer,
195 pq_data,
196 num_points,
197 metric,
198 search_io_limit,
199 })
200 }
201}
202
203pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
214where
215 Data: GraphDataType<VectorIdType = u32>,
216 ProviderFactory: VertexProviderFactory<Data>,
217{
218 io_tracker: IOTracker,
220 vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), query: &'a [Data::VectorDataType],
222
223 vertex_provider_factory: &'a ProviderFactory,
225
226 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
228}
229
230struct IOTracker {
233 io_time_us: AtomicU64,
234 preprocess_time_us: AtomicU64,
235 io_count: AtomicUsize,
236}
237
238impl Default for IOTracker {
239 fn default() -> Self {
240 Self {
241 io_time_us: AtomicU64::new(0),
242 preprocess_time_us: AtomicU64::new(0),
243 io_count: AtomicUsize::new(0),
244 }
245 }
246}
247
248impl IOTracker {
249 fn add_time(category: &AtomicU64, time: u64) {
250 category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
251 }
252
253 fn time(category: &AtomicU64) -> u64 {
254 category.load(std::sync::atomic::Ordering::Relaxed)
255 }
256
257 fn add_io_count(&self, count: usize) {
258 self.io_count
259 .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
260 }
261
262 fn io_count(&self) -> usize {
263 self.io_count.load(std::sync::atomic::Ordering::Relaxed)
264 }
265}
266
267#[derive(Clone, Copy)]
268pub struct RerankAndFilter<'a> {
269 filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
270}
271
272impl<'a> RerankAndFilter<'a> {
273 fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
274 Self { filter }
275 }
276}
277
278impl<Data, VP>
279 SearchPostProcess<
280 DiskAccessor<'_, Data, VP>,
281 [Data::VectorDataType],
282 (
283 <DiskProvider<Data> as DataProvider>::InternalId,
284 Data::AssociatedDataType,
285 ),
286 > for RerankAndFilter<'_>
287where
288 Data: GraphDataType<VectorIdType = u32>,
289 VP: VertexProvider<Data>,
290{
291 type Error = ANNError;
292 async fn post_process<I, B>(
293 &self,
294 accessor: &mut DiskAccessor<'_, Data, VP>,
295 query: &[Data::VectorDataType],
296 _computer: &DiskQueryComputer,
297 candidates: I,
298 output: &mut B,
299 ) -> Result<usize, Self::Error>
300 where
301 I: Iterator<Item = Neighbor<u32>> + Send,
302 B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized,
303 {
304 let provider = accessor.provider;
305
306 let mut uncached_ids = Vec::new();
307 let mut reranked = candidates
308 .map(|n| n.id)
309 .filter(|id| (self.filter)(id))
310 .filter_map(|n| {
311 if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
312 Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
313 } else {
314 uncached_ids.push(n);
315 None
316 }
317 })
318 .collect::<Result<Vec<_>, _>>()?;
319 if !uncached_ids.is_empty() {
320 ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
321 for n in &uncached_ids {
322 let v = accessor.scratch.vertex_provider.get_vector(n)?;
323 let d = provider.distance_comparer.evaluate_similarity(query, v);
324 let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
325 reranked.push(((*n, *a), d));
326 }
327 }
328
329 reranked
331 .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
332 Ok(output.extend(reranked))
334 }
335}
336
337impl<'this, Data, ProviderFactory>
338 SearchStrategy<
339 DiskProvider<Data>,
340 [Data::VectorDataType],
341 (
342 <DiskProvider<Data> as DataProvider>::InternalId,
343 Data::AssociatedDataType,
344 ),
345 > for DiskSearchStrategy<'this, Data, ProviderFactory>
346where
347 Data: GraphDataType<VectorIdType = u32>,
348 ProviderFactory: VertexProviderFactory<Data>,
349{
350 type QueryComputer = DiskQueryComputer;
351 type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
352 type SearchAccessorError = ANNError;
353 type PostProcessor = RerankAndFilter<'this>;
354
355 fn search_accessor<'a>(
356 &'a self,
357 provider: &'a DiskProvider<Data>,
358 _context: &DefaultContext,
359 ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
360 DiskAccessor::new(
361 provider,
362 &self.io_tracker,
363 self.query,
364 self.vertex_provider_factory,
365 self.scratch_pool,
366 )
367 }
368
369 fn post_processor(&self) -> Self::PostProcessor {
370 RerankAndFilter::new(self.vector_filter)
371 }
372}
373
374pub struct DiskQueryComputer {
376 num_pq_chunks: usize,
377 query_centroid_l2_distance: Vec<f32>,
378}
379
380impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
381 fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
382 let mut dist = 0.0f32;
383 #[allow(clippy::expect_used)]
384 compute_pq_distance_for_pq_coordinates(
385 changing,
386 self.num_pq_chunks,
387 &self.query_centroid_l2_distance,
388 std::slice::from_mut(&mut dist),
389 )
390 .expect("PQ distance compute for PQ coordinates is expected to succeed");
391 dist
392 }
393}
394
395impl<Data, VP> BuildQueryComputer<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
396where
397 Data: GraphDataType<VectorIdType = u32>,
398 VP: VertexProvider<Data>,
399{
400 type QueryComputerError = ANNError;
401 type QueryComputer = DiskQueryComputer;
402
403 fn build_query_computer(
404 &self,
405 _from: &[Data::VectorDataType],
406 ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
407 Ok(DiskQueryComputer {
408 num_pq_chunks: self.provider.pq_data.get_num_chunks(),
409 query_centroid_l2_distance: self
410 .scratch
411 .pq_scratch
412 .aligned_pqtable_dist_scratch
413 .as_slice()
414 .to_vec(),
415 })
416 }
417
418 async fn distances_unordered<Itr, F>(
419 &mut self,
420 vec_id_itr: Itr,
421 _computer: &Self::QueryComputer,
422 f: F,
423 ) -> Result<(), Self::GetError>
424 where
425 F: Send + FnMut(f32, Self::Id),
426 Itr: Iterator<Item = Self::Id>,
427 {
428 self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
429 }
430}
431
432impl<Data, VP> ExpandBeam<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
433where
434 Data: GraphDataType<VectorIdType = u32>,
435 VP: VertexProvider<Data>,
436{
437 fn expand_beam<Itr, P, F>(
438 &mut self,
439 ids: Itr,
440 _computer: &Self::QueryComputer,
441 mut pred: P,
442 mut f: F,
443 ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
444 where
445 Itr: Iterator<Item = Self::Id> + Send,
446 P: glue::HybridPredicate<Self::Id> + Send + Sync,
447 F: FnMut(f32, Self::Id) + Send,
448 {
449 let result = (|| {
450 let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
451 let load_ids: Box<[_]> = ids.take(io_limit).collect();
452
453 self.ensure_loaded(&load_ids)?;
454 let mut ids = Vec::new();
455 for i in load_ids {
456 ids.clear();
457 ids.extend(
458 self.scratch
459 .vertex_provider
460 .get_adjacency_list(&i)?
461 .iter()
462 .copied()
463 .filter(|id| pred.eval_mut(id)),
464 );
465
466 self.pq_distances(&ids, &mut f)?;
467 }
468
469 Ok(())
470 })();
471
472 std::future::ready(result)
473 }
474}
475
476struct DiskSearchScratch<Data, VP>
479where
480 Data: GraphDataType<VectorIdType = u32>,
481 VP: VertexProvider<Data>,
482{
483 distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
484 pq_scratch: PQScratch,
485 vertex_provider: VP,
486}
487
488#[derive(Clone)]
489struct DiskSearchScratchArgs<'a, ProviderFactory> {
490 graph_degree: usize,
491 dim: usize,
492 num_pq_chunks: usize,
493 num_pq_centers: usize,
494 vertex_factory: &'a ProviderFactory,
495 graph_header: &'a GraphHeader,
496}
497
498impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
499 for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
500where
501 Data: GraphDataType<VectorIdType = u32>,
502 ProviderFactory: VertexProviderFactory<Data>,
503{
504 type Error = ANNError;
505
506 fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
507 let pq_scratch = PQScratch::new(
508 args.graph_degree,
509 args.dim,
510 args.num_pq_chunks,
511 args.num_pq_centers,
512 )?;
513
514 const DEFAULT_BEAM_WIDTH: usize = 0; let vertex_provider = args
516 .vertex_factory
517 .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
518
519 Ok(Self {
520 distance_cache: HashMap::new(),
521 pq_scratch,
522 vertex_provider,
523 })
524 }
525
526 fn try_modify(
527 &mut self,
528 _args: &DiskSearchScratchArgs<ProviderFactory>,
529 ) -> Result<(), Self::Error> {
530 self.distance_cache.clear();
531 self.vertex_provider.clear();
532 Ok(())
533 }
534}
535
536pub struct DiskAccessor<'a, Data, VP>
537where
538 Data: GraphDataType<VectorIdType = u32>,
539 VP: VertexProvider<Data>,
540{
541 provider: &'a DiskProvider<Data>,
542 io_tracker: &'a IOTracker,
543 scratch: PoolOption<DiskSearchScratch<Data, VP>>,
544 query: &'a [Data::VectorDataType],
545}
546
547impl<Data, VP> DiskAccessor<'_, Data, VP>
548where
549 Data: GraphDataType<VectorIdType = u32>,
550 VP: VertexProvider<Data>,
551{
552 fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
555 where
556 F: FnMut(f32, u32),
557 {
558 let pq_scratch = &mut self.scratch.pq_scratch;
559 compute_pq_distance(
560 ids,
561 self.provider.pq_data.get_num_chunks(),
562 &pq_scratch.aligned_pqtable_dist_scratch,
563 self.provider.pq_data.pq_compressed_data().get_data(),
564 &mut pq_scratch.aligned_pq_coord_scratch,
565 &mut pq_scratch.aligned_dist_scratch,
566 )?;
567
568 for (i, id) in ids.iter().enumerate() {
569 let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
570 f(distance, *id);
571 }
572
573 Ok(())
574 }
575}
576
577impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
578where
579 Data: GraphDataType<VectorIdType = u32>,
580 VP: VertexProvider<Data>,
581{
582 async fn starting_points(&self) -> ANNResult<Vec<u32>> {
583 let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
584 Ok(vec![start_vertex_id])
585 }
586
587 fn terminate_early(&mut self) -> bool {
588 self.io_tracker.io_count() > self.provider.search_io_limit
589 }
590}
591
592impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
593where
594 Data: GraphDataType<VectorIdType = u32>,
595 VP: VertexProvider<Data>,
596{
597 fn new<VPF>(
598 provider: &'a DiskProvider<Data>,
599 io_tracker: &'a IOTracker,
600 query: &'a [Data::VectorDataType],
601 vertex_provider_factory: &'a VPF,
602 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
603 ) -> ANNResult<Self>
604 where
605 VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
606 {
607 let mut scratch = PoolOption::try_pooled(
608 scratch_pool,
609 &DiskSearchScratchArgs {
610 graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
611 dim: provider.graph_header.metadata().dims,
612 num_pq_chunks: provider.pq_data.get_num_chunks(),
613 num_pq_centers: provider.pq_data.get_num_centers(),
614 vertex_factory: vertex_provider_factory,
615 graph_header: &provider.graph_header,
616 },
617 )?;
618
619 scratch.pq_scratch.set(
620 provider.graph_header.metadata().dims,
621 query,
622 1.0_f32, )?;
624 let start_vertex_id = provider.graph_header.metadata().medoid as u32;
625
626 let timer = Instant::now();
627 quantizer_preprocess(
628 &mut scratch.pq_scratch,
629 &provider.pq_data,
630 provider.metric,
631 &[start_vertex_id],
632 )?;
633 IOTracker::add_time(
634 &io_tracker.preprocess_time_us,
635 timer.elapsed().as_micros() as u64,
636 );
637
638 Ok(Self {
639 provider,
640 io_tracker,
641 scratch,
642 query,
643 })
644 }
645 fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
646 if ids.is_empty() {
647 return Ok(());
648 }
649 let scratch = &mut self.scratch;
650 let timer = Instant::now();
651 ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
652 IOTracker::add_time(
653 &self.io_tracker.io_time_us,
654 timer.elapsed().as_micros() as u64,
655 );
656 self.io_tracker.add_io_count(ids.len());
657 for id in ids {
658 let distance = self
659 .provider
660 .distance_comparer
661 .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
662 let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
663 scratch
664 .distance_cache
665 .insert(*id, (distance, associated_data));
666 }
667 Ok(())
668 }
669}
670
671impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
672where
673 Data: GraphDataType<VectorIdType = u32>,
674 VP: VertexProvider<Data>,
675{
676 type Id = u32;
677}
678
679impl<'a, Data, VP> Accessor for DiskAccessor<'a, Data, VP>
680where
681 Data: GraphDataType<VectorIdType = u32>,
682 VP: VertexProvider<Data>,
683{
684 type Extended = &'a [u8];
686
687 type Element<'b>
693 = &'a [u8]
694 where
695 Self: 'b;
696
697 type ElementRef<'b> = &'b [u8];
699
700 type GetError = ANNError;
702
703 fn get_element(
704 &mut self,
705 id: Self::Id,
706 ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
707 std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
708 }
709}
710
711impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
712where
713 Data: GraphDataType<VectorIdType = u32>,
714 VP: VertexProvider<Data>,
715{
716 async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
717 Ok(0..self.provider.num_points as u32)
718 }
719}
720
721impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
722where
723 Data: GraphDataType<VectorIdType = u32>,
724 VP: VertexProvider<Data>,
725{
726 type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
727 fn delegate_neighbor(&'a mut self) -> Self::Delegate {
728 AsNeighborAccessor(self)
729 }
730}
731
732pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
738where
739 Data: GraphDataType<VectorIdType = u32>,
740 VP: VertexProvider<Data>;
741
742impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
743where
744 Data: GraphDataType<VectorIdType = u32>,
745 VP: VertexProvider<Data>,
746{
747 type Id = u32;
748}
749
750impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
751where
752 Data: GraphDataType<VectorIdType = u32>,
753 VP: VertexProvider<Data>,
754{
755 fn get_neighbors(
756 self,
757 id: Self::Id,
758 neighbors: &mut AdjacencyList<Self::Id>,
759 ) -> impl Future<Output = ANNResult<Self>> + Send {
760 if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
761 return future::ok(self); }
763
764 if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
765 return future::err(e);
766 }
767 let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
768 Ok(list) => list,
769 Err(e) => return future::err(e),
770 };
771 neighbors.overwrite_trusted(list);
772 future::ok(self)
773 }
774}
775
776pub struct DiskIndexSearcher<
780 Data,
781 ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
782> where
783 Data: GraphDataType<VectorIdType = u32>,
784 ProviderFactory: VertexProviderFactory<Data>,
785{
786 index: DiskANNIndex<DiskProvider<Data>>,
787 runtime: Runtime,
788
789 vertex_provider_factory: ProviderFactory,
791
792 scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
794}
795
796#[derive(Debug)]
797pub struct SearchResultStats {
798 pub cmps: u32,
799 pub result_count: u32,
800 pub query_statistics: QueryStatistics,
801}
802
803pub struct SearchResult<AssociatedData> {
808 pub results: Vec<SearchResultItem<AssociatedData>>,
810 pub stats: SearchResultStats,
811}
812
813pub struct SearchResultItem<AssociatedData> {
818 pub vertex_id: u32,
820 pub data: AssociatedData,
823 pub distance: f32,
825}
826
827impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
828where
829 Data: GraphDataType<VectorIdType = u32>,
830 ProviderFactory: VertexProviderFactory<Data>,
831{
832 pub fn new(
842 num_threads: usize,
843 search_io_limit: usize,
844 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
845 vertex_provider_factory: ProviderFactory,
846 metric: Metric,
847 runtime: Option<Runtime>,
848 ) -> ANNResult<Self> {
849 let runtime = match runtime {
850 Some(rt) => rt,
851 None => tokio::runtime::Builder::new_current_thread().build()?,
852 };
853
854 let graph_header = vertex_provider_factory.get_header()?;
855 let metadata = graph_header.metadata();
856 let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
857
858 let config = graph::config::Builder::new(
859 max_degree.into_usize(),
860 graph::config::MaxDegree::default_slack(),
861 1, metric.into(),
863 )
864 .build()?;
865
866 debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
867
868 let graph_header = vertex_provider_factory.get_header()?;
869 let pq_data = disk_index_reader.get_pq_data();
870 let scratch_pool_args = DiskSearchScratchArgs {
871 graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
872 dim: graph_header.metadata().dims,
873 num_pq_chunks: pq_data.get_num_chunks(),
874 num_pq_centers: pq_data.get_num_centers(),
875 vertex_factory: &vertex_provider_factory,
876 graph_header: &graph_header,
877 };
878 let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
879
880 let disk_provider = DiskProvider::new(
881 disk_index_reader,
882 graph_header,
883 metric,
884 metadata.num_pts.into_usize(),
885 search_io_limit,
886 )?;
887
888 let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
889 Ok(Self {
890 index,
891 runtime,
892 vertex_provider_factory,
893 scratch_pool,
894 })
895 }
896
897 fn search_strategy<'a>(
899 &'a self,
900 query: &'a [Data::VectorDataType],
901 vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
902 ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
903 DiskSearchStrategy {
904 io_tracker: IOTracker::default(),
905 vector_filter,
906 query,
907 vertex_provider_factory: &self.vertex_provider_factory,
908 scratch_pool: &self.scratch_pool,
909 }
910 }
911
912 pub fn search(
915 &self,
916 query: &[Data::VectorDataType],
917 return_list_size: u32,
918 search_list_size: u32,
919 beam_width: Option<usize>,
920 vector_filter: Option<VectorFilter<Data>>,
921 is_flat_search: bool,
922 ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
923 let mut query_stats = QueryStatistics::default();
924 let mut indices = vec![0u32; return_list_size as usize];
925 let mut distances = vec![0f32; return_list_size as usize];
926 let mut associated_data =
927 vec![Data::AssociatedDataType::default(); return_list_size as usize];
928
929 let stats = self.search_internal(
930 query,
931 return_list_size as usize,
932 search_list_size,
933 beam_width,
934 &mut query_stats,
935 &mut indices,
936 &mut distances,
937 &mut associated_data,
938 &vector_filter.unwrap_or(default_vector_filter::<Data>()),
939 is_flat_search,
940 )?;
941
942 let mut search_result = SearchResult {
943 results: Vec::with_capacity(return_list_size as usize),
944 stats,
945 };
946
947 for ((vertex_id, distance), associated_data) in indices
948 .into_iter()
949 .zip(distances.into_iter())
950 .zip(associated_data.into_iter())
951 {
952 search_result.results.push(SearchResultItem {
953 vertex_id,
954 distance,
955 data: associated_data,
956 });
957 }
958
959 Ok(search_result)
960 }
961
962 #[allow(clippy::too_many_arguments)]
965 pub(crate) fn search_internal(
966 &self,
967 query: &[Data::VectorDataType],
968 k_value: usize,
969 search_list_size: u32,
970 beam_width: Option<usize>,
971 query_stats: &mut QueryStatistics,
972 indices: &mut [u32],
973 distances: &mut [f32],
974 associated_data: &mut [Data::AssociatedDataType],
975 vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
976 is_flat_search: bool,
977 ) -> ANNResult<SearchResultStats> {
978 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
979 &mut indices[..k_value],
980 &mut distances[..k_value],
981 &mut associated_data[..k_value],
982 );
983
984 let strategy = self.search_strategy(query, vector_filter);
985 let timer = Instant::now();
986 let stats = if is_flat_search {
987 self.runtime.block_on(self.index.flat_search(
988 &strategy,
989 &DefaultContext,
990 strategy.query,
991 vector_filter,
992 &SearchParams::new(k_value, search_list_size as usize, beam_width)?,
993 &mut result_output_buffer,
994 ))?
995 } else {
996 self.runtime.block_on(self.index.search(
997 &strategy,
998 &DefaultContext,
999 strategy.query,
1000 &SearchParams::new(k_value, search_list_size as usize, beam_width)?,
1001 &mut result_output_buffer,
1002 ))?
1003 };
1004 query_stats.total_comparisons = stats.cmps;
1005 query_stats.search_hops = stats.hops;
1006
1007 query_stats.total_execution_time_us = timer.elapsed().as_micros();
1008 query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1009 query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1010 query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1011 query_stats.query_pq_preprocess_time_us =
1012 IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1013 query_stats.cpu_time_us = query_stats.total_execution_time_us
1014 - query_stats.io_time_us
1015 - query_stats.query_pq_preprocess_time_us;
1016 Ok(SearchResultStats {
1017 cmps: query_stats.total_comparisons,
1018 result_count: stats.result_count,
1019 query_statistics: query_stats.clone(),
1020 })
1021 }
1022}
1023
1024fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1030 vertex_provider: &mut V,
1031 ids: &[Data::VectorIdType],
1032) -> ANNResult<()> {
1033 vertex_provider.load_vertices(ids)?;
1034 for (idx, id) in ids.iter().enumerate() {
1035 vertex_provider.process_loaded_node(id, idx)?;
1036 }
1037 Ok(())
1038}
1039
1040#[cfg(test)]
1041mod disk_provider_tests {
1042 use diskann::{
1043 graph::{search::record::VisitedSearchRecord, SearchParamsError},
1044 utils::IntoUsize,
1045 ANNErrorKind,
1046 };
1047 use diskann_providers::storage::{
1048 DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1049 };
1050 use diskann_providers::{
1051 common::AlignedBoxWithSlice,
1052 test_utils::graph_data_type_utils::{
1053 GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1054 },
1055 utils::{
1056 create_thread_pool, file_util, load_aligned_bin, PQPathNames, ParallelIteratorInPool,
1057 },
1058 };
1059 use diskann_utils::test_data_root;
1060 use diskann_vector::distance::Metric;
1061 use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1062 use rstest::rstest;
1063 use vfs::OverlayFS;
1064
1065 use super::*;
1066 use crate::{
1067 build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1068 utils::{QueryStatistics, VirtualAlignedReaderFactory},
1069 };
1070
1071 const TEST_INDEX_PREFIX_128DIM: &str =
1072 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1073 const TEST_INDEX_128DIM: &str =
1074 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1075 const TEST_PQ_PIVOT_128DIM: &str =
1076 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1077 const TEST_PQ_COMPRESSED_128DIM: &str =
1078 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1079 const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1080 "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1081 const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1082
1083 const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1084 const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1085 const TEST_PQ_PIVOT_100DIM: &str =
1086 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1087 const TEST_PQ_COMPRESSED_100DIM: &str =
1088 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1089 const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1090 "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1091 const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1092 const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1093 const TEST_INDEX: &str =
1094 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1095 const TEST_INDEX_PREFIX: &str =
1096 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1097 const TEST_PQ_PIVOT: &str =
1098 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1099 const TEST_PQ_COMPRESSED: &str =
1100 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1101
1102 #[test]
1103 fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1104 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1105
1106 let search_engine = create_disk_index_searcher(
1107 CreateDiskIndexSearcherParams {
1108 max_thread_num: 5,
1109 pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1110 pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1111 index_path: TEST_INDEX_100DIM,
1112 index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1113 ..Default::default()
1114 },
1115 &storage_provider,
1116 );
1117 test_disk_search(TestDiskSearchParams {
1119 storage_provider: storage_provider.as_ref(),
1120 index_search_engine: &search_engine,
1121 thread_num: 1,
1122 query_file_path: TEST_QUERY_10PTS_100DIM,
1123 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1124 k: 10,
1125 l: 20,
1126 dim: 104,
1127 });
1128 test_disk_search(TestDiskSearchParams {
1130 storage_provider: storage_provider.as_ref(),
1131 index_search_engine: &search_engine,
1132 thread_num: 5,
1133 query_file_path: TEST_QUERY_10PTS_100DIM,
1134 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1135 k: 10,
1136 l: 20,
1137 dim: 104,
1138 });
1139 }
1140
1141 #[test]
1142 fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1143 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1144
1145 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1146 CreateDiskIndexSearcherParams {
1147 max_thread_num: 5,
1148 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1149 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1150 index_path: TEST_INDEX_128DIM,
1151 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1152 ..Default::default()
1153 },
1154 &storage_provider,
1155 );
1156 test_disk_search(TestDiskSearchParams {
1158 storage_provider: storage_provider.as_ref(),
1159 index_search_engine: &search_engine,
1160 thread_num: 1,
1161 query_file_path: TEST_QUERY_10PTS_128DIM,
1162 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1163 k: 10,
1164 l: 20,
1165 dim: 128,
1166 });
1167 test_disk_search(TestDiskSearchParams {
1169 storage_provider: storage_provider.as_ref(),
1170 index_search_engine: &search_engine,
1171 thread_num: 5,
1172 query_file_path: TEST_QUERY_10PTS_128DIM,
1173 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1174 k: 10,
1175 l: 20,
1176 dim: 128,
1177 });
1178 }
1179
1180 fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1181 storage_provider: &StorageReader,
1182 ) -> Vec<u32> {
1183 const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1184
1185 let (data, _npts, _dim) =
1186 file_util::load_bin::<u32, StorageReader>(storage_provider, ASSOCIATED_DATA_FILE, 0)
1187 .unwrap();
1188 data
1189 }
1190
1191 #[test]
1192 fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1193 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1194 let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1195 let params = TestParams {
1196 data_path: TEST_DATA_FILE.to_string(),
1197 index_path_prefix: index_path_prefix.to_string(),
1198 associated_data_path: Some(
1199 "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1200 ),
1201 ..TestParams::default()
1202 };
1203 let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1204 fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1206 {
1207 let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1208 CreateDiskIndexSearcherParams {
1209 max_thread_num: 5,
1210 pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1211 pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1212 .as_str(),
1213 index_path: format!("{}_disk.index", index_path_prefix).as_str(), index_path_prefix,
1215 ..Default::default()
1216 },
1217 &fixture.storage_provider,
1218 );
1219
1220 test_disk_search_with_associated(
1222 TestDiskSearchAssociateParams {
1223 storage_provider: fixture.storage_provider.as_ref(),
1224 index_search_engine: &search_engine,
1225 thread_num: 1,
1226 query_file_path: TEST_QUERY_10PTS_128DIM,
1227 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1228 k: 10,
1229 l: 20,
1230 dim: 128,
1231 },
1232 None,
1233 );
1234
1235 test_disk_search_with_associated(
1237 TestDiskSearchAssociateParams {
1238 storage_provider: fixture.storage_provider.as_ref(),
1239 index_search_engine: &search_engine,
1240 thread_num: 5,
1241 query_file_path: TEST_QUERY_10PTS_128DIM,
1242 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1243 k: 10,
1244 l: 20,
1245 dim: 128,
1246 },
1247 None,
1248 );
1249 }
1250
1251 fixture
1252 .storage_provider
1253 .delete(&format!("{}_disk.index", index_path_prefix))
1254 .expect("Failed to delete file");
1255 fixture
1256 .storage_provider
1257 .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1258 .expect("Failed to delete file");
1259 fixture
1260 .storage_provider
1261 .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1262 .expect("Failed to delete file");
1263 }
1264
1265 struct CreateDiskIndexSearcherParams<'a> {
1266 max_thread_num: usize,
1267 pq_pivot_file_path: &'a str,
1268 pq_compressed_file_path: &'a str,
1269 index_path: &'a str,
1270 index_path_prefix: &'a str,
1271 io_limit: usize,
1272 }
1273
1274 impl Default for CreateDiskIndexSearcherParams<'_> {
1275 fn default() -> Self {
1276 Self {
1277 max_thread_num: 1,
1278 pq_pivot_file_path: "",
1279 pq_compressed_file_path: "",
1280 index_path: "",
1281 index_path_prefix: "",
1282 io_limit: usize::MAX,
1283 }
1284 }
1285 }
1286
1287 fn create_disk_index_searcher<Data>(
1288 params: CreateDiskIndexSearcherParams,
1289 storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1290 ) -> DiskIndexSearcher<
1291 Data,
1292 DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1293 >
1294 where
1295 Data: GraphDataType<VectorIdType = u32>,
1296 {
1297 assert!(params.io_limit > 0);
1298
1299 let runtime = tokio::runtime::Builder::new_multi_thread()
1300 .worker_threads(params.max_thread_num)
1301 .build()
1302 .unwrap();
1303
1304 let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1305 params.pq_pivot_file_path.to_string(),
1306 params.pq_compressed_file_path.to_string(),
1307 storage_provider.as_ref(),
1308 )
1309 .unwrap();
1310
1311 let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1312 get_disk_index_file(params.index_path_prefix),
1313 Arc::clone(storage_provider),
1314 );
1315 let caching_strategy = CachingStrategy::None;
1316 let vertex_provider_factory =
1317 DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1318 .unwrap();
1319
1320 DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1321 params.max_thread_num,
1322 params.io_limit,
1323 &disk_index_reader,
1324 vertex_provider_factory,
1325 Metric::L2,
1326 Some(runtime),
1327 )
1328 .unwrap()
1329 }
1330
1331 fn load_query_result<StorageReader: StorageReadProvider>(
1332 storage_provider: &StorageReader,
1333 query_result_path: &str,
1334 ) -> Vec<u32> {
1335 let (result, _, _) =
1336 file_util::load_bin::<u32, StorageReader>(storage_provider, query_result_path, 0)
1337 .unwrap();
1338 result
1339 }
1340
1341 struct TestDiskSearchParams<'a, StorageType> {
1342 storage_provider: &'a StorageType,
1343 index_search_engine: &'a DiskIndexSearcher<
1344 GraphDataF32VectorUnitData,
1345 DiskVertexProviderFactory<
1346 GraphDataF32VectorUnitData,
1347 VirtualAlignedReaderFactory<OverlayFS>,
1348 >,
1349 >,
1350 thread_num: u64,
1351 query_file_path: &'a str,
1352 truth_result_file_path: &'a str,
1353 k: usize,
1354 l: usize,
1355 dim: usize,
1356 }
1357
1358 struct TestDiskSearchAssociateParams<'a, StorageType> {
1359 storage_provider: &'a StorageType,
1360 index_search_engine: &'a DiskIndexSearcher<
1361 GraphDataF32VectorU32Data,
1362 DiskVertexProviderFactory<
1363 GraphDataF32VectorU32Data,
1364 VirtualAlignedReaderFactory<OverlayFS>,
1365 >,
1366 >,
1367 thread_num: u64,
1368 query_file_path: &'a str,
1369 truth_result_file_path: &'a str,
1370 k: usize,
1371 l: usize,
1372 dim: usize,
1373 }
1374
1375 fn test_disk_search<StorageType: StorageReadProvider>(
1376 params: TestDiskSearchParams<StorageType>,
1377 ) {
1378 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1379 .unwrap()
1380 .0;
1381 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1382 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1383
1384 let queries = aligned_query
1385 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1386 .unwrap();
1387
1388 let truth_result =
1389 load_query_result(params.storage_provider, params.truth_result_file_path);
1390
1391 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1392 queries
1394 .par_iter()
1395 .enumerate()
1396 .for_each_in_pool(&pool, |(i, query)| {
1397 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1399 let mut temp = Vec::with_capacity(query.len() + 1);
1400 temp.push(0.0);
1401 temp.extend_from_slice(query);
1402 aligned_box.memcpy(temp.as_slice()).unwrap();
1403 let query = &aligned_box.as_slice()[1..];
1404
1405 let mut query_stats = QueryStatistics::default();
1406 let mut indices = vec![0u32; 10];
1407 let mut distances = vec![0f32; 10];
1408 let mut associated_data = vec![(); 10];
1409
1410 let result = params
1411 .index_search_engine
1412 .search_internal(
1414 query,
1415 params.k,
1416 params.l as u32,
1417 None, &mut query_stats,
1419 &mut indices,
1420 &mut distances,
1421 &mut associated_data,
1422 &(|_| true),
1423 false,
1424 );
1425
1426 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1428
1429 assert!(result.is_ok(), "Expected search to succeed");
1430
1431 let result_unwrapped = result.unwrap();
1432 assert!(
1433 result_unwrapped.query_statistics.total_io_operations > 0,
1434 "Expected IO operations to be greater than 0"
1435 );
1436 assert!(
1437 result_unwrapped.query_statistics.total_vertices_loaded > 0,
1438 "Expected vertices loaded to be greater than 0"
1439 );
1440
1441 assert_eq!(
1443 indices, truth_slice,
1444 "Results DO NOT match with the truth result for query {}",
1445 i
1446 );
1447 });
1448 }
1449
1450 fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1451 params: TestDiskSearchAssociateParams<StorageType>,
1452 beam_width: Option<usize>,
1453 ) {
1454 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1455 .unwrap()
1456 .0;
1457 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1458 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1459 let queries = aligned_query
1460 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1461 .unwrap();
1462 let truth_result =
1463 load_query_result(params.storage_provider, params.truth_result_file_path);
1464 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1465 queries
1467 .par_iter()
1468 .enumerate()
1469 .for_each_in_pool(&pool, |(i, query)| {
1470 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1472 let mut temp = Vec::with_capacity(query.len() + 1);
1473 temp.push(0.0);
1474 temp.extend_from_slice(query);
1475 aligned_box.memcpy(temp.as_slice()).unwrap();
1476 let query = &aligned_box.as_slice()[1..];
1477 let result = params
1478 .index_search_engine
1479 .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1480 .unwrap();
1481 let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1482 let associated_data: Vec<u32> =
1483 result.results.iter().map(|item| item.data).collect();
1484 let truth_data = get_truth_associated_data(params.storage_provider);
1485 let associated_data_truth: Vec<u32> = indices
1486 .iter()
1487 .map(|&vid| truth_data[vid as usize])
1488 .collect();
1489 assert_eq!(
1490 associated_data, associated_data_truth,
1491 "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1492 i,associated_data, associated_data_truth
1493 );
1494 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1495 assert_eq!(
1496 indices, truth_slice,
1497 "Results DO NOT match with the truth result for query {}",
1498 i
1499 );
1500 });
1501 }
1502
1503 #[test]
1504 fn test_disk_search_invalid_input() {
1505 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1506 let ctx = &DefaultContext;
1507
1508 let params = CreateDiskIndexSearcherParams {
1509 max_thread_num: 5,
1510 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1511 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1512 index_path: TEST_INDEX_128DIM,
1513 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1514 ..Default::default()
1515 };
1516
1517 let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1518 assert_eq!(
1519 paths.pivots, params.pq_pivot_file_path,
1520 "pq_pivot_file_path is not correct"
1521 );
1522 assert_eq!(
1523 paths.compressed_data, params.pq_compressed_file_path,
1524 "pq_compressed_file_path is not correct"
1525 );
1526 assert_eq!(
1527 params.index_path,
1528 format!("{}_disk.index", params.index_path_prefix),
1529 "index_path is not correct"
1530 );
1531
1532 let res = SearchParams::new_default(0, 10);
1533 assert!(res.is_err());
1534 assert_eq!(
1535 <SearchParamsError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1536 ANNErrorKind::IndexError
1537 );
1538 let res = SearchParams::new_default(20, 10);
1539 assert!(res.is_err());
1540 let res = SearchParams::new_default(10, 0);
1541 assert!(res.is_err());
1542 let res = SearchParams::new(10, 10, Some(0));
1543 assert!(res.is_err());
1544
1545 let search_engine =
1546 create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1547
1548 assert_eq!(
1550 search_engine
1551 .index
1552 .data_provider
1553 .to_external_id(ctx, 0)
1554 .unwrap(),
1555 0
1556 );
1557 assert_eq!(
1558 search_engine
1559 .index
1560 .data_provider
1561 .to_internal_id(ctx, &0)
1562 .unwrap(),
1563 0
1564 );
1565
1566 let provider_max_degree = search_engine
1567 .index
1568 .data_provider
1569 .graph_header
1570 .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1571 .unwrap();
1572 let index_max_degree = search_engine.index.config.pruned_degree().get();
1573 assert_eq!(provider_max_degree, index_max_degree);
1574
1575 let query = vec![0f32; 128];
1576 let mut query_stats = QueryStatistics::default();
1577 let mut indices = vec![0u32; 10];
1578 let mut distances = vec![0f32; 10];
1579 let mut associated_data = vec![0u32; 10];
1580
1581 let result = search_engine.search_internal(
1583 &query,
1584 10,
1585 10 - 1,
1586 None,
1587 &mut query_stats,
1588 &mut indices,
1589 &mut distances,
1590 &mut associated_data,
1591 &|_| true,
1592 false,
1593 );
1594
1595 assert!(result.is_err());
1596 assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1597 }
1598
1599 #[test]
1600 fn test_disk_search_beam_search() {
1601 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1602
1603 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1604 CreateDiskIndexSearcherParams {
1605 max_thread_num: 1,
1606 pq_pivot_file_path: TEST_PQ_PIVOT,
1607 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1608 index_path: TEST_INDEX,
1609 index_path_prefix: TEST_INDEX_PREFIX,
1610 ..Default::default()
1611 },
1612 &storage_provider,
1613 );
1614
1615 let query_vector: [f32; 128] = [1f32; 128];
1616 let mut indices = vec![0u32; 10];
1617 let mut distances = vec![0f32; 10];
1618 let mut associated_data = vec![(); 10];
1619
1620 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1621 &mut indices,
1622 &mut distances,
1623 &mut associated_data,
1624 );
1625 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1626 let mut search_record = VisitedSearchRecord::new(0);
1627 search_engine
1628 .runtime
1629 .block_on(search_engine.index.search_recorded(
1630 &strategy,
1631 &DefaultContext,
1632 &query_vector,
1633 &SearchParams::new(10, 10, Some(4)).unwrap(),
1634 &mut result_output_buffer,
1635 &mut search_record,
1636 ))
1637 .unwrap();
1638
1639 let ids = search_record
1640 .visited
1641 .iter()
1642 .map(|n| n.id)
1643 .collect::<Vec<_>>();
1644
1645 const EXPECTED_NODES: [u32; 18] = [
1646 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1647 ]; assert_eq!(ids, &EXPECTED_NODES);
1650
1651 let return_list_size = 10;
1652 let search_list_size = 10;
1653 let result = search_engine.search(
1654 &query_vector,
1655 return_list_size,
1656 search_list_size,
1657 Some(4),
1658 None,
1659 false,
1660 );
1661 assert!(result.is_ok(), "Expected search to succeed");
1662 let search_result = result.unwrap();
1663 assert_eq!(
1664 search_result.results.len() as u32,
1665 return_list_size,
1666 "Expected result count to match"
1667 );
1668 assert_eq!(
1669 indices,
1670 vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1671 "Expected indices to match"
1672 );
1673 }
1674
1675 #[cfg(feature = "experimental_diversity_search")]
1676 #[test]
1677 fn test_disk_search_diversity_search() {
1678 use diskann::graph::DiverseSearchParams;
1679 use diskann::neighbor::AttributeValueProvider;
1680 use std::collections::HashMap;
1681
1682 #[derive(Debug, Clone)]
1684 struct TestAttributeProvider {
1685 attributes: HashMap<u32, u32>,
1686 }
1687 impl TestAttributeProvider {
1688 fn new() -> Self {
1689 Self {
1690 attributes: HashMap::new(),
1691 }
1692 }
1693 fn insert(&mut self, id: u32, attribute: u32) {
1694 self.attributes.insert(id, attribute);
1695 }
1696 }
1697 impl diskann::provider::HasId for TestAttributeProvider {
1698 type Id = u32;
1699 }
1700
1701 impl AttributeValueProvider for TestAttributeProvider {
1702 type Value = u32;
1703
1704 fn get(&self, id: Self::Id) -> Option<Self::Value> {
1705 self.attributes.get(&id).copied()
1706 }
1707 }
1708
1709 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1710
1711 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1712 CreateDiskIndexSearcherParams {
1713 max_thread_num: 1,
1714 pq_pivot_file_path: TEST_PQ_PIVOT,
1715 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1716 index_path: TEST_INDEX,
1717 index_path_prefix: TEST_INDEX_PREFIX,
1718 ..Default::default()
1719 },
1720 &storage_provider,
1721 );
1722
1723 let query_vector: [f32; 128] = [1f32; 128];
1724
1725 let mut attribute_provider = TestAttributeProvider::new();
1727 let num_vectors = 256; for i in 0..num_vectors {
1729 let label = (i % 15) + 1;
1731 attribute_provider.insert(i, label);
1732 }
1733 let attribute_provider = std::sync::Arc::new(attribute_provider);
1735
1736 let mut indices = vec![0u32; 10];
1737 let mut distances = vec![0f32; 10];
1738 let mut associated_data = vec![(); 10];
1739
1740 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1741 &mut indices,
1742 &mut distances,
1743 &mut associated_data,
1744 );
1745 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1746 let mut search_record = VisitedSearchRecord::new(0);
1747
1748 let diverse_params = DiverseSearchParams::new(
1750 0, 3, attribute_provider.clone(),
1753 );
1754
1755 let search_params = SearchParams::new(10, 20, None).unwrap();
1756
1757 search_engine
1758 .runtime
1759 .block_on(search_engine.index.diverse_search_experimental(
1760 &strategy,
1761 &DefaultContext,
1762 &query_vector,
1763 &search_params,
1764 &diverse_params,
1765 &mut result_output_buffer,
1766 &mut search_record,
1767 ))
1768 .unwrap();
1769
1770 let ids = search_record
1771 .visited
1772 .iter()
1773 .map(|n| n.id)
1774 .collect::<Vec<_>>();
1775
1776 assert!(
1778 !ids.is_empty(),
1779 "Expected to visit some nodes during diversity search"
1780 );
1781
1782 let return_list_size = 10;
1783 let search_list_size = 20;
1784 let diverse_results_k = 1;
1785 let diverse_params = DiverseSearchParams::new(
1786 0, diverse_results_k,
1788 attribute_provider.clone(),
1789 );
1790
1791 let mut indices2 = vec![0u32; return_list_size as usize];
1793 let mut distances2 = vec![0f32; return_list_size as usize];
1794 let mut associated_data2 = vec![(); return_list_size as usize];
1795 let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1796 &mut indices2,
1797 &mut distances2,
1798 &mut associated_data2,
1799 );
1800 let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1801 let mut search_record2 = VisitedSearchRecord::new(0);
1802 let search_params2 =
1803 SearchParams::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1804
1805 let stats = search_engine
1806 .runtime
1807 .block_on(search_engine.index.diverse_search_experimental(
1808 &strategy2,
1809 &DefaultContext,
1810 &query_vector,
1811 &search_params2,
1812 &diverse_params,
1813 &mut result_output_buffer2,
1814 &mut search_record2,
1815 ))
1816 .unwrap();
1817
1818 assert!(
1820 stats.result_count > 0,
1821 "Expected diversity search to return results"
1822 );
1823 assert!(
1824 stats.result_count <= return_list_size,
1825 "Expected result count to be <= {}",
1826 return_list_size
1827 );
1828
1829 assert!(
1831 stats.result_count > 0,
1832 "Expected to get some search results"
1833 );
1834
1835 println!("\n=== Diversity Search Results ===");
1837 println!("Query: [1f32; 128]");
1838 println!("diverse_results_k: {}", diverse_results_k);
1839 println!("Total results: {}\n", stats.result_count);
1840 println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1841 println!("{}", "-".repeat(35));
1842 for i in 0..stats.result_count as usize {
1843 let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1844 println!(
1845 "{:<10} {:<15.2} {:<10}",
1846 indices2[i], distances2[i], attribute_value
1847 );
1848 }
1849
1850 for i in 0..(stats.result_count as usize).saturating_sub(1) {
1852 assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1853 assert!(
1854 distances2[i] <= distances2[i + 1],
1855 "Expected distances to be sorted in ascending order"
1856 );
1857 }
1858
1859 let mut attribute_counts = HashMap::new();
1861 for item in indices2.iter().take(stats.result_count as usize) {
1862 if let Some(attribute_value) = attribute_provider.get(*item) {
1863 *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1864 }
1865 }
1866
1867 println!("\n=== Attribute Distribution ===");
1869 let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1870 sorted_attrs.sort_by_key(|(k, _)| *k);
1871 for (attribute_value, count) in &sorted_attrs {
1872 println!(
1873 "Label {}: {} occurrences (max allowed: {})",
1874 attribute_value, count, diverse_results_k
1875 );
1876 }
1877 println!("Total unique labels: {}", attribute_counts.len());
1878 println!("================================\n");
1879
1880 for (attribute_value, count) in &attribute_counts {
1882 println!(
1883 "Assert: Label {} has {} occurrences (max: {})",
1884 attribute_value, count, diverse_results_k
1885 );
1886 assert!(
1887 *count <= diverse_results_k,
1888 "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1889 attribute_value,
1890 count,
1891 diverse_results_k
1892 );
1893 }
1894
1895 println!(
1898 "Assert: Found {} unique labels (expected at least 2)",
1899 attribute_counts.len()
1900 );
1901 assert!(
1902 attribute_counts.len() >= 2,
1903 "Expected at least 2 different attribute values for diversity, got {}",
1904 attribute_counts.len()
1905 );
1906 }
1907
1908 #[rstest]
1909 #[case(
1911 |_id: &u32| true,
1912 false,
1913 10,
1914 vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1915 vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1916 )]
1917 #[case(
1920 |id: &u32| *id == 0 || *id == 1,
1921 false,
1922 0,
1923 vec![0; 10],
1924 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1925 )]
1926 #[case(
1929 |id: &u32| *id == 0 || *id == 1,
1930 true,
1931 2,
1932 vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1933 vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1934 )]
1935 #[case(
1938 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1939 false,
1940 3,
1941 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1942 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1943 )]
1944 #[case(
1947 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1948 true,
1949 3,
1950 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1951 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1952 )]
1953 fn test_search_with_vector_filter(
1954 #[case] vector_filter: fn(&u32) -> bool,
1955 #[case] is_flat_search: bool,
1956 #[case] expected_result_count: u32,
1957 #[case] expected_indices: Vec<u32>,
1958 #[case] expected_distances: Vec<f32>,
1959 ) {
1960 let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1965 const ABS_TOLERANCE: f32 = 0.02;
1966 assert_eq!(got.len(), expected.len());
1967 for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1968 if (g - e).abs() > ABS_TOLERANCE {
1969 panic!(
1970 "distances differ at position {} by more than {}\n\n\
1971 got: {:?}\nexpected: {:?}",
1972 i, ABS_TOLERANCE, got, expected,
1973 );
1974 }
1975 }
1976 true
1977 };
1978
1979 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1980
1981 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1982 CreateDiskIndexSearcherParams {
1983 max_thread_num: 5,
1984 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1985 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1986 index_path: TEST_INDEX_128DIM,
1987 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1988 ..Default::default()
1989 },
1990 &storage_provider,
1991 );
1992 let query = vec![0.1f32; 128];
1993 let mut query_stats = QueryStatistics::default();
1994 let mut indices = vec![0u32; 10];
1995 let mut distances = vec![0f32; 10];
1996 let mut associated_data = vec![(); 10];
1997
1998 let result = search_engine.search_internal(
1999 &query,
2000 10,
2001 10,
2002 None, &mut query_stats,
2004 &mut indices,
2005 &mut distances,
2006 &mut associated_data,
2007 &vector_filter,
2008 is_flat_search,
2009 );
2010
2011 assert!(result.is_ok(), "Expected search to succeed");
2012 assert_eq!(
2013 result.unwrap().result_count,
2014 expected_result_count,
2015 "Expected result count to match"
2016 );
2017 assert_eq!(indices, expected_indices, "Expected indices to match");
2018 assert!(
2019 check_distances(&distances, &expected_distances),
2020 "Expected distances to match"
2021 );
2022
2023 let result_with_filter = search_engine.search(
2024 &query,
2025 10,
2026 10,
2027 None, Some(Box::new(vector_filter)),
2029 is_flat_search,
2030 );
2031
2032 assert!(result_with_filter.is_ok(), "Expected search to succeed");
2033 let result_with_filter_unwrapped = result_with_filter.unwrap();
2034 assert_eq!(
2035 result_with_filter_unwrapped.stats.result_count, expected_result_count,
2036 "Expected result count to match"
2037 );
2038 let actual_indices = result_with_filter_unwrapped
2039 .results
2040 .iter()
2041 .map(|x| x.vertex_id)
2042 .collect::<Vec<_>>();
2043 assert_eq!(
2044 actual_indices, expected_indices,
2045 "Expected indices to match"
2046 );
2047 let actual_distances = result_with_filter_unwrapped
2048 .results
2049 .iter()
2050 .map(|x| x.distance)
2051 .collect::<Vec<_>>();
2052 assert!(
2053 check_distances(&actual_distances, &expected_distances),
2054 "Expected distances to match"
2055 );
2056 }
2057
2058 #[test]
2059 fn test_beam_search_respects_io_limit() {
2060 let io_limit = 11; let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2062
2063 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2064 CreateDiskIndexSearcherParams {
2065 max_thread_num: 1,
2066 pq_pivot_file_path: TEST_PQ_PIVOT,
2067 pq_compressed_file_path: TEST_PQ_COMPRESSED,
2068 index_path: TEST_INDEX,
2069 index_path_prefix: TEST_INDEX_PREFIX,
2070 io_limit,
2071 },
2072 &storage_provider,
2073 );
2074 let query_vector: [f32; 128] = [1f32; 128];
2075
2076 let mut indices = vec![0u32; 10];
2077 let mut distances = vec![0f32; 10];
2078 let mut associated_data = vec![(); 10];
2079
2080 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2081 &mut indices,
2082 &mut distances,
2083 &mut associated_data,
2084 );
2085
2086 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2087
2088 let mut search_record = VisitedSearchRecord::new(0);
2089 search_engine
2090 .runtime
2091 .block_on(search_engine.index.search_recorded(
2092 &strategy,
2093 &DefaultContext,
2094 &query_vector,
2095 &SearchParams::new(10, 10, Some(4)).unwrap(),
2096 &mut result_output_buffer,
2097 &mut search_record,
2098 ))
2099 .unwrap();
2100 let visited_ids = search_record
2101 .visited
2102 .iter()
2103 .map(|n| n.id)
2104 .collect::<Vec<_>>();
2105
2106 let query_stats = strategy.io_tracker;
2107 assert!(
2109 query_stats.io_count() <= io_limit,
2110 "Expected IO operations to be <= {}, but got {}",
2111 io_limit,
2112 query_stats.io_count()
2113 );
2114
2115 const EXPECTED_NODES: [u32; 17] = [
2116 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2117 ]; let mut matching_count = 0;
2121 for expected_node in EXPECTED_NODES.iter() {
2122 if visited_ids.contains(expected_node) {
2123 matching_count += 1;
2124 }
2125 }
2126
2127 let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2129
2130 assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2133 }
2134}