Skip to main content

diskann_disk/search/provider/
disk_provider.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    collections::HashMap,
8    future::Future,
9    num::NonZeroUsize,
10    ops::Range,
11    sync::{
12        atomic::{AtomicU64, AtomicUsize},
13        Arc,
14    },
15    time::Instant,
16};
17
18use diskann::{
19    graph::{
20        self,
21        glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy},
22        search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer, SearchParams,
23    },
24    neighbor::Neighbor,
25    provider::{
26        Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
27        NeighborAccessor,
28    },
29    utils::{
30        object_pool::{ObjectPool, PoolOption, TryAsPooled},
31        IntoUsize, VectorRepr,
32    },
33    ANNError, ANNResult,
34};
35use diskann_providers::storage::StorageReadProvider;
36use diskann_providers::{
37    model::{
38        compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
39        pq::quantizer_preprocess, PQData, PQScratch,
40    },
41    storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
42};
43use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
44use futures_util::future;
45use tokio::runtime::Runtime;
46use tracing::debug;
47
48use crate::{
49    data_model::{CachingStrategy, GraphHeader},
50    filter_parameter::{default_vector_filter, VectorFilter},
51    search::{
52        provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
53        traits::{VertexProvider, VertexProviderFactory},
54    },
55    storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
56    utils::AlignedFileReaderFactory,
57    utils::QueryStatistics,
58};
59
60///////////////////
61// Disk Provider //
62///////////////////
63
64/// The DiskProvider is a data provider that loads data from disk using the disk readers
65/// The data format for disk is different from that of the in-memory providers.
66/// The disk format stores both the vectors and the adjacency list next to each other for
67/// better locality for quicker access.
68/// Please refer to the RFC documentation at [`docs\rfcs\cy2025\disk_provider_for_async_index.md`] for design details.
69pub struct DiskProvider<Data>
70where
71    Data: GraphDataType<VectorIdType = u32>,
72{
73    /// Holds the graph header information that contains metadata about disk-index file.
74    graph_header: GraphHeader,
75
76    // Full precision distance comparer used in post_process to reorder results.
77    distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
78
79    /// The PQ data used for quantization.
80    pq_data: Arc<PQData>,
81
82    /// The number of points in the graph.
83    num_points: usize,
84
85    /// Metric used for distance computation.
86    metric: Metric,
87
88    /// The number of IO operations that can be done in parallel.
89    search_io_limit: usize,
90}
91
92impl<Data> DataProvider for DiskProvider<Data>
93where
94    Data: GraphDataType<VectorIdType = u32>,
95{
96    type Context = DefaultContext;
97
98    type InternalId = u32;
99
100    type ExternalId = u32;
101
102    type Error = ANNError;
103
104    /// Translate an external id to its corresponding internal id.
105    fn to_internal_id(
106        &self,
107        _context: &DefaultContext,
108        gid: &Self::ExternalId,
109    ) -> Result<Self::InternalId, Self::Error> {
110        Ok(*gid)
111    }
112
113    /// Translate an internal id its corresponding external id.
114    fn to_external_id(
115        &self,
116        _context: &DefaultContext,
117        id: Self::InternalId,
118    ) -> Result<Self::ExternalId, Self::Error> {
119        Ok(id)
120    }
121}
122
123impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
124where
125    Data: GraphDataType<VectorIdType = u32>,
126{
127    type Error = ANNError;
128
129    async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
130    where
131        P: StorageReadProvider,
132    {
133        debug!(
134            "DiskProvider::load_with() called with file: {:?}",
135            get_disk_index_file(ctx.quant_load_context.metadata.prefix())
136        );
137
138        let graph_header = {
139            let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
140                ctx.quant_load_context.metadata.prefix(),
141            ));
142
143            let caching_strategy = if ctx.num_nodes_to_cache > 0 {
144                CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
145            } else {
146                CachingStrategy::None
147            };
148
149            let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
150                aligned_reader_factory,
151                caching_strategy,
152            )?;
153            VertexProviderFactory::get_header(&vertex_provider_factory)?
154        };
155
156        let metric = ctx.quant_load_context.metric;
157        let num_points = ctx.num_points;
158
159        let index_path_prefix = ctx.quant_load_context.metadata.prefix();
160        let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
161            get_pq_pivot_file(index_path_prefix),
162            get_compressed_pq_file(index_path_prefix),
163            provider,
164        )?;
165
166        Self::new(
167            &index_reader,
168            graph_header,
169            metric,
170            num_points,
171            ctx.search_io_limit,
172        )
173    }
174}
175
176impl<Data> DiskProvider<Data>
177where
178    Data: GraphDataType<VectorIdType = u32>,
179{
180    fn new(
181        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
182        graph_header: GraphHeader,
183        metric: Metric,
184        num_points: usize,
185        search_io_limit: usize,
186    ) -> ANNResult<Self> {
187        let distance_comparer =
188            Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
189
190        let pq_data = disk_index_reader.get_pq_data();
191
192        Ok(Self {
193            graph_header,
194            distance_comparer,
195            pq_data,
196            num_points,
197            metric,
198            search_io_limit,
199        })
200    }
201}
202
203/// The search strategy for the disk provider. This is used to create the search accessor
204/// for use in search in quant space and post_process function to reorder with full precision vectors.
205///
206/// # Why vertex_provider_factory and scratch_pool are here instead of DiskProvider
207///
208/// The DataProvider trait requires 'static bounds for multi-threaded async contexts,
209/// but vertex_provider_factory may have non-'static lifetime bounds (e.g., borrowing
210/// from local data structures). Moving these components to the search strategy allows
211/// DiskProvider to satisfy 'static constraints while enabling flexible per-search
212/// resource management.
213pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
214where
215    Data: GraphDataType<VectorIdType = u32>,
216    ProviderFactory: VertexProviderFactory<Data>,
217{
218    // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads.
219    io_tracker: IOTracker,
220    vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), // Fn param is u32 as we validate "VectorIdType = u32" everywhere in this provider in trait bounds.
221    query: &'a [Data::VectorDataType],
222
223    /// The vertex provider factory is used to create the vertex provider for each search instance.
224    vertex_provider_factory: &'a ProviderFactory,
225
226    /// Scratch pool for disk search operations that need allocations.
227    scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
228}
229
230// Struct to track IO. This is used by single thread, but needs to be Atomic as the Strategy has "Send" trait bound.
231// There should be minimal to no overhead compared to using a raw reference.
232struct IOTracker {
233    io_time_us: AtomicU64,
234    preprocess_time_us: AtomicU64,
235    io_count: AtomicUsize,
236}
237
238impl Default for IOTracker {
239    fn default() -> Self {
240        Self {
241            io_time_us: AtomicU64::new(0),
242            preprocess_time_us: AtomicU64::new(0),
243            io_count: AtomicUsize::new(0),
244        }
245    }
246}
247
248impl IOTracker {
249    fn add_time(category: &AtomicU64, time: u64) {
250        category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
251    }
252
253    fn time(category: &AtomicU64) -> u64 {
254        category.load(std::sync::atomic::Ordering::Relaxed)
255    }
256
257    fn add_io_count(&self, count: usize) {
258        self.io_count
259            .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
260    }
261
262    fn io_count(&self) -> usize {
263        self.io_count.load(std::sync::atomic::Ordering::Relaxed)
264    }
265}
266
267#[derive(Clone, Copy)]
268pub struct RerankAndFilter<'a> {
269    filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
270}
271
272impl<'a> RerankAndFilter<'a> {
273    fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
274        Self { filter }
275    }
276}
277
278impl<Data, VP>
279    SearchPostProcess<
280        DiskAccessor<'_, Data, VP>,
281        [Data::VectorDataType],
282        (
283            <DiskProvider<Data> as DataProvider>::InternalId,
284            Data::AssociatedDataType,
285        ),
286    > for RerankAndFilter<'_>
287where
288    Data: GraphDataType<VectorIdType = u32>,
289    VP: VertexProvider<Data>,
290{
291    type Error = ANNError;
292    async fn post_process<I, B>(
293        &self,
294        accessor: &mut DiskAccessor<'_, Data, VP>,
295        query: &[Data::VectorDataType],
296        _computer: &DiskQueryComputer,
297        candidates: I,
298        output: &mut B,
299    ) -> Result<usize, Self::Error>
300    where
301        I: Iterator<Item = Neighbor<u32>> + Send,
302        B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized,
303    {
304        let provider = accessor.provider;
305
306        let mut uncached_ids = Vec::new();
307        let mut reranked = candidates
308            .map(|n| n.id)
309            .filter(|id| (self.filter)(id))
310            .filter_map(|n| {
311                if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
312                    Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
313                } else {
314                    uncached_ids.push(n);
315                    None
316                }
317            })
318            .collect::<Result<Vec<_>, _>>()?;
319        if !uncached_ids.is_empty() {
320            ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
321            for n in &uncached_ids {
322                let v = accessor.scratch.vertex_provider.get_vector(n)?;
323                let d = provider.distance_comparer.evaluate_similarity(query, v);
324                let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
325                reranked.push(((*n, *a), d));
326            }
327        }
328
329        // Sort the full precision distances.
330        reranked
331            .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
332        // Store the reranked results.
333        Ok(output.extend(reranked))
334    }
335}
336
337impl<'this, Data, ProviderFactory>
338    SearchStrategy<
339        DiskProvider<Data>,
340        [Data::VectorDataType],
341        (
342            <DiskProvider<Data> as DataProvider>::InternalId,
343            Data::AssociatedDataType,
344        ),
345    > for DiskSearchStrategy<'this, Data, ProviderFactory>
346where
347    Data: GraphDataType<VectorIdType = u32>,
348    ProviderFactory: VertexProviderFactory<Data>,
349{
350    type QueryComputer = DiskQueryComputer;
351    type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
352    type SearchAccessorError = ANNError;
353    type PostProcessor = RerankAndFilter<'this>;
354
355    fn search_accessor<'a>(
356        &'a self,
357        provider: &'a DiskProvider<Data>,
358        _context: &DefaultContext,
359    ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
360        DiskAccessor::new(
361            provider,
362            &self.io_tracker,
363            self.query,
364            self.vertex_provider_factory,
365            self.scratch_pool,
366        )
367    }
368
369    fn post_processor(&self) -> Self::PostProcessor {
370        RerankAndFilter::new(self.vector_filter)
371    }
372}
373
374/// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates.
375pub struct DiskQueryComputer {
376    num_pq_chunks: usize,
377    query_centroid_l2_distance: Vec<f32>,
378}
379
380impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
381    fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
382        let mut dist = 0.0f32;
383        #[allow(clippy::expect_used)]
384        compute_pq_distance_for_pq_coordinates(
385            changing,
386            self.num_pq_chunks,
387            &self.query_centroid_l2_distance,
388            std::slice::from_mut(&mut dist),
389        )
390        .expect("PQ distance compute for PQ coordinates is expected to succeed");
391        dist
392    }
393}
394
395impl<Data, VP> BuildQueryComputer<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
396where
397    Data: GraphDataType<VectorIdType = u32>,
398    VP: VertexProvider<Data>,
399{
400    type QueryComputerError = ANNError;
401    type QueryComputer = DiskQueryComputer;
402
403    fn build_query_computer(
404        &self,
405        _from: &[Data::VectorDataType],
406    ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
407        Ok(DiskQueryComputer {
408            num_pq_chunks: self.provider.pq_data.get_num_chunks(),
409            query_centroid_l2_distance: self
410                .scratch
411                .pq_scratch
412                .aligned_pqtable_dist_scratch
413                .as_slice()
414                .to_vec(),
415        })
416    }
417
418    async fn distances_unordered<Itr, F>(
419        &mut self,
420        vec_id_itr: Itr,
421        _computer: &Self::QueryComputer,
422        f: F,
423    ) -> Result<(), Self::GetError>
424    where
425        F: Send + FnMut(f32, Self::Id),
426        Itr: Iterator<Item = Self::Id>,
427    {
428        self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
429    }
430}
431
432impl<Data, VP> ExpandBeam<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
433where
434    Data: GraphDataType<VectorIdType = u32>,
435    VP: VertexProvider<Data>,
436{
437    fn expand_beam<Itr, P, F>(
438        &mut self,
439        ids: Itr,
440        _computer: &Self::QueryComputer,
441        mut pred: P,
442        mut f: F,
443    ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
444    where
445        Itr: Iterator<Item = Self::Id> + Send,
446        P: glue::HybridPredicate<Self::Id> + Send + Sync,
447        F: FnMut(f32, Self::Id) + Send,
448    {
449        let result = (|| {
450            let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
451            let load_ids: Box<[_]> = ids.take(io_limit).collect();
452
453            self.ensure_loaded(&load_ids)?;
454            let mut ids = Vec::new();
455            for i in load_ids {
456                ids.clear();
457                ids.extend(
458                    self.scratch
459                        .vertex_provider
460                        .get_adjacency_list(&i)?
461                        .iter()
462                        .copied()
463                        .filter(|id| pred.eval_mut(id)),
464                );
465
466                self.pq_distances(&ids, &mut f)?;
467            }
468
469            Ok(())
470        })();
471
472        std::future::ready(result)
473    }
474}
475
476// Scratch space for disk search operations that need allocations.
477// These allocations are amortized across searches using the scratch pool.
478struct DiskSearchScratch<Data, VP>
479where
480    Data: GraphDataType<VectorIdType = u32>,
481    VP: VertexProvider<Data>,
482{
483    distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
484    pq_scratch: PQScratch,
485    vertex_provider: VP,
486}
487
488#[derive(Clone)]
489struct DiskSearchScratchArgs<'a, ProviderFactory> {
490    graph_degree: usize,
491    dim: usize,
492    num_pq_chunks: usize,
493    num_pq_centers: usize,
494    vertex_factory: &'a ProviderFactory,
495    graph_header: &'a GraphHeader,
496}
497
498impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
499    for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
500where
501    Data: GraphDataType<VectorIdType = u32>,
502    ProviderFactory: VertexProviderFactory<Data>,
503{
504    type Error = ANNError;
505
506    fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
507        let pq_scratch = PQScratch::new(
508            args.graph_degree,
509            args.dim,
510            args.num_pq_chunks,
511            args.num_pq_centers,
512        )?;
513
514        const DEFAULT_BEAM_WIDTH: usize = 0; // Setting as 0 to avoid preallocation of memory.
515        let vertex_provider = args
516            .vertex_factory
517            .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
518
519        Ok(Self {
520            distance_cache: HashMap::new(),
521            pq_scratch,
522            vertex_provider,
523        })
524    }
525
526    fn try_modify(
527        &mut self,
528        _args: &DiskSearchScratchArgs<ProviderFactory>,
529    ) -> Result<(), Self::Error> {
530        self.distance_cache.clear();
531        self.vertex_provider.clear();
532        Ok(())
533    }
534}
535
536pub struct DiskAccessor<'a, Data, VP>
537where
538    Data: GraphDataType<VectorIdType = u32>,
539    VP: VertexProvider<Data>,
540{
541    provider: &'a DiskProvider<Data>,
542    io_tracker: &'a IOTracker,
543    scratch: PoolOption<DiskSearchScratch<Data, VP>>,
544    query: &'a [Data::VectorDataType],
545}
546
547impl<Data, VP> DiskAccessor<'_, Data, VP>
548where
549    Data: GraphDataType<VectorIdType = u32>,
550    VP: VertexProvider<Data>,
551{
552    // Compute the PQ distance between each ID in `ids` and the distance table stored in
553    // `self`, invoking the callback with the results of each computation in order.
554    fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
555    where
556        F: FnMut(f32, u32),
557    {
558        let pq_scratch = &mut self.scratch.pq_scratch;
559        compute_pq_distance(
560            ids,
561            self.provider.pq_data.get_num_chunks(),
562            &pq_scratch.aligned_pqtable_dist_scratch,
563            self.provider.pq_data.pq_compressed_data().get_data(),
564            &mut pq_scratch.aligned_pq_coord_scratch,
565            &mut pq_scratch.aligned_dist_scratch,
566        )?;
567
568        for (i, id) in ids.iter().enumerate() {
569            let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
570            f(distance, *id);
571        }
572
573        Ok(())
574    }
575}
576
577impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
578where
579    Data: GraphDataType<VectorIdType = u32>,
580    VP: VertexProvider<Data>,
581{
582    async fn starting_points(&self) -> ANNResult<Vec<u32>> {
583        let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
584        Ok(vec![start_vertex_id])
585    }
586
587    fn terminate_early(&mut self) -> bool {
588        self.io_tracker.io_count() > self.provider.search_io_limit
589    }
590}
591
592impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
593where
594    Data: GraphDataType<VectorIdType = u32>,
595    VP: VertexProvider<Data>,
596{
597    fn new<VPF>(
598        provider: &'a DiskProvider<Data>,
599        io_tracker: &'a IOTracker,
600        query: &'a [Data::VectorDataType],
601        vertex_provider_factory: &'a VPF,
602        scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
603    ) -> ANNResult<Self>
604    where
605        VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
606    {
607        let mut scratch = PoolOption::try_pooled(
608            scratch_pool,
609            &DiskSearchScratchArgs {
610                graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
611                dim: provider.graph_header.metadata().dims,
612                num_pq_chunks: provider.pq_data.get_num_chunks(),
613                num_pq_centers: provider.pq_data.get_num_centers(),
614                vertex_factory: vertex_provider_factory,
615                graph_header: &provider.graph_header,
616            },
617        )?;
618
619        scratch.pq_scratch.set(
620            provider.graph_header.metadata().dims,
621            query,
622            1.0_f32, // Normalization factor
623        )?;
624        let start_vertex_id = provider.graph_header.metadata().medoid as u32;
625
626        let timer = Instant::now();
627        quantizer_preprocess(
628            &mut scratch.pq_scratch,
629            &provider.pq_data,
630            provider.metric,
631            &[start_vertex_id],
632        )?;
633        IOTracker::add_time(
634            &io_tracker.preprocess_time_us,
635            timer.elapsed().as_micros() as u64,
636        );
637
638        Ok(Self {
639            provider,
640            io_tracker,
641            scratch,
642            query,
643        })
644    }
645    fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
646        if ids.is_empty() {
647            return Ok(());
648        }
649        let scratch = &mut self.scratch;
650        let timer = Instant::now();
651        ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
652        IOTracker::add_time(
653            &self.io_tracker.io_time_us,
654            timer.elapsed().as_micros() as u64,
655        );
656        self.io_tracker.add_io_count(ids.len());
657        for id in ids {
658            let distance = self
659                .provider
660                .distance_comparer
661                .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
662            let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
663            scratch
664                .distance_cache
665                .insert(*id, (distance, associated_data));
666        }
667        Ok(())
668    }
669}
670
671impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
672where
673    Data: GraphDataType<VectorIdType = u32>,
674    VP: VertexProvider<Data>,
675{
676    type Id = u32;
677}
678
679impl<'a, Data, VP> Accessor for DiskAccessor<'a, Data, VP>
680where
681    Data: GraphDataType<VectorIdType = u32>,
682    VP: VertexProvider<Data>,
683{
684    /// This references the PQ vector in the underlying `pq_data` store.
685    type Extended = &'a [u8];
686
687    /// This accessor returns raw slices. There *is* a chance of racing when the fast
688    /// providers are used. We just have to live with it.
689    ///
690    /// Since the underlying PQ store is shared, we ignore the `'b` lifetime here and
691    /// instead use `'a`.
692    type Element<'b>
693        = &'a [u8]
694    where
695        Self: 'b;
696
697    /// `ElementRef` can have arbitrary lifetimes.
698    type ElementRef<'b> = &'b [u8];
699
700    /// Choose to panic on an out-of-bounds access rather than propagate an error.
701    type GetError = ANNError;
702
703    fn get_element(
704        &mut self,
705        id: Self::Id,
706    ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
707        std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
708    }
709}
710
711impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
712where
713    Data: GraphDataType<VectorIdType = u32>,
714    VP: VertexProvider<Data>,
715{
716    async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
717        Ok(0..self.provider.num_points as u32)
718    }
719}
720
721impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
722where
723    Data: GraphDataType<VectorIdType = u32>,
724    VP: VertexProvider<Data>,
725{
726    type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
727    fn delegate_neighbor(&'a mut self) -> Self::Delegate {
728        AsNeighborAccessor(self)
729    }
730}
731
732/// A light-weight wrapper around `&mut DiskAccessor` used to tailor the semantics of
733/// [`NeighborAccessor`].
734///
735/// This implementation ensures that the vector data for adjacency lists is also retrieved
736/// and cached to enhance reranking.
737pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
738where
739    Data: GraphDataType<VectorIdType = u32>,
740    VP: VertexProvider<Data>;
741
742impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
743where
744    Data: GraphDataType<VectorIdType = u32>,
745    VP: VertexProvider<Data>,
746{
747    type Id = u32;
748}
749
750impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
751where
752    Data: GraphDataType<VectorIdType = u32>,
753    VP: VertexProvider<Data>,
754{
755    fn get_neighbors(
756        self,
757        id: Self::Id,
758        neighbors: &mut AdjacencyList<Self::Id>,
759    ) -> impl Future<Output = ANNResult<Self>> + Send {
760        if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
761            return future::ok(self); // Returning empty results in `neighbors` out param if IO limit is reached.
762        }
763
764        if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
765            return future::err(e);
766        }
767        let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
768            Ok(list) => list,
769            Err(e) => return future::err(e),
770        };
771        neighbors.overwrite_trusted(list);
772        future::ok(self)
773    }
774}
775
776/// [`DiskIndexSearcher`] is a helper class to make it easy to construct index
777/// and do repeated search operations. It is a wrapper around the index.
778/// This is useful for drivers such as search_disk_index.exe in tools.
779pub struct DiskIndexSearcher<
780    Data,
781    ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
782> where
783    Data: GraphDataType<VectorIdType = u32>,
784    ProviderFactory: VertexProviderFactory<Data>,
785{
786    index: DiskANNIndex<DiskProvider<Data>>,
787    runtime: Runtime,
788
789    /// The vertex provider factory is used to create the vertex provider for each search instance.
790    vertex_provider_factory: ProviderFactory,
791
792    /// Scratch pool for disk search operations that need allocations.
793    scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
794}
795
796#[derive(Debug)]
797pub struct SearchResultStats {
798    pub cmps: u32,
799    pub result_count: u32,
800    pub query_statistics: QueryStatistics,
801}
802
803/// `SearchResult` is a struct representing the result of a search operation.
804///
805/// It contains a list of vector results and a statistics object
806///
807pub struct SearchResult<AssociatedData> {
808    /// A list of nearest neighbors resulting from the search.
809    pub results: Vec<SearchResultItem<AssociatedData>>,
810    pub stats: SearchResultStats,
811}
812
813/// `VectorResult` is a struct representing a nearest neighbor resulting from a search.
814///
815/// It contains the vertex id, associated data, and the distance to the query vector.
816///
817pub struct SearchResultItem<AssociatedData> {
818    /// The vertex id of the nearest neighbor.
819    pub vertex_id: u32,
820    /// The associated data of the nearest neighbor as a fixed size byte array.
821    /// The length is determined when the index is created.
822    pub data: AssociatedData,
823    /// The distance between the nearest neighbor and the query vector.
824    pub distance: f32,
825}
826
827impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
828where
829    Data: GraphDataType<VectorIdType = u32>,
830    ProviderFactory: VertexProviderFactory<Data>,
831{
832    /// Create a new asynchronous disk searcher instance.
833    ///
834    /// # Arguments
835    /// * `num_threads` - The maximum number of threads to use.
836    /// * `search_io_limit` - I/O operation limit.
837    /// * `disk_index_reader` - The disk index reader.
838    /// * `vertex_provider_factory` - The vertex provider factory.
839    /// * `metric` - Distance metric used for vector similarity calculations.
840    /// * `runtime` - Tokio runtime handle for executing async operations.
841    pub fn new(
842        num_threads: usize,
843        search_io_limit: usize,
844        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
845        vertex_provider_factory: ProviderFactory,
846        metric: Metric,
847        runtime: Option<Runtime>,
848    ) -> ANNResult<Self> {
849        let runtime = match runtime {
850            Some(rt) => rt,
851            None => tokio::runtime::Builder::new_current_thread().build()?,
852        };
853
854        let graph_header = vertex_provider_factory.get_header()?;
855        let metadata = graph_header.metadata();
856        let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
857
858        let config = graph::config::Builder::new(
859            max_degree.into_usize(),
860            graph::config::MaxDegree::default_slack(),
861            1, // build-search-list-size
862            metric.into(),
863        )
864        .build()?;
865
866        debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
867
868        let graph_header = vertex_provider_factory.get_header()?;
869        let pq_data = disk_index_reader.get_pq_data();
870        let scratch_pool_args = DiskSearchScratchArgs {
871            graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
872            dim: graph_header.metadata().dims,
873            num_pq_chunks: pq_data.get_num_chunks(),
874            num_pq_centers: pq_data.get_num_centers(),
875            vertex_factory: &vertex_provider_factory,
876            graph_header: &graph_header,
877        };
878        let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
879
880        let disk_provider = DiskProvider::new(
881            disk_index_reader,
882            graph_header,
883            metric,
884            metadata.num_pts.into_usize(),
885            search_io_limit,
886        )?;
887
888        let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
889        Ok(Self {
890            index,
891            runtime,
892            vertex_provider_factory,
893            scratch_pool,
894        })
895    }
896
897    /// Helper method to create a DiskSearchStrategy with common parameters
898    fn search_strategy<'a>(
899        &'a self,
900        query: &'a [Data::VectorDataType],
901        vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
902    ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
903        DiskSearchStrategy {
904            io_tracker: IOTracker::default(),
905            vector_filter,
906            query,
907            vertex_provider_factory: &self.vertex_provider_factory,
908            scratch_pool: &self.scratch_pool,
909        }
910    }
911
912    /// Perform a search on the disk index.
913    /// return the list of nearest neighbors and associated data.
914    pub fn search(
915        &self,
916        query: &[Data::VectorDataType],
917        return_list_size: u32,
918        search_list_size: u32,
919        beam_width: Option<usize>,
920        vector_filter: Option<VectorFilter<Data>>,
921        is_flat_search: bool,
922    ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
923        let mut query_stats = QueryStatistics::default();
924        let mut indices = vec![0u32; return_list_size as usize];
925        let mut distances = vec![0f32; return_list_size as usize];
926        let mut associated_data =
927            vec![Data::AssociatedDataType::default(); return_list_size as usize];
928
929        let stats = self.search_internal(
930            query,
931            return_list_size as usize,
932            search_list_size,
933            beam_width,
934            &mut query_stats,
935            &mut indices,
936            &mut distances,
937            &mut associated_data,
938            &vector_filter.unwrap_or(default_vector_filter::<Data>()),
939            is_flat_search,
940        )?;
941
942        let mut search_result = SearchResult {
943            results: Vec::with_capacity(return_list_size as usize),
944            stats,
945        };
946
947        for ((vertex_id, distance), associated_data) in indices
948            .into_iter()
949            .zip(distances.into_iter())
950            .zip(associated_data.into_iter())
951        {
952            search_result.results.push(SearchResultItem {
953                vertex_id,
954                distance,
955                data: associated_data,
956            });
957        }
958
959        Ok(search_result)
960    }
961
962    /// Perform a raw search on the disk index.
963    /// This is a lower-level API that allows more control over the search parameters and output buffers.
964    #[allow(clippy::too_many_arguments)]
965    pub(crate) fn search_internal(
966        &self,
967        query: &[Data::VectorDataType],
968        k_value: usize,
969        search_list_size: u32,
970        beam_width: Option<usize>,
971        query_stats: &mut QueryStatistics,
972        indices: &mut [u32],
973        distances: &mut [f32],
974        associated_data: &mut [Data::AssociatedDataType],
975        vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
976        is_flat_search: bool,
977    ) -> ANNResult<SearchResultStats> {
978        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
979            &mut indices[..k_value],
980            &mut distances[..k_value],
981            &mut associated_data[..k_value],
982        );
983
984        let strategy = self.search_strategy(query, vector_filter);
985        let timer = Instant::now();
986        let stats = if is_flat_search {
987            self.runtime.block_on(self.index.flat_search(
988                &strategy,
989                &DefaultContext,
990                strategy.query,
991                vector_filter,
992                &SearchParams::new(k_value, search_list_size as usize, beam_width)?,
993                &mut result_output_buffer,
994            ))?
995        } else {
996            self.runtime.block_on(self.index.search(
997                &strategy,
998                &DefaultContext,
999                strategy.query,
1000                &SearchParams::new(k_value, search_list_size as usize, beam_width)?,
1001                &mut result_output_buffer,
1002            ))?
1003        };
1004        query_stats.total_comparisons = stats.cmps;
1005        query_stats.search_hops = stats.hops;
1006
1007        query_stats.total_execution_time_us = timer.elapsed().as_micros();
1008        query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1009        query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1010        query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1011        query_stats.query_pq_preprocess_time_us =
1012            IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1013        query_stats.cpu_time_us = query_stats.total_execution_time_us
1014            - query_stats.io_time_us
1015            - query_stats.query_pq_preprocess_time_us;
1016        Ok(SearchResultStats {
1017            cmps: query_stats.total_comparisons,
1018            result_count: stats.result_count,
1019            query_statistics: query_stats.clone(),
1020        })
1021    }
1022}
1023
1024/// Helper function to ensure vertices are loaded and processed.
1025///
1026/// This is a convenience function that combines `load_vertices` and `process_loaded_node`
1027/// for each vertex ID. It first loads all the vertices in batch, then processes each
1028/// loaded node.
1029fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1030    vertex_provider: &mut V,
1031    ids: &[Data::VectorIdType],
1032) -> ANNResult<()> {
1033    vertex_provider.load_vertices(ids)?;
1034    for (idx, id) in ids.iter().enumerate() {
1035        vertex_provider.process_loaded_node(id, idx)?;
1036    }
1037    Ok(())
1038}
1039
1040#[cfg(test)]
1041mod disk_provider_tests {
1042    use diskann::{
1043        graph::{search::record::VisitedSearchRecord, SearchParamsError},
1044        utils::IntoUsize,
1045        ANNErrorKind,
1046    };
1047    use diskann_providers::storage::{
1048        DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1049    };
1050    use diskann_providers::{
1051        common::AlignedBoxWithSlice,
1052        test_utils::graph_data_type_utils::{
1053            GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1054        },
1055        utils::{
1056            create_thread_pool, file_util, load_aligned_bin, PQPathNames, ParallelIteratorInPool,
1057        },
1058    };
1059    use diskann_utils::test_data_root;
1060    use diskann_vector::distance::Metric;
1061    use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1062    use rstest::rstest;
1063    use vfs::OverlayFS;
1064
1065    use super::*;
1066    use crate::{
1067        build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1068        utils::{QueryStatistics, VirtualAlignedReaderFactory},
1069    };
1070
1071    const TEST_INDEX_PREFIX_128DIM: &str =
1072        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1073    const TEST_INDEX_128DIM: &str =
1074        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1075    const TEST_PQ_PIVOT_128DIM: &str =
1076        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1077    const TEST_PQ_COMPRESSED_128DIM: &str =
1078        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1079    const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1080        "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1081    const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1082
1083    const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1084    const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1085    const TEST_PQ_PIVOT_100DIM: &str =
1086        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1087    const TEST_PQ_COMPRESSED_100DIM: &str =
1088        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1089    const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1090        "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1091    const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1092    const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1093    const TEST_INDEX: &str =
1094        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1095    const TEST_INDEX_PREFIX: &str =
1096        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1097    const TEST_PQ_PIVOT: &str =
1098        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1099    const TEST_PQ_COMPRESSED: &str =
1100        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1101
1102    #[test]
1103    fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1104        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1105
1106        let search_engine = create_disk_index_searcher(
1107            CreateDiskIndexSearcherParams {
1108                max_thread_num: 5,
1109                pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1110                pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1111                index_path: TEST_INDEX_100DIM,
1112                index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1113                ..Default::default()
1114            },
1115            &storage_provider,
1116        );
1117        // Test single thread.
1118        test_disk_search(TestDiskSearchParams {
1119            storage_provider: storage_provider.as_ref(),
1120            index_search_engine: &search_engine,
1121            thread_num: 1,
1122            query_file_path: TEST_QUERY_10PTS_100DIM,
1123            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1124            k: 10,
1125            l: 20,
1126            dim: 104,
1127        });
1128        // Test multi thread.
1129        test_disk_search(TestDiskSearchParams {
1130            storage_provider: storage_provider.as_ref(),
1131            index_search_engine: &search_engine,
1132            thread_num: 5,
1133            query_file_path: TEST_QUERY_10PTS_100DIM,
1134            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1135            k: 10,
1136            l: 20,
1137            dim: 104,
1138        });
1139    }
1140
1141    #[test]
1142    fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1143        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1144
1145        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1146            CreateDiskIndexSearcherParams {
1147                max_thread_num: 5,
1148                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1149                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1150                index_path: TEST_INDEX_128DIM,
1151                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1152                ..Default::default()
1153            },
1154            &storage_provider,
1155        );
1156        // Test single thread.
1157        test_disk_search(TestDiskSearchParams {
1158            storage_provider: storage_provider.as_ref(),
1159            index_search_engine: &search_engine,
1160            thread_num: 1,
1161            query_file_path: TEST_QUERY_10PTS_128DIM,
1162            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1163            k: 10,
1164            l: 20,
1165            dim: 128,
1166        });
1167        // Test multi thread.
1168        test_disk_search(TestDiskSearchParams {
1169            storage_provider: storage_provider.as_ref(),
1170            index_search_engine: &search_engine,
1171            thread_num: 5,
1172            query_file_path: TEST_QUERY_10PTS_128DIM,
1173            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1174            k: 10,
1175            l: 20,
1176            dim: 128,
1177        });
1178    }
1179
1180    fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1181        storage_provider: &StorageReader,
1182    ) -> Vec<u32> {
1183        const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1184
1185        let (data, _npts, _dim) =
1186            file_util::load_bin::<u32, StorageReader>(storage_provider, ASSOCIATED_DATA_FILE, 0)
1187                .unwrap();
1188        data
1189    }
1190
1191    #[test]
1192    fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1193        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1194        let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1195        let params = TestParams {
1196            data_path: TEST_DATA_FILE.to_string(),
1197            index_path_prefix: index_path_prefix.to_string(),
1198            associated_data_path: Some(
1199                "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1200            ),
1201            ..TestParams::default()
1202        };
1203        let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1204        // Build the index with the associated data
1205        fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1206        {
1207            let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1208                CreateDiskIndexSearcherParams {
1209                    max_thread_num: 5,
1210                    pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1211                    pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1212                        .as_str(),
1213                    index_path: format!("{}_disk.index", index_path_prefix).as_str(), //TEST_INDEX_128DIM,
1214                    index_path_prefix,
1215                    ..Default::default()
1216                },
1217                &fixture.storage_provider,
1218            );
1219
1220            // Test single thread.
1221            test_disk_search_with_associated(
1222                TestDiskSearchAssociateParams {
1223                    storage_provider: fixture.storage_provider.as_ref(),
1224                    index_search_engine: &search_engine,
1225                    thread_num: 1,
1226                    query_file_path: TEST_QUERY_10PTS_128DIM,
1227                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1228                    k: 10,
1229                    l: 20,
1230                    dim: 128,
1231                },
1232                None,
1233            );
1234
1235            // Test multi thread.
1236            test_disk_search_with_associated(
1237                TestDiskSearchAssociateParams {
1238                    storage_provider: fixture.storage_provider.as_ref(),
1239                    index_search_engine: &search_engine,
1240                    thread_num: 5,
1241                    query_file_path: TEST_QUERY_10PTS_128DIM,
1242                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1243                    k: 10,
1244                    l: 20,
1245                    dim: 128,
1246                },
1247                None,
1248            );
1249        }
1250
1251        fixture
1252            .storage_provider
1253            .delete(&format!("{}_disk.index", index_path_prefix))
1254            .expect("Failed to delete file");
1255        fixture
1256            .storage_provider
1257            .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1258            .expect("Failed to delete file");
1259        fixture
1260            .storage_provider
1261            .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1262            .expect("Failed to delete file");
1263    }
1264
1265    struct CreateDiskIndexSearcherParams<'a> {
1266        max_thread_num: usize,
1267        pq_pivot_file_path: &'a str,
1268        pq_compressed_file_path: &'a str,
1269        index_path: &'a str,
1270        index_path_prefix: &'a str,
1271        io_limit: usize,
1272    }
1273
1274    impl Default for CreateDiskIndexSearcherParams<'_> {
1275        fn default() -> Self {
1276            Self {
1277                max_thread_num: 1,
1278                pq_pivot_file_path: "",
1279                pq_compressed_file_path: "",
1280                index_path: "",
1281                index_path_prefix: "",
1282                io_limit: usize::MAX,
1283            }
1284        }
1285    }
1286
1287    fn create_disk_index_searcher<Data>(
1288        params: CreateDiskIndexSearcherParams,
1289        storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1290    ) -> DiskIndexSearcher<
1291        Data,
1292        DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1293    >
1294    where
1295        Data: GraphDataType<VectorIdType = u32>,
1296    {
1297        assert!(params.io_limit > 0);
1298
1299        let runtime = tokio::runtime::Builder::new_multi_thread()
1300            .worker_threads(params.max_thread_num)
1301            .build()
1302            .unwrap();
1303
1304        let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1305            params.pq_pivot_file_path.to_string(),
1306            params.pq_compressed_file_path.to_string(),
1307            storage_provider.as_ref(),
1308        )
1309        .unwrap();
1310
1311        let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1312            get_disk_index_file(params.index_path_prefix),
1313            Arc::clone(storage_provider),
1314        );
1315        let caching_strategy = CachingStrategy::None;
1316        let vertex_provider_factory =
1317            DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1318                .unwrap();
1319
1320        DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1321            params.max_thread_num,
1322            params.io_limit,
1323            &disk_index_reader,
1324            vertex_provider_factory,
1325            Metric::L2,
1326            Some(runtime),
1327        )
1328        .unwrap()
1329    }
1330
1331    fn load_query_result<StorageReader: StorageReadProvider>(
1332        storage_provider: &StorageReader,
1333        query_result_path: &str,
1334    ) -> Vec<u32> {
1335        let (result, _, _) =
1336            file_util::load_bin::<u32, StorageReader>(storage_provider, query_result_path, 0)
1337                .unwrap();
1338        result
1339    }
1340
1341    struct TestDiskSearchParams<'a, StorageType> {
1342        storage_provider: &'a StorageType,
1343        index_search_engine: &'a DiskIndexSearcher<
1344            GraphDataF32VectorUnitData,
1345            DiskVertexProviderFactory<
1346                GraphDataF32VectorUnitData,
1347                VirtualAlignedReaderFactory<OverlayFS>,
1348            >,
1349        >,
1350        thread_num: u64,
1351        query_file_path: &'a str,
1352        truth_result_file_path: &'a str,
1353        k: usize,
1354        l: usize,
1355        dim: usize,
1356    }
1357
1358    struct TestDiskSearchAssociateParams<'a, StorageType> {
1359        storage_provider: &'a StorageType,
1360        index_search_engine: &'a DiskIndexSearcher<
1361            GraphDataF32VectorU32Data,
1362            DiskVertexProviderFactory<
1363                GraphDataF32VectorU32Data,
1364                VirtualAlignedReaderFactory<OverlayFS>,
1365            >,
1366        >,
1367        thread_num: u64,
1368        query_file_path: &'a str,
1369        truth_result_file_path: &'a str,
1370        k: usize,
1371        l: usize,
1372        dim: usize,
1373    }
1374
1375    fn test_disk_search<StorageType: StorageReadProvider>(
1376        params: TestDiskSearchParams<StorageType>,
1377    ) {
1378        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1379            .unwrap()
1380            .0;
1381        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1382        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1383
1384        let queries = aligned_query
1385            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1386            .unwrap();
1387
1388        let truth_result =
1389            load_query_result(params.storage_provider, params.truth_result_file_path);
1390
1391        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1392        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1393        queries
1394            .par_iter()
1395            .enumerate()
1396            .for_each_in_pool(&pool, |(i, query)| {
1397                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1398                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1399                let mut temp = Vec::with_capacity(query.len() + 1);
1400                temp.push(0.0);
1401                temp.extend_from_slice(query);
1402                aligned_box.memcpy(temp.as_slice()).unwrap();
1403                let query = &aligned_box.as_slice()[1..];
1404
1405                let mut query_stats = QueryStatistics::default();
1406                let mut indices = vec![0u32; 10];
1407                let mut distances = vec![0f32; 10];
1408                let mut associated_data = vec![(); 10];
1409
1410                let result = params
1411                    .index_search_engine
1412                    //.search_with_associated_data(query, params.k as u32, params.l as u32)
1413                    .search_internal(
1414                        query,
1415                        params.k,
1416                        params.l as u32,
1417                        None, // beam_width
1418                        &mut query_stats,
1419                        &mut indices,
1420                        &mut distances,
1421                        &mut associated_data,
1422                        &(|_| true),
1423                        false,
1424                    );
1425
1426                // Calculate the range of the truth_result for this query
1427                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1428
1429                assert!(result.is_ok(), "Expected search to succeed");
1430
1431                let result_unwrapped = result.unwrap();
1432                assert!(
1433                    result_unwrapped.query_statistics.total_io_operations > 0,
1434                    "Expected IO operations to be greater than 0"
1435                );
1436                assert!(
1437                    result_unwrapped.query_statistics.total_vertices_loaded > 0,
1438                    "Expected vertices loaded to be greater than 0"
1439                );
1440
1441                // Compare res with truth_slice using assert_eq!
1442                assert_eq!(
1443                    indices, truth_slice,
1444                    "Results DO NOT match with the truth result for query {}",
1445                    i
1446                );
1447            });
1448    }
1449
1450    fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1451        params: TestDiskSearchAssociateParams<StorageType>,
1452        beam_width: Option<usize>,
1453    ) {
1454        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1455            .unwrap()
1456            .0;
1457        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1458        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1459        let queries = aligned_query
1460            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1461            .unwrap();
1462        let truth_result =
1463            load_query_result(params.storage_provider, params.truth_result_file_path);
1464        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1465        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1466        queries
1467            .par_iter()
1468            .enumerate()
1469            .for_each_in_pool(&pool, |(i, query)| {
1470                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1471                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1472                let mut temp = Vec::with_capacity(query.len() + 1);
1473                temp.push(0.0);
1474                temp.extend_from_slice(query);
1475                aligned_box.memcpy(temp.as_slice()).unwrap();
1476                let query = &aligned_box.as_slice()[1..];
1477                let result = params
1478                    .index_search_engine
1479                    .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1480                    .unwrap();
1481                let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1482                let associated_data: Vec<u32> =
1483                    result.results.iter().map(|item| item.data).collect();
1484                let truth_data = get_truth_associated_data(params.storage_provider);
1485                let associated_data_truth: Vec<u32> = indices
1486                    .iter()
1487                    .map(|&vid| truth_data[vid as usize])
1488                    .collect();
1489                assert_eq!(
1490                    associated_data, associated_data_truth,
1491                    "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1492                    i,associated_data, associated_data_truth
1493                );
1494                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1495                assert_eq!(
1496                    indices, truth_slice,
1497                    "Results DO NOT match with the truth result for query {}",
1498                    i
1499                );
1500            });
1501    }
1502
1503    #[test]
1504    fn test_disk_search_invalid_input() {
1505        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1506        let ctx = &DefaultContext;
1507
1508        let params = CreateDiskIndexSearcherParams {
1509            max_thread_num: 5,
1510            pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1511            pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1512            index_path: TEST_INDEX_128DIM,
1513            index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1514            ..Default::default()
1515        };
1516
1517        let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1518        assert_eq!(
1519            paths.pivots, params.pq_pivot_file_path,
1520            "pq_pivot_file_path is not correct"
1521        );
1522        assert_eq!(
1523            paths.compressed_data, params.pq_compressed_file_path,
1524            "pq_compressed_file_path is not correct"
1525        );
1526        assert_eq!(
1527            params.index_path,
1528            format!("{}_disk.index", params.index_path_prefix),
1529            "index_path is not correct"
1530        );
1531
1532        let res = SearchParams::new_default(0, 10);
1533        assert!(res.is_err());
1534        assert_eq!(
1535            <SearchParamsError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1536            ANNErrorKind::IndexError
1537        );
1538        let res = SearchParams::new_default(20, 10);
1539        assert!(res.is_err());
1540        let res = SearchParams::new_default(10, 0);
1541        assert!(res.is_err());
1542        let res = SearchParams::new(10, 10, Some(0));
1543        assert!(res.is_err());
1544
1545        let search_engine =
1546            create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1547
1548        // minor validation tests to improve code coverage
1549        assert_eq!(
1550            search_engine
1551                .index
1552                .data_provider
1553                .to_external_id(ctx, 0)
1554                .unwrap(),
1555            0
1556        );
1557        assert_eq!(
1558            search_engine
1559                .index
1560                .data_provider
1561                .to_internal_id(ctx, &0)
1562                .unwrap(),
1563            0
1564        );
1565
1566        let provider_max_degree = search_engine
1567            .index
1568            .data_provider
1569            .graph_header
1570            .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1571            .unwrap();
1572        let index_max_degree = search_engine.index.config.pruned_degree().get();
1573        assert_eq!(provider_max_degree, index_max_degree);
1574
1575        let query = vec![0f32; 128];
1576        let mut query_stats = QueryStatistics::default();
1577        let mut indices = vec![0u32; 10];
1578        let mut distances = vec![0f32; 10];
1579        let mut associated_data = vec![0u32; 10];
1580
1581        // Set L: {} to a value of at least K:
1582        let result = search_engine.search_internal(
1583            &query,
1584            10,
1585            10 - 1,
1586            None,
1587            &mut query_stats,
1588            &mut indices,
1589            &mut distances,
1590            &mut associated_data,
1591            &|_| true,
1592            false,
1593        );
1594
1595        assert!(result.is_err());
1596        assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1597    }
1598
1599    #[test]
1600    fn test_disk_search_beam_search() {
1601        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1602
1603        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1604            CreateDiskIndexSearcherParams {
1605                max_thread_num: 1,
1606                pq_pivot_file_path: TEST_PQ_PIVOT,
1607                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1608                index_path: TEST_INDEX,
1609                index_path_prefix: TEST_INDEX_PREFIX,
1610                ..Default::default()
1611            },
1612            &storage_provider,
1613        );
1614
1615        let query_vector: [f32; 128] = [1f32; 128];
1616        let mut indices = vec![0u32; 10];
1617        let mut distances = vec![0f32; 10];
1618        let mut associated_data = vec![(); 10];
1619
1620        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1621            &mut indices,
1622            &mut distances,
1623            &mut associated_data,
1624        );
1625        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1626        let mut search_record = VisitedSearchRecord::new(0);
1627        search_engine
1628            .runtime
1629            .block_on(search_engine.index.search_recorded(
1630                &strategy,
1631                &DefaultContext,
1632                &query_vector,
1633                &SearchParams::new(10, 10, Some(4)).unwrap(),
1634                &mut result_output_buffer,
1635                &mut search_record,
1636            ))
1637            .unwrap();
1638
1639        let ids = search_record
1640            .visited
1641            .iter()
1642            .map(|n| n.id)
1643            .collect::<Vec<_>>();
1644
1645        const EXPECTED_NODES: [u32; 18] = [
1646            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1647        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
1648
1649        assert_eq!(ids, &EXPECTED_NODES);
1650
1651        let return_list_size = 10;
1652        let search_list_size = 10;
1653        let result = search_engine.search(
1654            &query_vector,
1655            return_list_size,
1656            search_list_size,
1657            Some(4),
1658            None,
1659            false,
1660        );
1661        assert!(result.is_ok(), "Expected search to succeed");
1662        let search_result = result.unwrap();
1663        assert_eq!(
1664            search_result.results.len() as u32,
1665            return_list_size,
1666            "Expected result count to match"
1667        );
1668        assert_eq!(
1669            indices,
1670            vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1671            "Expected indices to match"
1672        );
1673    }
1674
1675    #[cfg(feature = "experimental_diversity_search")]
1676    #[test]
1677    fn test_disk_search_diversity_search() {
1678        use diskann::graph::DiverseSearchParams;
1679        use diskann::neighbor::AttributeValueProvider;
1680        use std::collections::HashMap;
1681
1682        // Simple test attribute provider
1683        #[derive(Debug, Clone)]
1684        struct TestAttributeProvider {
1685            attributes: HashMap<u32, u32>,
1686        }
1687        impl TestAttributeProvider {
1688            fn new() -> Self {
1689                Self {
1690                    attributes: HashMap::new(),
1691                }
1692            }
1693            fn insert(&mut self, id: u32, attribute: u32) {
1694                self.attributes.insert(id, attribute);
1695            }
1696        }
1697        impl diskann::provider::HasId for TestAttributeProvider {
1698            type Id = u32;
1699        }
1700
1701        impl AttributeValueProvider for TestAttributeProvider {
1702            type Value = u32;
1703
1704            fn get(&self, id: Self::Id) -> Option<Self::Value> {
1705                self.attributes.get(&id).copied()
1706            }
1707        }
1708
1709        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1710
1711        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1712            CreateDiskIndexSearcherParams {
1713                max_thread_num: 1,
1714                pq_pivot_file_path: TEST_PQ_PIVOT,
1715                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1716                index_path: TEST_INDEX,
1717                index_path_prefix: TEST_INDEX_PREFIX,
1718                ..Default::default()
1719            },
1720            &storage_provider,
1721        );
1722
1723        let query_vector: [f32; 128] = [1f32; 128];
1724
1725        // Create attribute provider with random labels (1 to 3) for all vectors
1726        let mut attribute_provider = TestAttributeProvider::new();
1727        let num_vectors = 256; // Number of vectors in the test dataset
1728        for i in 0..num_vectors {
1729            // Assign labels 1-3 based on modulo to ensure distribution
1730            let label = (i % 15) + 1;
1731            attribute_provider.insert(i, label);
1732        }
1733        // Wrap in Arc once to avoid cloning the HashMap later
1734        let attribute_provider = std::sync::Arc::new(attribute_provider);
1735
1736        let mut indices = vec![0u32; 10];
1737        let mut distances = vec![0f32; 10];
1738        let mut associated_data = vec![(); 10];
1739
1740        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1741            &mut indices,
1742            &mut distances,
1743            &mut associated_data,
1744        );
1745        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1746        let mut search_record = VisitedSearchRecord::new(0);
1747
1748        // Create diverse search parameters with attribute provider
1749        let diverse_params = DiverseSearchParams::new(
1750            0, // diverse_attribute_id
1751            3, // diverse_results_k
1752            attribute_provider.clone(),
1753        );
1754
1755        let search_params = SearchParams::new(10, 20, None).unwrap();
1756
1757        search_engine
1758            .runtime
1759            .block_on(search_engine.index.diverse_search_experimental(
1760                &strategy,
1761                &DefaultContext,
1762                &query_vector,
1763                &search_params,
1764                &diverse_params,
1765                &mut result_output_buffer,
1766                &mut search_record,
1767            ))
1768            .unwrap();
1769
1770        let ids = search_record
1771            .visited
1772            .iter()
1773            .map(|n| n.id)
1774            .collect::<Vec<_>>();
1775
1776        // Verify that search was performed and visited some nodes
1777        assert!(
1778            !ids.is_empty(),
1779            "Expected to visit some nodes during diversity search"
1780        );
1781
1782        let return_list_size = 10;
1783        let search_list_size = 20;
1784        let diverse_results_k = 1;
1785        let diverse_params = DiverseSearchParams::new(
1786            0, // diverse_attribute_id
1787            diverse_results_k,
1788            attribute_provider.clone(),
1789        );
1790
1791        // Test diverse search using the experimental API
1792        let mut indices2 = vec![0u32; return_list_size as usize];
1793        let mut distances2 = vec![0f32; return_list_size as usize];
1794        let mut associated_data2 = vec![(); return_list_size as usize];
1795        let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1796            &mut indices2,
1797            &mut distances2,
1798            &mut associated_data2,
1799        );
1800        let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1801        let mut search_record2 = VisitedSearchRecord::new(0);
1802        let search_params2 =
1803            SearchParams::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1804
1805        let stats = search_engine
1806            .runtime
1807            .block_on(search_engine.index.diverse_search_experimental(
1808                &strategy2,
1809                &DefaultContext,
1810                &query_vector,
1811                &search_params2,
1812                &diverse_params,
1813                &mut result_output_buffer2,
1814                &mut search_record2,
1815            ))
1816            .unwrap();
1817
1818        // Verify results
1819        assert!(
1820            stats.result_count > 0,
1821            "Expected diversity search to return results"
1822        );
1823        assert!(
1824            stats.result_count <= return_list_size,
1825            "Expected result count to be <= {}",
1826            return_list_size
1827        );
1828
1829        // Verify that we got some results
1830        assert!(
1831            stats.result_count > 0,
1832            "Expected to get some search results"
1833        );
1834
1835        // Print search results with their attributes
1836        println!("\n=== Diversity Search Results ===");
1837        println!("Query: [1f32; 128]");
1838        println!("diverse_results_k: {}", diverse_results_k);
1839        println!("Total results: {}\n", stats.result_count);
1840        println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1841        println!("{}", "-".repeat(35));
1842        for i in 0..stats.result_count as usize {
1843            let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1844            println!(
1845                "{:<10} {:<15.2} {:<10}",
1846                indices2[i], distances2[i], attribute_value
1847            );
1848        }
1849
1850        // Verify that distances are non-negative and sorted
1851        for i in 0..(stats.result_count as usize).saturating_sub(1) {
1852            assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1853            assert!(
1854                distances2[i] <= distances2[i + 1],
1855                "Expected distances to be sorted in ascending order"
1856            );
1857        }
1858
1859        // Verify diversity: Check that we have diverse attribute values in the results
1860        let mut attribute_counts = HashMap::new();
1861        for item in indices2.iter().take(stats.result_count as usize) {
1862            if let Some(attribute_value) = attribute_provider.get(*item) {
1863                *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1864            }
1865        }
1866
1867        // Print attribute distribution
1868        println!("\n=== Attribute Distribution ===");
1869        let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1870        sorted_attrs.sort_by_key(|(k, _)| *k);
1871        for (attribute_value, count) in &sorted_attrs {
1872            println!(
1873                "Label {}: {} occurrences (max allowed: {})",
1874                attribute_value, count, diverse_results_k
1875            );
1876        }
1877        println!("Total unique labels: {}", attribute_counts.len());
1878        println!("================================\n");
1879
1880        // With diverse_results_k = 5, we expect at most 5 results per attribute value
1881        for (attribute_value, count) in &attribute_counts {
1882            println!(
1883                "Assert: Label {} has {} occurrences (max: {})",
1884                attribute_value, count, diverse_results_k
1885            );
1886            assert!(
1887                *count <= diverse_results_k,
1888                "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1889                attribute_value,
1890                count,
1891                diverse_results_k
1892            );
1893        }
1894
1895        // Verify that we have multiple different attribute values (diversity)
1896        // With 3 possible labels and diverse_results_k=5, we should see at least 2 different labels
1897        println!(
1898            "Assert: Found {} unique labels (expected at least 2)",
1899            attribute_counts.len()
1900        );
1901        assert!(
1902            attribute_counts.len() >= 2,
1903            "Expected at least 2 different attribute values for diversity, got {}",
1904            attribute_counts.len()
1905        );
1906    }
1907
1908    #[rstest]
1909    // This case checks expected behavior of unfiltered search.
1910    #[case(
1911        |_id: &u32| true,
1912        false,
1913        10,
1914        vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1915        vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1916    )]
1917    // This case validates post-filtering using 2 ids which are not present in the unfiltered result set.
1918    // It is expected that the post-filtering will return an empty result
1919    #[case(
1920        |id: &u32| *id == 0 || *id == 1,
1921        false,
1922        0,
1923        vec![0; 10],
1924        vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1925    )]
1926    // This case validates pre-filtering using 2 ids which are not present in the unfiltered result set.
1927    // It is expected that the pre-filtering will do search over matching ids
1928    #[case(
1929        |id: &u32| *id == 0 || *id == 1,
1930        true,
1931        2,
1932        vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1933        vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1934    )]
1935    // This case validates post-filtering using 3 ids from the unfiltered result set.
1936    // It is expected that the post-filtering will filter out non-matching ids
1937    #[case(
1938        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1939        false,
1940        3,
1941        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1942        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1943    )]
1944    // This case validates pre-filtering using 3 ids from the unfiltered result set.
1945    // It is expected that the pre-filtering will do search over matching ids
1946    #[case(
1947        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1948        true,
1949        3,
1950        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1951        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1952    )]
1953    fn test_search_with_vector_filter(
1954        #[case] vector_filter: fn(&u32) -> bool,
1955        #[case] is_flat_search: bool,
1956        #[case] expected_result_count: u32,
1957        #[case] expected_indices: Vec<u32>,
1958        #[case] expected_distances: Vec<f32>,
1959    ) {
1960        // Exact distances can vary slightly depending on the architecture used
1961        // to compute distances due to different unrolling strategies and SIMD widthd.
1962        //
1963        // This parameter allows for a small margin when matching distances.
1964        let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1965            const ABS_TOLERANCE: f32 = 0.02;
1966            assert_eq!(got.len(), expected.len());
1967            for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1968                if (g - e).abs() > ABS_TOLERANCE {
1969                    panic!(
1970                        "distances differ at position {} by more than {}\n\n\
1971                         got: {:?}\nexpected: {:?}",
1972                        i, ABS_TOLERANCE, got, expected,
1973                    );
1974                }
1975            }
1976            true
1977        };
1978
1979        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1980
1981        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1982            CreateDiskIndexSearcherParams {
1983                max_thread_num: 5,
1984                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1985                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1986                index_path: TEST_INDEX_128DIM,
1987                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1988                ..Default::default()
1989            },
1990            &storage_provider,
1991        );
1992        let query = vec![0.1f32; 128];
1993        let mut query_stats = QueryStatistics::default();
1994        let mut indices = vec![0u32; 10];
1995        let mut distances = vec![0f32; 10];
1996        let mut associated_data = vec![(); 10];
1997
1998        let result = search_engine.search_internal(
1999            &query,
2000            10,
2001            10,
2002            None, // beam_width
2003            &mut query_stats,
2004            &mut indices,
2005            &mut distances,
2006            &mut associated_data,
2007            &vector_filter,
2008            is_flat_search,
2009        );
2010
2011        assert!(result.is_ok(), "Expected search to succeed");
2012        assert_eq!(
2013            result.unwrap().result_count,
2014            expected_result_count,
2015            "Expected result count to match"
2016        );
2017        assert_eq!(indices, expected_indices, "Expected indices to match");
2018        assert!(
2019            check_distances(&distances, &expected_distances),
2020            "Expected distances to match"
2021        );
2022
2023        let result_with_filter = search_engine.search(
2024            &query,
2025            10,
2026            10,
2027            None, // beam_width
2028            Some(Box::new(vector_filter)),
2029            is_flat_search,
2030        );
2031
2032        assert!(result_with_filter.is_ok(), "Expected search to succeed");
2033        let result_with_filter_unwrapped = result_with_filter.unwrap();
2034        assert_eq!(
2035            result_with_filter_unwrapped.stats.result_count, expected_result_count,
2036            "Expected result count to match"
2037        );
2038        let actual_indices = result_with_filter_unwrapped
2039            .results
2040            .iter()
2041            .map(|x| x.vertex_id)
2042            .collect::<Vec<_>>();
2043        assert_eq!(
2044            actual_indices, expected_indices,
2045            "Expected indices to match"
2046        );
2047        let actual_distances = result_with_filter_unwrapped
2048            .results
2049            .iter()
2050            .map(|x| x.distance)
2051            .collect::<Vec<_>>();
2052        assert!(
2053            check_distances(&actual_distances, &expected_distances),
2054            "Expected distances to match"
2055        );
2056    }
2057
2058    #[test]
2059    fn test_beam_search_respects_io_limit() {
2060        let io_limit = 11; // Set a small IO limit for testing
2061        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2062
2063        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2064            CreateDiskIndexSearcherParams {
2065                max_thread_num: 1,
2066                pq_pivot_file_path: TEST_PQ_PIVOT,
2067                pq_compressed_file_path: TEST_PQ_COMPRESSED,
2068                index_path: TEST_INDEX,
2069                index_path_prefix: TEST_INDEX_PREFIX,
2070                io_limit,
2071            },
2072            &storage_provider,
2073        );
2074        let query_vector: [f32; 128] = [1f32; 128];
2075
2076        let mut indices = vec![0u32; 10];
2077        let mut distances = vec![0f32; 10];
2078        let mut associated_data = vec![(); 10];
2079
2080        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2081            &mut indices,
2082            &mut distances,
2083            &mut associated_data,
2084        );
2085
2086        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2087
2088        let mut search_record = VisitedSearchRecord::new(0);
2089        search_engine
2090            .runtime
2091            .block_on(search_engine.index.search_recorded(
2092                &strategy,
2093                &DefaultContext,
2094                &query_vector,
2095                &SearchParams::new(10, 10, Some(4)).unwrap(),
2096                &mut result_output_buffer,
2097                &mut search_record,
2098            ))
2099            .unwrap();
2100        let visited_ids = search_record
2101            .visited
2102            .iter()
2103            .map(|n| n.id)
2104            .collect::<Vec<_>>();
2105
2106        let query_stats = strategy.io_tracker;
2107        //Verify the IO limit was respected
2108        assert!(
2109            query_stats.io_count() <= io_limit,
2110            "Expected IO operations to be <= {}, but got {}",
2111            io_limit,
2112            query_stats.io_count()
2113        );
2114
2115        const EXPECTED_NODES: [u32; 17] = [
2116            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2117        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
2118
2119        // Count matching results
2120        let mut matching_count = 0;
2121        for expected_node in EXPECTED_NODES.iter() {
2122            if visited_ids.contains(expected_node) {
2123                matching_count += 1;
2124            }
2125        }
2126
2127        // Calculate recall
2128        let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2129
2130        //Verify the recall is above 60%. The threshold her eis arbitrary, just to make sure when
2131        // search hits io_limit that it doesn't break and the recall degrades gracefully
2132        assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2133    }
2134}