Skip to main content

diskann_disk/search/provider/
disk_provider.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    collections::HashMap,
8    future::Future,
9    num::NonZeroUsize,
10    ops::Range,
11    sync::{
12        atomic::{AtomicU64, AtomicUsize},
13        Arc,
14    },
15    time::Instant,
16};
17
18use diskann::{
19    graph::{
20        self,
21        glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy},
22        search::Knn,
23        search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer,
24    },
25    neighbor::Neighbor,
26    provider::{
27        Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
28        NeighborAccessor,
29    },
30    utils::{
31        object_pool::{ObjectPool, PoolOption, TryAsPooled},
32        IntoUsize, VectorRepr,
33    },
34    ANNError, ANNResult,
35};
36use diskann_providers::storage::StorageReadProvider;
37use diskann_providers::{
38    model::{
39        compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
40        pq::quantizer_preprocess, PQData, PQScratch,
41    },
42    storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
43};
44use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
45use futures_util::future;
46use tokio::runtime::Runtime;
47use tracing::debug;
48
49use crate::{
50    data_model::{CachingStrategy, GraphHeader},
51    filter_parameter::{default_vector_filter, VectorFilter},
52    search::{
53        provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
54        traits::{VertexProvider, VertexProviderFactory},
55    },
56    storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
57    utils::AlignedFileReaderFactory,
58    utils::QueryStatistics,
59};
60
61///////////////////
62// Disk Provider //
63///////////////////
64
65/// The DiskProvider is a data provider that loads data from disk using the disk readers
66/// The data format for disk is different from that of the in-memory providers.
67/// The disk format stores both the vectors and the adjacency list next to each other for
68/// better locality for quicker access.
69/// Please refer to the RFC documentation at [`docs\rfcs\cy2025\disk_provider_for_async_index.md`] for design details.
70pub struct DiskProvider<Data>
71where
72    Data: GraphDataType<VectorIdType = u32>,
73{
74    /// Holds the graph header information that contains metadata about disk-index file.
75    graph_header: GraphHeader,
76
77    // Full precision distance comparer used in post_process to reorder results.
78    distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
79
80    /// The PQ data used for quantization.
81    pq_data: Arc<PQData>,
82
83    /// The number of points in the graph.
84    num_points: usize,
85
86    /// Metric used for distance computation.
87    metric: Metric,
88
89    /// The number of IO operations that can be done in parallel.
90    search_io_limit: usize,
91}
92
93impl<Data> DataProvider for DiskProvider<Data>
94where
95    Data: GraphDataType<VectorIdType = u32>,
96{
97    type Context = DefaultContext;
98
99    type InternalId = u32;
100
101    type ExternalId = u32;
102
103    type Error = ANNError;
104
105    /// Translate an external id to its corresponding internal id.
106    fn to_internal_id(
107        &self,
108        _context: &DefaultContext,
109        gid: &Self::ExternalId,
110    ) -> Result<Self::InternalId, Self::Error> {
111        Ok(*gid)
112    }
113
114    /// Translate an internal id its corresponding external id.
115    fn to_external_id(
116        &self,
117        _context: &DefaultContext,
118        id: Self::InternalId,
119    ) -> Result<Self::ExternalId, Self::Error> {
120        Ok(id)
121    }
122}
123
124impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
125where
126    Data: GraphDataType<VectorIdType = u32>,
127{
128    type Error = ANNError;
129
130    async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
131    where
132        P: StorageReadProvider,
133    {
134        debug!(
135            "DiskProvider::load_with() called with file: {:?}",
136            get_disk_index_file(ctx.quant_load_context.metadata.prefix())
137        );
138
139        let graph_header = {
140            let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
141                ctx.quant_load_context.metadata.prefix(),
142            ));
143
144            let caching_strategy = if ctx.num_nodes_to_cache > 0 {
145                CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
146            } else {
147                CachingStrategy::None
148            };
149
150            let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
151                aligned_reader_factory,
152                caching_strategy,
153            )?;
154            VertexProviderFactory::get_header(&vertex_provider_factory)?
155        };
156
157        let metric = ctx.quant_load_context.metric;
158        let num_points = ctx.num_points;
159
160        let index_path_prefix = ctx.quant_load_context.metadata.prefix();
161        let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
162            get_pq_pivot_file(index_path_prefix),
163            get_compressed_pq_file(index_path_prefix),
164            provider,
165        )?;
166
167        Self::new(
168            &index_reader,
169            graph_header,
170            metric,
171            num_points,
172            ctx.search_io_limit,
173        )
174    }
175}
176
177impl<Data> DiskProvider<Data>
178where
179    Data: GraphDataType<VectorIdType = u32>,
180{
181    fn new(
182        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
183        graph_header: GraphHeader,
184        metric: Metric,
185        num_points: usize,
186        search_io_limit: usize,
187    ) -> ANNResult<Self> {
188        let distance_comparer =
189            Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
190
191        let pq_data = disk_index_reader.get_pq_data();
192
193        Ok(Self {
194            graph_header,
195            distance_comparer,
196            pq_data,
197            num_points,
198            metric,
199            search_io_limit,
200        })
201    }
202}
203
204/// The search strategy for the disk provider. This is used to create the search accessor
205/// for use in search in quant space and post_process function to reorder with full precision vectors.
206///
207/// # Why vertex_provider_factory and scratch_pool are here instead of DiskProvider
208///
209/// The DataProvider trait requires 'static bounds for multi-threaded async contexts,
210/// but vertex_provider_factory may have non-'static lifetime bounds (e.g., borrowing
211/// from local data structures). Moving these components to the search strategy allows
212/// DiskProvider to satisfy 'static constraints while enabling flexible per-search
213/// resource management.
214pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
215where
216    Data: GraphDataType<VectorIdType = u32>,
217    ProviderFactory: VertexProviderFactory<Data>,
218{
219    // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads.
220    io_tracker: IOTracker,
221    vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), // Fn param is u32 as we validate "VectorIdType = u32" everywhere in this provider in trait bounds.
222    query: &'a [Data::VectorDataType],
223
224    /// The vertex provider factory is used to create the vertex provider for each search instance.
225    vertex_provider_factory: &'a ProviderFactory,
226
227    /// Scratch pool for disk search operations that need allocations.
228    scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
229}
230
231// Struct to track IO. This is used by single thread, but needs to be Atomic as the Strategy has "Send" trait bound.
232// There should be minimal to no overhead compared to using a raw reference.
233struct IOTracker {
234    io_time_us: AtomicU64,
235    preprocess_time_us: AtomicU64,
236    io_count: AtomicUsize,
237}
238
239impl Default for IOTracker {
240    fn default() -> Self {
241        Self {
242            io_time_us: AtomicU64::new(0),
243            preprocess_time_us: AtomicU64::new(0),
244            io_count: AtomicUsize::new(0),
245        }
246    }
247}
248
249impl IOTracker {
250    fn add_time(category: &AtomicU64, time: u64) {
251        category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
252    }
253
254    fn time(category: &AtomicU64) -> u64 {
255        category.load(std::sync::atomic::Ordering::Relaxed)
256    }
257
258    fn add_io_count(&self, count: usize) {
259        self.io_count
260            .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
261    }
262
263    fn io_count(&self) -> usize {
264        self.io_count.load(std::sync::atomic::Ordering::Relaxed)
265    }
266}
267
268#[derive(Clone, Copy)]
269pub struct RerankAndFilter<'a> {
270    filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
271}
272
273impl<'a> RerankAndFilter<'a> {
274    fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
275        Self { filter }
276    }
277}
278
279impl<Data, VP>
280    SearchPostProcess<
281        DiskAccessor<'_, Data, VP>,
282        [Data::VectorDataType],
283        (
284            <DiskProvider<Data> as DataProvider>::InternalId,
285            Data::AssociatedDataType,
286        ),
287    > for RerankAndFilter<'_>
288where
289    Data: GraphDataType<VectorIdType = u32>,
290    VP: VertexProvider<Data>,
291{
292    type Error = ANNError;
293    async fn post_process<I, B>(
294        &self,
295        accessor: &mut DiskAccessor<'_, Data, VP>,
296        query: &[Data::VectorDataType],
297        _computer: &DiskQueryComputer,
298        candidates: I,
299        output: &mut B,
300    ) -> Result<usize, Self::Error>
301    where
302        I: Iterator<Item = Neighbor<u32>> + Send,
303        B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized,
304    {
305        let provider = accessor.provider;
306
307        let mut uncached_ids = Vec::new();
308        let mut reranked = candidates
309            .map(|n| n.id)
310            .filter(|id| (self.filter)(id))
311            .filter_map(|n| {
312                if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
313                    Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
314                } else {
315                    uncached_ids.push(n);
316                    None
317                }
318            })
319            .collect::<Result<Vec<_>, _>>()?;
320        if !uncached_ids.is_empty() {
321            ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
322            for n in &uncached_ids {
323                let v = accessor.scratch.vertex_provider.get_vector(n)?;
324                let d = provider.distance_comparer.evaluate_similarity(query, v);
325                let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
326                reranked.push(((*n, *a), d));
327            }
328        }
329
330        // Sort the full precision distances.
331        reranked
332            .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
333        // Store the reranked results.
334        Ok(output.extend(reranked))
335    }
336}
337
338impl<'this, Data, ProviderFactory>
339    SearchStrategy<
340        DiskProvider<Data>,
341        [Data::VectorDataType],
342        (
343            <DiskProvider<Data> as DataProvider>::InternalId,
344            Data::AssociatedDataType,
345        ),
346    > for DiskSearchStrategy<'this, Data, ProviderFactory>
347where
348    Data: GraphDataType<VectorIdType = u32>,
349    ProviderFactory: VertexProviderFactory<Data>,
350{
351    type QueryComputer = DiskQueryComputer;
352    type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
353    type SearchAccessorError = ANNError;
354    type PostProcessor = RerankAndFilter<'this>;
355
356    fn search_accessor<'a>(
357        &'a self,
358        provider: &'a DiskProvider<Data>,
359        _context: &DefaultContext,
360    ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
361        DiskAccessor::new(
362            provider,
363            &self.io_tracker,
364            self.query,
365            self.vertex_provider_factory,
366            self.scratch_pool,
367        )
368    }
369
370    fn post_processor(&self) -> Self::PostProcessor {
371        RerankAndFilter::new(self.vector_filter)
372    }
373}
374
375/// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates.
376pub struct DiskQueryComputer {
377    num_pq_chunks: usize,
378    query_centroid_l2_distance: Vec<f32>,
379}
380
381impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
382    fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
383        let mut dist = 0.0f32;
384        #[allow(clippy::expect_used)]
385        compute_pq_distance_for_pq_coordinates(
386            changing,
387            self.num_pq_chunks,
388            &self.query_centroid_l2_distance,
389            std::slice::from_mut(&mut dist),
390        )
391        .expect("PQ distance compute for PQ coordinates is expected to succeed");
392        dist
393    }
394}
395
396impl<Data, VP> BuildQueryComputer<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
397where
398    Data: GraphDataType<VectorIdType = u32>,
399    VP: VertexProvider<Data>,
400{
401    type QueryComputerError = ANNError;
402    type QueryComputer = DiskQueryComputer;
403
404    fn build_query_computer(
405        &self,
406        _from: &[Data::VectorDataType],
407    ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
408        Ok(DiskQueryComputer {
409            num_pq_chunks: self.provider.pq_data.get_num_chunks(),
410            query_centroid_l2_distance: self
411                .scratch
412                .pq_scratch
413                .aligned_pqtable_dist_scratch
414                .as_slice()
415                .to_vec(),
416        })
417    }
418
419    async fn distances_unordered<Itr, F>(
420        &mut self,
421        vec_id_itr: Itr,
422        _computer: &Self::QueryComputer,
423        f: F,
424    ) -> Result<(), Self::GetError>
425    where
426        F: Send + FnMut(f32, Self::Id),
427        Itr: Iterator<Item = Self::Id>,
428    {
429        self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
430    }
431}
432
433impl<Data, VP> ExpandBeam<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
434where
435    Data: GraphDataType<VectorIdType = u32>,
436    VP: VertexProvider<Data>,
437{
438    fn expand_beam<Itr, P, F>(
439        &mut self,
440        ids: Itr,
441        _computer: &Self::QueryComputer,
442        mut pred: P,
443        mut f: F,
444    ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
445    where
446        Itr: Iterator<Item = Self::Id> + Send,
447        P: glue::HybridPredicate<Self::Id> + Send + Sync,
448        F: FnMut(f32, Self::Id) + Send,
449    {
450        let result = (|| {
451            let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
452            let load_ids: Box<[_]> = ids.take(io_limit).collect();
453
454            self.ensure_loaded(&load_ids)?;
455            let mut ids = Vec::new();
456            for i in load_ids {
457                ids.clear();
458                ids.extend(
459                    self.scratch
460                        .vertex_provider
461                        .get_adjacency_list(&i)?
462                        .iter()
463                        .copied()
464                        .filter(|id| pred.eval_mut(id)),
465                );
466
467                self.pq_distances(&ids, &mut f)?;
468            }
469
470            Ok(())
471        })();
472
473        std::future::ready(result)
474    }
475}
476
477// Scratch space for disk search operations that need allocations.
478// These allocations are amortized across searches using the scratch pool.
479struct DiskSearchScratch<Data, VP>
480where
481    Data: GraphDataType<VectorIdType = u32>,
482    VP: VertexProvider<Data>,
483{
484    distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
485    pq_scratch: PQScratch,
486    vertex_provider: VP,
487}
488
489#[derive(Clone)]
490struct DiskSearchScratchArgs<'a, ProviderFactory> {
491    graph_degree: usize,
492    dim: usize,
493    num_pq_chunks: usize,
494    num_pq_centers: usize,
495    vertex_factory: &'a ProviderFactory,
496    graph_header: &'a GraphHeader,
497}
498
499impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
500    for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
501where
502    Data: GraphDataType<VectorIdType = u32>,
503    ProviderFactory: VertexProviderFactory<Data>,
504{
505    type Error = ANNError;
506
507    fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
508        let pq_scratch = PQScratch::new(
509            args.graph_degree,
510            args.dim,
511            args.num_pq_chunks,
512            args.num_pq_centers,
513        )?;
514
515        const DEFAULT_BEAM_WIDTH: usize = 0; // Setting as 0 to avoid preallocation of memory.
516        let vertex_provider = args
517            .vertex_factory
518            .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
519
520        Ok(Self {
521            distance_cache: HashMap::new(),
522            pq_scratch,
523            vertex_provider,
524        })
525    }
526
527    fn try_modify(
528        &mut self,
529        _args: &DiskSearchScratchArgs<ProviderFactory>,
530    ) -> Result<(), Self::Error> {
531        self.distance_cache.clear();
532        self.vertex_provider.clear();
533        Ok(())
534    }
535}
536
537pub struct DiskAccessor<'a, Data, VP>
538where
539    Data: GraphDataType<VectorIdType = u32>,
540    VP: VertexProvider<Data>,
541{
542    provider: &'a DiskProvider<Data>,
543    io_tracker: &'a IOTracker,
544    scratch: PoolOption<DiskSearchScratch<Data, VP>>,
545    query: &'a [Data::VectorDataType],
546}
547
548impl<Data, VP> DiskAccessor<'_, Data, VP>
549where
550    Data: GraphDataType<VectorIdType = u32>,
551    VP: VertexProvider<Data>,
552{
553    // Compute the PQ distance between each ID in `ids` and the distance table stored in
554    // `self`, invoking the callback with the results of each computation in order.
555    fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
556    where
557        F: FnMut(f32, u32),
558    {
559        let pq_scratch = &mut self.scratch.pq_scratch;
560        compute_pq_distance(
561            ids,
562            self.provider.pq_data.get_num_chunks(),
563            &pq_scratch.aligned_pqtable_dist_scratch,
564            self.provider.pq_data.pq_compressed_data().get_data(),
565            &mut pq_scratch.aligned_pq_coord_scratch,
566            &mut pq_scratch.aligned_dist_scratch,
567        )?;
568
569        for (i, id) in ids.iter().enumerate() {
570            let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
571            f(distance, *id);
572        }
573
574        Ok(())
575    }
576}
577
578impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
579where
580    Data: GraphDataType<VectorIdType = u32>,
581    VP: VertexProvider<Data>,
582{
583    async fn starting_points(&self) -> ANNResult<Vec<u32>> {
584        let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
585        Ok(vec![start_vertex_id])
586    }
587
588    fn terminate_early(&mut self) -> bool {
589        self.io_tracker.io_count() > self.provider.search_io_limit
590    }
591}
592
593impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
594where
595    Data: GraphDataType<VectorIdType = u32>,
596    VP: VertexProvider<Data>,
597{
598    fn new<VPF>(
599        provider: &'a DiskProvider<Data>,
600        io_tracker: &'a IOTracker,
601        query: &'a [Data::VectorDataType],
602        vertex_provider_factory: &'a VPF,
603        scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
604    ) -> ANNResult<Self>
605    where
606        VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
607    {
608        let mut scratch = PoolOption::try_pooled(
609            scratch_pool,
610            &DiskSearchScratchArgs {
611                graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
612                dim: provider.graph_header.metadata().dims,
613                num_pq_chunks: provider.pq_data.get_num_chunks(),
614                num_pq_centers: provider.pq_data.get_num_centers(),
615                vertex_factory: vertex_provider_factory,
616                graph_header: &provider.graph_header,
617            },
618        )?;
619
620        scratch.pq_scratch.set(
621            provider.graph_header.metadata().dims,
622            query,
623            1.0_f32, // Normalization factor
624        )?;
625        let start_vertex_id = provider.graph_header.metadata().medoid as u32;
626
627        let timer = Instant::now();
628        quantizer_preprocess(
629            &mut scratch.pq_scratch,
630            &provider.pq_data,
631            provider.metric,
632            &[start_vertex_id],
633        )?;
634        IOTracker::add_time(
635            &io_tracker.preprocess_time_us,
636            timer.elapsed().as_micros() as u64,
637        );
638
639        Ok(Self {
640            provider,
641            io_tracker,
642            scratch,
643            query,
644        })
645    }
646    fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
647        if ids.is_empty() {
648            return Ok(());
649        }
650        let scratch = &mut self.scratch;
651        let timer = Instant::now();
652        ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
653        IOTracker::add_time(
654            &self.io_tracker.io_time_us,
655            timer.elapsed().as_micros() as u64,
656        );
657        self.io_tracker.add_io_count(ids.len());
658        for id in ids {
659            let distance = self
660                .provider
661                .distance_comparer
662                .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
663            let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
664            scratch
665                .distance_cache
666                .insert(*id, (distance, associated_data));
667        }
668        Ok(())
669    }
670}
671
672impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
673where
674    Data: GraphDataType<VectorIdType = u32>,
675    VP: VertexProvider<Data>,
676{
677    type Id = u32;
678}
679
680impl<'a, Data, VP> Accessor for DiskAccessor<'a, Data, VP>
681where
682    Data: GraphDataType<VectorIdType = u32>,
683    VP: VertexProvider<Data>,
684{
685    /// This references the PQ vector in the underlying `pq_data` store.
686    type Extended = &'a [u8];
687
688    /// This accessor returns raw slices. There *is* a chance of racing when the fast
689    /// providers are used. We just have to live with it.
690    ///
691    /// Since the underlying PQ store is shared, we ignore the `'b` lifetime here and
692    /// instead use `'a`.
693    type Element<'b>
694        = &'a [u8]
695    where
696        Self: 'b;
697
698    /// `ElementRef` can have arbitrary lifetimes.
699    type ElementRef<'b> = &'b [u8];
700
701    /// Choose to panic on an out-of-bounds access rather than propagate an error.
702    type GetError = ANNError;
703
704    fn get_element(
705        &mut self,
706        id: Self::Id,
707    ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
708        std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
709    }
710}
711
712impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
713where
714    Data: GraphDataType<VectorIdType = u32>,
715    VP: VertexProvider<Data>,
716{
717    async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
718        Ok(0..self.provider.num_points as u32)
719    }
720}
721
722impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
723where
724    Data: GraphDataType<VectorIdType = u32>,
725    VP: VertexProvider<Data>,
726{
727    type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
728    fn delegate_neighbor(&'a mut self) -> Self::Delegate {
729        AsNeighborAccessor(self)
730    }
731}
732
733/// A light-weight wrapper around `&mut DiskAccessor` used to tailor the semantics of
734/// [`NeighborAccessor`].
735///
736/// This implementation ensures that the vector data for adjacency lists is also retrieved
737/// and cached to enhance reranking.
738pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
739where
740    Data: GraphDataType<VectorIdType = u32>,
741    VP: VertexProvider<Data>;
742
743impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
744where
745    Data: GraphDataType<VectorIdType = u32>,
746    VP: VertexProvider<Data>,
747{
748    type Id = u32;
749}
750
751impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
752where
753    Data: GraphDataType<VectorIdType = u32>,
754    VP: VertexProvider<Data>,
755{
756    fn get_neighbors(
757        self,
758        id: Self::Id,
759        neighbors: &mut AdjacencyList<Self::Id>,
760    ) -> impl Future<Output = ANNResult<Self>> + Send {
761        if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
762            return future::ok(self); // Returning empty results in `neighbors` out param if IO limit is reached.
763        }
764
765        if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
766            return future::err(e);
767        }
768        let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
769            Ok(list) => list,
770            Err(e) => return future::err(e),
771        };
772        neighbors.overwrite_trusted(list);
773        future::ok(self)
774    }
775}
776
777/// [`DiskIndexSearcher`] is a helper class to make it easy to construct index
778/// and do repeated search operations. It is a wrapper around the index.
779/// This is useful for drivers such as search_disk_index.exe in tools.
780pub struct DiskIndexSearcher<
781    Data,
782    ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
783> where
784    Data: GraphDataType<VectorIdType = u32>,
785    ProviderFactory: VertexProviderFactory<Data>,
786{
787    index: DiskANNIndex<DiskProvider<Data>>,
788    runtime: Runtime,
789
790    /// The vertex provider factory is used to create the vertex provider for each search instance.
791    vertex_provider_factory: ProviderFactory,
792
793    /// Scratch pool for disk search operations that need allocations.
794    scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
795}
796
797#[derive(Debug)]
798pub struct SearchResultStats {
799    pub cmps: u32,
800    pub result_count: u32,
801    pub query_statistics: QueryStatistics,
802}
803
804/// `SearchResult` is a struct representing the result of a search operation.
805///
806/// It contains a list of vector results and a statistics object
807///
808pub struct SearchResult<AssociatedData> {
809    /// A list of nearest neighbors resulting from the search.
810    pub results: Vec<SearchResultItem<AssociatedData>>,
811    pub stats: SearchResultStats,
812}
813
814/// `VectorResult` is a struct representing a nearest neighbor resulting from a search.
815///
816/// It contains the vertex id, associated data, and the distance to the query vector.
817///
818pub struct SearchResultItem<AssociatedData> {
819    /// The vertex id of the nearest neighbor.
820    pub vertex_id: u32,
821    /// The associated data of the nearest neighbor as a fixed size byte array.
822    /// The length is determined when the index is created.
823    pub data: AssociatedData,
824    /// The distance between the nearest neighbor and the query vector.
825    pub distance: f32,
826}
827
828impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
829where
830    Data: GraphDataType<VectorIdType = u32>,
831    ProviderFactory: VertexProviderFactory<Data>,
832{
833    /// Create a new asynchronous disk searcher instance.
834    ///
835    /// # Arguments
836    /// * `num_threads` - The maximum number of threads to use.
837    /// * `search_io_limit` - I/O operation limit.
838    /// * `disk_index_reader` - The disk index reader.
839    /// * `vertex_provider_factory` - The vertex provider factory.
840    /// * `metric` - Distance metric used for vector similarity calculations.
841    /// * `runtime` - Tokio runtime handle for executing async operations.
842    pub fn new(
843        num_threads: usize,
844        search_io_limit: usize,
845        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
846        vertex_provider_factory: ProviderFactory,
847        metric: Metric,
848        runtime: Option<Runtime>,
849    ) -> ANNResult<Self> {
850        let runtime = match runtime {
851            Some(rt) => rt,
852            None => tokio::runtime::Builder::new_current_thread().build()?,
853        };
854
855        let graph_header = vertex_provider_factory.get_header()?;
856        let metadata = graph_header.metadata();
857        let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
858
859        let config = graph::config::Builder::new(
860            max_degree.into_usize(),
861            graph::config::MaxDegree::default_slack(),
862            1, // build-search-list-size
863            metric.into(),
864        )
865        .build()?;
866
867        debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
868
869        let graph_header = vertex_provider_factory.get_header()?;
870        let pq_data = disk_index_reader.get_pq_data();
871        let scratch_pool_args = DiskSearchScratchArgs {
872            graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
873            dim: graph_header.metadata().dims,
874            num_pq_chunks: pq_data.get_num_chunks(),
875            num_pq_centers: pq_data.get_num_centers(),
876            vertex_factory: &vertex_provider_factory,
877            graph_header: &graph_header,
878        };
879        let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
880
881        let disk_provider = DiskProvider::new(
882            disk_index_reader,
883            graph_header,
884            metric,
885            metadata.num_pts.into_usize(),
886            search_io_limit,
887        )?;
888
889        let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
890        Ok(Self {
891            index,
892            runtime,
893            vertex_provider_factory,
894            scratch_pool,
895        })
896    }
897
898    /// Helper method to create a DiskSearchStrategy with common parameters
899    fn search_strategy<'a>(
900        &'a self,
901        query: &'a [Data::VectorDataType],
902        vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
903    ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
904        DiskSearchStrategy {
905            io_tracker: IOTracker::default(),
906            vector_filter,
907            query,
908            vertex_provider_factory: &self.vertex_provider_factory,
909            scratch_pool: &self.scratch_pool,
910        }
911    }
912
913    /// Perform a search on the disk index.
914    /// return the list of nearest neighbors and associated data.
915    pub fn search(
916        &self,
917        query: &[Data::VectorDataType],
918        return_list_size: u32,
919        search_list_size: u32,
920        beam_width: Option<usize>,
921        vector_filter: Option<VectorFilter<Data>>,
922        is_flat_search: bool,
923    ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
924        let mut query_stats = QueryStatistics::default();
925        let mut indices = vec![0u32; return_list_size as usize];
926        let mut distances = vec![0f32; return_list_size as usize];
927        let mut associated_data =
928            vec![Data::AssociatedDataType::default(); return_list_size as usize];
929
930        let stats = self.search_internal(
931            query,
932            return_list_size as usize,
933            search_list_size,
934            beam_width,
935            &mut query_stats,
936            &mut indices,
937            &mut distances,
938            &mut associated_data,
939            &vector_filter.unwrap_or(default_vector_filter::<Data>()),
940            is_flat_search,
941        )?;
942
943        let mut search_result = SearchResult {
944            results: Vec::with_capacity(return_list_size as usize),
945            stats,
946        };
947
948        for ((vertex_id, distance), associated_data) in indices
949            .into_iter()
950            .zip(distances.into_iter())
951            .zip(associated_data.into_iter())
952        {
953            search_result.results.push(SearchResultItem {
954                vertex_id,
955                distance,
956                data: associated_data,
957            });
958        }
959
960        Ok(search_result)
961    }
962
963    /// Perform a raw search on the disk index.
964    /// This is a lower-level API that allows more control over the search parameters and output buffers.
965    #[allow(clippy::too_many_arguments)]
966    pub(crate) fn search_internal(
967        &self,
968        query: &[Data::VectorDataType],
969        k_value: usize,
970        search_list_size: u32,
971        beam_width: Option<usize>,
972        query_stats: &mut QueryStatistics,
973        indices: &mut [u32],
974        distances: &mut [f32],
975        associated_data: &mut [Data::AssociatedDataType],
976        vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
977        is_flat_search: bool,
978    ) -> ANNResult<SearchResultStats> {
979        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
980            &mut indices[..k_value],
981            &mut distances[..k_value],
982            &mut associated_data[..k_value],
983        );
984
985        let strategy = self.search_strategy(query, vector_filter);
986        let timer = Instant::now();
987        let k = k_value;
988        let l = search_list_size as usize;
989        let stats = if is_flat_search {
990            self.runtime.block_on(self.index.flat_search(
991                &strategy,
992                &DefaultContext,
993                strategy.query,
994                vector_filter,
995                &Knn::new(k, l, beam_width)?,
996                &mut result_output_buffer,
997            ))?
998        } else {
999            let knn_search = Knn::new(k, l, beam_width)?;
1000            self.runtime.block_on(self.index.search(
1001                knn_search,
1002                &strategy,
1003                &DefaultContext,
1004                strategy.query,
1005                &mut result_output_buffer,
1006            ))?
1007        };
1008        query_stats.total_comparisons = stats.cmps;
1009        query_stats.search_hops = stats.hops;
1010
1011        query_stats.total_execution_time_us = timer.elapsed().as_micros();
1012        query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1013        query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1014        query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1015        query_stats.query_pq_preprocess_time_us =
1016            IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1017        query_stats.cpu_time_us = query_stats.total_execution_time_us
1018            - query_stats.io_time_us
1019            - query_stats.query_pq_preprocess_time_us;
1020        Ok(SearchResultStats {
1021            cmps: query_stats.total_comparisons,
1022            result_count: stats.result_count,
1023            query_statistics: query_stats.clone(),
1024        })
1025    }
1026}
1027
1028/// Helper function to ensure vertices are loaded and processed.
1029///
1030/// This is a convenience function that combines `load_vertices` and `process_loaded_node`
1031/// for each vertex ID. It first loads all the vertices in batch, then processes each
1032/// loaded node.
1033fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1034    vertex_provider: &mut V,
1035    ids: &[Data::VectorIdType],
1036) -> ANNResult<()> {
1037    vertex_provider.load_vertices(ids)?;
1038    for (idx, id) in ids.iter().enumerate() {
1039        vertex_provider.process_loaded_node(id, idx)?;
1040    }
1041    Ok(())
1042}
1043
1044#[cfg(test)]
1045mod disk_provider_tests {
1046    use diskann::{
1047        graph::{
1048            search::{record::VisitedSearchRecord, Knn},
1049            KnnSearchError,
1050        },
1051        utils::IntoUsize,
1052        ANNErrorKind,
1053    };
1054    use diskann_providers::storage::{
1055        DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1056    };
1057    use diskann_providers::{
1058        common::AlignedBoxWithSlice,
1059        test_utils::graph_data_type_utils::{
1060            GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1061        },
1062        utils::{create_thread_pool, load_aligned_bin, PQPathNames, ParallelIteratorInPool},
1063    };
1064    use diskann_utils::{io::read_bin, test_data_root};
1065    use diskann_vector::distance::Metric;
1066    use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1067    use rstest::rstest;
1068    use vfs::OverlayFS;
1069
1070    use super::*;
1071    use crate::{
1072        build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1073        utils::{QueryStatistics, VirtualAlignedReaderFactory},
1074    };
1075
1076    const TEST_INDEX_PREFIX_128DIM: &str =
1077        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1078    const TEST_INDEX_128DIM: &str =
1079        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1080    const TEST_PQ_PIVOT_128DIM: &str =
1081        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1082    const TEST_PQ_COMPRESSED_128DIM: &str =
1083        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1084    const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1085        "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1086    const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1087
1088    const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1089    const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1090    const TEST_PQ_PIVOT_100DIM: &str =
1091        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1092    const TEST_PQ_COMPRESSED_100DIM: &str =
1093        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1094    const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1095        "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1096    const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1097    const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1098    const TEST_INDEX: &str =
1099        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1100    const TEST_INDEX_PREFIX: &str =
1101        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1102    const TEST_PQ_PIVOT: &str =
1103        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1104    const TEST_PQ_COMPRESSED: &str =
1105        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1106
1107    #[test]
1108    fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1109        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1110
1111        let search_engine = create_disk_index_searcher(
1112            CreateDiskIndexSearcherParams {
1113                max_thread_num: 5,
1114                pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1115                pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1116                index_path: TEST_INDEX_100DIM,
1117                index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1118                ..Default::default()
1119            },
1120            &storage_provider,
1121        );
1122        // Test single thread.
1123        test_disk_search(TestDiskSearchParams {
1124            storage_provider: storage_provider.as_ref(),
1125            index_search_engine: &search_engine,
1126            thread_num: 1,
1127            query_file_path: TEST_QUERY_10PTS_100DIM,
1128            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1129            k: 10,
1130            l: 20,
1131            dim: 104,
1132        });
1133        // Test multi thread.
1134        test_disk_search(TestDiskSearchParams {
1135            storage_provider: storage_provider.as_ref(),
1136            index_search_engine: &search_engine,
1137            thread_num: 5,
1138            query_file_path: TEST_QUERY_10PTS_100DIM,
1139            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1140            k: 10,
1141            l: 20,
1142            dim: 104,
1143        });
1144    }
1145
1146    #[test]
1147    fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1148        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1149
1150        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1151            CreateDiskIndexSearcherParams {
1152                max_thread_num: 5,
1153                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1154                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1155                index_path: TEST_INDEX_128DIM,
1156                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1157                ..Default::default()
1158            },
1159            &storage_provider,
1160        );
1161        // Test single thread.
1162        test_disk_search(TestDiskSearchParams {
1163            storage_provider: storage_provider.as_ref(),
1164            index_search_engine: &search_engine,
1165            thread_num: 1,
1166            query_file_path: TEST_QUERY_10PTS_128DIM,
1167            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1168            k: 10,
1169            l: 20,
1170            dim: 128,
1171        });
1172        // Test multi thread.
1173        test_disk_search(TestDiskSearchParams {
1174            storage_provider: storage_provider.as_ref(),
1175            index_search_engine: &search_engine,
1176            thread_num: 5,
1177            query_file_path: TEST_QUERY_10PTS_128DIM,
1178            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1179            k: 10,
1180            l: 20,
1181            dim: 128,
1182        });
1183    }
1184
1185    fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1186        storage_provider: &StorageReader,
1187    ) -> Vec<u32> {
1188        const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1189
1190        let data =
1191            read_bin::<u32>(&mut storage_provider.open_reader(ASSOCIATED_DATA_FILE).unwrap())
1192                .unwrap();
1193        data.into_inner().into_vec()
1194    }
1195
1196    #[test]
1197    fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1198        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1199        let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1200        let params = TestParams {
1201            data_path: TEST_DATA_FILE.to_string(),
1202            index_path_prefix: index_path_prefix.to_string(),
1203            associated_data_path: Some(
1204                "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1205            ),
1206            ..TestParams::default()
1207        };
1208        let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1209        // Build the index with the associated data
1210        fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1211        {
1212            let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1213                CreateDiskIndexSearcherParams {
1214                    max_thread_num: 5,
1215                    pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1216                    pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1217                        .as_str(),
1218                    index_path: format!("{}_disk.index", index_path_prefix).as_str(), //TEST_INDEX_128DIM,
1219                    index_path_prefix,
1220                    ..Default::default()
1221                },
1222                &fixture.storage_provider,
1223            );
1224
1225            // Test single thread.
1226            test_disk_search_with_associated(
1227                TestDiskSearchAssociateParams {
1228                    storage_provider: fixture.storage_provider.as_ref(),
1229                    index_search_engine: &search_engine,
1230                    thread_num: 1,
1231                    query_file_path: TEST_QUERY_10PTS_128DIM,
1232                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1233                    k: 10,
1234                    l: 20,
1235                    dim: 128,
1236                },
1237                None,
1238            );
1239
1240            // Test multi thread.
1241            test_disk_search_with_associated(
1242                TestDiskSearchAssociateParams {
1243                    storage_provider: fixture.storage_provider.as_ref(),
1244                    index_search_engine: &search_engine,
1245                    thread_num: 5,
1246                    query_file_path: TEST_QUERY_10PTS_128DIM,
1247                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1248                    k: 10,
1249                    l: 20,
1250                    dim: 128,
1251                },
1252                None,
1253            );
1254        }
1255
1256        fixture
1257            .storage_provider
1258            .delete(&format!("{}_disk.index", index_path_prefix))
1259            .expect("Failed to delete file");
1260        fixture
1261            .storage_provider
1262            .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1263            .expect("Failed to delete file");
1264        fixture
1265            .storage_provider
1266            .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1267            .expect("Failed to delete file");
1268    }
1269
1270    struct CreateDiskIndexSearcherParams<'a> {
1271        max_thread_num: usize,
1272        pq_pivot_file_path: &'a str,
1273        pq_compressed_file_path: &'a str,
1274        index_path: &'a str,
1275        index_path_prefix: &'a str,
1276        io_limit: usize,
1277    }
1278
1279    impl Default for CreateDiskIndexSearcherParams<'_> {
1280        fn default() -> Self {
1281            Self {
1282                max_thread_num: 1,
1283                pq_pivot_file_path: "",
1284                pq_compressed_file_path: "",
1285                index_path: "",
1286                index_path_prefix: "",
1287                io_limit: usize::MAX,
1288            }
1289        }
1290    }
1291
1292    fn create_disk_index_searcher<Data>(
1293        params: CreateDiskIndexSearcherParams,
1294        storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1295    ) -> DiskIndexSearcher<
1296        Data,
1297        DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1298    >
1299    where
1300        Data: GraphDataType<VectorIdType = u32>,
1301    {
1302        assert!(params.io_limit > 0);
1303
1304        let runtime = tokio::runtime::Builder::new_multi_thread()
1305            .worker_threads(params.max_thread_num)
1306            .build()
1307            .unwrap();
1308
1309        let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1310            params.pq_pivot_file_path.to_string(),
1311            params.pq_compressed_file_path.to_string(),
1312            storage_provider.as_ref(),
1313        )
1314        .unwrap();
1315
1316        let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1317            get_disk_index_file(params.index_path_prefix),
1318            Arc::clone(storage_provider),
1319        );
1320        let caching_strategy = CachingStrategy::None;
1321        let vertex_provider_factory =
1322            DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1323                .unwrap();
1324
1325        DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1326            params.max_thread_num,
1327            params.io_limit,
1328            &disk_index_reader,
1329            vertex_provider_factory,
1330            Metric::L2,
1331            Some(runtime),
1332        )
1333        .unwrap()
1334    }
1335
1336    fn load_query_result<StorageReader: StorageReadProvider>(
1337        storage_provider: &StorageReader,
1338        query_result_path: &str,
1339    ) -> Vec<u32> {
1340        let result =
1341            read_bin::<u32>(&mut storage_provider.open_reader(query_result_path).unwrap()).unwrap();
1342        result.into_inner().into_vec()
1343    }
1344
1345    struct TestDiskSearchParams<'a, StorageType> {
1346        storage_provider: &'a StorageType,
1347        index_search_engine: &'a DiskIndexSearcher<
1348            GraphDataF32VectorUnitData,
1349            DiskVertexProviderFactory<
1350                GraphDataF32VectorUnitData,
1351                VirtualAlignedReaderFactory<OverlayFS>,
1352            >,
1353        >,
1354        thread_num: u64,
1355        query_file_path: &'a str,
1356        truth_result_file_path: &'a str,
1357        k: usize,
1358        l: usize,
1359        dim: usize,
1360    }
1361
1362    struct TestDiskSearchAssociateParams<'a, StorageType> {
1363        storage_provider: &'a StorageType,
1364        index_search_engine: &'a DiskIndexSearcher<
1365            GraphDataF32VectorU32Data,
1366            DiskVertexProviderFactory<
1367                GraphDataF32VectorU32Data,
1368                VirtualAlignedReaderFactory<OverlayFS>,
1369            >,
1370        >,
1371        thread_num: u64,
1372        query_file_path: &'a str,
1373        truth_result_file_path: &'a str,
1374        k: usize,
1375        l: usize,
1376        dim: usize,
1377    }
1378
1379    fn test_disk_search<StorageType: StorageReadProvider>(
1380        params: TestDiskSearchParams<StorageType>,
1381    ) {
1382        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1383            .unwrap()
1384            .0;
1385        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1386        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1387
1388        let queries = aligned_query
1389            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1390            .unwrap();
1391
1392        let truth_result =
1393            load_query_result(params.storage_provider, params.truth_result_file_path);
1394
1395        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1396        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1397        queries
1398            .par_iter()
1399            .enumerate()
1400            .for_each_in_pool(&pool, |(i, query)| {
1401                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1402                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1403                let mut temp = Vec::with_capacity(query.len() + 1);
1404                temp.push(0.0);
1405                temp.extend_from_slice(query);
1406                aligned_box.memcpy(temp.as_slice()).unwrap();
1407                let query = &aligned_box.as_slice()[1..];
1408
1409                let mut query_stats = QueryStatistics::default();
1410                let mut indices = vec![0u32; 10];
1411                let mut distances = vec![0f32; 10];
1412                let mut associated_data = vec![(); 10];
1413
1414                let result = params
1415                    .index_search_engine
1416                    //.search_with_associated_data(query, params.k as u32, params.l as u32)
1417                    .search_internal(
1418                        query,
1419                        params.k,
1420                        params.l as u32,
1421                        None, // beam_width
1422                        &mut query_stats,
1423                        &mut indices,
1424                        &mut distances,
1425                        &mut associated_data,
1426                        &(|_| true),
1427                        false,
1428                    );
1429
1430                // Calculate the range of the truth_result for this query
1431                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1432
1433                assert!(result.is_ok(), "Expected search to succeed");
1434
1435                let result_unwrapped = result.unwrap();
1436                assert!(
1437                    result_unwrapped.query_statistics.total_io_operations > 0,
1438                    "Expected IO operations to be greater than 0"
1439                );
1440                assert!(
1441                    result_unwrapped.query_statistics.total_vertices_loaded > 0,
1442                    "Expected vertices loaded to be greater than 0"
1443                );
1444
1445                // Compare res with truth_slice using assert_eq!
1446                assert_eq!(
1447                    indices, truth_slice,
1448                    "Results DO NOT match with the truth result for query {}",
1449                    i
1450                );
1451            });
1452    }
1453
1454    fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1455        params: TestDiskSearchAssociateParams<StorageType>,
1456        beam_width: Option<usize>,
1457    ) {
1458        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1459            .unwrap()
1460            .0;
1461        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1462        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1463        let queries = aligned_query
1464            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1465            .unwrap();
1466        let truth_result =
1467            load_query_result(params.storage_provider, params.truth_result_file_path);
1468        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1469        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1470        queries
1471            .par_iter()
1472            .enumerate()
1473            .for_each_in_pool(&pool, |(i, query)| {
1474                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1475                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1476                let mut temp = Vec::with_capacity(query.len() + 1);
1477                temp.push(0.0);
1478                temp.extend_from_slice(query);
1479                aligned_box.memcpy(temp.as_slice()).unwrap();
1480                let query = &aligned_box.as_slice()[1..];
1481                let result = params
1482                    .index_search_engine
1483                    .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1484                    .unwrap();
1485                let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1486                let associated_data: Vec<u32> =
1487                    result.results.iter().map(|item| item.data).collect();
1488                let truth_data = get_truth_associated_data(params.storage_provider);
1489                let associated_data_truth: Vec<u32> = indices
1490                    .iter()
1491                    .map(|&vid| truth_data[vid as usize])
1492                    .collect();
1493                assert_eq!(
1494                    associated_data, associated_data_truth,
1495                    "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1496                    i,associated_data, associated_data_truth
1497                );
1498                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1499                assert_eq!(
1500                    indices, truth_slice,
1501                    "Results DO NOT match with the truth result for query {}",
1502                    i
1503                );
1504            });
1505    }
1506
1507    #[test]
1508    fn test_disk_search_invalid_input() {
1509        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1510        let ctx = &DefaultContext;
1511
1512        let params = CreateDiskIndexSearcherParams {
1513            max_thread_num: 5,
1514            pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1515            pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1516            index_path: TEST_INDEX_128DIM,
1517            index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1518            ..Default::default()
1519        };
1520
1521        let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1522        assert_eq!(
1523            paths.pivots, params.pq_pivot_file_path,
1524            "pq_pivot_file_path is not correct"
1525        );
1526        assert_eq!(
1527            paths.compressed_data, params.pq_compressed_file_path,
1528            "pq_compressed_file_path is not correct"
1529        );
1530        assert_eq!(
1531            params.index_path,
1532            format!("{}_disk.index", params.index_path_prefix),
1533            "index_path is not correct"
1534        );
1535
1536        // Test error case: l < k
1537        let res = Knn::new_default(20, 10);
1538        assert!(res.is_err());
1539        assert_eq!(
1540            <KnnSearchError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1541            ANNErrorKind::IndexError
1542        );
1543        // Test error case: beam_width = 0
1544        let res = Knn::new(10, 10, Some(0));
1545        assert!(res.is_err());
1546
1547        let search_engine =
1548            create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1549
1550        // minor validation tests to improve code coverage
1551        assert_eq!(
1552            search_engine
1553                .index
1554                .data_provider
1555                .to_external_id(ctx, 0)
1556                .unwrap(),
1557            0
1558        );
1559        assert_eq!(
1560            search_engine
1561                .index
1562                .data_provider
1563                .to_internal_id(ctx, &0)
1564                .unwrap(),
1565            0
1566        );
1567
1568        let provider_max_degree = search_engine
1569            .index
1570            .data_provider
1571            .graph_header
1572            .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1573            .unwrap();
1574        let index_max_degree = search_engine.index.config.pruned_degree().get();
1575        assert_eq!(provider_max_degree, index_max_degree);
1576
1577        let query = vec![0f32; 128];
1578        let mut query_stats = QueryStatistics::default();
1579        let mut indices = vec![0u32; 10];
1580        let mut distances = vec![0f32; 10];
1581        let mut associated_data = vec![0u32; 10];
1582
1583        // Set L: {} to a value of at least K:
1584        let result = search_engine.search_internal(
1585            &query,
1586            10,
1587            10 - 1,
1588            None,
1589            &mut query_stats,
1590            &mut indices,
1591            &mut distances,
1592            &mut associated_data,
1593            &|_| true,
1594            false,
1595        );
1596
1597        assert!(result.is_err());
1598        assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1599    }
1600
1601    #[test]
1602    fn test_disk_search_beam_search() {
1603        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1604
1605        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1606            CreateDiskIndexSearcherParams {
1607                max_thread_num: 1,
1608                pq_pivot_file_path: TEST_PQ_PIVOT,
1609                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1610                index_path: TEST_INDEX,
1611                index_path_prefix: TEST_INDEX_PREFIX,
1612                ..Default::default()
1613            },
1614            &storage_provider,
1615        );
1616
1617        let query_vector: [f32; 128] = [1f32; 128];
1618        let mut indices = vec![0u32; 10];
1619        let mut distances = vec![0f32; 10];
1620        let mut associated_data = vec![(); 10];
1621
1622        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1623            &mut indices,
1624            &mut distances,
1625            &mut associated_data,
1626        );
1627        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1628        let mut search_record = VisitedSearchRecord::new(0);
1629        let search_params = Knn::new(10, 10, Some(4)).unwrap();
1630        let recorded_search =
1631            diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
1632        search_engine
1633            .runtime
1634            .block_on(search_engine.index.search(
1635                recorded_search,
1636                &strategy,
1637                &DefaultContext,
1638                query_vector.as_slice(),
1639                &mut result_output_buffer,
1640            ))
1641            .unwrap();
1642
1643        let ids = search_record
1644            .visited
1645            .iter()
1646            .map(|n| n.id)
1647            .collect::<Vec<_>>();
1648
1649        const EXPECTED_NODES: [u32; 18] = [
1650            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1651        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
1652
1653        assert_eq!(ids, &EXPECTED_NODES);
1654
1655        let return_list_size = 10;
1656        let search_list_size = 10;
1657        let result = search_engine.search(
1658            &query_vector,
1659            return_list_size,
1660            search_list_size,
1661            Some(4),
1662            None,
1663            false,
1664        );
1665        assert!(result.is_ok(), "Expected search to succeed");
1666        let search_result = result.unwrap();
1667        assert_eq!(
1668            search_result.results.len() as u32,
1669            return_list_size,
1670            "Expected result count to match"
1671        );
1672        assert_eq!(
1673            indices,
1674            vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1675            "Expected indices to match"
1676        );
1677    }
1678
1679    #[cfg(feature = "experimental_diversity_search")]
1680    #[test]
1681    fn test_disk_search_diversity_search() {
1682        use diskann::graph::DiverseSearchParams;
1683        use diskann::neighbor::AttributeValueProvider;
1684        use std::collections::HashMap;
1685
1686        // Simple test attribute provider
1687        #[derive(Debug, Clone)]
1688        struct TestAttributeProvider {
1689            attributes: HashMap<u32, u32>,
1690        }
1691        impl TestAttributeProvider {
1692            fn new() -> Self {
1693                Self {
1694                    attributes: HashMap::new(),
1695                }
1696            }
1697            fn insert(&mut self, id: u32, attribute: u32) {
1698                self.attributes.insert(id, attribute);
1699            }
1700        }
1701        impl diskann::provider::HasId for TestAttributeProvider {
1702            type Id = u32;
1703        }
1704
1705        impl AttributeValueProvider for TestAttributeProvider {
1706            type Value = u32;
1707
1708            fn get(&self, id: Self::Id) -> Option<Self::Value> {
1709                self.attributes.get(&id).copied()
1710            }
1711        }
1712
1713        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1714
1715        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1716            CreateDiskIndexSearcherParams {
1717                max_thread_num: 1,
1718                pq_pivot_file_path: TEST_PQ_PIVOT,
1719                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1720                index_path: TEST_INDEX,
1721                index_path_prefix: TEST_INDEX_PREFIX,
1722                ..Default::default()
1723            },
1724            &storage_provider,
1725        );
1726
1727        let query_vector: [f32; 128] = [1f32; 128];
1728
1729        // Create attribute provider with random labels (1 to 3) for all vectors
1730        let mut attribute_provider = TestAttributeProvider::new();
1731        let num_vectors = 256; // Number of vectors in the test dataset
1732        for i in 0..num_vectors {
1733            // Assign labels 1-3 based on modulo to ensure distribution
1734            let label = (i % 15) + 1;
1735            attribute_provider.insert(i, label);
1736        }
1737        // Wrap in Arc once to avoid cloning the HashMap later
1738        let attribute_provider = std::sync::Arc::new(attribute_provider);
1739
1740        let mut indices = vec![0u32; 10];
1741        let mut distances = vec![0f32; 10];
1742        let mut associated_data = vec![(); 10];
1743
1744        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1745            &mut indices,
1746            &mut distances,
1747            &mut associated_data,
1748        );
1749        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1750
1751        // Create diverse search parameters with attribute provider
1752        let diverse_params = DiverseSearchParams::new(
1753            0, // diverse_attribute_id
1754            3, // diverse_results_k
1755            attribute_provider.clone(),
1756        );
1757
1758        let search_params = Knn::new(10, 20, None).unwrap();
1759
1760        let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params);
1761        let stats = search_engine
1762            .runtime
1763            .block_on(search_engine.index.search(
1764                diverse_search,
1765                &strategy,
1766                &DefaultContext,
1767                query_vector.as_slice(),
1768                &mut result_output_buffer,
1769            ))
1770            .unwrap();
1771
1772        // Verify that search was performed and returned some results
1773        assert!(
1774            stats.result_count > 0,
1775            "Expected to get some results during diversity search"
1776        );
1777
1778        let return_list_size = 10;
1779        let search_list_size = 20;
1780        let diverse_results_k = 1;
1781        let diverse_params = DiverseSearchParams::new(
1782            0, // diverse_attribute_id
1783            diverse_results_k,
1784            attribute_provider.clone(),
1785        );
1786
1787        // Test diverse search using the search API
1788        let mut indices2 = vec![0u32; return_list_size as usize];
1789        let mut distances2 = vec![0f32; return_list_size as usize];
1790        let mut associated_data2 = vec![(); return_list_size as usize];
1791        let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1792            &mut indices2,
1793            &mut distances2,
1794            &mut associated_data2,
1795        );
1796        let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1797        let search_params2 =
1798            Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1799
1800        let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params);
1801        let stats = search_engine
1802            .runtime
1803            .block_on(search_engine.index.search(
1804                diverse_search2,
1805                &strategy2,
1806                &DefaultContext,
1807                query_vector.as_slice(),
1808                &mut result_output_buffer2,
1809            ))
1810            .unwrap();
1811
1812        // Verify results
1813        assert!(
1814            stats.result_count > 0,
1815            "Expected diversity search to return results"
1816        );
1817        assert!(
1818            stats.result_count <= return_list_size,
1819            "Expected result count to be <= {}",
1820            return_list_size
1821        );
1822
1823        // Verify that we got some results
1824        assert!(
1825            stats.result_count > 0,
1826            "Expected to get some search results"
1827        );
1828
1829        // Print search results with their attributes
1830        println!("\n=== Diversity Search Results ===");
1831        println!("Query: [1f32; 128]");
1832        println!("diverse_results_k: {}", diverse_results_k);
1833        println!("Total results: {}\n", stats.result_count);
1834        println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1835        println!("{}", "-".repeat(35));
1836        for i in 0..stats.result_count as usize {
1837            let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1838            println!(
1839                "{:<10} {:<15.2} {:<10}",
1840                indices2[i], distances2[i], attribute_value
1841            );
1842        }
1843
1844        // Verify that distances are non-negative and sorted
1845        for i in 0..(stats.result_count as usize).saturating_sub(1) {
1846            assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1847            assert!(
1848                distances2[i] <= distances2[i + 1],
1849                "Expected distances to be sorted in ascending order"
1850            );
1851        }
1852
1853        // Verify diversity: Check that we have diverse attribute values in the results
1854        let mut attribute_counts = HashMap::new();
1855        for item in indices2.iter().take(stats.result_count as usize) {
1856            if let Some(attribute_value) = attribute_provider.get(*item) {
1857                *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1858            }
1859        }
1860
1861        // Print attribute distribution
1862        println!("\n=== Attribute Distribution ===");
1863        let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1864        sorted_attrs.sort_by_key(|(k, _)| *k);
1865        for (attribute_value, count) in &sorted_attrs {
1866            println!(
1867                "Label {}: {} occurrences (max allowed: {})",
1868                attribute_value, count, diverse_results_k
1869            );
1870        }
1871        println!("Total unique labels: {}", attribute_counts.len());
1872        println!("================================\n");
1873
1874        // With diverse_results_k = 5, we expect at most 5 results per attribute value
1875        for (attribute_value, count) in &attribute_counts {
1876            println!(
1877                "Assert: Label {} has {} occurrences (max: {})",
1878                attribute_value, count, diverse_results_k
1879            );
1880            assert!(
1881                *count <= diverse_results_k,
1882                "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1883                attribute_value,
1884                count,
1885                diverse_results_k
1886            );
1887        }
1888
1889        // Verify that we have multiple different attribute values (diversity)
1890        // With 3 possible labels and diverse_results_k=5, we should see at least 2 different labels
1891        println!(
1892            "Assert: Found {} unique labels (expected at least 2)",
1893            attribute_counts.len()
1894        );
1895        assert!(
1896            attribute_counts.len() >= 2,
1897            "Expected at least 2 different attribute values for diversity, got {}",
1898            attribute_counts.len()
1899        );
1900    }
1901
1902    #[rstest]
1903    // This case checks expected behavior of unfiltered search.
1904    #[case(
1905        |_id: &u32| true,
1906        false,
1907        10,
1908        vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1909        vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1910    )]
1911    // This case validates post-filtering using 2 ids which are not present in the unfiltered result set.
1912    // It is expected that the post-filtering will return an empty result
1913    #[case(
1914        |id: &u32| *id == 0 || *id == 1,
1915        false,
1916        0,
1917        vec![0; 10],
1918        vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1919    )]
1920    // This case validates pre-filtering using 2 ids which are not present in the unfiltered result set.
1921    // It is expected that the pre-filtering will do search over matching ids
1922    #[case(
1923        |id: &u32| *id == 0 || *id == 1,
1924        true,
1925        2,
1926        vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1927        vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1928    )]
1929    // This case validates post-filtering using 3 ids from the unfiltered result set.
1930    // It is expected that the post-filtering will filter out non-matching ids
1931    #[case(
1932        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1933        false,
1934        3,
1935        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1936        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1937    )]
1938    // This case validates pre-filtering using 3 ids from the unfiltered result set.
1939    // It is expected that the pre-filtering will do search over matching ids
1940    #[case(
1941        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1942        true,
1943        3,
1944        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1945        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1946    )]
1947    fn test_search_with_vector_filter(
1948        #[case] vector_filter: fn(&u32) -> bool,
1949        #[case] is_flat_search: bool,
1950        #[case] expected_result_count: u32,
1951        #[case] expected_indices: Vec<u32>,
1952        #[case] expected_distances: Vec<f32>,
1953    ) {
1954        // Exact distances can vary slightly depending on the architecture used
1955        // to compute distances due to different unrolling strategies and SIMD widthd.
1956        //
1957        // This parameter allows for a small margin when matching distances.
1958        let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1959            const ABS_TOLERANCE: f32 = 0.02;
1960            assert_eq!(got.len(), expected.len());
1961            for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1962                if (g - e).abs() > ABS_TOLERANCE {
1963                    panic!(
1964                        "distances differ at position {} by more than {}\n\n\
1965                         got: {:?}\nexpected: {:?}",
1966                        i, ABS_TOLERANCE, got, expected,
1967                    );
1968                }
1969            }
1970            true
1971        };
1972
1973        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1974
1975        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1976            CreateDiskIndexSearcherParams {
1977                max_thread_num: 5,
1978                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1979                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1980                index_path: TEST_INDEX_128DIM,
1981                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1982                ..Default::default()
1983            },
1984            &storage_provider,
1985        );
1986        let query = vec![0.1f32; 128];
1987        let mut query_stats = QueryStatistics::default();
1988        let mut indices = vec![0u32; 10];
1989        let mut distances = vec![0f32; 10];
1990        let mut associated_data = vec![(); 10];
1991
1992        let result = search_engine.search_internal(
1993            &query,
1994            10,
1995            10,
1996            None, // beam_width
1997            &mut query_stats,
1998            &mut indices,
1999            &mut distances,
2000            &mut associated_data,
2001            &vector_filter,
2002            is_flat_search,
2003        );
2004
2005        assert!(result.is_ok(), "Expected search to succeed");
2006        assert_eq!(
2007            result.unwrap().result_count,
2008            expected_result_count,
2009            "Expected result count to match"
2010        );
2011        assert_eq!(indices, expected_indices, "Expected indices to match");
2012        assert!(
2013            check_distances(&distances, &expected_distances),
2014            "Expected distances to match"
2015        );
2016
2017        let result_with_filter = search_engine.search(
2018            &query,
2019            10,
2020            10,
2021            None, // beam_width
2022            Some(Box::new(vector_filter)),
2023            is_flat_search,
2024        );
2025
2026        assert!(result_with_filter.is_ok(), "Expected search to succeed");
2027        let result_with_filter_unwrapped = result_with_filter.unwrap();
2028        assert_eq!(
2029            result_with_filter_unwrapped.stats.result_count, expected_result_count,
2030            "Expected result count to match"
2031        );
2032        let actual_indices = result_with_filter_unwrapped
2033            .results
2034            .iter()
2035            .map(|x| x.vertex_id)
2036            .collect::<Vec<_>>();
2037        assert_eq!(
2038            actual_indices, expected_indices,
2039            "Expected indices to match"
2040        );
2041        let actual_distances = result_with_filter_unwrapped
2042            .results
2043            .iter()
2044            .map(|x| x.distance)
2045            .collect::<Vec<_>>();
2046        assert!(
2047            check_distances(&actual_distances, &expected_distances),
2048            "Expected distances to match"
2049        );
2050    }
2051
2052    #[test]
2053    fn test_beam_search_respects_io_limit() {
2054        let io_limit = 11; // Set a small IO limit for testing
2055        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2056
2057        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2058            CreateDiskIndexSearcherParams {
2059                max_thread_num: 1,
2060                pq_pivot_file_path: TEST_PQ_PIVOT,
2061                pq_compressed_file_path: TEST_PQ_COMPRESSED,
2062                index_path: TEST_INDEX,
2063                index_path_prefix: TEST_INDEX_PREFIX,
2064                io_limit,
2065            },
2066            &storage_provider,
2067        );
2068        let query_vector: [f32; 128] = [1f32; 128];
2069
2070        let mut indices = vec![0u32; 10];
2071        let mut distances = vec![0f32; 10];
2072        let mut associated_data = vec![(); 10];
2073
2074        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2075            &mut indices,
2076            &mut distances,
2077            &mut associated_data,
2078        );
2079
2080        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2081
2082        let mut search_record = VisitedSearchRecord::new(0);
2083        let search_params = Knn::new(10, 10, Some(4)).unwrap();
2084        let recorded_search =
2085            diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
2086        search_engine
2087            .runtime
2088            .block_on(search_engine.index.search(
2089                recorded_search,
2090                &strategy,
2091                &DefaultContext,
2092                query_vector.as_slice(),
2093                &mut result_output_buffer,
2094            ))
2095            .unwrap();
2096        let visited_ids = search_record
2097            .visited
2098            .iter()
2099            .map(|n| n.id)
2100            .collect::<Vec<_>>();
2101
2102        let query_stats = strategy.io_tracker;
2103        //Verify the IO limit was respected
2104        assert!(
2105            query_stats.io_count() <= io_limit,
2106            "Expected IO operations to be <= {}, but got {}",
2107            io_limit,
2108            query_stats.io_count()
2109        );
2110
2111        const EXPECTED_NODES: [u32; 17] = [
2112            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2113        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
2114
2115        // Count matching results
2116        let mut matching_count = 0;
2117        for expected_node in EXPECTED_NODES.iter() {
2118            if visited_ids.contains(expected_node) {
2119                matching_count += 1;
2120            }
2121        }
2122
2123        // Calculate recall
2124        let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2125
2126        //Verify the recall is above 60%. The threshold her eis arbitrary, just to make sure when
2127        // search hits io_limit that it doesn't break and the recall degrades gracefully
2128        assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2129    }
2130}