1use std::{
7 collections::HashMap,
8 future::Future,
9 num::NonZeroUsize,
10 ops::Range,
11 sync::{
12 atomic::{AtomicU64, AtomicUsize},
13 Arc,
14 },
15 time::Instant,
16};
17
18use diskann::{
19 graph::{
20 self,
21 glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy},
22 search::Knn,
23 search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer,
24 },
25 neighbor::Neighbor,
26 provider::{
27 Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
28 NeighborAccessor,
29 },
30 utils::{
31 object_pool::{ObjectPool, PoolOption, TryAsPooled},
32 IntoUsize, VectorRepr,
33 },
34 ANNError, ANNResult,
35};
36use diskann_providers::storage::StorageReadProvider;
37use diskann_providers::{
38 model::{
39 compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
40 pq::quantizer_preprocess, PQData, PQScratch,
41 },
42 storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
43};
44use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
45use futures_util::future;
46use tokio::runtime::Runtime;
47use tracing::debug;
48
49use crate::{
50 data_model::{CachingStrategy, GraphHeader},
51 filter_parameter::{default_vector_filter, VectorFilter},
52 search::{
53 provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
54 traits::{VertexProvider, VertexProviderFactory},
55 },
56 storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
57 utils::AlignedFileReaderFactory,
58 utils::QueryStatistics,
59};
60
61pub struct DiskProvider<Data>
71where
72 Data: GraphDataType<VectorIdType = u32>,
73{
74 graph_header: GraphHeader,
76
77 distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
79
80 pq_data: Arc<PQData>,
82
83 num_points: usize,
85
86 metric: Metric,
88
89 search_io_limit: usize,
91}
92
93impl<Data> DataProvider for DiskProvider<Data>
94where
95 Data: GraphDataType<VectorIdType = u32>,
96{
97 type Context = DefaultContext;
98
99 type InternalId = u32;
100
101 type ExternalId = u32;
102
103 type Error = ANNError;
104
105 fn to_internal_id(
107 &self,
108 _context: &DefaultContext,
109 gid: &Self::ExternalId,
110 ) -> Result<Self::InternalId, Self::Error> {
111 Ok(*gid)
112 }
113
114 fn to_external_id(
116 &self,
117 _context: &DefaultContext,
118 id: Self::InternalId,
119 ) -> Result<Self::ExternalId, Self::Error> {
120 Ok(id)
121 }
122}
123
124impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
125where
126 Data: GraphDataType<VectorIdType = u32>,
127{
128 type Error = ANNError;
129
130 async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
131 where
132 P: StorageReadProvider,
133 {
134 debug!(
135 "DiskProvider::load_with() called with file: {:?}",
136 get_disk_index_file(ctx.quant_load_context.metadata.prefix())
137 );
138
139 let graph_header = {
140 let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
141 ctx.quant_load_context.metadata.prefix(),
142 ));
143
144 let caching_strategy = if ctx.num_nodes_to_cache > 0 {
145 CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
146 } else {
147 CachingStrategy::None
148 };
149
150 let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
151 aligned_reader_factory,
152 caching_strategy,
153 )?;
154 VertexProviderFactory::get_header(&vertex_provider_factory)?
155 };
156
157 let metric = ctx.quant_load_context.metric;
158 let num_points = ctx.num_points;
159
160 let index_path_prefix = ctx.quant_load_context.metadata.prefix();
161 let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
162 get_pq_pivot_file(index_path_prefix),
163 get_compressed_pq_file(index_path_prefix),
164 provider,
165 )?;
166
167 Self::new(
168 &index_reader,
169 graph_header,
170 metric,
171 num_points,
172 ctx.search_io_limit,
173 )
174 }
175}
176
177impl<Data> DiskProvider<Data>
178where
179 Data: GraphDataType<VectorIdType = u32>,
180{
181 fn new(
182 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
183 graph_header: GraphHeader,
184 metric: Metric,
185 num_points: usize,
186 search_io_limit: usize,
187 ) -> ANNResult<Self> {
188 let distance_comparer =
189 Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
190
191 let pq_data = disk_index_reader.get_pq_data();
192
193 Ok(Self {
194 graph_header,
195 distance_comparer,
196 pq_data,
197 num_points,
198 metric,
199 search_io_limit,
200 })
201 }
202}
203
204pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
215where
216 Data: GraphDataType<VectorIdType = u32>,
217 ProviderFactory: VertexProviderFactory<Data>,
218{
219 io_tracker: IOTracker,
221 vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), query: &'a [Data::VectorDataType],
223
224 vertex_provider_factory: &'a ProviderFactory,
226
227 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
229}
230
231struct IOTracker {
234 io_time_us: AtomicU64,
235 preprocess_time_us: AtomicU64,
236 io_count: AtomicUsize,
237}
238
239impl Default for IOTracker {
240 fn default() -> Self {
241 Self {
242 io_time_us: AtomicU64::new(0),
243 preprocess_time_us: AtomicU64::new(0),
244 io_count: AtomicUsize::new(0),
245 }
246 }
247}
248
249impl IOTracker {
250 fn add_time(category: &AtomicU64, time: u64) {
251 category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
252 }
253
254 fn time(category: &AtomicU64) -> u64 {
255 category.load(std::sync::atomic::Ordering::Relaxed)
256 }
257
258 fn add_io_count(&self, count: usize) {
259 self.io_count
260 .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
261 }
262
263 fn io_count(&self) -> usize {
264 self.io_count.load(std::sync::atomic::Ordering::Relaxed)
265 }
266}
267
268#[derive(Clone, Copy)]
269pub struct RerankAndFilter<'a> {
270 filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
271}
272
273impl<'a> RerankAndFilter<'a> {
274 fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
275 Self { filter }
276 }
277}
278
279impl<Data, VP>
280 SearchPostProcess<
281 DiskAccessor<'_, Data, VP>,
282 [Data::VectorDataType],
283 (
284 <DiskProvider<Data> as DataProvider>::InternalId,
285 Data::AssociatedDataType,
286 ),
287 > for RerankAndFilter<'_>
288where
289 Data: GraphDataType<VectorIdType = u32>,
290 VP: VertexProvider<Data>,
291{
292 type Error = ANNError;
293 async fn post_process<I, B>(
294 &self,
295 accessor: &mut DiskAccessor<'_, Data, VP>,
296 query: &[Data::VectorDataType],
297 _computer: &DiskQueryComputer,
298 candidates: I,
299 output: &mut B,
300 ) -> Result<usize, Self::Error>
301 where
302 I: Iterator<Item = Neighbor<u32>> + Send,
303 B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized,
304 {
305 let provider = accessor.provider;
306
307 let mut uncached_ids = Vec::new();
308 let mut reranked = candidates
309 .map(|n| n.id)
310 .filter(|id| (self.filter)(id))
311 .filter_map(|n| {
312 if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
313 Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
314 } else {
315 uncached_ids.push(n);
316 None
317 }
318 })
319 .collect::<Result<Vec<_>, _>>()?;
320 if !uncached_ids.is_empty() {
321 ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
322 for n in &uncached_ids {
323 let v = accessor.scratch.vertex_provider.get_vector(n)?;
324 let d = provider.distance_comparer.evaluate_similarity(query, v);
325 let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
326 reranked.push(((*n, *a), d));
327 }
328 }
329
330 reranked
332 .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
333 Ok(output.extend(reranked))
335 }
336}
337
338impl<'this, Data, ProviderFactory>
339 SearchStrategy<
340 DiskProvider<Data>,
341 [Data::VectorDataType],
342 (
343 <DiskProvider<Data> as DataProvider>::InternalId,
344 Data::AssociatedDataType,
345 ),
346 > for DiskSearchStrategy<'this, Data, ProviderFactory>
347where
348 Data: GraphDataType<VectorIdType = u32>,
349 ProviderFactory: VertexProviderFactory<Data>,
350{
351 type QueryComputer = DiskQueryComputer;
352 type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
353 type SearchAccessorError = ANNError;
354 type PostProcessor = RerankAndFilter<'this>;
355
356 fn search_accessor<'a>(
357 &'a self,
358 provider: &'a DiskProvider<Data>,
359 _context: &DefaultContext,
360 ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
361 DiskAccessor::new(
362 provider,
363 &self.io_tracker,
364 self.query,
365 self.vertex_provider_factory,
366 self.scratch_pool,
367 )
368 }
369
370 fn post_processor(&self) -> Self::PostProcessor {
371 RerankAndFilter::new(self.vector_filter)
372 }
373}
374
375pub struct DiskQueryComputer {
377 num_pq_chunks: usize,
378 query_centroid_l2_distance: Vec<f32>,
379}
380
381impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
382 fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
383 let mut dist = 0.0f32;
384 #[allow(clippy::expect_used)]
385 compute_pq_distance_for_pq_coordinates(
386 changing,
387 self.num_pq_chunks,
388 &self.query_centroid_l2_distance,
389 std::slice::from_mut(&mut dist),
390 )
391 .expect("PQ distance compute for PQ coordinates is expected to succeed");
392 dist
393 }
394}
395
396impl<Data, VP> BuildQueryComputer<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
397where
398 Data: GraphDataType<VectorIdType = u32>,
399 VP: VertexProvider<Data>,
400{
401 type QueryComputerError = ANNError;
402 type QueryComputer = DiskQueryComputer;
403
404 fn build_query_computer(
405 &self,
406 _from: &[Data::VectorDataType],
407 ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
408 Ok(DiskQueryComputer {
409 num_pq_chunks: self.provider.pq_data.get_num_chunks(),
410 query_centroid_l2_distance: self
411 .scratch
412 .pq_scratch
413 .aligned_pqtable_dist_scratch
414 .as_slice()
415 .to_vec(),
416 })
417 }
418
419 async fn distances_unordered<Itr, F>(
420 &mut self,
421 vec_id_itr: Itr,
422 _computer: &Self::QueryComputer,
423 f: F,
424 ) -> Result<(), Self::GetError>
425 where
426 F: Send + FnMut(f32, Self::Id),
427 Itr: Iterator<Item = Self::Id>,
428 {
429 self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
430 }
431}
432
433impl<Data, VP> ExpandBeam<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
434where
435 Data: GraphDataType<VectorIdType = u32>,
436 VP: VertexProvider<Data>,
437{
438 fn expand_beam<Itr, P, F>(
439 &mut self,
440 ids: Itr,
441 _computer: &Self::QueryComputer,
442 mut pred: P,
443 mut f: F,
444 ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
445 where
446 Itr: Iterator<Item = Self::Id> + Send,
447 P: glue::HybridPredicate<Self::Id> + Send + Sync,
448 F: FnMut(f32, Self::Id) + Send,
449 {
450 let result = (|| {
451 let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
452 let load_ids: Box<[_]> = ids.take(io_limit).collect();
453
454 self.ensure_loaded(&load_ids)?;
455 let mut ids = Vec::new();
456 for i in load_ids {
457 ids.clear();
458 ids.extend(
459 self.scratch
460 .vertex_provider
461 .get_adjacency_list(&i)?
462 .iter()
463 .copied()
464 .filter(|id| pred.eval_mut(id)),
465 );
466
467 self.pq_distances(&ids, &mut f)?;
468 }
469
470 Ok(())
471 })();
472
473 std::future::ready(result)
474 }
475}
476
477struct DiskSearchScratch<Data, VP>
480where
481 Data: GraphDataType<VectorIdType = u32>,
482 VP: VertexProvider<Data>,
483{
484 distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
485 pq_scratch: PQScratch,
486 vertex_provider: VP,
487}
488
489#[derive(Clone)]
490struct DiskSearchScratchArgs<'a, ProviderFactory> {
491 graph_degree: usize,
492 dim: usize,
493 num_pq_chunks: usize,
494 num_pq_centers: usize,
495 vertex_factory: &'a ProviderFactory,
496 graph_header: &'a GraphHeader,
497}
498
499impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
500 for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
501where
502 Data: GraphDataType<VectorIdType = u32>,
503 ProviderFactory: VertexProviderFactory<Data>,
504{
505 type Error = ANNError;
506
507 fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
508 let pq_scratch = PQScratch::new(
509 args.graph_degree,
510 args.dim,
511 args.num_pq_chunks,
512 args.num_pq_centers,
513 )?;
514
515 const DEFAULT_BEAM_WIDTH: usize = 0; let vertex_provider = args
517 .vertex_factory
518 .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
519
520 Ok(Self {
521 distance_cache: HashMap::new(),
522 pq_scratch,
523 vertex_provider,
524 })
525 }
526
527 fn try_modify(
528 &mut self,
529 _args: &DiskSearchScratchArgs<ProviderFactory>,
530 ) -> Result<(), Self::Error> {
531 self.distance_cache.clear();
532 self.vertex_provider.clear();
533 Ok(())
534 }
535}
536
537pub struct DiskAccessor<'a, Data, VP>
538where
539 Data: GraphDataType<VectorIdType = u32>,
540 VP: VertexProvider<Data>,
541{
542 provider: &'a DiskProvider<Data>,
543 io_tracker: &'a IOTracker,
544 scratch: PoolOption<DiskSearchScratch<Data, VP>>,
545 query: &'a [Data::VectorDataType],
546}
547
548impl<Data, VP> DiskAccessor<'_, Data, VP>
549where
550 Data: GraphDataType<VectorIdType = u32>,
551 VP: VertexProvider<Data>,
552{
553 fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
556 where
557 F: FnMut(f32, u32),
558 {
559 let pq_scratch = &mut self.scratch.pq_scratch;
560 compute_pq_distance(
561 ids,
562 self.provider.pq_data.get_num_chunks(),
563 &pq_scratch.aligned_pqtable_dist_scratch,
564 self.provider.pq_data.pq_compressed_data().get_data(),
565 &mut pq_scratch.aligned_pq_coord_scratch,
566 &mut pq_scratch.aligned_dist_scratch,
567 )?;
568
569 for (i, id) in ids.iter().enumerate() {
570 let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
571 f(distance, *id);
572 }
573
574 Ok(())
575 }
576}
577
578impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
579where
580 Data: GraphDataType<VectorIdType = u32>,
581 VP: VertexProvider<Data>,
582{
583 async fn starting_points(&self) -> ANNResult<Vec<u32>> {
584 let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
585 Ok(vec![start_vertex_id])
586 }
587
588 fn terminate_early(&mut self) -> bool {
589 self.io_tracker.io_count() > self.provider.search_io_limit
590 }
591}
592
593impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
594where
595 Data: GraphDataType<VectorIdType = u32>,
596 VP: VertexProvider<Data>,
597{
598 fn new<VPF>(
599 provider: &'a DiskProvider<Data>,
600 io_tracker: &'a IOTracker,
601 query: &'a [Data::VectorDataType],
602 vertex_provider_factory: &'a VPF,
603 scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
604 ) -> ANNResult<Self>
605 where
606 VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
607 {
608 let mut scratch = PoolOption::try_pooled(
609 scratch_pool,
610 &DiskSearchScratchArgs {
611 graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
612 dim: provider.graph_header.metadata().dims,
613 num_pq_chunks: provider.pq_data.get_num_chunks(),
614 num_pq_centers: provider.pq_data.get_num_centers(),
615 vertex_factory: vertex_provider_factory,
616 graph_header: &provider.graph_header,
617 },
618 )?;
619
620 scratch.pq_scratch.set(
621 provider.graph_header.metadata().dims,
622 query,
623 1.0_f32, )?;
625 let start_vertex_id = provider.graph_header.metadata().medoid as u32;
626
627 let timer = Instant::now();
628 quantizer_preprocess(
629 &mut scratch.pq_scratch,
630 &provider.pq_data,
631 provider.metric,
632 &[start_vertex_id],
633 )?;
634 IOTracker::add_time(
635 &io_tracker.preprocess_time_us,
636 timer.elapsed().as_micros() as u64,
637 );
638
639 Ok(Self {
640 provider,
641 io_tracker,
642 scratch,
643 query,
644 })
645 }
646 fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
647 if ids.is_empty() {
648 return Ok(());
649 }
650 let scratch = &mut self.scratch;
651 let timer = Instant::now();
652 ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
653 IOTracker::add_time(
654 &self.io_tracker.io_time_us,
655 timer.elapsed().as_micros() as u64,
656 );
657 self.io_tracker.add_io_count(ids.len());
658 for id in ids {
659 let distance = self
660 .provider
661 .distance_comparer
662 .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
663 let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
664 scratch
665 .distance_cache
666 .insert(*id, (distance, associated_data));
667 }
668 Ok(())
669 }
670}
671
672impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
673where
674 Data: GraphDataType<VectorIdType = u32>,
675 VP: VertexProvider<Data>,
676{
677 type Id = u32;
678}
679
680impl<'a, Data, VP> Accessor for DiskAccessor<'a, Data, VP>
681where
682 Data: GraphDataType<VectorIdType = u32>,
683 VP: VertexProvider<Data>,
684{
685 type Extended = &'a [u8];
687
688 type Element<'b>
694 = &'a [u8]
695 where
696 Self: 'b;
697
698 type ElementRef<'b> = &'b [u8];
700
701 type GetError = ANNError;
703
704 fn get_element(
705 &mut self,
706 id: Self::Id,
707 ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
708 std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
709 }
710}
711
712impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
713where
714 Data: GraphDataType<VectorIdType = u32>,
715 VP: VertexProvider<Data>,
716{
717 async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
718 Ok(0..self.provider.num_points as u32)
719 }
720}
721
722impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
723where
724 Data: GraphDataType<VectorIdType = u32>,
725 VP: VertexProvider<Data>,
726{
727 type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
728 fn delegate_neighbor(&'a mut self) -> Self::Delegate {
729 AsNeighborAccessor(self)
730 }
731}
732
733pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
739where
740 Data: GraphDataType<VectorIdType = u32>,
741 VP: VertexProvider<Data>;
742
743impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
744where
745 Data: GraphDataType<VectorIdType = u32>,
746 VP: VertexProvider<Data>,
747{
748 type Id = u32;
749}
750
751impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
752where
753 Data: GraphDataType<VectorIdType = u32>,
754 VP: VertexProvider<Data>,
755{
756 fn get_neighbors(
757 self,
758 id: Self::Id,
759 neighbors: &mut AdjacencyList<Self::Id>,
760 ) -> impl Future<Output = ANNResult<Self>> + Send {
761 if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
762 return future::ok(self); }
764
765 if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
766 return future::err(e);
767 }
768 let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
769 Ok(list) => list,
770 Err(e) => return future::err(e),
771 };
772 neighbors.overwrite_trusted(list);
773 future::ok(self)
774 }
775}
776
777pub struct DiskIndexSearcher<
781 Data,
782 ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
783> where
784 Data: GraphDataType<VectorIdType = u32>,
785 ProviderFactory: VertexProviderFactory<Data>,
786{
787 index: DiskANNIndex<DiskProvider<Data>>,
788 runtime: Runtime,
789
790 vertex_provider_factory: ProviderFactory,
792
793 scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
795}
796
797#[derive(Debug)]
798pub struct SearchResultStats {
799 pub cmps: u32,
800 pub result_count: u32,
801 pub query_statistics: QueryStatistics,
802}
803
804pub struct SearchResult<AssociatedData> {
809 pub results: Vec<SearchResultItem<AssociatedData>>,
811 pub stats: SearchResultStats,
812}
813
814pub struct SearchResultItem<AssociatedData> {
819 pub vertex_id: u32,
821 pub data: AssociatedData,
824 pub distance: f32,
826}
827
828impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
829where
830 Data: GraphDataType<VectorIdType = u32>,
831 ProviderFactory: VertexProviderFactory<Data>,
832{
833 pub fn new(
843 num_threads: usize,
844 search_io_limit: usize,
845 disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
846 vertex_provider_factory: ProviderFactory,
847 metric: Metric,
848 runtime: Option<Runtime>,
849 ) -> ANNResult<Self> {
850 let runtime = match runtime {
851 Some(rt) => rt,
852 None => tokio::runtime::Builder::new_current_thread().build()?,
853 };
854
855 let graph_header = vertex_provider_factory.get_header()?;
856 let metadata = graph_header.metadata();
857 let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
858
859 let config = graph::config::Builder::new(
860 max_degree.into_usize(),
861 graph::config::MaxDegree::default_slack(),
862 1, metric.into(),
864 )
865 .build()?;
866
867 debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
868
869 let graph_header = vertex_provider_factory.get_header()?;
870 let pq_data = disk_index_reader.get_pq_data();
871 let scratch_pool_args = DiskSearchScratchArgs {
872 graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
873 dim: graph_header.metadata().dims,
874 num_pq_chunks: pq_data.get_num_chunks(),
875 num_pq_centers: pq_data.get_num_centers(),
876 vertex_factory: &vertex_provider_factory,
877 graph_header: &graph_header,
878 };
879 let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
880
881 let disk_provider = DiskProvider::new(
882 disk_index_reader,
883 graph_header,
884 metric,
885 metadata.num_pts.into_usize(),
886 search_io_limit,
887 )?;
888
889 let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
890 Ok(Self {
891 index,
892 runtime,
893 vertex_provider_factory,
894 scratch_pool,
895 })
896 }
897
898 fn search_strategy<'a>(
900 &'a self,
901 query: &'a [Data::VectorDataType],
902 vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
903 ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
904 DiskSearchStrategy {
905 io_tracker: IOTracker::default(),
906 vector_filter,
907 query,
908 vertex_provider_factory: &self.vertex_provider_factory,
909 scratch_pool: &self.scratch_pool,
910 }
911 }
912
913 pub fn search(
916 &self,
917 query: &[Data::VectorDataType],
918 return_list_size: u32,
919 search_list_size: u32,
920 beam_width: Option<usize>,
921 vector_filter: Option<VectorFilter<Data>>,
922 is_flat_search: bool,
923 ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
924 let mut query_stats = QueryStatistics::default();
925 let mut indices = vec![0u32; return_list_size as usize];
926 let mut distances = vec![0f32; return_list_size as usize];
927 let mut associated_data =
928 vec![Data::AssociatedDataType::default(); return_list_size as usize];
929
930 let stats = self.search_internal(
931 query,
932 return_list_size as usize,
933 search_list_size,
934 beam_width,
935 &mut query_stats,
936 &mut indices,
937 &mut distances,
938 &mut associated_data,
939 &vector_filter.unwrap_or(default_vector_filter::<Data>()),
940 is_flat_search,
941 )?;
942
943 let mut search_result = SearchResult {
944 results: Vec::with_capacity(return_list_size as usize),
945 stats,
946 };
947
948 for ((vertex_id, distance), associated_data) in indices
949 .into_iter()
950 .zip(distances.into_iter())
951 .zip(associated_data.into_iter())
952 {
953 search_result.results.push(SearchResultItem {
954 vertex_id,
955 distance,
956 data: associated_data,
957 });
958 }
959
960 Ok(search_result)
961 }
962
963 #[allow(clippy::too_many_arguments)]
966 pub(crate) fn search_internal(
967 &self,
968 query: &[Data::VectorDataType],
969 k_value: usize,
970 search_list_size: u32,
971 beam_width: Option<usize>,
972 query_stats: &mut QueryStatistics,
973 indices: &mut [u32],
974 distances: &mut [f32],
975 associated_data: &mut [Data::AssociatedDataType],
976 vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
977 is_flat_search: bool,
978 ) -> ANNResult<SearchResultStats> {
979 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
980 &mut indices[..k_value],
981 &mut distances[..k_value],
982 &mut associated_data[..k_value],
983 );
984
985 let strategy = self.search_strategy(query, vector_filter);
986 let timer = Instant::now();
987 let k = k_value;
988 let l = search_list_size as usize;
989 let stats = if is_flat_search {
990 self.runtime.block_on(self.index.flat_search(
991 &strategy,
992 &DefaultContext,
993 strategy.query,
994 vector_filter,
995 &Knn::new(k, l, beam_width)?,
996 &mut result_output_buffer,
997 ))?
998 } else {
999 let knn_search = Knn::new(k, l, beam_width)?;
1000 self.runtime.block_on(self.index.search(
1001 knn_search,
1002 &strategy,
1003 &DefaultContext,
1004 strategy.query,
1005 &mut result_output_buffer,
1006 ))?
1007 };
1008 query_stats.total_comparisons = stats.cmps;
1009 query_stats.search_hops = stats.hops;
1010
1011 query_stats.total_execution_time_us = timer.elapsed().as_micros();
1012 query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1013 query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1014 query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1015 query_stats.query_pq_preprocess_time_us =
1016 IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1017 query_stats.cpu_time_us = query_stats.total_execution_time_us
1018 - query_stats.io_time_us
1019 - query_stats.query_pq_preprocess_time_us;
1020 Ok(SearchResultStats {
1021 cmps: query_stats.total_comparisons,
1022 result_count: stats.result_count,
1023 query_statistics: query_stats.clone(),
1024 })
1025 }
1026}
1027
1028fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1034 vertex_provider: &mut V,
1035 ids: &[Data::VectorIdType],
1036) -> ANNResult<()> {
1037 vertex_provider.load_vertices(ids)?;
1038 for (idx, id) in ids.iter().enumerate() {
1039 vertex_provider.process_loaded_node(id, idx)?;
1040 }
1041 Ok(())
1042}
1043
1044#[cfg(test)]
1045mod disk_provider_tests {
1046 use diskann::{
1047 graph::{
1048 search::{record::VisitedSearchRecord, Knn},
1049 KnnSearchError,
1050 },
1051 utils::IntoUsize,
1052 ANNErrorKind,
1053 };
1054 use diskann_providers::storage::{
1055 DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1056 };
1057 use diskann_providers::{
1058 common::AlignedBoxWithSlice,
1059 test_utils::graph_data_type_utils::{
1060 GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1061 },
1062 utils::{create_thread_pool, load_aligned_bin, PQPathNames, ParallelIteratorInPool},
1063 };
1064 use diskann_utils::{io::read_bin, test_data_root};
1065 use diskann_vector::distance::Metric;
1066 use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1067 use rstest::rstest;
1068 use vfs::OverlayFS;
1069
1070 use super::*;
1071 use crate::{
1072 build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1073 utils::{QueryStatistics, VirtualAlignedReaderFactory},
1074 };
1075
1076 const TEST_INDEX_PREFIX_128DIM: &str =
1077 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1078 const TEST_INDEX_128DIM: &str =
1079 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1080 const TEST_PQ_PIVOT_128DIM: &str =
1081 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1082 const TEST_PQ_COMPRESSED_128DIM: &str =
1083 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1084 const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1085 "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1086 const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1087
1088 const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1089 const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1090 const TEST_PQ_PIVOT_100DIM: &str =
1091 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1092 const TEST_PQ_COMPRESSED_100DIM: &str =
1093 "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1094 const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1095 "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1096 const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1097 const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1098 const TEST_INDEX: &str =
1099 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1100 const TEST_INDEX_PREFIX: &str =
1101 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1102 const TEST_PQ_PIVOT: &str =
1103 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1104 const TEST_PQ_COMPRESSED: &str =
1105 "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1106
1107 #[test]
1108 fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1109 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1110
1111 let search_engine = create_disk_index_searcher(
1112 CreateDiskIndexSearcherParams {
1113 max_thread_num: 5,
1114 pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1115 pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1116 index_path: TEST_INDEX_100DIM,
1117 index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1118 ..Default::default()
1119 },
1120 &storage_provider,
1121 );
1122 test_disk_search(TestDiskSearchParams {
1124 storage_provider: storage_provider.as_ref(),
1125 index_search_engine: &search_engine,
1126 thread_num: 1,
1127 query_file_path: TEST_QUERY_10PTS_100DIM,
1128 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1129 k: 10,
1130 l: 20,
1131 dim: 104,
1132 });
1133 test_disk_search(TestDiskSearchParams {
1135 storage_provider: storage_provider.as_ref(),
1136 index_search_engine: &search_engine,
1137 thread_num: 5,
1138 query_file_path: TEST_QUERY_10PTS_100DIM,
1139 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1140 k: 10,
1141 l: 20,
1142 dim: 104,
1143 });
1144 }
1145
1146 #[test]
1147 fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1148 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1149
1150 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1151 CreateDiskIndexSearcherParams {
1152 max_thread_num: 5,
1153 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1154 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1155 index_path: TEST_INDEX_128DIM,
1156 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1157 ..Default::default()
1158 },
1159 &storage_provider,
1160 );
1161 test_disk_search(TestDiskSearchParams {
1163 storage_provider: storage_provider.as_ref(),
1164 index_search_engine: &search_engine,
1165 thread_num: 1,
1166 query_file_path: TEST_QUERY_10PTS_128DIM,
1167 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1168 k: 10,
1169 l: 20,
1170 dim: 128,
1171 });
1172 test_disk_search(TestDiskSearchParams {
1174 storage_provider: storage_provider.as_ref(),
1175 index_search_engine: &search_engine,
1176 thread_num: 5,
1177 query_file_path: TEST_QUERY_10PTS_128DIM,
1178 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1179 k: 10,
1180 l: 20,
1181 dim: 128,
1182 });
1183 }
1184
1185 fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1186 storage_provider: &StorageReader,
1187 ) -> Vec<u32> {
1188 const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1189
1190 let data =
1191 read_bin::<u32>(&mut storage_provider.open_reader(ASSOCIATED_DATA_FILE).unwrap())
1192 .unwrap();
1193 data.into_inner().into_vec()
1194 }
1195
1196 #[test]
1197 fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1198 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1199 let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1200 let params = TestParams {
1201 data_path: TEST_DATA_FILE.to_string(),
1202 index_path_prefix: index_path_prefix.to_string(),
1203 associated_data_path: Some(
1204 "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1205 ),
1206 ..TestParams::default()
1207 };
1208 let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1209 fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1211 {
1212 let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1213 CreateDiskIndexSearcherParams {
1214 max_thread_num: 5,
1215 pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1216 pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1217 .as_str(),
1218 index_path: format!("{}_disk.index", index_path_prefix).as_str(), index_path_prefix,
1220 ..Default::default()
1221 },
1222 &fixture.storage_provider,
1223 );
1224
1225 test_disk_search_with_associated(
1227 TestDiskSearchAssociateParams {
1228 storage_provider: fixture.storage_provider.as_ref(),
1229 index_search_engine: &search_engine,
1230 thread_num: 1,
1231 query_file_path: TEST_QUERY_10PTS_128DIM,
1232 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1233 k: 10,
1234 l: 20,
1235 dim: 128,
1236 },
1237 None,
1238 );
1239
1240 test_disk_search_with_associated(
1242 TestDiskSearchAssociateParams {
1243 storage_provider: fixture.storage_provider.as_ref(),
1244 index_search_engine: &search_engine,
1245 thread_num: 5,
1246 query_file_path: TEST_QUERY_10PTS_128DIM,
1247 truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1248 k: 10,
1249 l: 20,
1250 dim: 128,
1251 },
1252 None,
1253 );
1254 }
1255
1256 fixture
1257 .storage_provider
1258 .delete(&format!("{}_disk.index", index_path_prefix))
1259 .expect("Failed to delete file");
1260 fixture
1261 .storage_provider
1262 .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1263 .expect("Failed to delete file");
1264 fixture
1265 .storage_provider
1266 .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1267 .expect("Failed to delete file");
1268 }
1269
1270 struct CreateDiskIndexSearcherParams<'a> {
1271 max_thread_num: usize,
1272 pq_pivot_file_path: &'a str,
1273 pq_compressed_file_path: &'a str,
1274 index_path: &'a str,
1275 index_path_prefix: &'a str,
1276 io_limit: usize,
1277 }
1278
1279 impl Default for CreateDiskIndexSearcherParams<'_> {
1280 fn default() -> Self {
1281 Self {
1282 max_thread_num: 1,
1283 pq_pivot_file_path: "",
1284 pq_compressed_file_path: "",
1285 index_path: "",
1286 index_path_prefix: "",
1287 io_limit: usize::MAX,
1288 }
1289 }
1290 }
1291
1292 fn create_disk_index_searcher<Data>(
1293 params: CreateDiskIndexSearcherParams,
1294 storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1295 ) -> DiskIndexSearcher<
1296 Data,
1297 DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1298 >
1299 where
1300 Data: GraphDataType<VectorIdType = u32>,
1301 {
1302 assert!(params.io_limit > 0);
1303
1304 let runtime = tokio::runtime::Builder::new_multi_thread()
1305 .worker_threads(params.max_thread_num)
1306 .build()
1307 .unwrap();
1308
1309 let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1310 params.pq_pivot_file_path.to_string(),
1311 params.pq_compressed_file_path.to_string(),
1312 storage_provider.as_ref(),
1313 )
1314 .unwrap();
1315
1316 let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1317 get_disk_index_file(params.index_path_prefix),
1318 Arc::clone(storage_provider),
1319 );
1320 let caching_strategy = CachingStrategy::None;
1321 let vertex_provider_factory =
1322 DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1323 .unwrap();
1324
1325 DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1326 params.max_thread_num,
1327 params.io_limit,
1328 &disk_index_reader,
1329 vertex_provider_factory,
1330 Metric::L2,
1331 Some(runtime),
1332 )
1333 .unwrap()
1334 }
1335
1336 fn load_query_result<StorageReader: StorageReadProvider>(
1337 storage_provider: &StorageReader,
1338 query_result_path: &str,
1339 ) -> Vec<u32> {
1340 let result =
1341 read_bin::<u32>(&mut storage_provider.open_reader(query_result_path).unwrap()).unwrap();
1342 result.into_inner().into_vec()
1343 }
1344
1345 struct TestDiskSearchParams<'a, StorageType> {
1346 storage_provider: &'a StorageType,
1347 index_search_engine: &'a DiskIndexSearcher<
1348 GraphDataF32VectorUnitData,
1349 DiskVertexProviderFactory<
1350 GraphDataF32VectorUnitData,
1351 VirtualAlignedReaderFactory<OverlayFS>,
1352 >,
1353 >,
1354 thread_num: u64,
1355 query_file_path: &'a str,
1356 truth_result_file_path: &'a str,
1357 k: usize,
1358 l: usize,
1359 dim: usize,
1360 }
1361
1362 struct TestDiskSearchAssociateParams<'a, StorageType> {
1363 storage_provider: &'a StorageType,
1364 index_search_engine: &'a DiskIndexSearcher<
1365 GraphDataF32VectorU32Data,
1366 DiskVertexProviderFactory<
1367 GraphDataF32VectorU32Data,
1368 VirtualAlignedReaderFactory<OverlayFS>,
1369 >,
1370 >,
1371 thread_num: u64,
1372 query_file_path: &'a str,
1373 truth_result_file_path: &'a str,
1374 k: usize,
1375 l: usize,
1376 dim: usize,
1377 }
1378
1379 fn test_disk_search<StorageType: StorageReadProvider>(
1380 params: TestDiskSearchParams<StorageType>,
1381 ) {
1382 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1383 .unwrap()
1384 .0;
1385 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1386 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1387
1388 let queries = aligned_query
1389 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1390 .unwrap();
1391
1392 let truth_result =
1393 load_query_result(params.storage_provider, params.truth_result_file_path);
1394
1395 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1396 queries
1398 .par_iter()
1399 .enumerate()
1400 .for_each_in_pool(&pool, |(i, query)| {
1401 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1403 let mut temp = Vec::with_capacity(query.len() + 1);
1404 temp.push(0.0);
1405 temp.extend_from_slice(query);
1406 aligned_box.memcpy(temp.as_slice()).unwrap();
1407 let query = &aligned_box.as_slice()[1..];
1408
1409 let mut query_stats = QueryStatistics::default();
1410 let mut indices = vec![0u32; 10];
1411 let mut distances = vec![0f32; 10];
1412 let mut associated_data = vec![(); 10];
1413
1414 let result = params
1415 .index_search_engine
1416 .search_internal(
1418 query,
1419 params.k,
1420 params.l as u32,
1421 None, &mut query_stats,
1423 &mut indices,
1424 &mut distances,
1425 &mut associated_data,
1426 &(|_| true),
1427 false,
1428 );
1429
1430 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1432
1433 assert!(result.is_ok(), "Expected search to succeed");
1434
1435 let result_unwrapped = result.unwrap();
1436 assert!(
1437 result_unwrapped.query_statistics.total_io_operations > 0,
1438 "Expected IO operations to be greater than 0"
1439 );
1440 assert!(
1441 result_unwrapped.query_statistics.total_vertices_loaded > 0,
1442 "Expected vertices loaded to be greater than 0"
1443 );
1444
1445 assert_eq!(
1447 indices, truth_slice,
1448 "Results DO NOT match with the truth result for query {}",
1449 i
1450 );
1451 });
1452 }
1453
1454 fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1455 params: TestDiskSearchAssociateParams<StorageType>,
1456 beam_width: Option<usize>,
1457 ) {
1458 let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1459 .unwrap()
1460 .0;
1461 let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1462 aligned_query.memcpy(query_vector.as_slice()).unwrap();
1463 let queries = aligned_query
1464 .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1465 .unwrap();
1466 let truth_result =
1467 load_query_result(params.storage_provider, params.truth_result_file_path);
1468 let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1469 queries
1471 .par_iter()
1472 .enumerate()
1473 .for_each_in_pool(&pool, |(i, query)| {
1474 let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1476 let mut temp = Vec::with_capacity(query.len() + 1);
1477 temp.push(0.0);
1478 temp.extend_from_slice(query);
1479 aligned_box.memcpy(temp.as_slice()).unwrap();
1480 let query = &aligned_box.as_slice()[1..];
1481 let result = params
1482 .index_search_engine
1483 .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1484 .unwrap();
1485 let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1486 let associated_data: Vec<u32> =
1487 result.results.iter().map(|item| item.data).collect();
1488 let truth_data = get_truth_associated_data(params.storage_provider);
1489 let associated_data_truth: Vec<u32> = indices
1490 .iter()
1491 .map(|&vid| truth_data[vid as usize])
1492 .collect();
1493 assert_eq!(
1494 associated_data, associated_data_truth,
1495 "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1496 i,associated_data, associated_data_truth
1497 );
1498 let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1499 assert_eq!(
1500 indices, truth_slice,
1501 "Results DO NOT match with the truth result for query {}",
1502 i
1503 );
1504 });
1505 }
1506
1507 #[test]
1508 fn test_disk_search_invalid_input() {
1509 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1510 let ctx = &DefaultContext;
1511
1512 let params = CreateDiskIndexSearcherParams {
1513 max_thread_num: 5,
1514 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1515 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1516 index_path: TEST_INDEX_128DIM,
1517 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1518 ..Default::default()
1519 };
1520
1521 let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1522 assert_eq!(
1523 paths.pivots, params.pq_pivot_file_path,
1524 "pq_pivot_file_path is not correct"
1525 );
1526 assert_eq!(
1527 paths.compressed_data, params.pq_compressed_file_path,
1528 "pq_compressed_file_path is not correct"
1529 );
1530 assert_eq!(
1531 params.index_path,
1532 format!("{}_disk.index", params.index_path_prefix),
1533 "index_path is not correct"
1534 );
1535
1536 let res = Knn::new_default(20, 10);
1538 assert!(res.is_err());
1539 assert_eq!(
1540 <KnnSearchError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1541 ANNErrorKind::IndexError
1542 );
1543 let res = Knn::new(10, 10, Some(0));
1545 assert!(res.is_err());
1546
1547 let search_engine =
1548 create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1549
1550 assert_eq!(
1552 search_engine
1553 .index
1554 .data_provider
1555 .to_external_id(ctx, 0)
1556 .unwrap(),
1557 0
1558 );
1559 assert_eq!(
1560 search_engine
1561 .index
1562 .data_provider
1563 .to_internal_id(ctx, &0)
1564 .unwrap(),
1565 0
1566 );
1567
1568 let provider_max_degree = search_engine
1569 .index
1570 .data_provider
1571 .graph_header
1572 .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1573 .unwrap();
1574 let index_max_degree = search_engine.index.config.pruned_degree().get();
1575 assert_eq!(provider_max_degree, index_max_degree);
1576
1577 let query = vec![0f32; 128];
1578 let mut query_stats = QueryStatistics::default();
1579 let mut indices = vec![0u32; 10];
1580 let mut distances = vec![0f32; 10];
1581 let mut associated_data = vec![0u32; 10];
1582
1583 let result = search_engine.search_internal(
1585 &query,
1586 10,
1587 10 - 1,
1588 None,
1589 &mut query_stats,
1590 &mut indices,
1591 &mut distances,
1592 &mut associated_data,
1593 &|_| true,
1594 false,
1595 );
1596
1597 assert!(result.is_err());
1598 assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1599 }
1600
1601 #[test]
1602 fn test_disk_search_beam_search() {
1603 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1604
1605 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1606 CreateDiskIndexSearcherParams {
1607 max_thread_num: 1,
1608 pq_pivot_file_path: TEST_PQ_PIVOT,
1609 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1610 index_path: TEST_INDEX,
1611 index_path_prefix: TEST_INDEX_PREFIX,
1612 ..Default::default()
1613 },
1614 &storage_provider,
1615 );
1616
1617 let query_vector: [f32; 128] = [1f32; 128];
1618 let mut indices = vec![0u32; 10];
1619 let mut distances = vec![0f32; 10];
1620 let mut associated_data = vec![(); 10];
1621
1622 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1623 &mut indices,
1624 &mut distances,
1625 &mut associated_data,
1626 );
1627 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1628 let mut search_record = VisitedSearchRecord::new(0);
1629 let search_params = Knn::new(10, 10, Some(4)).unwrap();
1630 let recorded_search =
1631 diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
1632 search_engine
1633 .runtime
1634 .block_on(search_engine.index.search(
1635 recorded_search,
1636 &strategy,
1637 &DefaultContext,
1638 query_vector.as_slice(),
1639 &mut result_output_buffer,
1640 ))
1641 .unwrap();
1642
1643 let ids = search_record
1644 .visited
1645 .iter()
1646 .map(|n| n.id)
1647 .collect::<Vec<_>>();
1648
1649 const EXPECTED_NODES: [u32; 18] = [
1650 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1651 ]; assert_eq!(ids, &EXPECTED_NODES);
1654
1655 let return_list_size = 10;
1656 let search_list_size = 10;
1657 let result = search_engine.search(
1658 &query_vector,
1659 return_list_size,
1660 search_list_size,
1661 Some(4),
1662 None,
1663 false,
1664 );
1665 assert!(result.is_ok(), "Expected search to succeed");
1666 let search_result = result.unwrap();
1667 assert_eq!(
1668 search_result.results.len() as u32,
1669 return_list_size,
1670 "Expected result count to match"
1671 );
1672 assert_eq!(
1673 indices,
1674 vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1675 "Expected indices to match"
1676 );
1677 }
1678
1679 #[cfg(feature = "experimental_diversity_search")]
1680 #[test]
1681 fn test_disk_search_diversity_search() {
1682 use diskann::graph::DiverseSearchParams;
1683 use diskann::neighbor::AttributeValueProvider;
1684 use std::collections::HashMap;
1685
1686 #[derive(Debug, Clone)]
1688 struct TestAttributeProvider {
1689 attributes: HashMap<u32, u32>,
1690 }
1691 impl TestAttributeProvider {
1692 fn new() -> Self {
1693 Self {
1694 attributes: HashMap::new(),
1695 }
1696 }
1697 fn insert(&mut self, id: u32, attribute: u32) {
1698 self.attributes.insert(id, attribute);
1699 }
1700 }
1701 impl diskann::provider::HasId for TestAttributeProvider {
1702 type Id = u32;
1703 }
1704
1705 impl AttributeValueProvider for TestAttributeProvider {
1706 type Value = u32;
1707
1708 fn get(&self, id: Self::Id) -> Option<Self::Value> {
1709 self.attributes.get(&id).copied()
1710 }
1711 }
1712
1713 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1714
1715 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1716 CreateDiskIndexSearcherParams {
1717 max_thread_num: 1,
1718 pq_pivot_file_path: TEST_PQ_PIVOT,
1719 pq_compressed_file_path: TEST_PQ_COMPRESSED,
1720 index_path: TEST_INDEX,
1721 index_path_prefix: TEST_INDEX_PREFIX,
1722 ..Default::default()
1723 },
1724 &storage_provider,
1725 );
1726
1727 let query_vector: [f32; 128] = [1f32; 128];
1728
1729 let mut attribute_provider = TestAttributeProvider::new();
1731 let num_vectors = 256; for i in 0..num_vectors {
1733 let label = (i % 15) + 1;
1735 attribute_provider.insert(i, label);
1736 }
1737 let attribute_provider = std::sync::Arc::new(attribute_provider);
1739
1740 let mut indices = vec![0u32; 10];
1741 let mut distances = vec![0f32; 10];
1742 let mut associated_data = vec![(); 10];
1743
1744 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1745 &mut indices,
1746 &mut distances,
1747 &mut associated_data,
1748 );
1749 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1750
1751 let diverse_params = DiverseSearchParams::new(
1753 0, 3, attribute_provider.clone(),
1756 );
1757
1758 let search_params = Knn::new(10, 20, None).unwrap();
1759
1760 let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params);
1761 let stats = search_engine
1762 .runtime
1763 .block_on(search_engine.index.search(
1764 diverse_search,
1765 &strategy,
1766 &DefaultContext,
1767 query_vector.as_slice(),
1768 &mut result_output_buffer,
1769 ))
1770 .unwrap();
1771
1772 assert!(
1774 stats.result_count > 0,
1775 "Expected to get some results during diversity search"
1776 );
1777
1778 let return_list_size = 10;
1779 let search_list_size = 20;
1780 let diverse_results_k = 1;
1781 let diverse_params = DiverseSearchParams::new(
1782 0, diverse_results_k,
1784 attribute_provider.clone(),
1785 );
1786
1787 let mut indices2 = vec![0u32; return_list_size as usize];
1789 let mut distances2 = vec![0f32; return_list_size as usize];
1790 let mut associated_data2 = vec![(); return_list_size as usize];
1791 let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1792 &mut indices2,
1793 &mut distances2,
1794 &mut associated_data2,
1795 );
1796 let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1797 let search_params2 =
1798 Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1799
1800 let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params);
1801 let stats = search_engine
1802 .runtime
1803 .block_on(search_engine.index.search(
1804 diverse_search2,
1805 &strategy2,
1806 &DefaultContext,
1807 query_vector.as_slice(),
1808 &mut result_output_buffer2,
1809 ))
1810 .unwrap();
1811
1812 assert!(
1814 stats.result_count > 0,
1815 "Expected diversity search to return results"
1816 );
1817 assert!(
1818 stats.result_count <= return_list_size,
1819 "Expected result count to be <= {}",
1820 return_list_size
1821 );
1822
1823 assert!(
1825 stats.result_count > 0,
1826 "Expected to get some search results"
1827 );
1828
1829 println!("\n=== Diversity Search Results ===");
1831 println!("Query: [1f32; 128]");
1832 println!("diverse_results_k: {}", diverse_results_k);
1833 println!("Total results: {}\n", stats.result_count);
1834 println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1835 println!("{}", "-".repeat(35));
1836 for i in 0..stats.result_count as usize {
1837 let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1838 println!(
1839 "{:<10} {:<15.2} {:<10}",
1840 indices2[i], distances2[i], attribute_value
1841 );
1842 }
1843
1844 for i in 0..(stats.result_count as usize).saturating_sub(1) {
1846 assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1847 assert!(
1848 distances2[i] <= distances2[i + 1],
1849 "Expected distances to be sorted in ascending order"
1850 );
1851 }
1852
1853 let mut attribute_counts = HashMap::new();
1855 for item in indices2.iter().take(stats.result_count as usize) {
1856 if let Some(attribute_value) = attribute_provider.get(*item) {
1857 *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1858 }
1859 }
1860
1861 println!("\n=== Attribute Distribution ===");
1863 let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1864 sorted_attrs.sort_by_key(|(k, _)| *k);
1865 for (attribute_value, count) in &sorted_attrs {
1866 println!(
1867 "Label {}: {} occurrences (max allowed: {})",
1868 attribute_value, count, diverse_results_k
1869 );
1870 }
1871 println!("Total unique labels: {}", attribute_counts.len());
1872 println!("================================\n");
1873
1874 for (attribute_value, count) in &attribute_counts {
1876 println!(
1877 "Assert: Label {} has {} occurrences (max: {})",
1878 attribute_value, count, diverse_results_k
1879 );
1880 assert!(
1881 *count <= diverse_results_k,
1882 "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1883 attribute_value,
1884 count,
1885 diverse_results_k
1886 );
1887 }
1888
1889 println!(
1892 "Assert: Found {} unique labels (expected at least 2)",
1893 attribute_counts.len()
1894 );
1895 assert!(
1896 attribute_counts.len() >= 2,
1897 "Expected at least 2 different attribute values for diversity, got {}",
1898 attribute_counts.len()
1899 );
1900 }
1901
1902 #[rstest]
1903 #[case(
1905 |_id: &u32| true,
1906 false,
1907 10,
1908 vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1909 vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1910 )]
1911 #[case(
1914 |id: &u32| *id == 0 || *id == 1,
1915 false,
1916 0,
1917 vec![0; 10],
1918 vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1919 )]
1920 #[case(
1923 |id: &u32| *id == 0 || *id == 1,
1924 true,
1925 2,
1926 vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1927 vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1928 )]
1929 #[case(
1932 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1933 false,
1934 3,
1935 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1936 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1937 )]
1938 #[case(
1941 |id: &u32| *id == 72 || *id == 87 || *id == 170,
1942 true,
1943 3,
1944 vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1945 vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1946 )]
1947 fn test_search_with_vector_filter(
1948 #[case] vector_filter: fn(&u32) -> bool,
1949 #[case] is_flat_search: bool,
1950 #[case] expected_result_count: u32,
1951 #[case] expected_indices: Vec<u32>,
1952 #[case] expected_distances: Vec<f32>,
1953 ) {
1954 let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1959 const ABS_TOLERANCE: f32 = 0.02;
1960 assert_eq!(got.len(), expected.len());
1961 for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1962 if (g - e).abs() > ABS_TOLERANCE {
1963 panic!(
1964 "distances differ at position {} by more than {}\n\n\
1965 got: {:?}\nexpected: {:?}",
1966 i, ABS_TOLERANCE, got, expected,
1967 );
1968 }
1969 }
1970 true
1971 };
1972
1973 let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1974
1975 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1976 CreateDiskIndexSearcherParams {
1977 max_thread_num: 5,
1978 pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1979 pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1980 index_path: TEST_INDEX_128DIM,
1981 index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1982 ..Default::default()
1983 },
1984 &storage_provider,
1985 );
1986 let query = vec![0.1f32; 128];
1987 let mut query_stats = QueryStatistics::default();
1988 let mut indices = vec![0u32; 10];
1989 let mut distances = vec![0f32; 10];
1990 let mut associated_data = vec![(); 10];
1991
1992 let result = search_engine.search_internal(
1993 &query,
1994 10,
1995 10,
1996 None, &mut query_stats,
1998 &mut indices,
1999 &mut distances,
2000 &mut associated_data,
2001 &vector_filter,
2002 is_flat_search,
2003 );
2004
2005 assert!(result.is_ok(), "Expected search to succeed");
2006 assert_eq!(
2007 result.unwrap().result_count,
2008 expected_result_count,
2009 "Expected result count to match"
2010 );
2011 assert_eq!(indices, expected_indices, "Expected indices to match");
2012 assert!(
2013 check_distances(&distances, &expected_distances),
2014 "Expected distances to match"
2015 );
2016
2017 let result_with_filter = search_engine.search(
2018 &query,
2019 10,
2020 10,
2021 None, Some(Box::new(vector_filter)),
2023 is_flat_search,
2024 );
2025
2026 assert!(result_with_filter.is_ok(), "Expected search to succeed");
2027 let result_with_filter_unwrapped = result_with_filter.unwrap();
2028 assert_eq!(
2029 result_with_filter_unwrapped.stats.result_count, expected_result_count,
2030 "Expected result count to match"
2031 );
2032 let actual_indices = result_with_filter_unwrapped
2033 .results
2034 .iter()
2035 .map(|x| x.vertex_id)
2036 .collect::<Vec<_>>();
2037 assert_eq!(
2038 actual_indices, expected_indices,
2039 "Expected indices to match"
2040 );
2041 let actual_distances = result_with_filter_unwrapped
2042 .results
2043 .iter()
2044 .map(|x| x.distance)
2045 .collect::<Vec<_>>();
2046 assert!(
2047 check_distances(&actual_distances, &expected_distances),
2048 "Expected distances to match"
2049 );
2050 }
2051
2052 #[test]
2053 fn test_beam_search_respects_io_limit() {
2054 let io_limit = 11; let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2056
2057 let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2058 CreateDiskIndexSearcherParams {
2059 max_thread_num: 1,
2060 pq_pivot_file_path: TEST_PQ_PIVOT,
2061 pq_compressed_file_path: TEST_PQ_COMPRESSED,
2062 index_path: TEST_INDEX,
2063 index_path_prefix: TEST_INDEX_PREFIX,
2064 io_limit,
2065 },
2066 &storage_provider,
2067 );
2068 let query_vector: [f32; 128] = [1f32; 128];
2069
2070 let mut indices = vec![0u32; 10];
2071 let mut distances = vec![0f32; 10];
2072 let mut associated_data = vec![(); 10];
2073
2074 let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2075 &mut indices,
2076 &mut distances,
2077 &mut associated_data,
2078 );
2079
2080 let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2081
2082 let mut search_record = VisitedSearchRecord::new(0);
2083 let search_params = Knn::new(10, 10, Some(4)).unwrap();
2084 let recorded_search =
2085 diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
2086 search_engine
2087 .runtime
2088 .block_on(search_engine.index.search(
2089 recorded_search,
2090 &strategy,
2091 &DefaultContext,
2092 query_vector.as_slice(),
2093 &mut result_output_buffer,
2094 ))
2095 .unwrap();
2096 let visited_ids = search_record
2097 .visited
2098 .iter()
2099 .map(|n| n.id)
2100 .collect::<Vec<_>>();
2101
2102 let query_stats = strategy.io_tracker;
2103 assert!(
2105 query_stats.io_count() <= io_limit,
2106 "Expected IO operations to be <= {}, but got {}",
2107 io_limit,
2108 query_stats.io_count()
2109 );
2110
2111 const EXPECTED_NODES: [u32; 17] = [
2112 72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2113 ]; let mut matching_count = 0;
2117 for expected_node in EXPECTED_NODES.iter() {
2118 if visited_ids.contains(expected_node) {
2119 matching_count += 1;
2120 }
2121 }
2122
2123 let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2125
2126 assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2129 }
2130}