Skip to main content

diskann_disk/search/provider/
disk_provider.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    collections::HashMap,
8    future::Future,
9    num::NonZeroUsize,
10    ops::Range,
11    sync::{
12        atomic::{AtomicU64, AtomicUsize},
13        Arc,
14    },
15    time::Instant,
16};
17
18use diskann::{
19    graph::{
20        self,
21        glue::{
22            self, DefaultPostProcessor, ExpandBeam, IdIterator, SearchExt, SearchPostProcess,
23            SearchStrategy,
24        },
25        search::Knn,
26        search_output_buffer, AdjacencyList, DiskANNIndex,
27    },
28    neighbor::Neighbor,
29    provider::{
30        Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
31        NeighborAccessor, NoopGuard,
32    },
33    utils::{
34        object_pool::{ObjectPool, PoolOption, TryAsPooled},
35        IntoUsize, VectorRepr,
36    },
37    ANNError, ANNResult,
38};
39use diskann_providers::storage::StorageReadProvider;
40use diskann_providers::{
41    model::{
42        compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
43        pq::quantizer_preprocess, PQData, PQScratch,
44    },
45    storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
46};
47use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
48use futures_util::future;
49use tokio::runtime::Runtime;
50use tracing::debug;
51
52use crate::{
53    data_model::{CachingStrategy, GraphHeader},
54    filter_parameter::{default_vector_filter, VectorFilter},
55    search::{
56        provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
57        traits::{VertexProvider, VertexProviderFactory},
58    },
59    storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
60    utils::AlignedFileReaderFactory,
61    utils::QueryStatistics,
62};
63
64///////////////////
65// Disk Provider //
66///////////////////
67
68/// The DiskProvider is a data provider that loads data from disk using the disk readers
69/// The data format for disk is different from that of the in-memory providers.
70/// The disk format stores both the vectors and the adjacency list next to each other for
71/// better locality for quicker access.
72/// Please refer to the RFC documentation at [`docs\rfcs\cy2025\disk_provider_for_async_index.md`] for design details.
73pub struct DiskProvider<Data>
74where
75    Data: GraphDataType<VectorIdType = u32>,
76{
77    /// Holds the graph header information that contains metadata about disk-index file.
78    graph_header: GraphHeader,
79
80    // Full precision distance comparer used in post_process to reorder results.
81    distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
82
83    /// The PQ data used for quantization.
84    pq_data: Arc<PQData>,
85
86    /// The number of points in the graph.
87    num_points: usize,
88
89    /// Metric used for distance computation.
90    metric: Metric,
91
92    /// The number of IO operations that can be done in parallel.
93    search_io_limit: usize,
94}
95
96impl<Data> DataProvider for DiskProvider<Data>
97where
98    Data: GraphDataType<VectorIdType = u32>,
99{
100    type Context = DefaultContext;
101
102    type InternalId = u32;
103
104    type ExternalId = u32;
105
106    type Guard = NoopGuard<u32>;
107
108    type Error = ANNError;
109
110    /// Translate an external id to its corresponding internal id.
111    fn to_internal_id(
112        &self,
113        _context: &DefaultContext,
114        gid: &Self::ExternalId,
115    ) -> Result<Self::InternalId, Self::Error> {
116        Ok(*gid)
117    }
118
119    /// Translate an internal id its corresponding external id.
120    fn to_external_id(
121        &self,
122        _context: &DefaultContext,
123        id: Self::InternalId,
124    ) -> Result<Self::ExternalId, Self::Error> {
125        Ok(id)
126    }
127}
128
129impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
130where
131    Data: GraphDataType<VectorIdType = u32>,
132{
133    type Error = ANNError;
134
135    async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
136    where
137        P: StorageReadProvider,
138    {
139        debug!(
140            "DiskProvider::load_with() called with file: {:?}",
141            get_disk_index_file(ctx.quant_load_context.metadata.prefix())
142        );
143
144        let graph_header = {
145            let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
146                ctx.quant_load_context.metadata.prefix(),
147            ));
148
149            let caching_strategy = if ctx.num_nodes_to_cache > 0 {
150                CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
151            } else {
152                CachingStrategy::None
153            };
154
155            let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
156                aligned_reader_factory,
157                caching_strategy,
158            )?;
159            VertexProviderFactory::get_header(&vertex_provider_factory)?
160        };
161
162        let metric = ctx.quant_load_context.metric;
163        let num_points = ctx.num_points;
164
165        let index_path_prefix = ctx.quant_load_context.metadata.prefix();
166        let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
167            get_pq_pivot_file(index_path_prefix),
168            get_compressed_pq_file(index_path_prefix),
169            provider,
170        )?;
171
172        Self::new(
173            &index_reader,
174            graph_header,
175            metric,
176            num_points,
177            ctx.search_io_limit,
178        )
179    }
180}
181
182impl<Data> DiskProvider<Data>
183where
184    Data: GraphDataType<VectorIdType = u32>,
185{
186    fn new(
187        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
188        graph_header: GraphHeader,
189        metric: Metric,
190        num_points: usize,
191        search_io_limit: usize,
192    ) -> ANNResult<Self> {
193        let distance_comparer =
194            Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
195
196        let pq_data = disk_index_reader.get_pq_data();
197
198        Ok(Self {
199            graph_header,
200            distance_comparer,
201            pq_data,
202            num_points,
203            metric,
204            search_io_limit,
205        })
206    }
207}
208
209/// The search strategy for the disk provider. This is used to create the search accessor
210/// for use in search in quant space and post_process function to reorder with full precision vectors.
211///
212/// # Why vertex_provider_factory and scratch_pool are here instead of DiskProvider
213///
214/// The DataProvider trait requires 'static bounds for multi-threaded async contexts,
215/// but vertex_provider_factory may have non-'static lifetime bounds (e.g., borrowing
216/// from local data structures). Moving these components to the search strategy allows
217/// DiskProvider to satisfy 'static constraints while enabling flexible per-search
218/// resource management.
219pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
220where
221    Data: GraphDataType<VectorIdType = u32>,
222    ProviderFactory: VertexProviderFactory<Data>,
223{
224    // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads.
225    io_tracker: IOTracker,
226    vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), // Fn param is u32 as we validate "VectorIdType = u32" everywhere in this provider in trait bounds.
227    query: &'a [Data::VectorDataType],
228
229    /// The vertex provider factory is used to create the vertex provider for each search instance.
230    vertex_provider_factory: &'a ProviderFactory,
231
232    /// Scratch pool for disk search operations that need allocations.
233    scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
234}
235
236// Struct to track IO. This is used by single thread, but needs to be Atomic as the Strategy has "Send" trait bound.
237// There should be minimal to no overhead compared to using a raw reference.
238struct IOTracker {
239    io_time_us: AtomicU64,
240    preprocess_time_us: AtomicU64,
241    io_count: AtomicUsize,
242}
243
244impl Default for IOTracker {
245    fn default() -> Self {
246        Self {
247            io_time_us: AtomicU64::new(0),
248            preprocess_time_us: AtomicU64::new(0),
249            io_count: AtomicUsize::new(0),
250        }
251    }
252}
253
254impl IOTracker {
255    fn add_time(category: &AtomicU64, time: u64) {
256        category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
257    }
258
259    fn time(category: &AtomicU64) -> u64 {
260        category.load(std::sync::atomic::Ordering::Relaxed)
261    }
262
263    fn add_io_count(&self, count: usize) {
264        self.io_count
265            .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
266    }
267
268    fn io_count(&self) -> usize {
269        self.io_count.load(std::sync::atomic::Ordering::Relaxed)
270    }
271}
272
273#[derive(Clone, Copy)]
274pub struct RerankAndFilter<'a> {
275    filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
276}
277
278impl<'a> RerankAndFilter<'a> {
279    fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
280        Self { filter }
281    }
282}
283
284impl<Data, VP>
285    SearchPostProcess<
286        DiskAccessor<'_, Data, VP>,
287        &[Data::VectorDataType],
288        (
289            <DiskProvider<Data> as DataProvider>::InternalId,
290            Data::AssociatedDataType,
291        ),
292    > for RerankAndFilter<'_>
293where
294    Data: GraphDataType<VectorIdType = u32>,
295    VP: VertexProvider<Data>,
296{
297    type Error = ANNError;
298    async fn post_process<I, B>(
299        &self,
300        accessor: &mut DiskAccessor<'_, Data, VP>,
301        query: &[Data::VectorDataType],
302        _computer: &DiskQueryComputer,
303        candidates: I,
304        output: &mut B,
305    ) -> Result<usize, Self::Error>
306    where
307        I: Iterator<Item = Neighbor<u32>> + Send,
308        B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)>
309            + Send
310            + ?Sized,
311    {
312        let provider = accessor.provider;
313
314        let mut uncached_ids = Vec::new();
315        let mut reranked = candidates
316            .map(|n| n.id)
317            .filter(|id| (self.filter)(id))
318            .filter_map(|n| {
319                if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
320                    Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
321                } else {
322                    uncached_ids.push(n);
323                    None
324                }
325            })
326            .collect::<Result<Vec<_>, _>>()?;
327        if !uncached_ids.is_empty() {
328            ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
329            for n in &uncached_ids {
330                let v = accessor.scratch.vertex_provider.get_vector(n)?;
331                let d = provider.distance_comparer.evaluate_similarity(query, v);
332                let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
333                reranked.push(((*n, *a), d));
334            }
335        }
336
337        // Sort the full precision distances.
338        reranked
339            .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
340        // Store the reranked results.
341        Ok(output.extend(reranked))
342    }
343}
344
345impl<'this, Data, ProviderFactory> SearchStrategy<DiskProvider<Data>, &[Data::VectorDataType]>
346    for DiskSearchStrategy<'this, Data, ProviderFactory>
347where
348    Data: GraphDataType<VectorIdType = u32>,
349    ProviderFactory: VertexProviderFactory<Data>,
350{
351    type QueryComputer = DiskQueryComputer;
352    type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
353    type SearchAccessorError = ANNError;
354
355    fn search_accessor<'a>(
356        &'a self,
357        provider: &'a DiskProvider<Data>,
358        _context: &DefaultContext,
359    ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
360        DiskAccessor::new(
361            provider,
362            &self.io_tracker,
363            self.query,
364            self.vertex_provider_factory,
365            self.scratch_pool,
366        )
367    }
368}
369
370impl<'this, Data, ProviderFactory>
371    DefaultPostProcessor<
372        DiskProvider<Data>,
373        &[Data::VectorDataType],
374        (
375            <DiskProvider<Data> as DataProvider>::InternalId,
376            Data::AssociatedDataType,
377        ),
378    > for DiskSearchStrategy<'this, Data, ProviderFactory>
379where
380    Data: GraphDataType<VectorIdType = u32>,
381    ProviderFactory: VertexProviderFactory<Data>,
382{
383    type Processor = RerankAndFilter<'this>;
384
385    fn default_post_processor(&self) -> Self::Processor {
386        RerankAndFilter::new(self.vector_filter)
387    }
388}
389
390/// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates.
391pub struct DiskQueryComputer {
392    num_pq_chunks: usize,
393    query_centroid_l2_distance: Vec<f32>,
394}
395
396impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
397    fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
398        let mut dist = 0.0f32;
399        #[allow(clippy::expect_used)]
400        compute_pq_distance_for_pq_coordinates(
401            changing,
402            self.num_pq_chunks,
403            &self.query_centroid_l2_distance,
404            std::slice::from_mut(&mut dist),
405        )
406        .expect("PQ distance compute for PQ coordinates is expected to succeed");
407        dist
408    }
409}
410
411impl<Data, VP> BuildQueryComputer<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
412where
413    Data: GraphDataType<VectorIdType = u32>,
414    VP: VertexProvider<Data>,
415{
416    type QueryComputerError = ANNError;
417    type QueryComputer = DiskQueryComputer;
418
419    fn build_query_computer(
420        &self,
421        _from: &[Data::VectorDataType],
422    ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
423        Ok(DiskQueryComputer {
424            num_pq_chunks: self.provider.pq_data.get_num_chunks(),
425            query_centroid_l2_distance: self
426                .scratch
427                .pq_scratch
428                .aligned_pqtable_dist_scratch
429                .as_slice()
430                .to_vec(),
431        })
432    }
433
434    async fn distances_unordered<Itr, F>(
435        &mut self,
436        vec_id_itr: Itr,
437        _computer: &Self::QueryComputer,
438        f: F,
439    ) -> Result<(), Self::GetError>
440    where
441        F: Send + FnMut(f32, Self::Id),
442        Itr: Iterator<Item = Self::Id>,
443    {
444        self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
445    }
446}
447
448impl<Data, VP> ExpandBeam<&[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
449where
450    Data: GraphDataType<VectorIdType = u32>,
451    VP: VertexProvider<Data>,
452{
453    fn expand_beam<Itr, P, F>(
454        &mut self,
455        ids: Itr,
456        _computer: &Self::QueryComputer,
457        mut pred: P,
458        mut f: F,
459    ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
460    where
461        Itr: Iterator<Item = Self::Id> + Send,
462        P: glue::HybridPredicate<Self::Id> + Send + Sync,
463        F: FnMut(f32, Self::Id) + Send,
464    {
465        let result = (|| {
466            let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
467            let load_ids: Box<[_]> = ids.take(io_limit).collect();
468
469            self.ensure_loaded(&load_ids)?;
470            let mut ids = Vec::new();
471            for i in load_ids {
472                ids.clear();
473                ids.extend(
474                    self.scratch
475                        .vertex_provider
476                        .get_adjacency_list(&i)?
477                        .iter()
478                        .copied()
479                        .filter(|id| pred.eval_mut(id)),
480                );
481
482                self.pq_distances(&ids, &mut f)?;
483            }
484
485            Ok(())
486        })();
487
488        std::future::ready(result)
489    }
490}
491
492// Scratch space for disk search operations that need allocations.
493// These allocations are amortized across searches using the scratch pool.
494struct DiskSearchScratch<Data, VP>
495where
496    Data: GraphDataType<VectorIdType = u32>,
497    VP: VertexProvider<Data>,
498{
499    distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
500    pq_scratch: PQScratch,
501    vertex_provider: VP,
502}
503
504#[derive(Clone)]
505struct DiskSearchScratchArgs<'a, ProviderFactory> {
506    graph_degree: usize,
507    dim: usize,
508    num_pq_chunks: usize,
509    num_pq_centers: usize,
510    vertex_factory: &'a ProviderFactory,
511    graph_header: &'a GraphHeader,
512}
513
514impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
515    for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
516where
517    Data: GraphDataType<VectorIdType = u32>,
518    ProviderFactory: VertexProviderFactory<Data>,
519{
520    type Error = ANNError;
521
522    fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
523        let pq_scratch = PQScratch::new(
524            args.graph_degree,
525            args.dim,
526            args.num_pq_chunks,
527            args.num_pq_centers,
528        )?;
529
530        const DEFAULT_BEAM_WIDTH: usize = 0; // Setting as 0 to avoid preallocation of memory.
531        let vertex_provider = args
532            .vertex_factory
533            .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
534
535        Ok(Self {
536            distance_cache: HashMap::new(),
537            pq_scratch,
538            vertex_provider,
539        })
540    }
541
542    fn try_modify(
543        &mut self,
544        _args: &DiskSearchScratchArgs<ProviderFactory>,
545    ) -> Result<(), Self::Error> {
546        self.distance_cache.clear();
547        self.vertex_provider.clear();
548        Ok(())
549    }
550}
551
552pub struct DiskAccessor<'a, Data, VP>
553where
554    Data: GraphDataType<VectorIdType = u32>,
555    VP: VertexProvider<Data>,
556{
557    provider: &'a DiskProvider<Data>,
558    io_tracker: &'a IOTracker,
559    scratch: PoolOption<DiskSearchScratch<Data, VP>>,
560    query: &'a [Data::VectorDataType],
561}
562
563impl<Data, VP> DiskAccessor<'_, Data, VP>
564where
565    Data: GraphDataType<VectorIdType = u32>,
566    VP: VertexProvider<Data>,
567{
568    // Compute the PQ distance between each ID in `ids` and the distance table stored in
569    // `self`, invoking the callback with the results of each computation in order.
570    fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
571    where
572        F: FnMut(f32, u32),
573    {
574        let pq_scratch = &mut self.scratch.pq_scratch;
575        compute_pq_distance(
576            ids,
577            self.provider.pq_data.get_num_chunks(),
578            &pq_scratch.aligned_pqtable_dist_scratch,
579            self.provider.pq_data.pq_compressed_data().get_data(),
580            &mut pq_scratch.aligned_pq_coord_scratch,
581            &mut pq_scratch.aligned_dist_scratch,
582        )?;
583
584        for (i, id) in ids.iter().enumerate() {
585            let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
586            f(distance, *id);
587        }
588
589        Ok(())
590    }
591}
592
593impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
594where
595    Data: GraphDataType<VectorIdType = u32>,
596    VP: VertexProvider<Data>,
597{
598    async fn starting_points(&self) -> ANNResult<Vec<u32>> {
599        let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
600        Ok(vec![start_vertex_id])
601    }
602
603    fn terminate_early(&mut self) -> bool {
604        self.io_tracker.io_count() > self.provider.search_io_limit
605    }
606}
607
608impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
609where
610    Data: GraphDataType<VectorIdType = u32>,
611    VP: VertexProvider<Data>,
612{
613    fn new<VPF>(
614        provider: &'a DiskProvider<Data>,
615        io_tracker: &'a IOTracker,
616        query: &'a [Data::VectorDataType],
617        vertex_provider_factory: &'a VPF,
618        scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
619    ) -> ANNResult<Self>
620    where
621        VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
622    {
623        let mut scratch = PoolOption::try_pooled(
624            scratch_pool,
625            &DiskSearchScratchArgs {
626                graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
627                dim: provider.graph_header.metadata().dims,
628                num_pq_chunks: provider.pq_data.get_num_chunks(),
629                num_pq_centers: provider.pq_data.get_num_centers(),
630                vertex_factory: vertex_provider_factory,
631                graph_header: &provider.graph_header,
632            },
633        )?;
634
635        scratch.pq_scratch.set(
636            provider.graph_header.metadata().dims,
637            query,
638            1.0_f32, // Normalization factor
639        )?;
640        let start_vertex_id = provider.graph_header.metadata().medoid as u32;
641
642        let timer = Instant::now();
643        quantizer_preprocess(
644            &mut scratch.pq_scratch,
645            &provider.pq_data,
646            provider.metric,
647            &[start_vertex_id],
648        )?;
649        IOTracker::add_time(
650            &io_tracker.preprocess_time_us,
651            timer.elapsed().as_micros() as u64,
652        );
653
654        Ok(Self {
655            provider,
656            io_tracker,
657            scratch,
658            query,
659        })
660    }
661    fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
662        if ids.is_empty() {
663            return Ok(());
664        }
665        let scratch = &mut self.scratch;
666        let timer = Instant::now();
667        ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
668        IOTracker::add_time(
669            &self.io_tracker.io_time_us,
670            timer.elapsed().as_micros() as u64,
671        );
672        self.io_tracker.add_io_count(ids.len());
673        for id in ids {
674            let distance = self
675                .provider
676                .distance_comparer
677                .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
678            let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
679            scratch
680                .distance_cache
681                .insert(*id, (distance, associated_data));
682        }
683        Ok(())
684    }
685}
686
687impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
688where
689    Data: GraphDataType<VectorIdType = u32>,
690    VP: VertexProvider<Data>,
691{
692    type Id = u32;
693}
694
695impl<Data, VP> Accessor for DiskAccessor<'_, Data, VP>
696where
697    Data: GraphDataType<VectorIdType = u32>,
698    VP: VertexProvider<Data>,
699{
700    /// This accessor returns raw slices. There *is* a chance of racing when the fast
701    /// providers are used. We just have to live with it.
702    type Element<'a>
703        = &'a [u8]
704    where
705        Self: 'a;
706
707    /// `ElementRef` can have arbitrary lifetimes.
708    type ElementRef<'a> = &'a [u8];
709
710    /// Choose to panic on an out-of-bounds access rather than propagate an error.
711    type GetError = ANNError;
712
713    fn get_element(
714        &mut self,
715        id: Self::Id,
716    ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
717        std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
718    }
719}
720
721impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
722where
723    Data: GraphDataType<VectorIdType = u32>,
724    VP: VertexProvider<Data>,
725{
726    async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
727        Ok(0..self.provider.num_points as u32)
728    }
729}
730
731impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
732where
733    Data: GraphDataType<VectorIdType = u32>,
734    VP: VertexProvider<Data>,
735{
736    type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
737    fn delegate_neighbor(&'a mut self) -> Self::Delegate {
738        AsNeighborAccessor(self)
739    }
740}
741
742/// A light-weight wrapper around `&mut DiskAccessor` used to tailor the semantics of
743/// [`NeighborAccessor`].
744///
745/// This implementation ensures that the vector data for adjacency lists is also retrieved
746/// and cached to enhance reranking.
747pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
748where
749    Data: GraphDataType<VectorIdType = u32>,
750    VP: VertexProvider<Data>;
751
752impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
753where
754    Data: GraphDataType<VectorIdType = u32>,
755    VP: VertexProvider<Data>,
756{
757    type Id = u32;
758}
759
760impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
761where
762    Data: GraphDataType<VectorIdType = u32>,
763    VP: VertexProvider<Data>,
764{
765    fn get_neighbors(
766        self,
767        id: Self::Id,
768        neighbors: &mut AdjacencyList<Self::Id>,
769    ) -> impl Future<Output = ANNResult<Self>> + Send {
770        if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
771            return future::ok(self); // Returning empty results in `neighbors` out param if IO limit is reached.
772        }
773
774        if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
775            return future::err(e);
776        }
777        let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
778            Ok(list) => list,
779            Err(e) => return future::err(e),
780        };
781        neighbors.overwrite_trusted(list);
782        future::ok(self)
783    }
784}
785
786/// [`DiskIndexSearcher`] is a helper class to make it easy to construct index
787/// and do repeated search operations. It is a wrapper around the index.
788/// This is useful for drivers such as search_disk_index.exe in tools.
789pub struct DiskIndexSearcher<
790    Data,
791    ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
792> where
793    Data: GraphDataType<VectorIdType = u32>,
794    ProviderFactory: VertexProviderFactory<Data>,
795{
796    index: DiskANNIndex<DiskProvider<Data>>,
797    runtime: Runtime,
798
799    /// The vertex provider factory is used to create the vertex provider for each search instance.
800    vertex_provider_factory: ProviderFactory,
801
802    /// Scratch pool for disk search operations that need allocations.
803    scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
804}
805
806#[derive(Debug)]
807pub struct SearchResultStats {
808    pub cmps: u32,
809    pub result_count: u32,
810    pub query_statistics: QueryStatistics,
811}
812
813/// `SearchResult` is a struct representing the result of a search operation.
814///
815/// It contains a list of vector results and a statistics object
816///
817pub struct SearchResult<AssociatedData> {
818    /// A list of nearest neighbors resulting from the search.
819    pub results: Vec<SearchResultItem<AssociatedData>>,
820    pub stats: SearchResultStats,
821}
822
823/// `VectorResult` is a struct representing a nearest neighbor resulting from a search.
824///
825/// It contains the vertex id, associated data, and the distance to the query vector.
826///
827pub struct SearchResultItem<AssociatedData> {
828    /// The vertex id of the nearest neighbor.
829    pub vertex_id: u32,
830    /// The associated data of the nearest neighbor as a fixed size byte array.
831    /// The length is determined when the index is created.
832    pub data: AssociatedData,
833    /// The distance between the nearest neighbor and the query vector.
834    pub distance: f32,
835}
836
837impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
838where
839    Data: GraphDataType<VectorIdType = u32>,
840    ProviderFactory: VertexProviderFactory<Data>,
841{
842    /// Create a new asynchronous disk searcher instance.
843    ///
844    /// # Arguments
845    /// * `num_threads` - The maximum number of threads to use.
846    /// * `search_io_limit` - I/O operation limit.
847    /// * `disk_index_reader` - The disk index reader.
848    /// * `vertex_provider_factory` - The vertex provider factory.
849    /// * `metric` - Distance metric used for vector similarity calculations.
850    /// * `runtime` - Tokio runtime handle for executing async operations.
851    pub fn new(
852        num_threads: usize,
853        search_io_limit: usize,
854        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
855        vertex_provider_factory: ProviderFactory,
856        metric: Metric,
857        runtime: Option<Runtime>,
858    ) -> ANNResult<Self> {
859        let runtime = match runtime {
860            Some(rt) => rt,
861            None => tokio::runtime::Builder::new_current_thread().build()?,
862        };
863
864        let graph_header = vertex_provider_factory.get_header()?;
865        let metadata = graph_header.metadata();
866        let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
867
868        let config = graph::config::Builder::new(
869            max_degree.into_usize(),
870            graph::config::MaxDegree::default_slack(),
871            1, // build-search-list-size
872            metric.into(),
873        )
874        .build()?;
875
876        debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
877
878        let graph_header = vertex_provider_factory.get_header()?;
879        let pq_data = disk_index_reader.get_pq_data();
880        let scratch_pool_args = DiskSearchScratchArgs {
881            graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
882            dim: graph_header.metadata().dims,
883            num_pq_chunks: pq_data.get_num_chunks(),
884            num_pq_centers: pq_data.get_num_centers(),
885            vertex_factory: &vertex_provider_factory,
886            graph_header: &graph_header,
887        };
888        let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
889
890        let disk_provider = DiskProvider::new(
891            disk_index_reader,
892            graph_header,
893            metric,
894            metadata.num_pts.into_usize(),
895            search_io_limit,
896        )?;
897
898        let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
899        Ok(Self {
900            index,
901            runtime,
902            vertex_provider_factory,
903            scratch_pool,
904        })
905    }
906
907    /// Helper method to create a DiskSearchStrategy with common parameters
908    fn search_strategy<'a>(
909        &'a self,
910        query: &'a [Data::VectorDataType],
911        vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
912    ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
913        DiskSearchStrategy {
914            io_tracker: IOTracker::default(),
915            vector_filter,
916            query,
917            vertex_provider_factory: &self.vertex_provider_factory,
918            scratch_pool: &self.scratch_pool,
919        }
920    }
921
922    /// Perform a search on the disk index.
923    /// return the list of nearest neighbors and associated data.
924    pub fn search(
925        &self,
926        query: &[Data::VectorDataType],
927        return_list_size: u32,
928        search_list_size: u32,
929        beam_width: Option<usize>,
930        vector_filter: Option<VectorFilter<Data>>,
931        is_flat_search: bool,
932    ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
933        let mut query_stats = QueryStatistics::default();
934        let mut indices = vec![0u32; return_list_size as usize];
935        let mut distances = vec![0f32; return_list_size as usize];
936        let mut associated_data =
937            vec![Data::AssociatedDataType::default(); return_list_size as usize];
938
939        let stats = self.search_internal(
940            query,
941            return_list_size as usize,
942            search_list_size,
943            beam_width,
944            &mut query_stats,
945            &mut indices,
946            &mut distances,
947            &mut associated_data,
948            &vector_filter.unwrap_or(default_vector_filter::<Data>()),
949            is_flat_search,
950        )?;
951
952        let mut search_result = SearchResult {
953            results: Vec::with_capacity(return_list_size as usize),
954            stats,
955        };
956
957        for ((vertex_id, distance), associated_data) in indices
958            .into_iter()
959            .zip(distances.into_iter())
960            .zip(associated_data.into_iter())
961        {
962            search_result.results.push(SearchResultItem {
963                vertex_id,
964                distance,
965                data: associated_data,
966            });
967        }
968
969        Ok(search_result)
970    }
971
972    /// Perform a raw search on the disk index.
973    /// This is a lower-level API that allows more control over the search parameters and output buffers.
974    #[allow(clippy::too_many_arguments)]
975    pub(crate) fn search_internal(
976        &self,
977        query: &[Data::VectorDataType],
978        k_value: usize,
979        search_list_size: u32,
980        beam_width: Option<usize>,
981        query_stats: &mut QueryStatistics,
982        indices: &mut [u32],
983        distances: &mut [f32],
984        associated_data: &mut [Data::AssociatedDataType],
985        vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
986        is_flat_search: bool,
987    ) -> ANNResult<SearchResultStats> {
988        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
989            &mut indices[..k_value],
990            &mut distances[..k_value],
991            &mut associated_data[..k_value],
992        );
993
994        let strategy = self.search_strategy(query, vector_filter);
995        let timer = Instant::now();
996        let k = k_value;
997        let l = search_list_size as usize;
998        let stats = if is_flat_search {
999            self.runtime.block_on(self.index.flat_search(
1000                &strategy,
1001                &DefaultContext,
1002                strategy.query,
1003                vector_filter,
1004                &Knn::new(k, l, beam_width)?,
1005                &mut result_output_buffer,
1006            ))?
1007        } else {
1008            let knn_search = Knn::new(k, l, beam_width)?;
1009            self.runtime.block_on(self.index.search(
1010                knn_search,
1011                &strategy,
1012                &DefaultContext,
1013                strategy.query,
1014                &mut result_output_buffer,
1015            ))?
1016        };
1017        query_stats.total_comparisons = stats.cmps;
1018        query_stats.search_hops = stats.hops;
1019
1020        query_stats.total_execution_time_us = timer.elapsed().as_micros();
1021        query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1022        query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1023        query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1024        query_stats.query_pq_preprocess_time_us =
1025            IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1026        query_stats.cpu_time_us = query_stats.total_execution_time_us
1027            - query_stats.io_time_us
1028            - query_stats.query_pq_preprocess_time_us;
1029        Ok(SearchResultStats {
1030            cmps: query_stats.total_comparisons,
1031            result_count: stats.result_count,
1032            query_statistics: query_stats.clone(),
1033        })
1034    }
1035}
1036
1037/// Helper function to ensure vertices are loaded and processed.
1038///
1039/// This is a convenience function that combines `load_vertices` and `process_loaded_node`
1040/// for each vertex ID. It first loads all the vertices in batch, then processes each
1041/// loaded node.
1042fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1043    vertex_provider: &mut V,
1044    ids: &[Data::VectorIdType],
1045) -> ANNResult<()> {
1046    vertex_provider.load_vertices(ids)?;
1047    for (idx, id) in ids.iter().enumerate() {
1048        vertex_provider.process_loaded_node(id, idx)?;
1049    }
1050    Ok(())
1051}
1052
1053#[cfg(test)]
1054mod disk_provider_tests {
1055    use diskann::{
1056        graph::{
1057            search::{record::VisitedSearchRecord, Knn},
1058            KnnSearchError,
1059        },
1060        utils::IntoUsize,
1061        ANNErrorKind,
1062    };
1063    use diskann_providers::storage::{
1064        DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1065    };
1066    use diskann_providers::{
1067        common::AlignedBoxWithSlice,
1068        test_utils::graph_data_type_utils::{
1069            GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1070        },
1071        utils::{create_thread_pool, load_aligned_bin, PQPathNames, ParallelIteratorInPool},
1072    };
1073    use diskann_utils::{io::read_bin, test_data_root};
1074    use diskann_vector::distance::Metric;
1075    use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1076    use rstest::rstest;
1077    use vfs::OverlayFS;
1078
1079    use super::*;
1080    use crate::{
1081        build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1082        utils::{QueryStatistics, VirtualAlignedReaderFactory},
1083    };
1084
1085    const TEST_INDEX_PREFIX_128DIM: &str =
1086        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1087    const TEST_INDEX_128DIM: &str =
1088        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1089    const TEST_PQ_PIVOT_128DIM: &str =
1090        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1091    const TEST_PQ_COMPRESSED_128DIM: &str =
1092        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1093    const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1094        "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1095    const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1096
1097    const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1098    const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1099    const TEST_PQ_PIVOT_100DIM: &str =
1100        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1101    const TEST_PQ_COMPRESSED_100DIM: &str =
1102        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1103    const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1104        "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1105    const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1106    const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1107    const TEST_INDEX: &str =
1108        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1109    const TEST_INDEX_PREFIX: &str =
1110        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1111    const TEST_PQ_PIVOT: &str =
1112        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1113    const TEST_PQ_COMPRESSED: &str =
1114        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1115
1116    #[test]
1117    fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1118        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1119
1120        let search_engine = create_disk_index_searcher(
1121            CreateDiskIndexSearcherParams {
1122                max_thread_num: 5,
1123                pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1124                pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1125                index_path: TEST_INDEX_100DIM,
1126                index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1127                ..Default::default()
1128            },
1129            &storage_provider,
1130        );
1131        // Test single thread.
1132        test_disk_search(TestDiskSearchParams {
1133            storage_provider: storage_provider.as_ref(),
1134            index_search_engine: &search_engine,
1135            thread_num: 1,
1136            query_file_path: TEST_QUERY_10PTS_100DIM,
1137            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1138            k: 10,
1139            l: 20,
1140            dim: 104,
1141        });
1142        // Test multi thread.
1143        test_disk_search(TestDiskSearchParams {
1144            storage_provider: storage_provider.as_ref(),
1145            index_search_engine: &search_engine,
1146            thread_num: 5,
1147            query_file_path: TEST_QUERY_10PTS_100DIM,
1148            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1149            k: 10,
1150            l: 20,
1151            dim: 104,
1152        });
1153    }
1154
1155    #[test]
1156    fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1157        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1158
1159        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1160            CreateDiskIndexSearcherParams {
1161                max_thread_num: 5,
1162                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1163                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1164                index_path: TEST_INDEX_128DIM,
1165                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1166                ..Default::default()
1167            },
1168            &storage_provider,
1169        );
1170        // Test single thread.
1171        test_disk_search(TestDiskSearchParams {
1172            storage_provider: storage_provider.as_ref(),
1173            index_search_engine: &search_engine,
1174            thread_num: 1,
1175            query_file_path: TEST_QUERY_10PTS_128DIM,
1176            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1177            k: 10,
1178            l: 20,
1179            dim: 128,
1180        });
1181        // Test multi thread.
1182        test_disk_search(TestDiskSearchParams {
1183            storage_provider: storage_provider.as_ref(),
1184            index_search_engine: &search_engine,
1185            thread_num: 5,
1186            query_file_path: TEST_QUERY_10PTS_128DIM,
1187            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1188            k: 10,
1189            l: 20,
1190            dim: 128,
1191        });
1192    }
1193
1194    fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1195        storage_provider: &StorageReader,
1196    ) -> Vec<u32> {
1197        const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1198
1199        let data =
1200            read_bin::<u32>(&mut storage_provider.open_reader(ASSOCIATED_DATA_FILE).unwrap())
1201                .unwrap();
1202        data.into_inner().into_vec()
1203    }
1204
1205    #[test]
1206    fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1207        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1208        let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1209        let params = TestParams {
1210            data_path: TEST_DATA_FILE.to_string(),
1211            index_path_prefix: index_path_prefix.to_string(),
1212            associated_data_path: Some(
1213                "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1214            ),
1215            ..TestParams::default()
1216        };
1217        let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1218        // Build the index with the associated data
1219        fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1220        {
1221            let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1222                CreateDiskIndexSearcherParams {
1223                    max_thread_num: 5,
1224                    pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1225                    pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1226                        .as_str(),
1227                    index_path: format!("{}_disk.index", index_path_prefix).as_str(), //TEST_INDEX_128DIM,
1228                    index_path_prefix,
1229                    ..Default::default()
1230                },
1231                &fixture.storage_provider,
1232            );
1233
1234            // Test single thread.
1235            test_disk_search_with_associated(
1236                TestDiskSearchAssociateParams {
1237                    storage_provider: fixture.storage_provider.as_ref(),
1238                    index_search_engine: &search_engine,
1239                    thread_num: 1,
1240                    query_file_path: TEST_QUERY_10PTS_128DIM,
1241                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1242                    k: 10,
1243                    l: 20,
1244                    dim: 128,
1245                },
1246                None,
1247            );
1248
1249            // Test multi thread.
1250            test_disk_search_with_associated(
1251                TestDiskSearchAssociateParams {
1252                    storage_provider: fixture.storage_provider.as_ref(),
1253                    index_search_engine: &search_engine,
1254                    thread_num: 5,
1255                    query_file_path: TEST_QUERY_10PTS_128DIM,
1256                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1257                    k: 10,
1258                    l: 20,
1259                    dim: 128,
1260                },
1261                None,
1262            );
1263        }
1264
1265        fixture
1266            .storage_provider
1267            .delete(&format!("{}_disk.index", index_path_prefix))
1268            .expect("Failed to delete file");
1269        fixture
1270            .storage_provider
1271            .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1272            .expect("Failed to delete file");
1273        fixture
1274            .storage_provider
1275            .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1276            .expect("Failed to delete file");
1277    }
1278
1279    struct CreateDiskIndexSearcherParams<'a> {
1280        max_thread_num: usize,
1281        pq_pivot_file_path: &'a str,
1282        pq_compressed_file_path: &'a str,
1283        index_path: &'a str,
1284        index_path_prefix: &'a str,
1285        io_limit: usize,
1286    }
1287
1288    impl Default for CreateDiskIndexSearcherParams<'_> {
1289        fn default() -> Self {
1290            Self {
1291                max_thread_num: 1,
1292                pq_pivot_file_path: "",
1293                pq_compressed_file_path: "",
1294                index_path: "",
1295                index_path_prefix: "",
1296                io_limit: usize::MAX,
1297            }
1298        }
1299    }
1300
1301    fn create_disk_index_searcher<Data>(
1302        params: CreateDiskIndexSearcherParams,
1303        storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1304    ) -> DiskIndexSearcher<
1305        Data,
1306        DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1307    >
1308    where
1309        Data: GraphDataType<VectorIdType = u32>,
1310    {
1311        assert!(params.io_limit > 0);
1312
1313        let runtime = tokio::runtime::Builder::new_multi_thread()
1314            .worker_threads(params.max_thread_num)
1315            .build()
1316            .unwrap();
1317
1318        let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1319            params.pq_pivot_file_path.to_string(),
1320            params.pq_compressed_file_path.to_string(),
1321            storage_provider.as_ref(),
1322        )
1323        .unwrap();
1324
1325        let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1326            get_disk_index_file(params.index_path_prefix),
1327            Arc::clone(storage_provider),
1328        );
1329        let caching_strategy = CachingStrategy::None;
1330        let vertex_provider_factory =
1331            DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1332                .unwrap();
1333
1334        DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1335            params.max_thread_num,
1336            params.io_limit,
1337            &disk_index_reader,
1338            vertex_provider_factory,
1339            Metric::L2,
1340            Some(runtime),
1341        )
1342        .unwrap()
1343    }
1344
1345    fn load_query_result<StorageReader: StorageReadProvider>(
1346        storage_provider: &StorageReader,
1347        query_result_path: &str,
1348    ) -> Vec<u32> {
1349        let result =
1350            read_bin::<u32>(&mut storage_provider.open_reader(query_result_path).unwrap()).unwrap();
1351        result.into_inner().into_vec()
1352    }
1353
1354    struct TestDiskSearchParams<'a, StorageType> {
1355        storage_provider: &'a StorageType,
1356        index_search_engine: &'a DiskIndexSearcher<
1357            GraphDataF32VectorUnitData,
1358            DiskVertexProviderFactory<
1359                GraphDataF32VectorUnitData,
1360                VirtualAlignedReaderFactory<OverlayFS>,
1361            >,
1362        >,
1363        thread_num: u64,
1364        query_file_path: &'a str,
1365        truth_result_file_path: &'a str,
1366        k: usize,
1367        l: usize,
1368        dim: usize,
1369    }
1370
1371    struct TestDiskSearchAssociateParams<'a, StorageType> {
1372        storage_provider: &'a StorageType,
1373        index_search_engine: &'a DiskIndexSearcher<
1374            GraphDataF32VectorU32Data,
1375            DiskVertexProviderFactory<
1376                GraphDataF32VectorU32Data,
1377                VirtualAlignedReaderFactory<OverlayFS>,
1378            >,
1379        >,
1380        thread_num: u64,
1381        query_file_path: &'a str,
1382        truth_result_file_path: &'a str,
1383        k: usize,
1384        l: usize,
1385        dim: usize,
1386    }
1387
1388    fn test_disk_search<StorageType: StorageReadProvider>(
1389        params: TestDiskSearchParams<StorageType>,
1390    ) {
1391        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1392            .unwrap()
1393            .0;
1394        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1395        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1396
1397        let queries = aligned_query
1398            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1399            .unwrap();
1400
1401        let truth_result =
1402            load_query_result(params.storage_provider, params.truth_result_file_path);
1403
1404        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1405        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1406        queries
1407            .par_iter()
1408            .enumerate()
1409            .for_each_in_pool(&pool, |(i, query)| {
1410                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1411                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1412                let mut temp = Vec::with_capacity(query.len() + 1);
1413                temp.push(0.0);
1414                temp.extend_from_slice(query);
1415                aligned_box.memcpy(temp.as_slice()).unwrap();
1416                let query = &aligned_box.as_slice()[1..];
1417
1418                let mut query_stats = QueryStatistics::default();
1419                let mut indices = vec![0u32; 10];
1420                let mut distances = vec![0f32; 10];
1421                let mut associated_data = vec![(); 10];
1422
1423                let result = params
1424                    .index_search_engine
1425                    //.search_with_associated_data(query, params.k as u32, params.l as u32)
1426                    .search_internal(
1427                        query,
1428                        params.k,
1429                        params.l as u32,
1430                        None, // beam_width
1431                        &mut query_stats,
1432                        &mut indices,
1433                        &mut distances,
1434                        &mut associated_data,
1435                        &(|_| true),
1436                        false,
1437                    );
1438
1439                // Calculate the range of the truth_result for this query
1440                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1441
1442                assert!(result.is_ok(), "Expected search to succeed");
1443
1444                let result_unwrapped = result.unwrap();
1445                assert!(
1446                    result_unwrapped.query_statistics.total_io_operations > 0,
1447                    "Expected IO operations to be greater than 0"
1448                );
1449                assert!(
1450                    result_unwrapped.query_statistics.total_vertices_loaded > 0,
1451                    "Expected vertices loaded to be greater than 0"
1452                );
1453
1454                // Compare res with truth_slice using assert_eq!
1455                assert_eq!(
1456                    indices, truth_slice,
1457                    "Results DO NOT match with the truth result for query {}",
1458                    i
1459                );
1460            });
1461    }
1462
1463    fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1464        params: TestDiskSearchAssociateParams<StorageType>,
1465        beam_width: Option<usize>,
1466    ) {
1467        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1468            .unwrap()
1469            .0;
1470        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1471        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1472        let queries = aligned_query
1473            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1474            .unwrap();
1475        let truth_result =
1476            load_query_result(params.storage_provider, params.truth_result_file_path);
1477        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1478        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1479        queries
1480            .par_iter()
1481            .enumerate()
1482            .for_each_in_pool(&pool, |(i, query)| {
1483                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1484                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1485                let mut temp = Vec::with_capacity(query.len() + 1);
1486                temp.push(0.0);
1487                temp.extend_from_slice(query);
1488                aligned_box.memcpy(temp.as_slice()).unwrap();
1489                let query = &aligned_box.as_slice()[1..];
1490                let result = params
1491                    .index_search_engine
1492                    .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1493                    .unwrap();
1494                let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1495                let associated_data: Vec<u32> =
1496                    result.results.iter().map(|item| item.data).collect();
1497                let truth_data = get_truth_associated_data(params.storage_provider);
1498                let associated_data_truth: Vec<u32> = indices
1499                    .iter()
1500                    .map(|&vid| truth_data[vid as usize])
1501                    .collect();
1502                assert_eq!(
1503                    associated_data, associated_data_truth,
1504                    "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1505                    i,associated_data, associated_data_truth
1506                );
1507                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1508                assert_eq!(
1509                    indices, truth_slice,
1510                    "Results DO NOT match with the truth result for query {}",
1511                    i
1512                );
1513            });
1514    }
1515
1516    #[test]
1517    fn test_disk_search_invalid_input() {
1518        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1519        let ctx = &DefaultContext;
1520
1521        let params = CreateDiskIndexSearcherParams {
1522            max_thread_num: 5,
1523            pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1524            pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1525            index_path: TEST_INDEX_128DIM,
1526            index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1527            ..Default::default()
1528        };
1529
1530        let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1531        assert_eq!(
1532            paths.pivots, params.pq_pivot_file_path,
1533            "pq_pivot_file_path is not correct"
1534        );
1535        assert_eq!(
1536            paths.compressed_data, params.pq_compressed_file_path,
1537            "pq_compressed_file_path is not correct"
1538        );
1539        assert_eq!(
1540            params.index_path,
1541            format!("{}_disk.index", params.index_path_prefix),
1542            "index_path is not correct"
1543        );
1544
1545        // Test error case: l < k
1546        let res = Knn::new_default(20, 10);
1547        assert!(res.is_err());
1548        assert_eq!(
1549            <KnnSearchError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1550            ANNErrorKind::IndexError
1551        );
1552        // Test error case: beam_width = 0
1553        let res = Knn::new(10, 10, Some(0));
1554        assert!(res.is_err());
1555
1556        let search_engine =
1557            create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1558
1559        // minor validation tests to improve code coverage
1560        assert_eq!(
1561            search_engine
1562                .index
1563                .data_provider
1564                .to_external_id(ctx, 0)
1565                .unwrap(),
1566            0
1567        );
1568        assert_eq!(
1569            search_engine
1570                .index
1571                .data_provider
1572                .to_internal_id(ctx, &0)
1573                .unwrap(),
1574            0
1575        );
1576
1577        let provider_max_degree = search_engine
1578            .index
1579            .data_provider
1580            .graph_header
1581            .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1582            .unwrap();
1583        let index_max_degree = search_engine.index.config.pruned_degree().get();
1584        assert_eq!(provider_max_degree, index_max_degree);
1585
1586        let query = vec![0f32; 128];
1587        let mut query_stats = QueryStatistics::default();
1588        let mut indices = vec![0u32; 10];
1589        let mut distances = vec![0f32; 10];
1590        let mut associated_data = vec![0u32; 10];
1591
1592        // Set L: {} to a value of at least K:
1593        let result = search_engine.search_internal(
1594            &query,
1595            10,
1596            10 - 1,
1597            None,
1598            &mut query_stats,
1599            &mut indices,
1600            &mut distances,
1601            &mut associated_data,
1602            &|_| true,
1603            false,
1604        );
1605
1606        assert!(result.is_err());
1607        assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1608    }
1609
1610    #[test]
1611    fn test_disk_search_beam_search() {
1612        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1613
1614        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1615            CreateDiskIndexSearcherParams {
1616                max_thread_num: 1,
1617                pq_pivot_file_path: TEST_PQ_PIVOT,
1618                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1619                index_path: TEST_INDEX,
1620                index_path_prefix: TEST_INDEX_PREFIX,
1621                ..Default::default()
1622            },
1623            &storage_provider,
1624        );
1625
1626        let query_vector: [f32; 128] = [1f32; 128];
1627        let mut indices = vec![0u32; 10];
1628        let mut distances = vec![0f32; 10];
1629        let mut associated_data = vec![(); 10];
1630
1631        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1632            &mut indices,
1633            &mut distances,
1634            &mut associated_data,
1635        );
1636        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1637        let mut search_record = VisitedSearchRecord::new(0);
1638        let search_params = Knn::new(10, 10, Some(4)).unwrap();
1639        let recorded_search =
1640            diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
1641        search_engine
1642            .runtime
1643            .block_on(search_engine.index.search(
1644                recorded_search,
1645                &strategy,
1646                &DefaultContext,
1647                query_vector.as_slice(),
1648                &mut result_output_buffer,
1649            ))
1650            .unwrap();
1651
1652        let ids = search_record
1653            .visited
1654            .iter()
1655            .map(|n| n.id)
1656            .collect::<Vec<_>>();
1657
1658        const EXPECTED_NODES: [u32; 18] = [
1659            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1660        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
1661
1662        assert_eq!(ids, &EXPECTED_NODES);
1663
1664        let return_list_size = 10;
1665        let search_list_size = 10;
1666        let result = search_engine.search(
1667            &query_vector,
1668            return_list_size,
1669            search_list_size,
1670            Some(4),
1671            None,
1672            false,
1673        );
1674        assert!(result.is_ok(), "Expected search to succeed");
1675        let search_result = result.unwrap();
1676        assert_eq!(
1677            search_result.results.len() as u32,
1678            return_list_size,
1679            "Expected result count to match"
1680        );
1681        assert_eq!(
1682            indices,
1683            vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1684            "Expected indices to match"
1685        );
1686    }
1687
1688    #[cfg(feature = "experimental_diversity_search")]
1689    #[test]
1690    fn test_disk_search_diversity_search() {
1691        use diskann::graph::DiverseSearchParams;
1692        use diskann::neighbor::AttributeValueProvider;
1693        use std::collections::HashMap;
1694
1695        // Simple test attribute provider
1696        #[derive(Debug, Clone)]
1697        struct TestAttributeProvider {
1698            attributes: HashMap<u32, u32>,
1699        }
1700        impl TestAttributeProvider {
1701            fn new() -> Self {
1702                Self {
1703                    attributes: HashMap::new(),
1704                }
1705            }
1706            fn insert(&mut self, id: u32, attribute: u32) {
1707                self.attributes.insert(id, attribute);
1708            }
1709        }
1710        impl diskann::provider::HasId for TestAttributeProvider {
1711            type Id = u32;
1712        }
1713
1714        impl AttributeValueProvider for TestAttributeProvider {
1715            type Value = u32;
1716
1717            fn get(&self, id: Self::Id) -> Option<Self::Value> {
1718                self.attributes.get(&id).copied()
1719            }
1720        }
1721
1722        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1723
1724        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1725            CreateDiskIndexSearcherParams {
1726                max_thread_num: 1,
1727                pq_pivot_file_path: TEST_PQ_PIVOT,
1728                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1729                index_path: TEST_INDEX,
1730                index_path_prefix: TEST_INDEX_PREFIX,
1731                ..Default::default()
1732            },
1733            &storage_provider,
1734        );
1735
1736        let query_vector: [f32; 128] = [1f32; 128];
1737
1738        // Create attribute provider with random labels (1 to 3) for all vectors
1739        let mut attribute_provider = TestAttributeProvider::new();
1740        let num_vectors = 256; // Number of vectors in the test dataset
1741        for i in 0..num_vectors {
1742            // Assign labels 1-3 based on modulo to ensure distribution
1743            let label = (i % 15) + 1;
1744            attribute_provider.insert(i, label);
1745        }
1746        // Wrap in Arc once to avoid cloning the HashMap later
1747        let attribute_provider = std::sync::Arc::new(attribute_provider);
1748
1749        let mut indices = vec![0u32; 10];
1750        let mut distances = vec![0f32; 10];
1751        let mut associated_data = vec![(); 10];
1752
1753        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1754            &mut indices,
1755            &mut distances,
1756            &mut associated_data,
1757        );
1758        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1759
1760        // Create diverse search parameters with attribute provider
1761        let diverse_params = DiverseSearchParams::new(
1762            0, // diverse_attribute_id
1763            3, // diverse_results_k
1764            attribute_provider.clone(),
1765        );
1766
1767        let search_params = Knn::new(10, 20, None).unwrap();
1768
1769        let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params);
1770        let stats = search_engine
1771            .runtime
1772            .block_on(search_engine.index.search(
1773                diverse_search,
1774                &strategy,
1775                &DefaultContext,
1776                query_vector.as_slice(),
1777                &mut result_output_buffer,
1778            ))
1779            .unwrap();
1780
1781        // Verify that search was performed and returned some results
1782        assert!(
1783            stats.result_count > 0,
1784            "Expected to get some results during diversity search"
1785        );
1786
1787        let return_list_size = 10;
1788        let search_list_size = 20;
1789        let diverse_results_k = 1;
1790        let diverse_params = DiverseSearchParams::new(
1791            0, // diverse_attribute_id
1792            diverse_results_k,
1793            attribute_provider.clone(),
1794        );
1795
1796        // Test diverse search using the search API
1797        let mut indices2 = vec![0u32; return_list_size as usize];
1798        let mut distances2 = vec![0f32; return_list_size as usize];
1799        let mut associated_data2 = vec![(); return_list_size as usize];
1800        let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1801            &mut indices2,
1802            &mut distances2,
1803            &mut associated_data2,
1804        );
1805        let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1806        let search_params2 =
1807            Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1808
1809        let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params);
1810        let stats = search_engine
1811            .runtime
1812            .block_on(search_engine.index.search(
1813                diverse_search2,
1814                &strategy2,
1815                &DefaultContext,
1816                query_vector.as_slice(),
1817                &mut result_output_buffer2,
1818            ))
1819            .unwrap();
1820
1821        // Verify results
1822        assert!(
1823            stats.result_count > 0,
1824            "Expected diversity search to return results"
1825        );
1826        assert!(
1827            stats.result_count <= return_list_size,
1828            "Expected result count to be <= {}",
1829            return_list_size
1830        );
1831
1832        // Verify that we got some results
1833        assert!(
1834            stats.result_count > 0,
1835            "Expected to get some search results"
1836        );
1837
1838        // Print search results with their attributes
1839        println!("\n=== Diversity Search Results ===");
1840        println!("Query: [1f32; 128]");
1841        println!("diverse_results_k: {}", diverse_results_k);
1842        println!("Total results: {}\n", stats.result_count);
1843        println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1844        println!("{}", "-".repeat(35));
1845        for i in 0..stats.result_count as usize {
1846            let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1847            println!(
1848                "{:<10} {:<15.2} {:<10}",
1849                indices2[i], distances2[i], attribute_value
1850            );
1851        }
1852
1853        // Verify that distances are non-negative and sorted
1854        for i in 0..(stats.result_count as usize).saturating_sub(1) {
1855            assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1856            assert!(
1857                distances2[i] <= distances2[i + 1],
1858                "Expected distances to be sorted in ascending order"
1859            );
1860        }
1861
1862        // Verify diversity: Check that we have diverse attribute values in the results
1863        let mut attribute_counts = HashMap::new();
1864        for item in indices2.iter().take(stats.result_count as usize) {
1865            if let Some(attribute_value) = attribute_provider.get(*item) {
1866                *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1867            }
1868        }
1869
1870        // Print attribute distribution
1871        println!("\n=== Attribute Distribution ===");
1872        let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1873        sorted_attrs.sort_by_key(|(k, _)| *k);
1874        for (attribute_value, count) in &sorted_attrs {
1875            println!(
1876                "Label {}: {} occurrences (max allowed: {})",
1877                attribute_value, count, diverse_results_k
1878            );
1879        }
1880        println!("Total unique labels: {}", attribute_counts.len());
1881        println!("================================\n");
1882
1883        // With diverse_results_k = 5, we expect at most 5 results per attribute value
1884        for (attribute_value, count) in &attribute_counts {
1885            println!(
1886                "Assert: Label {} has {} occurrences (max: {})",
1887                attribute_value, count, diverse_results_k
1888            );
1889            assert!(
1890                *count <= diverse_results_k,
1891                "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1892                attribute_value,
1893                count,
1894                diverse_results_k
1895            );
1896        }
1897
1898        // Verify that we have multiple different attribute values (diversity)
1899        // With 3 possible labels and diverse_results_k=5, we should see at least 2 different labels
1900        println!(
1901            "Assert: Found {} unique labels (expected at least 2)",
1902            attribute_counts.len()
1903        );
1904        assert!(
1905            attribute_counts.len() >= 2,
1906            "Expected at least 2 different attribute values for diversity, got {}",
1907            attribute_counts.len()
1908        );
1909    }
1910
1911    #[rstest]
1912    // This case checks expected behavior of unfiltered search.
1913    #[case(
1914        |_id: &u32| true,
1915        false,
1916        10,
1917        vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1918        vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1919    )]
1920    // This case validates post-filtering using 2 ids which are not present in the unfiltered result set.
1921    // It is expected that the post-filtering will return an empty result
1922    #[case(
1923        |id: &u32| *id == 0 || *id == 1,
1924        false,
1925        0,
1926        vec![0; 10],
1927        vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1928    )]
1929    // This case validates pre-filtering using 2 ids which are not present in the unfiltered result set.
1930    // It is expected that the pre-filtering will do search over matching ids
1931    #[case(
1932        |id: &u32| *id == 0 || *id == 1,
1933        true,
1934        2,
1935        vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1936        vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1937    )]
1938    // This case validates post-filtering using 3 ids from the unfiltered result set.
1939    // It is expected that the post-filtering will filter out non-matching ids
1940    #[case(
1941        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1942        false,
1943        3,
1944        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1945        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1946    )]
1947    // This case validates pre-filtering using 3 ids from the unfiltered result set.
1948    // It is expected that the pre-filtering will do search over matching ids
1949    #[case(
1950        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1951        true,
1952        3,
1953        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1954        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1955    )]
1956    fn test_search_with_vector_filter(
1957        #[case] vector_filter: fn(&u32) -> bool,
1958        #[case] is_flat_search: bool,
1959        #[case] expected_result_count: u32,
1960        #[case] expected_indices: Vec<u32>,
1961        #[case] expected_distances: Vec<f32>,
1962    ) {
1963        // Exact distances can vary slightly depending on the architecture used
1964        // to compute distances due to different unrolling strategies and SIMD widthd.
1965        //
1966        // This parameter allows for a small margin when matching distances.
1967        let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1968            const ABS_TOLERANCE: f32 = 0.02;
1969            assert_eq!(got.len(), expected.len());
1970            for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1971                if (g - e).abs() > ABS_TOLERANCE {
1972                    panic!(
1973                        "distances differ at position {} by more than {}\n\n\
1974                         got: {:?}\nexpected: {:?}",
1975                        i, ABS_TOLERANCE, got, expected,
1976                    );
1977                }
1978            }
1979            true
1980        };
1981
1982        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1983
1984        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1985            CreateDiskIndexSearcherParams {
1986                max_thread_num: 5,
1987                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1988                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1989                index_path: TEST_INDEX_128DIM,
1990                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1991                ..Default::default()
1992            },
1993            &storage_provider,
1994        );
1995        let query = vec![0.1f32; 128];
1996        let mut query_stats = QueryStatistics::default();
1997        let mut indices = vec![0u32; 10];
1998        let mut distances = vec![0f32; 10];
1999        let mut associated_data = vec![(); 10];
2000
2001        let result = search_engine.search_internal(
2002            &query,
2003            10,
2004            10,
2005            None, // beam_width
2006            &mut query_stats,
2007            &mut indices,
2008            &mut distances,
2009            &mut associated_data,
2010            &vector_filter,
2011            is_flat_search,
2012        );
2013
2014        assert!(result.is_ok(), "Expected search to succeed");
2015        assert_eq!(
2016            result.unwrap().result_count,
2017            expected_result_count,
2018            "Expected result count to match"
2019        );
2020        assert_eq!(indices, expected_indices, "Expected indices to match");
2021        assert!(
2022            check_distances(&distances, &expected_distances),
2023            "Expected distances to match"
2024        );
2025
2026        let result_with_filter = search_engine.search(
2027            &query,
2028            10,
2029            10,
2030            None, // beam_width
2031            Some(Box::new(vector_filter)),
2032            is_flat_search,
2033        );
2034
2035        assert!(result_with_filter.is_ok(), "Expected search to succeed");
2036        let result_with_filter_unwrapped = result_with_filter.unwrap();
2037        assert_eq!(
2038            result_with_filter_unwrapped.stats.result_count, expected_result_count,
2039            "Expected result count to match"
2040        );
2041        let actual_indices = result_with_filter_unwrapped
2042            .results
2043            .iter()
2044            .map(|x| x.vertex_id)
2045            .collect::<Vec<_>>();
2046        assert_eq!(
2047            actual_indices, expected_indices,
2048            "Expected indices to match"
2049        );
2050        let actual_distances = result_with_filter_unwrapped
2051            .results
2052            .iter()
2053            .map(|x| x.distance)
2054            .collect::<Vec<_>>();
2055        assert!(
2056            check_distances(&actual_distances, &expected_distances),
2057            "Expected distances to match"
2058        );
2059    }
2060
2061    #[test]
2062    fn test_beam_search_respects_io_limit() {
2063        let io_limit = 11; // Set a small IO limit for testing
2064        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2065
2066        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2067            CreateDiskIndexSearcherParams {
2068                max_thread_num: 1,
2069                pq_pivot_file_path: TEST_PQ_PIVOT,
2070                pq_compressed_file_path: TEST_PQ_COMPRESSED,
2071                index_path: TEST_INDEX,
2072                index_path_prefix: TEST_INDEX_PREFIX,
2073                io_limit,
2074            },
2075            &storage_provider,
2076        );
2077        let query_vector: [f32; 128] = [1f32; 128];
2078
2079        let mut indices = vec![0u32; 10];
2080        let mut distances = vec![0f32; 10];
2081        let mut associated_data = vec![(); 10];
2082
2083        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2084            &mut indices,
2085            &mut distances,
2086            &mut associated_data,
2087        );
2088
2089        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2090
2091        let mut search_record = VisitedSearchRecord::new(0);
2092        let search_params = Knn::new(10, 10, Some(4)).unwrap();
2093        let recorded_search =
2094            diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
2095        search_engine
2096            .runtime
2097            .block_on(search_engine.index.search(
2098                recorded_search,
2099                &strategy,
2100                &DefaultContext,
2101                query_vector.as_slice(),
2102                &mut result_output_buffer,
2103            ))
2104            .unwrap();
2105        let visited_ids = search_record
2106            .visited
2107            .iter()
2108            .map(|n| n.id)
2109            .collect::<Vec<_>>();
2110
2111        let query_stats = strategy.io_tracker;
2112        //Verify the IO limit was respected
2113        assert!(
2114            query_stats.io_count() <= io_limit,
2115            "Expected IO operations to be <= {}, but got {}",
2116            io_limit,
2117            query_stats.io_count()
2118        );
2119
2120        const EXPECTED_NODES: [u32; 17] = [
2121            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2122        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
2123
2124        // Count matching results
2125        let mut matching_count = 0;
2126        for expected_node in EXPECTED_NODES.iter() {
2127            if visited_ids.contains(expected_node) {
2128                matching_count += 1;
2129            }
2130        }
2131
2132        // Calculate recall
2133        let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2134
2135        //Verify the recall is above 60%. The threshold her eis arbitrary, just to make sure when
2136        // search hits io_limit that it doesn't break and the recall degrades gracefully
2137        assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2138    }
2139}