1use std::{
7 collections::HashMap,
8 future::Future,
9 num::NonZeroUsize,
10 ops::Range,
11 sync::{
12 atomic::{AtomicU64, AtomicUsize},
13 Arc,
14 },
15 time::Instant,
16};
17
18use diskann::{
19 graph::{
20 self,
21 glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy},
22 search::Knn,
23 search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer,
24 },
25 neighbor::Neighbor,
26 provider::{
27 Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
28 NeighborAccessor,
29 },
30 utils::{
31 object_pool::{ObjectPool, PoolOption, TryAsPooled},
32 IntoUsize, VectorRepr,
33 },
34 ANNError, ANNResult,
35};
36use diskann_providers::storage::StorageReadProvider;
37use diskann_providers::{
38 model::{
39 compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
40 pq::quantizer_preprocess, PQData, PQScratch,
41 },
42 storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
43};
44use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
45use futures_util::future;
46use tokio::runtime::Runtime;
47use tracing::debug;
48
49use crate::{
50 data_model::{CachingStrategy, GraphHeader},
51 filter_parameter::{default_vector_filter, VectorFilter},
52 search::{
53 provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
54 traits::{VertexProvider, VertexProviderFactory},
55 },
56 storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
57 utils::AlignedFileReaderFactory,
58 utils::QueryStatistics,
59};
60
61pub struct DiskProvider<Data>
71where
72 Data: GraphDataType<VectorIdType = u32>,
73{
74 graph_header: GraphHeader,
76
77 distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
79
80 pq_data: Arc<PQData>,
82
83 num_points: usize,
85
86 metric: Metric,
88
89 search_io_limit: usize,
91}
92
93impl<Data> DataProvider for DiskProvider<Data>
94where
95 Data: GraphDataType<VectorIdType = u32>,
96{
97 type Context = DefaultContext;
98
99 type InternalId = u32;
100
101 type ExternalId = u32;
102
103 type Error = ANNError;
104
105 fn to_internal_id(
107 &self,
108 _context: &DefaultContext,
109 gid: &Self::ExternalId,
110 ) -> Result<Self::InternalId, Self::Error> {
111 Ok(*gid)
112 }
113
114 fn to_external_id(
116 &self,
117 _context: &DefaultContext,
118 id: Self::InternalId,
119 ) -> Result<Self::ExternalId, Self::Error> {
120 Ok(id)
121 }
122}
123
124impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
125where
126 Data: GraphDataType<VectorIdType = u32>,
127{
128 type Error = ANNError;
129
130 async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
131 where
132 P: StorageReadProvider,
133 {
134 debug!(
135 "DiskProvider::load_with() called with file: {:?}",
136 get_disk_index_file(ctx.quant_load_context.metadata.prefix())
137 );
138
139 let graph_header = {
140 let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
141 ctx.quant_load_context.metadata.prefix(),
142 ));
143
144 let caching_strategy = if ctx.num_nodes_to_cache > 0 {
145 CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
146 } else {
147 CachingStrategy::None
148 };
149
150 let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
151 aligned_reader_factory,
152 caching_strategy,
153 )?;
154 VertexProviderFactory::get_header(&vertex_provider_factory)?
155 };
156
157 let metric = ctx.quant_load_context.metric;
158 let num_points = ctx.num_points;
159
160 let index_path_prefix = ctx.quant_load_context.metadata.prefix();
161 let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
162 get_pq_pivot_file(index_path_prefix),
163 get_compressed_pq_file(index_path_prefix),
164 provider,
165 )?;
166
167 Self::new(
168 &index_reader,
169 graph_header,
170 metric,
171 num_points,
172 ctx.search_io_limit,
173 )
174 }
175}
176
177impl<Data> DiskProvider<Data>
178where
179 Data: GraphDataType<VectorIdType = u32>,
180{
181 fn new(
182 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
183 graph_header: GraphHeader,
184 metric: Metric,
185 num_points: usize,
186 search_io_limit: usize,
187 ) -> ANNResult<Self> {
188 let distance_comparer =
189 Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
190
191 let pq_data = disk_index_reader.get_pq_data();
192
193 Ok(Self {
194 graph_header,
195 distance_comparer,
196 pq_data,
197 num_points,
198 metric,
199 search_io_limit,
200 })
201 }
202}
203
204pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
215where
216 Data: GraphDataType<VectorIdType = u32>,
217 ProviderFactory: VertexProviderFactory<Data>,
218{
219 io_tracker: IOTracker,
221 vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), query: &'a [Data::VectorDataType],
223
224 vertex_provider_factory: &'a ProviderFactory,
226
227 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
229}
230
231struct IOTracker {
234 io_time_us: AtomicU64,
235 preprocess_time_us: AtomicU64,
236 io_count: AtomicUsize,
237}
238
239impl Default for IOTracker {
240 fn default() -> Self {
241 Self {
242 io_time_us: AtomicU64::new(0),
243 preprocess_time_us: AtomicU64::new(0),
244 io_count: AtomicUsize::new(0),
245 }
246 }
247}
248
249impl IOTracker {
250 fn add_time(category: &AtomicU64, time: u64) {
251 category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
252 }
253
254 fn time(category: &AtomicU64) -> u64 {
255 category.load(std::sync::atomic::Ordering::Relaxed)
256 }
257
258 fn add_io_count(&self, count: usize) {
259 self.io_count
260 .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
261 }
262
263 fn io_count(&self) -> usize {
264 self.io_count.load(std::sync::atomic::Ordering::Relaxed)
265 }
266}
267
268#[derive(Clone, Copy)]
269pub struct RerankAndFilter<'a> {
270 filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
271}
272
273impl<'a> RerankAndFilter<'a> {
274 fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
275 Self { filter }
276 }
277}
278
279impl<Data, VP>
280 SearchPostProcess<
281 DiskAccessor<'_, Data, VP>,
282 [Data::VectorDataType],
283 (
284 <DiskProvider<Data> as DataProvider>::InternalId,
285 Data::AssociatedDataType,
286 ),
287 > for RerankAndFilter<'_>
288where
289 Data: GraphDataType<VectorIdType = u32>,
290 VP: VertexProvider<Data>,
291{
292 type Error = ANNError;
293 async fn post_process<I, B>(
294 &self,
295 accessor: &mut DiskAccessor<'_, Data, VP>,
296 query: &[Data::VectorDataType],
297 _computer: &DiskQueryComputer,
298 candidates: I,
299 output: &mut B,
300 ) -> Result<usize, Self::Error>
301 where
302 I: Iterator<Item = Neighbor<u32>> + Send,
303 B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized,
304 {
305 let provider = accessor.provider;
306
307 let mut uncached_ids = Vec::new();
308 let mut reranked = candidates
309 .map(|n| n.id)
310 .filter(|id| (self.filter)(id))
311 .filter_map(|n| {
312 if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
313 Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
314 } else {
315 uncached_ids.push(n);
316 None
317 }
318 })
319 .collect::<Result<Vec<_>, _>>()?;
320 if !uncached_ids.is_empty() {
321 ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
322 for n in &uncached_ids {
323 let v = accessor.scratch.vertex_provider.get_vector(n)?;
324 let d = provider.distance_comparer.evaluate_similarity(query, v);
325 let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
326 reranked.push(((*n, *a), d));
327 }
328 }
329
330 reranked
332 .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
333 Ok(output.extend(reranked))
335 }
336}
337
338impl<'this, Data, ProviderFactory>
339 SearchStrategy<
340 DiskProvider<Data>,
341 [Data::VectorDataType],
342 (
343 <DiskProvider<Data> as DataProvider>::InternalId,
344 Data::AssociatedDataType,
345 ),
346 > for DiskSearchStrategy<'this, Data, ProviderFactory>
347where
348 Data: GraphDataType<VectorIdType = u32>,
349 ProviderFactory: VertexProviderFactory<Data>,
350{
351 type QueryComputer = DiskQueryComputer;
352 type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
353 type SearchAccessorError = ANNError;
354 type PostProcessor = RerankAndFilter<'this>;
355
356 fn search_accessor<'a>(
357 &'a self,
358 provider: &'a DiskProvider<Data>,
359 _context: &DefaultContext,
360 ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
361 DiskAccessor::new(
362 provider,
363 &self.io_tracker,
364 self.query,
365 self.vertex_provider_factory,
366 self.scratch_pool,
367 )
368 }
369
370 fn post_processor(&self) -> Self::PostProcessor {
371 RerankAndFilter::new(self.vector_filter)
372 }
373}
374
375pub struct DiskQueryComputer {
377 num_pq_chunks: usize,
378 query_centroid_l2_distance: Vec<f32>,
379}
380
381impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
382 fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
383 let mut dist = 0.0f32;
384 #[allow(clippy::expect_used)]
385 compute_pq_distance_for_pq_coordinates(
386 changing,
387 self.num_pq_chunks,
388 &self.query_centroid_l2_distance,
389 std::slice::from_mut(&mut dist),
390 )
391 .expect("PQ distance compute for PQ coordinates is expected to succeed");
392 dist
393 }
394}
395
396impl<Data, VP> BuildQueryComputer<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
397where
398 Data: GraphDataType<VectorIdType = u32>,
399 VP: VertexProvider<Data>,
400{
401 type QueryComputerError = ANNError;
402 type QueryComputer = DiskQueryComputer;
403
404 fn build_query_computer(
405 &self,
406 _from: &[Data::VectorDataType],
407 ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
408 Ok(DiskQueryComputer {
409 num_pq_chunks: self.provider.pq_data.get_num_chunks(),
410 query_centroid_l2_distance: self
411 .scratch
412 .pq_scratch
413 .aligned_pqtable_dist_scratch
414 .as_slice()
415 .to_vec(),
416 })
417 }
418
419 async fn distances_unordered<Itr, F>(
420 &mut self,
421 vec_id_itr: Itr,
422 _computer: &Self::QueryComputer,
423 f: F,
424 ) -> Result<(), Self::GetError>
425 where
426 F: Send + FnMut(f32, Self::Id),
427 Itr: Iterator<Item = Self::Id>,
428 {
429 self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
430 }
431}
432
433impl<Data, VP> ExpandBeam<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
434where
435 Data: GraphDataType<VectorIdType = u32>,
436 VP: VertexProvider<Data>,
437{
438 fn expand_beam<Itr, P, F>(
439 &mut self,
440 ids: Itr,
441 _computer: &Self::QueryComputer,
442 mut pred: P,
443 mut f: F,
444 ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
445 where
446 Itr: Iterator<Item = Self::Id> + Send,
447 P: glue::HybridPredicate<Self::Id> + Send + Sync,
448 F: FnMut(f32, Self::Id) + Send,
449 {
450 let result = (|| {
451 let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
452 let load_ids: Box<[_]> = ids.take(io_limit).collect();
453
454 self.ensure_loaded(&load_ids)?;
455 let mut ids = Vec::new();
456 for i in load_ids {
457 ids.clear();
458 ids.extend(
459 self.scratch
460 .vertex_provider
461 .get_adjacency_list(&i)?
462 .iter()
463 .copied()
464 .filter(|id| pred.eval_mut(id)),
465 );
466
467 self.pq_distances(&ids, &mut f)?;
468 }
469
470 Ok(())
471 })();
472
473 std::future::ready(result)
474 }
475}
476
477struct DiskSearchScratch<Data, VP>
480where
481 Data: GraphDataType<VectorIdType = u32>,
482 VP: VertexProvider<Data>,
483{
484 distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
485 pq_scratch: PQScratch,
486 vertex_provider: VP,
487}
488
489#[derive(Clone)]
490struct DiskSearchScratchArgs<'a, ProviderFactory> {
491 graph_degree: usize,
492 dim: usize,
493 num_pq_chunks: usize,
494 num_pq_centers: usize,
495 vertex_factory: &'a ProviderFactory,
496 graph_header: &'a GraphHeader,
497}
498
499impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
500 for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
501where
502 Data: GraphDataType<VectorIdType = u32>,
503 ProviderFactory: VertexProviderFactory<Data>,
504{
505 type Error = ANNError;
506
507 fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
508 let pq_scratch = PQScratch::new(
509 args.graph_degree,
510 args.dim,
511 args.num_pq_chunks,
512 args.num_pq_centers,
513 )?;
514
515 const DEFAULT_BEAM_WIDTH: usize = 0; let vertex_provider = args
517 .vertex_factory
518 .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
519
520 Ok(Self {
521 distance_cache: HashMap::new(),
522 pq_scratch,
523 vertex_provider,
524 })
525 }
526
527 fn try_modify(
528 &mut self,
529 _args: &DiskSearchScratchArgs<ProviderFactory>,
530 ) -> Result<(), Self::Error> {
531 self.distance_cache.clear();
532 self.vertex_provider.clear();
533 Ok(())
534 }
535}
536
537pub struct DiskAccessor<'a, Data, VP>
538where
539 Data: GraphDataType<VectorIdType = u32>,
540 VP: VertexProvider<Data>,
541{
542 provider: &'a DiskProvider<Data>,
543 io_tracker: &'a IOTracker,
544 scratch: PoolOption<DiskSearchScratch<Data, VP>>,
545 query: &'a [Data::VectorDataType],
546}
547
548impl<Data, VP> DiskAccessor<'_, Data, VP>
549where
550 Data: GraphDataType<VectorIdType = u32>,
551 VP: VertexProvider<Data>,
552{
553 fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
556 where
557 F: FnMut(f32, u32),
558 {
559 let pq_scratch = &mut self.scratch.pq_scratch;
560 compute_pq_distance(
561 ids,
562 self.provider.pq_data.get_num_chunks(),
563 &pq_scratch.aligned_pqtable_dist_scratch,
564 self.provider.pq_data.pq_compressed_data().get_data(),
565 &mut pq_scratch.aligned_pq_coord_scratch,
566 &mut pq_scratch.aligned_dist_scratch,
567 )?;
568
569 for (i, id) in ids.iter().enumerate() {
570 let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
571 f(distance, *id);
572 }
573
574 Ok(())
575 }
576}
577
578impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
579where
580 Data: GraphDataType<VectorIdType = u32>,
581 VP: VertexProvider<Data>,
582{
583 async fn starting_points(&self) -> ANNResult<Vec<u32>> {
584 let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
585 Ok(vec![start_vertex_id])
586 }
587
588 fn terminate_early(&mut self) -> bool {
589 self.io_tracker.io_count() > self.provider.search_io_limit
590 }
591}
592
593impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
594where
595 Data: GraphDataType<VectorIdType = u32>,
596 VP: VertexProvider<Data>,
597{
598 fn new<VPF>(
599 provider: &'a DiskProvider<Data>,
600 io_tracker: &'a IOTracker,
601 query: &'a [Data::VectorDataType],
602 vertex_provider_factory: &'a VPF,
603 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
604 ) -> ANNResult<Self>
605 where
606 VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
607 {
608 let mut scratch = PoolOption::try_pooled(
609 scratch_pool,
610 &DiskSearchScratchArgs {
611 graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
612 dim: provider.graph_header.metadata().dims,
613 num_pq_chunks: provider.pq_data.get_num_chunks(),
614 num_pq_centers: provider.pq_data.get_num_centers(),
615 vertex_factory: vertex_provider_factory,
616 graph_header: &provider.graph_header,
617 },
618 )?;
619
620 scratch.pq_scratch.set(
621 provider.graph_header.metadata().dims,
622 query,
623 1.0_f32, )?;
625 let start_vertex_id = provider.graph_header.metadata().medoid as u32;
626
627 let timer = Instant::now();
628 quantizer_preprocess(
629 &mut scratch.pq_scratch,
630 &provider.pq_data,
631 provider.metric,
632 &[start_vertex_id],
633 )?;
634 IOTracker::add_time(
635 &io_tracker.preprocess_time_us,
636 timer.elapsed().as_micros() as u64,
637 );
638
639 Ok(Self {
640 provider,
641 io_tracker,
642 scratch,
643 query,
644 })
645 }
646 fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
647 if ids.is_empty() {
648 return Ok(());
649 }
650 let scratch = &mut self.scratch;
651 let timer = Instant::now();
652 ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
653 IOTracker::add_time(
654 &self.io_tracker.io_time_us,
655 timer.elapsed().as_micros() as u64,
656 );
657 self.io_tracker.add_io_count(ids.len());
658 for id in ids {
659 let distance = self
660 .provider
661 .distance_comparer
662 .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
663 let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
664 scratch
665 .distance_cache
666 .insert(*id, (distance, associated_data));
667 }
668 Ok(())
669 }
670}
671
672impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
673where
674 Data: GraphDataType<VectorIdType = u32>,
675 VP: VertexProvider<Data>,
676{
677 type Id = u32;
678}
679
680impl<'a, Data, VP> Accessor for DiskAccessor<'a, Data, VP>
681where
682 Data: GraphDataType<VectorIdType = u32>,
683 VP: VertexProvider<Data>,
684{
685 type Extended = &'a [u8];
687
688 type Element<'b>
694 = &'a [u8]
695 where
696 Self: 'b;
697
698 type ElementRef<'b> = &'b [u8];
700
701 type GetError = ANNError;
703
704 fn get_element(
705 &mut self,
706 id: Self::Id,
707 ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
708 std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
709 }
710}
711
712impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
713where
714 Data: GraphDataType<VectorIdType = u32>,
715 VP: VertexProvider<Data>,
716{
717 async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
718 Ok(0..self.provider.num_points as u32)
719 }
720}
721
722impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
723where
724 Data: GraphDataType<VectorIdType = u32>,
725 VP: VertexProvider<Data>,
726{
727 type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
728 fn delegate_neighbor(&'a mut self) -> Self::Delegate {
729 AsNeighborAccessor(self)
730 }
731}
732
733pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
739where
740 Data: GraphDataType<VectorIdType = u32>,
741 VP: VertexProvider<Data>;
742
743impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
744where
745 Data: GraphDataType<VectorIdType = u32>,
746 VP: VertexProvider<Data>,
747{
748 type Id = u32;
749}
750
751impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
752where
753 Data: GraphDataType<VectorIdType = u32>,
754 VP: VertexProvider<Data>,
755{
756 fn get_neighbors(
757 self,
758 id: Self::Id,
759 neighbors: &mut AdjacencyList<Self::Id>,
760 ) -> impl Future<Output = ANNResult<Self>> + Send {
761 if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
762 return future::ok(self); }
764
765 if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
766 return future::err(e);
767 }
768 let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
769 Ok(list) => list,
770 Err(e) => return future::err(e),
771 };
772 neighbors.overwrite_trusted(list);
773 future::ok(self)
774 }
775}
776
777pub struct DiskIndexSearcher<
781 Data,
782 ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
783> where
784 Data: GraphDataType<VectorIdType = u32>,
785 ProviderFactory: VertexProviderFactory<Data>,
786{
787 index: DiskANNIndex<DiskProvider<Data>>,
788 runtime: Runtime,
789
790 vertex_provider_factory: ProviderFactory,
792
793 scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
795}
796
797#[derive(Debug)]
798pub struct SearchResultStats {
799 pub cmps: u32,
800 pub result_count: u32,
801 pub query_statistics: QueryStatistics,
802}
803
804pub struct SearchResult<AssociatedData> {
809 pub results: Vec<SearchResultItem<AssociatedData>>,
811 pub stats: SearchResultStats,
812}
813
814pub struct SearchResultItem<AssociatedData> {
819 pub vertex_id: u32,
821 pub data: AssociatedData,
824 pub distance: f32,
826}
827
828impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
829where
830 Data: GraphDataType<VectorIdType = u32>,
831 ProviderFactory: VertexProviderFactory<Data>,
832{
833 pub fn new(
843 num_threads: usize,
844 search_io_limit: usize,
845 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
846 vertex_provider_factory: ProviderFactory,
847 metric: Metric,
848 runtime: Option<Runtime>,
849 ) -> ANNResult<Self> {
850 let runtime = match runtime {
851 Some(rt) => rt,
852 None => tokio::runtime::Builder::new_current_thread().build()?,
853 };
854
855 let graph_header = vertex_provider_factory.get_header()?;
856 let metadata = graph_header.metadata();
857 let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
858
859 let config = graph::config::Builder::new(
860 max_degree.into_usize(),
861 graph::config::MaxDegree::default_slack(),
862 1, metric.into(),
864 )
865 .build()?;
866
867 debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
868
869 let graph_header = vertex_provider_factory.get_header()?;
870 let pq_data = disk_index_reader.get_pq_data();
871 let scratch_pool_args = DiskSearchScratchArgs {
872 graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
873 dim: graph_header.metadata().dims,
874 num_pq_chunks: pq_data.get_num_chunks(),
875 num_pq_centers: pq_data.get_num_centers(),
876 vertex_factory: &vertex_provider_factory,
877 graph_header: &graph_header,
878 };
879 let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
880
881 let disk_provider = DiskProvider::new(
882 disk_index_reader,
883 graph_header,
884 metric,
885 metadata.num_pts.into_usize(),
886 search_io_limit,
887 )?;
888
889 let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
890 Ok(Self {
891 index,
892 runtime,
893 vertex_provider_factory,
894 scratch_pool,
895 })
896 }
897
898 fn search_strategy<'a>(
900 &'a self,
901 query: &'a [Data::VectorDataType],
902 vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
903 ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
904 DiskSearchStrategy {
905 io_tracker: IOTracker::default(),
906 vector_filter,
907 query,
908 vertex_provider_factory: &self.vertex_provider_factory,
909 scratch_pool: &self.scratch_pool,
910 }
911 }
912
913 pub fn search(
916 &self,
917 query: &[Data::VectorDataType],
918 return_list_size: u32,
919 search_list_size: u32,
920 beam_width: Option<usize>,
921 vector_filter: Option<VectorFilter<Data>>,
922 is_flat_search: bool,
923 ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
924 let mut query_stats = QueryStatistics::default();
925 let mut indices = vec![0u32; return_list_size as usize];
926 let mut distances = vec![0f32; return_list_size as usize];
927 let mut associated_data =
928 vec![Data::AssociatedDataType::default(); return_list_size as usize];
929
930 let stats = self.search_internal(
931 query,
932 return_list_size as usize,
933 search_list_size,
934 beam_width,
935 &mut query_stats,
936 &mut indices,
937 &mut distances,
938 &mut associated_data,
939 &vector_filter.unwrap_or(default_vector_filter::<Data>()),
940 is_flat_search,
941 )?;
942
943 let mut search_result = SearchResult {
944 results: Vec::with_capacity(return_list_size as usize),
945 stats,
946 };
947
948 for ((vertex_id, distance), associated_data) in indices
949 .into_iter()
950 .zip(distances.into_iter())
951 .zip(associated_data.into_iter())
952 {
953 search_result.results.push(SearchResultItem {
954 vertex_id,
955 distance,
956 data: associated_data,
957 });
958 }
959
960 Ok(search_result)
961 }
962
963 #[allow(clippy::too_many_arguments)]
966 pub(crate) fn search_internal(
967 &self,
968 query: &[Data::VectorDataType],
969 k_value: usize,
970 search_list_size: u32,
971 beam_width: Option<usize>,
972 query_stats: &mut QueryStatistics,
973 indices: &mut [u32],
974 distances: &mut [f32],
975 associated_data: &mut [Data::AssociatedDataType],
976 vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
977 is_flat_search: bool,
978 ) -> ANNResult<SearchResultStats> {
979 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
980 &mut indices[..k_value],
981 &mut distances[..k_value],
982 &mut associated_data[..k_value],
983 );
984
985 let strategy = self.search_strategy(query, vector_filter);
986 let timer = Instant::now();
987 let k = k_value;
988 let l = search_list_size as usize;
989 let stats = if is_flat_search {
990 self.runtime.block_on(self.index.flat_search(
991 &strategy,
992 &DefaultContext,
993 strategy.query,
994 vector_filter,
995 &Knn::new(k, l, beam_width)?,
996 &mut result_output_buffer,
997 ))?
998 } else {
999 let knn_search = Knn::new(k, l, beam_width)?;
1000 self.runtime.block_on(self.index.search(
1001 knn_search,
1002 &strategy,
1003 &DefaultContext,
1004 strategy.query,
1005 &mut result_output_buffer,
1006 ))?
1007 };
1008 query_stats.total_comparisons = stats.cmps;
1009 query_stats.search_hops = stats.hops;
1010
1011 query_stats.total_execution_time_us = timer.elapsed().as_micros();
1012 query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1013 query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1014 query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1015 query_stats.query_pq_preprocess_time_us =
1016 IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1017 query_stats.cpu_time_us = query_stats.total_execution_time_us
1018 - query_stats.io_time_us
1019 - query_stats.query_pq_preprocess_time_us;
1020 Ok(SearchResultStats {
1021 cmps: query_stats.total_comparisons,
1022 result_count: stats.result_count,
1023 query_statistics: query_stats.clone(),
1024 })
1025 }
1026}
1027
1028fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1034 vertex_provider: &mut V,
1035 ids: &[Data::VectorIdType],
1036) -> ANNResult<()> {
1037 vertex_provider.load_vertices(ids)?;
1038 for (idx, id) in ids.iter().enumerate() {
1039 vertex_provider.process_loaded_node(id, idx)?;
1040 }
1041 Ok(())
1042}
1043
1044#[cfg(test)]
1045mod disk_provider_tests {
1046 use diskann::{
1047 graph::{
1048 search::{record::VisitedSearchRecord, Knn},
1049 KnnSearchError,
1050 },
1051 utils::IntoUsize,
1052 ANNErrorKind,
1053 };
1054 use diskann_providers::storage::{
1055 DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1056 };
1057 use diskann_providers::{
1058 common::AlignedBoxWithSlice,
1059 test_utils::graph_data_type_utils::{
1060 GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1061 },
1062 utils::{
1063 create_thread_pool, file_util, load_aligned_bin, PQPathNames, ParallelIteratorInPool,
1064 },
1065 };
1066 use diskann_utils::test_data_root;
1067 use diskann_vector::distance::Metric;
1068 use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1069 use rstest::rstest;
1070 use vfs::OverlayFS;
1071
1072 use super::*;
1073 use crate::{
1074 build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1075 utils::{QueryStatistics, VirtualAlignedReaderFactory},
1076 };
1077
1078 const TEST_INDEX_PREFIX_128DIM: &str =
1079 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1080 const TEST_INDEX_128DIM: &str =
1081 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1082 const TEST_PQ_PIVOT_128DIM: &str =
1083 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1084 const TEST_PQ_COMPRESSED_128DIM: &str =
1085 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1086 const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1087 "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1088 const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1089
1090 const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1091 const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1092 const TEST_PQ_PIVOT_100DIM: &str =
1093 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1094 const TEST_PQ_COMPRESSED_100DIM: &str =
1095 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1096 const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1097 "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1098 const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1099 const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1100 const TEST_INDEX: &str =
1101 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1102 const TEST_INDEX_PREFIX: &str =
1103 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1104 const TEST_PQ_PIVOT: &str =
1105 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1106 const TEST_PQ_COMPRESSED: &str =
1107 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1108
1109 #[test]
1110 fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1111 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1112
1113 let search_engine = create_disk_index_searcher(
1114 CreateDiskIndexSearcherParams {
1115 max_thread_num: 5,
1116 pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1117 pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1118 index_path: TEST_INDEX_100DIM,
1119 index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1120 ..Default::default()
1121 },
1122 &storage_provider,
1123 );
1124 test_disk_search(TestDiskSearchParams {
1126 storage_provider: storage_provider.as_ref(),
1127 index_search_engine: &search_engine,
1128 thread_num: 1,
1129 query_file_path: TEST_QUERY_10PTS_100DIM,
1130 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1131 k: 10,
1132 l: 20,
1133 dim: 104,
1134 });
1135 test_disk_search(TestDiskSearchParams {
1137 storage_provider: storage_provider.as_ref(),
1138 index_search_engine: &search_engine,
1139 thread_num: 5,
1140 query_file_path: TEST_QUERY_10PTS_100DIM,
1141 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1142 k: 10,
1143 l: 20,
1144 dim: 104,
1145 });
1146 }
1147
1148 #[test]
1149 fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1150 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1151
1152 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1153 CreateDiskIndexSearcherParams {
1154 max_thread_num: 5,
1155 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1156 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1157 index_path: TEST_INDEX_128DIM,
1158 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1159 ..Default::default()
1160 },
1161 &storage_provider,
1162 );
1163 test_disk_search(TestDiskSearchParams {
1165 storage_provider: storage_provider.as_ref(),
1166 index_search_engine: &search_engine,
1167 thread_num: 1,
1168 query_file_path: TEST_QUERY_10PTS_128DIM,
1169 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1170 k: 10,
1171 l: 20,
1172 dim: 128,
1173 });
1174 test_disk_search(TestDiskSearchParams {
1176 storage_provider: storage_provider.as_ref(),
1177 index_search_engine: &search_engine,
1178 thread_num: 5,
1179 query_file_path: TEST_QUERY_10PTS_128DIM,
1180 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1181 k: 10,
1182 l: 20,
1183 dim: 128,
1184 });
1185 }
1186
1187 fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1188 storage_provider: &StorageReader,
1189 ) -> Vec<u32> {
1190 const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1191
1192 let (data, _npts, _dim) =
1193 file_util::load_bin::<u32, StorageReader>(storage_provider, ASSOCIATED_DATA_FILE, 0)
1194 .unwrap();
1195 data
1196 }
1197
1198 #[test]
1199 fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1200 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1201 let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1202 let params = TestParams {
1203 data_path: TEST_DATA_FILE.to_string(),
1204 index_path_prefix: index_path_prefix.to_string(),
1205 associated_data_path: Some(
1206 "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1207 ),
1208 ..TestParams::default()
1209 };
1210 let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1211 fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1213 {
1214 let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1215 CreateDiskIndexSearcherParams {
1216 max_thread_num: 5,
1217 pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1218 pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1219 .as_str(),
1220 index_path: format!("{}_disk.index", index_path_prefix).as_str(), index_path_prefix,
1222 ..Default::default()
1223 },
1224 &fixture.storage_provider,
1225 );
1226
1227 test_disk_search_with_associated(
1229 TestDiskSearchAssociateParams {
1230 storage_provider: fixture.storage_provider.as_ref(),
1231 index_search_engine: &search_engine,
1232 thread_num: 1,
1233 query_file_path: TEST_QUERY_10PTS_128DIM,
1234 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1235 k: 10,
1236 l: 20,
1237 dim: 128,
1238 },
1239 None,
1240 );
1241
1242 test_disk_search_with_associated(
1244 TestDiskSearchAssociateParams {
1245 storage_provider: fixture.storage_provider.as_ref(),
1246 index_search_engine: &search_engine,
1247 thread_num: 5,
1248 query_file_path: TEST_QUERY_10PTS_128DIM,
1249 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1250 k: 10,
1251 l: 20,
1252 dim: 128,
1253 },
1254 None,
1255 );
1256 }
1257
1258 fixture
1259 .storage_provider
1260 .delete(&format!("{}_disk.index", index_path_prefix))
1261 .expect("Failed to delete file");
1262 fixture
1263 .storage_provider
1264 .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1265 .expect("Failed to delete file");
1266 fixture
1267 .storage_provider
1268 .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1269 .expect("Failed to delete file");
1270 }
1271
1272 struct CreateDiskIndexSearcherParams<'a> {
1273 max_thread_num: usize,
1274 pq_pivot_file_path: &'a str,
1275 pq_compressed_file_path: &'a str,
1276 index_path: &'a str,
1277 index_path_prefix: &'a str,
1278 io_limit: usize,
1279 }
1280
1281 impl Default for CreateDiskIndexSearcherParams<'_> {
1282 fn default() -> Self {
1283 Self {
1284 max_thread_num: 1,
1285 pq_pivot_file_path: "",
1286 pq_compressed_file_path: "",
1287 index_path: "",
1288 index_path_prefix: "",
1289 io_limit: usize::MAX,
1290 }
1291 }
1292 }
1293
1294 fn create_disk_index_searcher<Data>(
1295 params: CreateDiskIndexSearcherParams,
1296 storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1297 ) -> DiskIndexSearcher<
1298 Data,
1299 DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1300 >
1301 where
1302 Data: GraphDataType<VectorIdType = u32>,
1303 {
1304 assert!(params.io_limit > 0);
1305
1306 let runtime = tokio::runtime::Builder::new_multi_thread()
1307 .worker_threads(params.max_thread_num)
1308 .build()
1309 .unwrap();
1310
1311 let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1312 params.pq_pivot_file_path.to_string(),
1313 params.pq_compressed_file_path.to_string(),
1314 storage_provider.as_ref(),
1315 )
1316 .unwrap();
1317
1318 let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1319 get_disk_index_file(params.index_path_prefix),
1320 Arc::clone(storage_provider),
1321 );
1322 let caching_strategy = CachingStrategy::None;
1323 let vertex_provider_factory =
1324 DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1325 .unwrap();
1326
1327 DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1328 params.max_thread_num,
1329 params.io_limit,
1330 &disk_index_reader,
1331 vertex_provider_factory,
1332 Metric::L2,
1333 Some(runtime),
1334 )
1335 .unwrap()
1336 }
1337
1338 fn load_query_result<StorageReader: StorageReadProvider>(
1339 storage_provider: &StorageReader,
1340 query_result_path: &str,
1341 ) -> Vec<u32> {
1342 let (result, _, _) =
1343 file_util::load_bin::<u32, StorageReader>(storage_provider, query_result_path, 0)
1344 .unwrap();
1345 result
1346 }
1347
1348 struct TestDiskSearchParams<'a, StorageType> {
1349 storage_provider: &'a StorageType,
1350 index_search_engine: &'a DiskIndexSearcher<
1351 GraphDataF32VectorUnitData,
1352 DiskVertexProviderFactory<
1353 GraphDataF32VectorUnitData,
1354 VirtualAlignedReaderFactory<OverlayFS>,
1355 >,
1356 >,
1357 thread_num: u64,
1358 query_file_path: &'a str,
1359 truth_result_file_path: &'a str,
1360 k: usize,
1361 l: usize,
1362 dim: usize,
1363 }
1364
1365 struct TestDiskSearchAssociateParams<'a, StorageType> {
1366 storage_provider: &'a StorageType,
1367 index_search_engine: &'a DiskIndexSearcher<
1368 GraphDataF32VectorU32Data,
1369 DiskVertexProviderFactory<
1370 GraphDataF32VectorU32Data,
1371 VirtualAlignedReaderFactory<OverlayFS>,
1372 >,
1373 >,
1374 thread_num: u64,
1375 query_file_path: &'a str,
1376 truth_result_file_path: &'a str,
1377 k: usize,
1378 l: usize,
1379 dim: usize,
1380 }
1381
1382 fn test_disk_search<StorageType: StorageReadProvider>(
1383 params: TestDiskSearchParams<StorageType>,
1384 ) {
1385 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1386 .unwrap()
1387 .0;
1388 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1389 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1390
1391 let queries = aligned_query
1392 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1393 .unwrap();
1394
1395 let truth_result =
1396 load_query_result(params.storage_provider, params.truth_result_file_path);
1397
1398 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1399 queries
1401 .par_iter()
1402 .enumerate()
1403 .for_each_in_pool(&pool, |(i, query)| {
1404 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1406 let mut temp = Vec::with_capacity(query.len() + 1);
1407 temp.push(0.0);
1408 temp.extend_from_slice(query);
1409 aligned_box.memcpy(temp.as_slice()).unwrap();
1410 let query = &aligned_box.as_slice()[1..];
1411
1412 let mut query_stats = QueryStatistics::default();
1413 let mut indices = vec![0u32; 10];
1414 let mut distances = vec![0f32; 10];
1415 let mut associated_data = vec![(); 10];
1416
1417 let result = params
1418 .index_search_engine
1419 .search_internal(
1421 query,
1422 params.k,
1423 params.l as u32,
1424 None, &mut query_stats,
1426 &mut indices,
1427 &mut distances,
1428 &mut associated_data,
1429 &(|_| true),
1430 false,
1431 );
1432
1433 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1435
1436 assert!(result.is_ok(), "Expected search to succeed");
1437
1438 let result_unwrapped = result.unwrap();
1439 assert!(
1440 result_unwrapped.query_statistics.total_io_operations > 0,
1441 "Expected IO operations to be greater than 0"
1442 );
1443 assert!(
1444 result_unwrapped.query_statistics.total_vertices_loaded > 0,
1445 "Expected vertices loaded to be greater than 0"
1446 );
1447
1448 assert_eq!(
1450 indices, truth_slice,
1451 "Results DO NOT match with the truth result for query {}",
1452 i
1453 );
1454 });
1455 }
1456
1457 fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1458 params: TestDiskSearchAssociateParams<StorageType>,
1459 beam_width: Option<usize>,
1460 ) {
1461 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1462 .unwrap()
1463 .0;
1464 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1465 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1466 let queries = aligned_query
1467 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1468 .unwrap();
1469 let truth_result =
1470 load_query_result(params.storage_provider, params.truth_result_file_path);
1471 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1472 queries
1474 .par_iter()
1475 .enumerate()
1476 .for_each_in_pool(&pool, |(i, query)| {
1477 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1479 let mut temp = Vec::with_capacity(query.len() + 1);
1480 temp.push(0.0);
1481 temp.extend_from_slice(query);
1482 aligned_box.memcpy(temp.as_slice()).unwrap();
1483 let query = &aligned_box.as_slice()[1..];
1484 let result = params
1485 .index_search_engine
1486 .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1487 .unwrap();
1488 let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1489 let associated_data: Vec<u32> =
1490 result.results.iter().map(|item| item.data).collect();
1491 let truth_data = get_truth_associated_data(params.storage_provider);
1492 let associated_data_truth: Vec<u32> = indices
1493 .iter()
1494 .map(|&vid| truth_data[vid as usize])
1495 .collect();
1496 assert_eq!(
1497 associated_data, associated_data_truth,
1498 "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1499 i,associated_data, associated_data_truth
1500 );
1501 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1502 assert_eq!(
1503 indices, truth_slice,
1504 "Results DO NOT match with the truth result for query {}",
1505 i
1506 );
1507 });
1508 }
1509
1510 #[test]
1511 fn test_disk_search_invalid_input() {
1512 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1513 let ctx = &DefaultContext;
1514
1515 let params = CreateDiskIndexSearcherParams {
1516 max_thread_num: 5,
1517 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1518 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1519 index_path: TEST_INDEX_128DIM,
1520 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1521 ..Default::default()
1522 };
1523
1524 let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1525 assert_eq!(
1526 paths.pivots, params.pq_pivot_file_path,
1527 "pq_pivot_file_path is not correct"
1528 );
1529 assert_eq!(
1530 paths.compressed_data, params.pq_compressed_file_path,
1531 "pq_compressed_file_path is not correct"
1532 );
1533 assert_eq!(
1534 params.index_path,
1535 format!("{}_disk.index", params.index_path_prefix),
1536 "index_path is not correct"
1537 );
1538
1539 let res = Knn::new_default(20, 10);
1541 assert!(res.is_err());
1542 assert_eq!(
1543 <KnnSearchError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1544 ANNErrorKind::IndexError
1545 );
1546 let res = Knn::new(10, 10, Some(0));
1548 assert!(res.is_err());
1549
1550 let search_engine =
1551 create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1552
1553 assert_eq!(
1555 search_engine
1556 .index
1557 .data_provider
1558 .to_external_id(ctx, 0)
1559 .unwrap(),
1560 0
1561 );
1562 assert_eq!(
1563 search_engine
1564 .index
1565 .data_provider
1566 .to_internal_id(ctx, &0)
1567 .unwrap(),
1568 0
1569 );
1570
1571 let provider_max_degree = search_engine
1572 .index
1573 .data_provider
1574 .graph_header
1575 .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1576 .unwrap();
1577 let index_max_degree = search_engine.index.config.pruned_degree().get();
1578 assert_eq!(provider_max_degree, index_max_degree);
1579
1580 let query = vec![0f32; 128];
1581 let mut query_stats = QueryStatistics::default();
1582 let mut indices = vec![0u32; 10];
1583 let mut distances = vec![0f32; 10];
1584 let mut associated_data = vec![0u32; 10];
1585
1586 let result = search_engine.search_internal(
1588 &query,
1589 10,
1590 10 - 1,
1591 None,
1592 &mut query_stats,
1593 &mut indices,
1594 &mut distances,
1595 &mut associated_data,
1596 &|_| true,
1597 false,
1598 );
1599
1600 assert!(result.is_err());
1601 assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1602 }
1603
1604 #[test]
1605 fn test_disk_search_beam_search() {
1606 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1607
1608 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1609 CreateDiskIndexSearcherParams {
1610 max_thread_num: 1,
1611 pq_pivot_file_path: TEST_PQ_PIVOT,
1612 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1613 index_path: TEST_INDEX,
1614 index_path_prefix: TEST_INDEX_PREFIX,
1615 ..Default::default()
1616 },
1617 &storage_provider,
1618 );
1619
1620 let query_vector: [f32; 128] = [1f32; 128];
1621 let mut indices = vec![0u32; 10];
1622 let mut distances = vec![0f32; 10];
1623 let mut associated_data = vec![(); 10];
1624
1625 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1626 &mut indices,
1627 &mut distances,
1628 &mut associated_data,
1629 );
1630 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1631 let mut search_record = VisitedSearchRecord::new(0);
1632 let search_params = Knn::new(10, 10, Some(4)).unwrap();
1633 let recorded_search =
1634 diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
1635 search_engine
1636 .runtime
1637 .block_on(search_engine.index.search(
1638 recorded_search,
1639 &strategy,
1640 &DefaultContext,
1641 query_vector.as_slice(),
1642 &mut result_output_buffer,
1643 ))
1644 .unwrap();
1645
1646 let ids = search_record
1647 .visited
1648 .iter()
1649 .map(|n| n.id)
1650 .collect::<Vec<_>>();
1651
1652 const EXPECTED_NODES: [u32; 18] = [
1653 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1654 ]; assert_eq!(ids, &EXPECTED_NODES);
1657
1658 let return_list_size = 10;
1659 let search_list_size = 10;
1660 let result = search_engine.search(
1661 &query_vector,
1662 return_list_size,
1663 search_list_size,
1664 Some(4),
1665 None,
1666 false,
1667 );
1668 assert!(result.is_ok(), "Expected search to succeed");
1669 let search_result = result.unwrap();
1670 assert_eq!(
1671 search_result.results.len() as u32,
1672 return_list_size,
1673 "Expected result count to match"
1674 );
1675 assert_eq!(
1676 indices,
1677 vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1678 "Expected indices to match"
1679 );
1680 }
1681
1682 #[cfg(feature = "experimental_diversity_search")]
1683 #[test]
1684 fn test_disk_search_diversity_search() {
1685 use diskann::graph::DiverseSearchParams;
1686 use diskann::neighbor::AttributeValueProvider;
1687 use std::collections::HashMap;
1688
1689 #[derive(Debug, Clone)]
1691 struct TestAttributeProvider {
1692 attributes: HashMap<u32, u32>,
1693 }
1694 impl TestAttributeProvider {
1695 fn new() -> Self {
1696 Self {
1697 attributes: HashMap::new(),
1698 }
1699 }
1700 fn insert(&mut self, id: u32, attribute: u32) {
1701 self.attributes.insert(id, attribute);
1702 }
1703 }
1704 impl diskann::provider::HasId for TestAttributeProvider {
1705 type Id = u32;
1706 }
1707
1708 impl AttributeValueProvider for TestAttributeProvider {
1709 type Value = u32;
1710
1711 fn get(&self, id: Self::Id) -> Option<Self::Value> {
1712 self.attributes.get(&id).copied()
1713 }
1714 }
1715
1716 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1717
1718 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1719 CreateDiskIndexSearcherParams {
1720 max_thread_num: 1,
1721 pq_pivot_file_path: TEST_PQ_PIVOT,
1722 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1723 index_path: TEST_INDEX,
1724 index_path_prefix: TEST_INDEX_PREFIX,
1725 ..Default::default()
1726 },
1727 &storage_provider,
1728 );
1729
1730 let query_vector: [f32; 128] = [1f32; 128];
1731
1732 let mut attribute_provider = TestAttributeProvider::new();
1734 let num_vectors = 256; for i in 0..num_vectors {
1736 let label = (i % 15) + 1;
1738 attribute_provider.insert(i, label);
1739 }
1740 let attribute_provider = std::sync::Arc::new(attribute_provider);
1742
1743 let mut indices = vec![0u32; 10];
1744 let mut distances = vec![0f32; 10];
1745 let mut associated_data = vec![(); 10];
1746
1747 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1748 &mut indices,
1749 &mut distances,
1750 &mut associated_data,
1751 );
1752 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1753
1754 let diverse_params = DiverseSearchParams::new(
1756 0, 3, attribute_provider.clone(),
1759 );
1760
1761 let search_params = Knn::new(10, 20, None).unwrap();
1762
1763 let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params);
1764 let stats = search_engine
1765 .runtime
1766 .block_on(search_engine.index.search(
1767 diverse_search,
1768 &strategy,
1769 &DefaultContext,
1770 query_vector.as_slice(),
1771 &mut result_output_buffer,
1772 ))
1773 .unwrap();
1774
1775 assert!(
1777 stats.result_count > 0,
1778 "Expected to get some results during diversity search"
1779 );
1780
1781 let return_list_size = 10;
1782 let search_list_size = 20;
1783 let diverse_results_k = 1;
1784 let diverse_params = DiverseSearchParams::new(
1785 0, diverse_results_k,
1787 attribute_provider.clone(),
1788 );
1789
1790 let mut indices2 = vec![0u32; return_list_size as usize];
1792 let mut distances2 = vec![0f32; return_list_size as usize];
1793 let mut associated_data2 = vec![(); return_list_size as usize];
1794 let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1795 &mut indices2,
1796 &mut distances2,
1797 &mut associated_data2,
1798 );
1799 let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1800 let search_params2 =
1801 Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1802
1803 let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params);
1804 let stats = search_engine
1805 .runtime
1806 .block_on(search_engine.index.search(
1807 diverse_search2,
1808 &strategy2,
1809 &DefaultContext,
1810 query_vector.as_slice(),
1811 &mut result_output_buffer2,
1812 ))
1813 .unwrap();
1814
1815 assert!(
1817 stats.result_count > 0,
1818 "Expected diversity search to return results"
1819 );
1820 assert!(
1821 stats.result_count <= return_list_size,
1822 "Expected result count to be <= {}",
1823 return_list_size
1824 );
1825
1826 assert!(
1828 stats.result_count > 0,
1829 "Expected to get some search results"
1830 );
1831
1832 println!("\n=== Diversity Search Results ===");
1834 println!("Query: [1f32; 128]");
1835 println!("diverse_results_k: {}", diverse_results_k);
1836 println!("Total results: {}\n", stats.result_count);
1837 println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1838 println!("{}", "-".repeat(35));
1839 for i in 0..stats.result_count as usize {
1840 let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1841 println!(
1842 "{:<10} {:<15.2} {:<10}",
1843 indices2[i], distances2[i], attribute_value
1844 );
1845 }
1846
1847 for i in 0..(stats.result_count as usize).saturating_sub(1) {
1849 assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1850 assert!(
1851 distances2[i] <= distances2[i + 1],
1852 "Expected distances to be sorted in ascending order"
1853 );
1854 }
1855
1856 let mut attribute_counts = HashMap::new();
1858 for item in indices2.iter().take(stats.result_count as usize) {
1859 if let Some(attribute_value) = attribute_provider.get(*item) {
1860 *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1861 }
1862 }
1863
1864 println!("\n=== Attribute Distribution ===");
1866 let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1867 sorted_attrs.sort_by_key(|(k, _)| *k);
1868 for (attribute_value, count) in &sorted_attrs {
1869 println!(
1870 "Label {}: {} occurrences (max allowed: {})",
1871 attribute_value, count, diverse_results_k
1872 );
1873 }
1874 println!("Total unique labels: {}", attribute_counts.len());
1875 println!("================================\n");
1876
1877 for (attribute_value, count) in &attribute_counts {
1879 println!(
1880 "Assert: Label {} has {} occurrences (max: {})",
1881 attribute_value, count, diverse_results_k
1882 );
1883 assert!(
1884 *count <= diverse_results_k,
1885 "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1886 attribute_value,
1887 count,
1888 diverse_results_k
1889 );
1890 }
1891
1892 println!(
1895 "Assert: Found {} unique labels (expected at least 2)",
1896 attribute_counts.len()
1897 );
1898 assert!(
1899 attribute_counts.len() >= 2,
1900 "Expected at least 2 different attribute values for diversity, got {}",
1901 attribute_counts.len()
1902 );
1903 }
1904
1905 #[rstest]
1906 #[case(
1908 |_id: &u32| true,
1909 false,
1910 10,
1911 vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1912 vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1913 )]
1914 #[case(
1917 |id: &u32| *id == 0 || *id == 1,
1918 false,
1919 0,
1920 vec![0; 10],
1921 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1922 )]
1923 #[case(
1926 |id: &u32| *id == 0 || *id == 1,
1927 true,
1928 2,
1929 vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1930 vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1931 )]
1932 #[case(
1935 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1936 false,
1937 3,
1938 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1939 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1940 )]
1941 #[case(
1944 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1945 true,
1946 3,
1947 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1948 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1949 )]
1950 fn test_search_with_vector_filter(
1951 #[case] vector_filter: fn(&u32) -> bool,
1952 #[case] is_flat_search: bool,
1953 #[case] expected_result_count: u32,
1954 #[case] expected_indices: Vec<u32>,
1955 #[case] expected_distances: Vec<f32>,
1956 ) {
1957 let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1962 const ABS_TOLERANCE: f32 = 0.02;
1963 assert_eq!(got.len(), expected.len());
1964 for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1965 if (g - e).abs() > ABS_TOLERANCE {
1966 panic!(
1967 "distances differ at position {} by more than {}\n\n\
1968 got: {:?}\nexpected: {:?}",
1969 i, ABS_TOLERANCE, got, expected,
1970 );
1971 }
1972 }
1973 true
1974 };
1975
1976 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1977
1978 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1979 CreateDiskIndexSearcherParams {
1980 max_thread_num: 5,
1981 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1982 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1983 index_path: TEST_INDEX_128DIM,
1984 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1985 ..Default::default()
1986 },
1987 &storage_provider,
1988 );
1989 let query = vec![0.1f32; 128];
1990 let mut query_stats = QueryStatistics::default();
1991 let mut indices = vec![0u32; 10];
1992 let mut distances = vec![0f32; 10];
1993 let mut associated_data = vec![(); 10];
1994
1995 let result = search_engine.search_internal(
1996 &query,
1997 10,
1998 10,
1999 None, &mut query_stats,
2001 &mut indices,
2002 &mut distances,
2003 &mut associated_data,
2004 &vector_filter,
2005 is_flat_search,
2006 );
2007
2008 assert!(result.is_ok(), "Expected search to succeed");
2009 assert_eq!(
2010 result.unwrap().result_count,
2011 expected_result_count,
2012 "Expected result count to match"
2013 );
2014 assert_eq!(indices, expected_indices, "Expected indices to match");
2015 assert!(
2016 check_distances(&distances, &expected_distances),
2017 "Expected distances to match"
2018 );
2019
2020 let result_with_filter = search_engine.search(
2021 &query,
2022 10,
2023 10,
2024 None, Some(Box::new(vector_filter)),
2026 is_flat_search,
2027 );
2028
2029 assert!(result_with_filter.is_ok(), "Expected search to succeed");
2030 let result_with_filter_unwrapped = result_with_filter.unwrap();
2031 assert_eq!(
2032 result_with_filter_unwrapped.stats.result_count, expected_result_count,
2033 "Expected result count to match"
2034 );
2035 let actual_indices = result_with_filter_unwrapped
2036 .results
2037 .iter()
2038 .map(|x| x.vertex_id)
2039 .collect::<Vec<_>>();
2040 assert_eq!(
2041 actual_indices, expected_indices,
2042 "Expected indices to match"
2043 );
2044 let actual_distances = result_with_filter_unwrapped
2045 .results
2046 .iter()
2047 .map(|x| x.distance)
2048 .collect::<Vec<_>>();
2049 assert!(
2050 check_distances(&actual_distances, &expected_distances),
2051 "Expected distances to match"
2052 );
2053 }
2054
2055 #[test]
2056 fn test_beam_search_respects_io_limit() {
2057 let io_limit = 11; let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2059
2060 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2061 CreateDiskIndexSearcherParams {
2062 max_thread_num: 1,
2063 pq_pivot_file_path: TEST_PQ_PIVOT,
2064 pq_compressed_file_path: TEST_PQ_COMPRESSED,
2065 index_path: TEST_INDEX,
2066 index_path_prefix: TEST_INDEX_PREFIX,
2067 io_limit,
2068 },
2069 &storage_provider,
2070 );
2071 let query_vector: [f32; 128] = [1f32; 128];
2072
2073 let mut indices = vec![0u32; 10];
2074 let mut distances = vec![0f32; 10];
2075 let mut associated_data = vec![(); 10];
2076
2077 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2078 &mut indices,
2079 &mut distances,
2080 &mut associated_data,
2081 );
2082
2083 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2084
2085 let mut search_record = VisitedSearchRecord::new(0);
2086 let search_params = Knn::new(10, 10, Some(4)).unwrap();
2087 let recorded_search =
2088 diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
2089 search_engine
2090 .runtime
2091 .block_on(search_engine.index.search(
2092 recorded_search,
2093 &strategy,
2094 &DefaultContext,
2095 query_vector.as_slice(),
2096 &mut result_output_buffer,
2097 ))
2098 .unwrap();
2099 let visited_ids = search_record
2100 .visited
2101 .iter()
2102 .map(|n| n.id)
2103 .collect::<Vec<_>>();
2104
2105 let query_stats = strategy.io_tracker;
2106 assert!(
2108 query_stats.io_count() <= io_limit,
2109 "Expected IO operations to be <= {}, but got {}",
2110 io_limit,
2111 query_stats.io_count()
2112 );
2113
2114 const EXPECTED_NODES: [u32; 17] = [
2115 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2116 ]; let mut matching_count = 0;
2120 for expected_node in EXPECTED_NODES.iter() {
2121 if visited_ids.contains(expected_node) {
2122 matching_count += 1;
2123 }
2124 }
2125
2126 let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2128
2129 assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2132 }
2133}