Skip to main content

diskann_disk/search/provider/
disk_provider.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    collections::HashMap,
8    future::Future,
9    num::NonZeroUsize,
10    ops::Range,
11    sync::{
12        atomic::{AtomicU64, AtomicUsize},
13        Arc,
14    },
15    time::Instant,
16};
17
18use diskann::{
19    graph::{
20        self,
21        glue::{self, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, SearchStrategy},
22        search::Knn,
23        search_output_buffer, AdjacencyList, DiskANNIndex, SearchOutputBuffer,
24    },
25    neighbor::Neighbor,
26    provider::{
27        Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId,
28        NeighborAccessor,
29    },
30    utils::{
31        object_pool::{ObjectPool, PoolOption, TryAsPooled},
32        IntoUsize, VectorRepr,
33    },
34    ANNError, ANNResult,
35};
36use diskann_providers::storage::StorageReadProvider;
37use diskann_providers::{
38    model::{
39        compute_pq_distance, compute_pq_distance_for_pq_coordinates, graph::traits::GraphDataType,
40        pq::quantizer_preprocess, PQData, PQScratch,
41    },
42    storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith},
43};
44use diskann_vector::{distance::Metric, DistanceFunction, PreprocessedDistanceFunction};
45use futures_util::future;
46use tokio::runtime::Runtime;
47use tracing::debug;
48
49use crate::{
50    data_model::{CachingStrategy, GraphHeader},
51    filter_parameter::{default_vector_filter, VectorFilter},
52    search::{
53        provider::disk_vertex_provider_factory::DiskVertexProviderFactory,
54        traits::{VertexProvider, VertexProviderFactory},
55    },
56    storage::{api::AsyncDiskLoadContext, disk_index_reader::DiskIndexReader},
57    utils::AlignedFileReaderFactory,
58    utils::QueryStatistics,
59};
60
61///////////////////
62// Disk Provider //
63///////////////////
64
65/// The DiskProvider is a data provider that loads data from disk using the disk readers
66/// The data format for disk is different from that of the in-memory providers.
67/// The disk format stores both the vectors and the adjacency list next to each other for
68/// better locality for quicker access.
69/// Please refer to the RFC documentation at [`docs\rfcs\cy2025\disk_provider_for_async_index.md`] for design details.
70pub struct DiskProvider<Data>
71where
72    Data: GraphDataType<VectorIdType = u32>,
73{
74    /// Holds the graph header information that contains metadata about disk-index file.
75    graph_header: GraphHeader,
76
77    // Full precision distance comparer used in post_process to reorder results.
78    distance_comparer: <Data::VectorDataType as VectorRepr>::Distance,
79
80    /// The PQ data used for quantization.
81    pq_data: Arc<PQData>,
82
83    /// The number of points in the graph.
84    num_points: usize,
85
86    /// Metric used for distance computation.
87    metric: Metric,
88
89    /// The number of IO operations that can be done in parallel.
90    search_io_limit: usize,
91}
92
93impl<Data> DataProvider for DiskProvider<Data>
94where
95    Data: GraphDataType<VectorIdType = u32>,
96{
97    type Context = DefaultContext;
98
99    type InternalId = u32;
100
101    type ExternalId = u32;
102
103    type Error = ANNError;
104
105    /// Translate an external id to its corresponding internal id.
106    fn to_internal_id(
107        &self,
108        _context: &DefaultContext,
109        gid: &Self::ExternalId,
110    ) -> Result<Self::InternalId, Self::Error> {
111        Ok(*gid)
112    }
113
114    /// Translate an internal id its corresponding external id.
115    fn to_external_id(
116        &self,
117        _context: &DefaultContext,
118        id: Self::InternalId,
119    ) -> Result<Self::ExternalId, Self::Error> {
120        Ok(id)
121    }
122}
123
124impl<Data> LoadWith<AsyncDiskLoadContext> for DiskProvider<Data>
125where
126    Data: GraphDataType<VectorIdType = u32>,
127{
128    type Error = ANNError;
129
130    async fn load_with<P>(provider: &P, ctx: &AsyncDiskLoadContext) -> ANNResult<Self>
131    where
132        P: StorageReadProvider,
133    {
134        debug!(
135            "DiskProvider::load_with() called with file: {:?}",
136            get_disk_index_file(ctx.quant_load_context.metadata.prefix())
137        );
138
139        let graph_header = {
140            let aligned_reader_factory = AlignedFileReaderFactory::new(get_disk_index_file(
141                ctx.quant_load_context.metadata.prefix(),
142            ));
143
144            let caching_strategy = if ctx.num_nodes_to_cache > 0 {
145                CachingStrategy::StaticCacheWithBfsNodes(ctx.num_nodes_to_cache)
146            } else {
147                CachingStrategy::None
148            };
149
150            let vertex_provider_factory = DiskVertexProviderFactory::<Data, _>::new(
151                aligned_reader_factory,
152                caching_strategy,
153            )?;
154            VertexProviderFactory::get_header(&vertex_provider_factory)?
155        };
156
157        let metric = ctx.quant_load_context.metric;
158        let num_points = ctx.num_points;
159
160        let index_path_prefix = ctx.quant_load_context.metadata.prefix();
161        let index_reader = DiskIndexReader::<<Data as GraphDataType>::VectorDataType>::new(
162            get_pq_pivot_file(index_path_prefix),
163            get_compressed_pq_file(index_path_prefix),
164            provider,
165        )?;
166
167        Self::new(
168            &index_reader,
169            graph_header,
170            metric,
171            num_points,
172            ctx.search_io_limit,
173        )
174    }
175}
176
177impl<Data> DiskProvider<Data>
178where
179    Data: GraphDataType<VectorIdType = u32>,
180{
181    fn new(
182        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
183        graph_header: GraphHeader,
184        metric: Metric,
185        num_points: usize,
186        search_io_limit: usize,
187    ) -> ANNResult<Self> {
188        let distance_comparer =
189            Data::VectorDataType::distance(metric, Some(graph_header.metadata().dims));
190
191        let pq_data = disk_index_reader.get_pq_data();
192
193        Ok(Self {
194            graph_header,
195            distance_comparer,
196            pq_data,
197            num_points,
198            metric,
199            search_io_limit,
200        })
201    }
202}
203
204/// The search strategy for the disk provider. This is used to create the search accessor
205/// for use in search in quant space and post_process function to reorder with full precision vectors.
206///
207/// # Why vertex_provider_factory and scratch_pool are here instead of DiskProvider
208///
209/// The DataProvider trait requires 'static bounds for multi-threaded async contexts,
210/// but vertex_provider_factory may have non-'static lifetime bounds (e.g., borrowing
211/// from local data structures). Moving these components to the search strategy allows
212/// DiskProvider to satisfy 'static constraints while enabling flexible per-search
213/// resource management.
214pub struct DiskSearchStrategy<'a, Data, ProviderFactory>
215where
216    Data: GraphDataType<VectorIdType = u32>,
217    ProviderFactory: VertexProviderFactory<Data>,
218{
219    // This needs to be Arc instead of Rc because DiskSearchStrategy has "Send" trait bound, though this is not expected to be shared across threads.
220    io_tracker: IOTracker,
221    vector_filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), // Fn param is u32 as we validate "VectorIdType = u32" everywhere in this provider in trait bounds.
222    query: &'a [Data::VectorDataType],
223
224    /// The vertex provider factory is used to create the vertex provider for each search instance.
225    vertex_provider_factory: &'a ProviderFactory,
226
227    /// Scratch pool for disk search operations that need allocations.
228    scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
229}
230
231// Struct to track IO. This is used by single thread, but needs to be Atomic as the Strategy has "Send" trait bound.
232// There should be minimal to no overhead compared to using a raw reference.
233struct IOTracker {
234    io_time_us: AtomicU64,
235    preprocess_time_us: AtomicU64,
236    io_count: AtomicUsize,
237}
238
239impl Default for IOTracker {
240    fn default() -> Self {
241        Self {
242            io_time_us: AtomicU64::new(0),
243            preprocess_time_us: AtomicU64::new(0),
244            io_count: AtomicUsize::new(0),
245        }
246    }
247}
248
249impl IOTracker {
250    fn add_time(category: &AtomicU64, time: u64) {
251        category.fetch_add(time, std::sync::atomic::Ordering::Relaxed);
252    }
253
254    fn time(category: &AtomicU64) -> u64 {
255        category.load(std::sync::atomic::Ordering::Relaxed)
256    }
257
258    fn add_io_count(&self, count: usize) {
259        self.io_count
260            .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
261    }
262
263    fn io_count(&self) -> usize {
264        self.io_count.load(std::sync::atomic::Ordering::Relaxed)
265    }
266}
267
268#[derive(Clone, Copy)]
269pub struct RerankAndFilter<'a> {
270    filter: &'a (dyn Fn(&u32) -> bool + Send + Sync),
271}
272
273impl<'a> RerankAndFilter<'a> {
274    fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self {
275        Self { filter }
276    }
277}
278
279impl<Data, VP>
280    SearchPostProcess<
281        DiskAccessor<'_, Data, VP>,
282        [Data::VectorDataType],
283        (
284            <DiskProvider<Data> as DataProvider>::InternalId,
285            Data::AssociatedDataType,
286        ),
287    > for RerankAndFilter<'_>
288where
289    Data: GraphDataType<VectorIdType = u32>,
290    VP: VertexProvider<Data>,
291{
292    type Error = ANNError;
293    async fn post_process<I, B>(
294        &self,
295        accessor: &mut DiskAccessor<'_, Data, VP>,
296        query: &[Data::VectorDataType],
297        _computer: &DiskQueryComputer,
298        candidates: I,
299        output: &mut B,
300    ) -> Result<usize, Self::Error>
301    where
302        I: Iterator<Item = Neighbor<u32>> + Send,
303        B: SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send + ?Sized,
304    {
305        let provider = accessor.provider;
306
307        let mut uncached_ids = Vec::new();
308        let mut reranked = candidates
309            .map(|n| n.id)
310            .filter(|id| (self.filter)(id))
311            .filter_map(|n| {
312                if let Some(entry) = accessor.scratch.distance_cache.get(&n) {
313                    Some(Ok::<((u32, _), f32), ANNError>(((n, entry.1), entry.0)))
314                } else {
315                    uncached_ids.push(n);
316                    None
317                }
318            })
319            .collect::<Result<Vec<_>, _>>()?;
320        if !uncached_ids.is_empty() {
321            ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &uncached_ids)?;
322            for n in &uncached_ids {
323                let v = accessor.scratch.vertex_provider.get_vector(n)?;
324                let d = provider.distance_comparer.evaluate_similarity(query, v);
325                let a = accessor.scratch.vertex_provider.get_associated_data(n)?;
326                reranked.push(((*n, *a), d));
327            }
328        }
329
330        // Sort the full precision distances.
331        reranked
332            .sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
333        // Store the reranked results.
334        Ok(output.extend(reranked))
335    }
336}
337
338impl<'this, Data, ProviderFactory>
339    SearchStrategy<
340        DiskProvider<Data>,
341        [Data::VectorDataType],
342        (
343            <DiskProvider<Data> as DataProvider>::InternalId,
344            Data::AssociatedDataType,
345        ),
346    > for DiskSearchStrategy<'this, Data, ProviderFactory>
347where
348    Data: GraphDataType<VectorIdType = u32>,
349    ProviderFactory: VertexProviderFactory<Data>,
350{
351    type QueryComputer = DiskQueryComputer;
352    type SearchAccessor<'a> = DiskAccessor<'a, Data, ProviderFactory::VertexProviderType>;
353    type SearchAccessorError = ANNError;
354    type PostProcessor = RerankAndFilter<'this>;
355
356    fn search_accessor<'a>(
357        &'a self,
358        provider: &'a DiskProvider<Data>,
359        _context: &DefaultContext,
360    ) -> Result<Self::SearchAccessor<'a>, Self::SearchAccessorError> {
361        DiskAccessor::new(
362            provider,
363            &self.io_tracker,
364            self.query,
365            self.vertex_provider_factory,
366            self.scratch_pool,
367        )
368    }
369
370    fn post_processor(&self) -> Self::PostProcessor {
371        RerankAndFilter::new(self.vector_filter)
372    }
373}
374
375/// The query computer for the disk provider. This is used to compute the distance between the query vector and the PQ coordinates.
376pub struct DiskQueryComputer {
377    num_pq_chunks: usize,
378    query_centroid_l2_distance: Vec<f32>,
379}
380
381impl PreprocessedDistanceFunction<&[u8], f32> for DiskQueryComputer {
382    fn evaluate_similarity(&self, changing: &[u8]) -> f32 {
383        let mut dist = 0.0f32;
384        #[allow(clippy::expect_used)]
385        compute_pq_distance_for_pq_coordinates(
386            changing,
387            self.num_pq_chunks,
388            &self.query_centroid_l2_distance,
389            std::slice::from_mut(&mut dist),
390        )
391        .expect("PQ distance compute for PQ coordinates is expected to succeed");
392        dist
393    }
394}
395
396impl<Data, VP> BuildQueryComputer<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
397where
398    Data: GraphDataType<VectorIdType = u32>,
399    VP: VertexProvider<Data>,
400{
401    type QueryComputerError = ANNError;
402    type QueryComputer = DiskQueryComputer;
403
404    fn build_query_computer(
405        &self,
406        _from: &[Data::VectorDataType],
407    ) -> Result<Self::QueryComputer, Self::QueryComputerError> {
408        Ok(DiskQueryComputer {
409            num_pq_chunks: self.provider.pq_data.get_num_chunks(),
410            query_centroid_l2_distance: self
411                .scratch
412                .pq_scratch
413                .aligned_pqtable_dist_scratch
414                .as_slice()
415                .to_vec(),
416        })
417    }
418
419    async fn distances_unordered<Itr, F>(
420        &mut self,
421        vec_id_itr: Itr,
422        _computer: &Self::QueryComputer,
423        f: F,
424    ) -> Result<(), Self::GetError>
425    where
426        F: Send + FnMut(f32, Self::Id),
427        Itr: Iterator<Item = Self::Id>,
428    {
429        self.pq_distances(&vec_id_itr.collect::<Box<[_]>>(), f)
430    }
431}
432
433impl<Data, VP> ExpandBeam<[Data::VectorDataType]> for DiskAccessor<'_, Data, VP>
434where
435    Data: GraphDataType<VectorIdType = u32>,
436    VP: VertexProvider<Data>,
437{
438    fn expand_beam<Itr, P, F>(
439        &mut self,
440        ids: Itr,
441        _computer: &Self::QueryComputer,
442        mut pred: P,
443        mut f: F,
444    ) -> impl std::future::Future<Output = Result<(), Self::GetError>> + Send
445    where
446        Itr: Iterator<Item = Self::Id> + Send,
447        P: glue::HybridPredicate<Self::Id> + Send + Sync,
448        F: FnMut(f32, Self::Id) + Send,
449    {
450        let result = (|| {
451            let io_limit = self.provider.search_io_limit - self.io_tracker.io_count();
452            let load_ids: Box<[_]> = ids.take(io_limit).collect();
453
454            self.ensure_loaded(&load_ids)?;
455            let mut ids = Vec::new();
456            for i in load_ids {
457                ids.clear();
458                ids.extend(
459                    self.scratch
460                        .vertex_provider
461                        .get_adjacency_list(&i)?
462                        .iter()
463                        .copied()
464                        .filter(|id| pred.eval_mut(id)),
465                );
466
467                self.pq_distances(&ids, &mut f)?;
468            }
469
470            Ok(())
471        })();
472
473        std::future::ready(result)
474    }
475}
476
477// Scratch space for disk search operations that need allocations.
478// These allocations are amortized across searches using the scratch pool.
479struct DiskSearchScratch<Data, VP>
480where
481    Data: GraphDataType<VectorIdType = u32>,
482    VP: VertexProvider<Data>,
483{
484    distance_cache: HashMap<u32, (f32, Data::AssociatedDataType)>,
485    pq_scratch: PQScratch,
486    vertex_provider: VP,
487}
488
489#[derive(Clone)]
490struct DiskSearchScratchArgs<'a, ProviderFactory> {
491    graph_degree: usize,
492    dim: usize,
493    num_pq_chunks: usize,
494    num_pq_centers: usize,
495    vertex_factory: &'a ProviderFactory,
496    graph_header: &'a GraphHeader,
497}
498
499impl<Data, ProviderFactory> TryAsPooled<&DiskSearchScratchArgs<'_, ProviderFactory>>
500    for DiskSearchScratch<Data, ProviderFactory::VertexProviderType>
501where
502    Data: GraphDataType<VectorIdType = u32>,
503    ProviderFactory: VertexProviderFactory<Data>,
504{
505    type Error = ANNError;
506
507    fn try_create(args: &DiskSearchScratchArgs<ProviderFactory>) -> Result<Self, Self::Error> {
508        let pq_scratch = PQScratch::new(
509            args.graph_degree,
510            args.dim,
511            args.num_pq_chunks,
512            args.num_pq_centers,
513        )?;
514
515        const DEFAULT_BEAM_WIDTH: usize = 0; // Setting as 0 to avoid preallocation of memory.
516        let vertex_provider = args
517            .vertex_factory
518            .create_vertex_provider(DEFAULT_BEAM_WIDTH, args.graph_header)?;
519
520        Ok(Self {
521            distance_cache: HashMap::new(),
522            pq_scratch,
523            vertex_provider,
524        })
525    }
526
527    fn try_modify(
528        &mut self,
529        _args: &DiskSearchScratchArgs<ProviderFactory>,
530    ) -> Result<(), Self::Error> {
531        self.distance_cache.clear();
532        self.vertex_provider.clear();
533        Ok(())
534    }
535}
536
537pub struct DiskAccessor<'a, Data, VP>
538where
539    Data: GraphDataType<VectorIdType = u32>,
540    VP: VertexProvider<Data>,
541{
542    provider: &'a DiskProvider<Data>,
543    io_tracker: &'a IOTracker,
544    scratch: PoolOption<DiskSearchScratch<Data, VP>>,
545    query: &'a [Data::VectorDataType],
546}
547
548impl<Data, VP> DiskAccessor<'_, Data, VP>
549where
550    Data: GraphDataType<VectorIdType = u32>,
551    VP: VertexProvider<Data>,
552{
553    // Compute the PQ distance between each ID in `ids` and the distance table stored in
554    // `self`, invoking the callback with the results of each computation in order.
555    fn pq_distances<F>(&mut self, ids: &[u32], mut f: F) -> ANNResult<()>
556    where
557        F: FnMut(f32, u32),
558    {
559        let pq_scratch = &mut self.scratch.pq_scratch;
560        compute_pq_distance(
561            ids,
562            self.provider.pq_data.get_num_chunks(),
563            &pq_scratch.aligned_pqtable_dist_scratch,
564            self.provider.pq_data.pq_compressed_data().get_data(),
565            &mut pq_scratch.aligned_pq_coord_scratch,
566            &mut pq_scratch.aligned_dist_scratch,
567        )?;
568
569        for (i, id) in ids.iter().enumerate() {
570            let distance = self.scratch.pq_scratch.aligned_dist_scratch[i];
571            f(distance, *id);
572        }
573
574        Ok(())
575    }
576}
577
578impl<Data, VP> SearchExt for DiskAccessor<'_, Data, VP>
579where
580    Data: GraphDataType<VectorIdType = u32>,
581    VP: VertexProvider<Data>,
582{
583    async fn starting_points(&self) -> ANNResult<Vec<u32>> {
584        let start_vertex_id = self.provider.graph_header.metadata().medoid as u32;
585        Ok(vec![start_vertex_id])
586    }
587
588    fn terminate_early(&mut self) -> bool {
589        self.io_tracker.io_count() > self.provider.search_io_limit
590    }
591}
592
593impl<'a, Data, VP> DiskAccessor<'a, Data, VP>
594where
595    Data: GraphDataType<VectorIdType = u32>,
596    VP: VertexProvider<Data>,
597{
598    fn new<VPF>(
599        provider: &'a DiskProvider<Data>,
600        io_tracker: &'a IOTracker,
601        query: &'a [Data::VectorDataType],
602        vertex_provider_factory: &'a VPF,
603        scratch_pool: &'a Arc<ObjectPool<DiskSearchScratch<Data, VP>>>,
604    ) -> ANNResult<Self>
605    where
606        VPF: VertexProviderFactory<Data, VertexProviderType = VP>,
607    {
608        let mut scratch = PoolOption::try_pooled(
609            scratch_pool,
610            &DiskSearchScratchArgs {
611                graph_degree: provider.graph_header.max_degree::<Data::VectorDataType>()?,
612                dim: provider.graph_header.metadata().dims,
613                num_pq_chunks: provider.pq_data.get_num_chunks(),
614                num_pq_centers: provider.pq_data.get_num_centers(),
615                vertex_factory: vertex_provider_factory,
616                graph_header: &provider.graph_header,
617            },
618        )?;
619
620        scratch.pq_scratch.set(
621            provider.graph_header.metadata().dims,
622            query,
623            1.0_f32, // Normalization factor
624        )?;
625        let start_vertex_id = provider.graph_header.metadata().medoid as u32;
626
627        let timer = Instant::now();
628        quantizer_preprocess(
629            &mut scratch.pq_scratch,
630            &provider.pq_data,
631            provider.metric,
632            &[start_vertex_id],
633        )?;
634        IOTracker::add_time(
635            &io_tracker.preprocess_time_us,
636            timer.elapsed().as_micros() as u64,
637        );
638
639        Ok(Self {
640            provider,
641            io_tracker,
642            scratch,
643            query,
644        })
645    }
646    fn ensure_loaded(&mut self, ids: &[u32]) -> Result<(), ANNError> {
647        if ids.is_empty() {
648            return Ok(());
649        }
650        let scratch = &mut self.scratch;
651        let timer = Instant::now();
652        ensure_vertex_loaded(&mut scratch.vertex_provider, ids)?;
653        IOTracker::add_time(
654            &self.io_tracker.io_time_us,
655            timer.elapsed().as_micros() as u64,
656        );
657        self.io_tracker.add_io_count(ids.len());
658        for id in ids {
659            let distance = self
660                .provider
661                .distance_comparer
662                .evaluate_similarity(self.query, scratch.vertex_provider.get_vector(id)?);
663            let associated_data = *scratch.vertex_provider.get_associated_data(id)?;
664            scratch
665                .distance_cache
666                .insert(*id, (distance, associated_data));
667        }
668        Ok(())
669    }
670}
671
672impl<Data, VP> HasId for DiskAccessor<'_, Data, VP>
673where
674    Data: GraphDataType<VectorIdType = u32>,
675    VP: VertexProvider<Data>,
676{
677    type Id = u32;
678}
679
680impl<'a, Data, VP> Accessor for DiskAccessor<'a, Data, VP>
681where
682    Data: GraphDataType<VectorIdType = u32>,
683    VP: VertexProvider<Data>,
684{
685    /// This references the PQ vector in the underlying `pq_data` store.
686    type Extended = &'a [u8];
687
688    /// This accessor returns raw slices. There *is* a chance of racing when the fast
689    /// providers are used. We just have to live with it.
690    ///
691    /// Since the underlying PQ store is shared, we ignore the `'b` lifetime here and
692    /// instead use `'a`.
693    type Element<'b>
694        = &'a [u8]
695    where
696        Self: 'b;
697
698    /// `ElementRef` can have arbitrary lifetimes.
699    type ElementRef<'b> = &'b [u8];
700
701    /// Choose to panic on an out-of-bounds access rather than propagate an error.
702    type GetError = ANNError;
703
704    fn get_element(
705        &mut self,
706        id: Self::Id,
707    ) -> impl Future<Output = Result<Self::Element<'_>, Self::GetError>> + Send {
708        std::future::ready(self.provider.pq_data.get_compressed_vector(id as usize))
709    }
710}
711
712impl<Data, VP> IdIterator<Range<u32>> for DiskAccessor<'_, Data, VP>
713where
714    Data: GraphDataType<VectorIdType = u32>,
715    VP: VertexProvider<Data>,
716{
717    async fn id_iterator(&mut self) -> Result<Range<u32>, ANNError> {
718        Ok(0..self.provider.num_points as u32)
719    }
720}
721
722impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP>
723where
724    Data: GraphDataType<VectorIdType = u32>,
725    VP: VertexProvider<Data>,
726{
727    type Delegate = AsNeighborAccessor<'a, 'b, Data, VP>;
728    fn delegate_neighbor(&'a mut self) -> Self::Delegate {
729        AsNeighborAccessor(self)
730    }
731}
732
733/// A light-weight wrapper around `&mut DiskAccessor` used to tailor the semantics of
734/// [`NeighborAccessor`].
735///
736/// This implementation ensures that the vector data for adjacency lists is also retrieved
737/// and cached to enhance reranking.
738pub struct AsNeighborAccessor<'a, 'b, Data, VP>(&'a mut DiskAccessor<'b, Data, VP>)
739where
740    Data: GraphDataType<VectorIdType = u32>,
741    VP: VertexProvider<Data>;
742
743impl<Data, VP> HasId for AsNeighborAccessor<'_, '_, Data, VP>
744where
745    Data: GraphDataType<VectorIdType = u32>,
746    VP: VertexProvider<Data>,
747{
748    type Id = u32;
749}
750
751impl<Data, VP> NeighborAccessor for AsNeighborAccessor<'_, '_, Data, VP>
752where
753    Data: GraphDataType<VectorIdType = u32>,
754    VP: VertexProvider<Data>,
755{
756    fn get_neighbors(
757        self,
758        id: Self::Id,
759        neighbors: &mut AdjacencyList<Self::Id>,
760    ) -> impl Future<Output = ANNResult<Self>> + Send {
761        if self.0.io_tracker.io_count() > self.0.provider.search_io_limit {
762            return future::ok(self); // Returning empty results in `neighbors` out param if IO limit is reached.
763        }
764
765        if let Err(e) = ensure_vertex_loaded(&mut self.0.scratch.vertex_provider, &[id]) {
766            return future::err(e);
767        }
768        let list = match self.0.scratch.vertex_provider.get_adjacency_list(&id) {
769            Ok(list) => list,
770            Err(e) => return future::err(e),
771        };
772        neighbors.overwrite_trusted(list);
773        future::ok(self)
774    }
775}
776
777/// [`DiskIndexSearcher`] is a helper class to make it easy to construct index
778/// and do repeated search operations. It is a wrapper around the index.
779/// This is useful for drivers such as search_disk_index.exe in tools.
780pub struct DiskIndexSearcher<
781    Data,
782    ProviderFactory = DiskVertexProviderFactory<Data, AlignedFileReaderFactory>,
783> where
784    Data: GraphDataType<VectorIdType = u32>,
785    ProviderFactory: VertexProviderFactory<Data>,
786{
787    index: DiskANNIndex<DiskProvider<Data>>,
788    runtime: Runtime,
789
790    /// The vertex provider factory is used to create the vertex provider for each search instance.
791    vertex_provider_factory: ProviderFactory,
792
793    /// Scratch pool for disk search operations that need allocations.
794    scratch_pool: Arc<ObjectPool<DiskSearchScratch<Data, ProviderFactory::VertexProviderType>>>,
795}
796
797#[derive(Debug)]
798pub struct SearchResultStats {
799    pub cmps: u32,
800    pub result_count: u32,
801    pub query_statistics: QueryStatistics,
802}
803
804/// `SearchResult` is a struct representing the result of a search operation.
805///
806/// It contains a list of vector results and a statistics object
807///
808pub struct SearchResult<AssociatedData> {
809    /// A list of nearest neighbors resulting from the search.
810    pub results: Vec<SearchResultItem<AssociatedData>>,
811    pub stats: SearchResultStats,
812}
813
814/// `VectorResult` is a struct representing a nearest neighbor resulting from a search.
815///
816/// It contains the vertex id, associated data, and the distance to the query vector.
817///
818pub struct SearchResultItem<AssociatedData> {
819    /// The vertex id of the nearest neighbor.
820    pub vertex_id: u32,
821    /// The associated data of the nearest neighbor as a fixed size byte array.
822    /// The length is determined when the index is created.
823    pub data: AssociatedData,
824    /// The distance between the nearest neighbor and the query vector.
825    pub distance: f32,
826}
827
828impl<Data, ProviderFactory> DiskIndexSearcher<Data, ProviderFactory>
829where
830    Data: GraphDataType<VectorIdType = u32>,
831    ProviderFactory: VertexProviderFactory<Data>,
832{
833    /// Create a new asynchronous disk searcher instance.
834    ///
835    /// # Arguments
836    /// * `num_threads` - The maximum number of threads to use.
837    /// * `search_io_limit` - I/O operation limit.
838    /// * `disk_index_reader` - The disk index reader.
839    /// * `vertex_provider_factory` - The vertex provider factory.
840    /// * `metric` - Distance metric used for vector similarity calculations.
841    /// * `runtime` - Tokio runtime handle for executing async operations.
842    pub fn new(
843        num_threads: usize,
844        search_io_limit: usize,
845        disk_index_reader: &DiskIndexReader<Data::VectorDataType>,
846        vertex_provider_factory: ProviderFactory,
847        metric: Metric,
848        runtime: Option<Runtime>,
849    ) -> ANNResult<Self> {
850        let runtime = match runtime {
851            Some(rt) => rt,
852            None => tokio::runtime::Builder::new_current_thread().build()?,
853        };
854
855        let graph_header = vertex_provider_factory.get_header()?;
856        let metadata = graph_header.metadata();
857        let max_degree = graph_header.max_degree::<Data::VectorDataType>()? as u32;
858
859        let config = graph::config::Builder::new(
860            max_degree.into_usize(),
861            graph::config::MaxDegree::default_slack(),
862            1, // build-search-list-size
863            metric.into(),
864        )
865        .build()?;
866
867        debug!("Creating DiskIndexSearcher with index_config: {:?}", config);
868
869        let graph_header = vertex_provider_factory.get_header()?;
870        let pq_data = disk_index_reader.get_pq_data();
871        let scratch_pool_args = DiskSearchScratchArgs {
872            graph_degree: graph_header.max_degree::<Data::VectorDataType>()?,
873            dim: graph_header.metadata().dims,
874            num_pq_chunks: pq_data.get_num_chunks(),
875            num_pq_centers: pq_data.get_num_centers(),
876            vertex_factory: &vertex_provider_factory,
877            graph_header: &graph_header,
878        };
879        let scratch_pool = Arc::new(ObjectPool::try_new(&scratch_pool_args, 0, None)?);
880
881        let disk_provider = DiskProvider::new(
882            disk_index_reader,
883            graph_header,
884            metric,
885            metadata.num_pts.into_usize(),
886            search_io_limit,
887        )?;
888
889        let index = DiskANNIndex::new(config, disk_provider, NonZeroUsize::new(num_threads));
890        Ok(Self {
891            index,
892            runtime,
893            vertex_provider_factory,
894            scratch_pool,
895        })
896    }
897
898    /// Helper method to create a DiskSearchStrategy with common parameters
899    fn search_strategy<'a>(
900        &'a self,
901        query: &'a [Data::VectorDataType],
902        vector_filter: &'a (dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
903    ) -> DiskSearchStrategy<'a, Data, ProviderFactory> {
904        DiskSearchStrategy {
905            io_tracker: IOTracker::default(),
906            vector_filter,
907            query,
908            vertex_provider_factory: &self.vertex_provider_factory,
909            scratch_pool: &self.scratch_pool,
910        }
911    }
912
913    /// Perform a search on the disk index.
914    /// return the list of nearest neighbors and associated data.
915    pub fn search(
916        &self,
917        query: &[Data::VectorDataType],
918        return_list_size: u32,
919        search_list_size: u32,
920        beam_width: Option<usize>,
921        vector_filter: Option<VectorFilter<Data>>,
922        is_flat_search: bool,
923    ) -> ANNResult<SearchResult<Data::AssociatedDataType>> {
924        let mut query_stats = QueryStatistics::default();
925        let mut indices = vec![0u32; return_list_size as usize];
926        let mut distances = vec![0f32; return_list_size as usize];
927        let mut associated_data =
928            vec![Data::AssociatedDataType::default(); return_list_size as usize];
929
930        let stats = self.search_internal(
931            query,
932            return_list_size as usize,
933            search_list_size,
934            beam_width,
935            &mut query_stats,
936            &mut indices,
937            &mut distances,
938            &mut associated_data,
939            &vector_filter.unwrap_or(default_vector_filter::<Data>()),
940            is_flat_search,
941        )?;
942
943        let mut search_result = SearchResult {
944            results: Vec::with_capacity(return_list_size as usize),
945            stats,
946        };
947
948        for ((vertex_id, distance), associated_data) in indices
949            .into_iter()
950            .zip(distances.into_iter())
951            .zip(associated_data.into_iter())
952        {
953            search_result.results.push(SearchResultItem {
954                vertex_id,
955                distance,
956                data: associated_data,
957            });
958        }
959
960        Ok(search_result)
961    }
962
963    /// Perform a raw search on the disk index.
964    /// This is a lower-level API that allows more control over the search parameters and output buffers.
965    #[allow(clippy::too_many_arguments)]
966    pub(crate) fn search_internal(
967        &self,
968        query: &[Data::VectorDataType],
969        k_value: usize,
970        search_list_size: u32,
971        beam_width: Option<usize>,
972        query_stats: &mut QueryStatistics,
973        indices: &mut [u32],
974        distances: &mut [f32],
975        associated_data: &mut [Data::AssociatedDataType],
976        vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync),
977        is_flat_search: bool,
978    ) -> ANNResult<SearchResultStats> {
979        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
980            &mut indices[..k_value],
981            &mut distances[..k_value],
982            &mut associated_data[..k_value],
983        );
984
985        let strategy = self.search_strategy(query, vector_filter);
986        let timer = Instant::now();
987        let k = k_value;
988        let l = search_list_size as usize;
989        let stats = if is_flat_search {
990            self.runtime.block_on(self.index.flat_search(
991                &strategy,
992                &DefaultContext,
993                strategy.query,
994                vector_filter,
995                &Knn::new(k, l, beam_width)?,
996                &mut result_output_buffer,
997            ))?
998        } else {
999            let knn_search = Knn::new(k, l, beam_width)?;
1000            self.runtime.block_on(self.index.search(
1001                knn_search,
1002                &strategy,
1003                &DefaultContext,
1004                strategy.query,
1005                &mut result_output_buffer,
1006            ))?
1007        };
1008        query_stats.total_comparisons = stats.cmps;
1009        query_stats.search_hops = stats.hops;
1010
1011        query_stats.total_execution_time_us = timer.elapsed().as_micros();
1012        query_stats.io_time_us = IOTracker::time(&strategy.io_tracker.io_time_us) as u128;
1013        query_stats.total_io_operations = strategy.io_tracker.io_count() as u32;
1014        query_stats.total_vertices_loaded = strategy.io_tracker.io_count() as u32;
1015        query_stats.query_pq_preprocess_time_us =
1016            IOTracker::time(&strategy.io_tracker.preprocess_time_us) as u128;
1017        query_stats.cpu_time_us = query_stats.total_execution_time_us
1018            - query_stats.io_time_us
1019            - query_stats.query_pq_preprocess_time_us;
1020        Ok(SearchResultStats {
1021            cmps: query_stats.total_comparisons,
1022            result_count: stats.result_count,
1023            query_statistics: query_stats.clone(),
1024        })
1025    }
1026}
1027
1028/// Helper function to ensure vertices are loaded and processed.
1029///
1030/// This is a convenience function that combines `load_vertices` and `process_loaded_node`
1031/// for each vertex ID. It first loads all the vertices in batch, then processes each
1032/// loaded node.
1033fn ensure_vertex_loaded<Data: GraphDataType, V: VertexProvider<Data>>(
1034    vertex_provider: &mut V,
1035    ids: &[Data::VectorIdType],
1036) -> ANNResult<()> {
1037    vertex_provider.load_vertices(ids)?;
1038    for (idx, id) in ids.iter().enumerate() {
1039        vertex_provider.process_loaded_node(id, idx)?;
1040    }
1041    Ok(())
1042}
1043
1044#[cfg(test)]
1045mod disk_provider_tests {
1046    use diskann::{
1047        graph::{
1048            search::{record::VisitedSearchRecord, Knn},
1049            KnnSearchError,
1050        },
1051        utils::IntoUsize,
1052        ANNErrorKind,
1053    };
1054    use diskann_providers::storage::{
1055        DynWriteProvider, StorageReadProvider, VirtualStorageProvider,
1056    };
1057    use diskann_providers::{
1058        common::AlignedBoxWithSlice,
1059        test_utils::graph_data_type_utils::{
1060            GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
1061        },
1062        utils::{
1063            create_thread_pool, file_util, load_aligned_bin, PQPathNames, ParallelIteratorInPool,
1064        },
1065    };
1066    use diskann_utils::test_data_root;
1067    use diskann_vector::distance::Metric;
1068    use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator};
1069    use rstest::rstest;
1070    use vfs::OverlayFS;
1071
1072    use super::*;
1073    use crate::{
1074        build::builder::core::disk_index_builder_tests::{IndexBuildFixture, TestParams},
1075        utils::{QueryStatistics, VirtualAlignedReaderFactory},
1076    };
1077
1078    const TEST_INDEX_PREFIX_128DIM: &str =
1079        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1080    const TEST_INDEX_128DIM: &str =
1081        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1082    const TEST_PQ_PIVOT_128DIM: &str =
1083        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1084    const TEST_PQ_COMPRESSED_128DIM: &str =
1085        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1086    const TEST_TRUTH_RESULT_10PTS_128DIM: &str =
1087        "/disk_index_search/disk_index_10pts_idx_uint32_truth_search_res.bin";
1088    const TEST_QUERY_10PTS_128DIM: &str = "/disk_index_search/disk_index_sample_query_10pts.fbin";
1089
1090    const TEST_INDEX_PREFIX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index";
1091    const TEST_INDEX_100DIM: &str = "/disk_index_search/256pts_100dim_f32_truth_Index_disk.index";
1092    const TEST_PQ_PIVOT_100DIM: &str =
1093        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_pivots.bin";
1094    const TEST_PQ_COMPRESSED_100DIM: &str =
1095        "/disk_index_search/256pts_100dim_f32_truth_Index_pq_compressed.bin";
1096    const TEST_TRUTH_RESULT_10PTS_100DIM: &str =
1097        "/disk_index_search/256pts_100dim_f32_truth_query_result.bin";
1098    const TEST_QUERY_10PTS_100DIM: &str = "/disk_index_search/10pts_100dim_f32_base_query.bin";
1099    const TEST_DATA_FILE: &str = "/disk_index_search/disk_index_siftsmall_learn_256pts_data.fbin";
1100    const TEST_INDEX: &str =
1101        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_disk.index";
1102    const TEST_INDEX_PREFIX: &str =
1103        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search";
1104    const TEST_PQ_PIVOT: &str =
1105        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_pivots.bin";
1106    const TEST_PQ_COMPRESSED: &str =
1107        "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_truth_search_pq_compressed.bin";
1108
1109    #[test]
1110    fn test_disk_search_k10_l20_single_or_multi_thread_100dim() {
1111        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1112
1113        let search_engine = create_disk_index_searcher(
1114            CreateDiskIndexSearcherParams {
1115                max_thread_num: 5,
1116                pq_pivot_file_path: TEST_PQ_PIVOT_100DIM,
1117                pq_compressed_file_path: TEST_PQ_COMPRESSED_100DIM,
1118                index_path: TEST_INDEX_100DIM,
1119                index_path_prefix: TEST_INDEX_PREFIX_100DIM,
1120                ..Default::default()
1121            },
1122            &storage_provider,
1123        );
1124        // Test single thread.
1125        test_disk_search(TestDiskSearchParams {
1126            storage_provider: storage_provider.as_ref(),
1127            index_search_engine: &search_engine,
1128            thread_num: 1,
1129            query_file_path: TEST_QUERY_10PTS_100DIM,
1130            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1131            k: 10,
1132            l: 20,
1133            dim: 104,
1134        });
1135        // Test multi thread.
1136        test_disk_search(TestDiskSearchParams {
1137            storage_provider: storage_provider.as_ref(),
1138            index_search_engine: &search_engine,
1139            thread_num: 5,
1140            query_file_path: TEST_QUERY_10PTS_100DIM,
1141            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_100DIM,
1142            k: 10,
1143            l: 20,
1144            dim: 104,
1145        });
1146    }
1147
1148    #[test]
1149    fn test_disk_search_k10_l20_single_or_multi_thread_128dim() {
1150        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1151
1152        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1153            CreateDiskIndexSearcherParams {
1154                max_thread_num: 5,
1155                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1156                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1157                index_path: TEST_INDEX_128DIM,
1158                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1159                ..Default::default()
1160            },
1161            &storage_provider,
1162        );
1163        // Test single thread.
1164        test_disk_search(TestDiskSearchParams {
1165            storage_provider: storage_provider.as_ref(),
1166            index_search_engine: &search_engine,
1167            thread_num: 1,
1168            query_file_path: TEST_QUERY_10PTS_128DIM,
1169            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1170            k: 10,
1171            l: 20,
1172            dim: 128,
1173        });
1174        // Test multi thread.
1175        test_disk_search(TestDiskSearchParams {
1176            storage_provider: storage_provider.as_ref(),
1177            index_search_engine: &search_engine,
1178            thread_num: 5,
1179            query_file_path: TEST_QUERY_10PTS_128DIM,
1180            truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1181            k: 10,
1182            l: 20,
1183            dim: 128,
1184        });
1185    }
1186
1187    fn get_truth_associated_data<StorageReader: StorageReadProvider>(
1188        storage_provider: &StorageReader,
1189    ) -> Vec<u32> {
1190        const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
1191
1192        let (data, _npts, _dim) =
1193            file_util::load_bin::<u32, StorageReader>(storage_provider, ASSOCIATED_DATA_FILE, 0)
1194                .unwrap();
1195        data
1196    }
1197
1198    #[test]
1199    fn test_disk_search_with_associated_data_k10_l20_single_or_multi_thread_128dim() {
1200        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1201        let index_path_prefix = "/disk_index_search/disk_index_sift_learn_R4_L50_A1.2_test_disk_index_search_associated_data";
1202        let params = TestParams {
1203            data_path: TEST_DATA_FILE.to_string(),
1204            index_path_prefix: index_path_prefix.to_string(),
1205            associated_data_path: Some(
1206                "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
1207            ),
1208            ..TestParams::default()
1209        };
1210        let fixture = IndexBuildFixture::new(storage_provider, params).unwrap();
1211        // Build the index with the associated data
1212        fixture.build::<GraphDataF32VectorU32Data>().unwrap();
1213        {
1214            let search_engine = create_disk_index_searcher::<GraphDataF32VectorU32Data>(
1215                CreateDiskIndexSearcherParams {
1216                    max_thread_num: 5,
1217                    pq_pivot_file_path: format!("{}_pq_pivots.bin", index_path_prefix).as_str(),
1218                    pq_compressed_file_path: format!("{}_pq_compressed.bin", index_path_prefix)
1219                        .as_str(),
1220                    index_path: format!("{}_disk.index", index_path_prefix).as_str(), //TEST_INDEX_128DIM,
1221                    index_path_prefix,
1222                    ..Default::default()
1223                },
1224                &fixture.storage_provider,
1225            );
1226
1227            // Test single thread.
1228            test_disk_search_with_associated(
1229                TestDiskSearchAssociateParams {
1230                    storage_provider: fixture.storage_provider.as_ref(),
1231                    index_search_engine: &search_engine,
1232                    thread_num: 1,
1233                    query_file_path: TEST_QUERY_10PTS_128DIM,
1234                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1235                    k: 10,
1236                    l: 20,
1237                    dim: 128,
1238                },
1239                None,
1240            );
1241
1242            // Test multi thread.
1243            test_disk_search_with_associated(
1244                TestDiskSearchAssociateParams {
1245                    storage_provider: fixture.storage_provider.as_ref(),
1246                    index_search_engine: &search_engine,
1247                    thread_num: 5,
1248                    query_file_path: TEST_QUERY_10PTS_128DIM,
1249                    truth_result_file_path: TEST_TRUTH_RESULT_10PTS_128DIM,
1250                    k: 10,
1251                    l: 20,
1252                    dim: 128,
1253                },
1254                None,
1255            );
1256        }
1257
1258        fixture
1259            .storage_provider
1260            .delete(&format!("{}_disk.index", index_path_prefix))
1261            .expect("Failed to delete file");
1262        fixture
1263            .storage_provider
1264            .delete(&format!("{}_pq_pivots.bin", index_path_prefix))
1265            .expect("Failed to delete file");
1266        fixture
1267            .storage_provider
1268            .delete(&format!("{}_pq_compressed.bin", index_path_prefix))
1269            .expect("Failed to delete file");
1270    }
1271
1272    struct CreateDiskIndexSearcherParams<'a> {
1273        max_thread_num: usize,
1274        pq_pivot_file_path: &'a str,
1275        pq_compressed_file_path: &'a str,
1276        index_path: &'a str,
1277        index_path_prefix: &'a str,
1278        io_limit: usize,
1279    }
1280
1281    impl Default for CreateDiskIndexSearcherParams<'_> {
1282        fn default() -> Self {
1283            Self {
1284                max_thread_num: 1,
1285                pq_pivot_file_path: "",
1286                pq_compressed_file_path: "",
1287                index_path: "",
1288                index_path_prefix: "",
1289                io_limit: usize::MAX,
1290            }
1291        }
1292    }
1293
1294    fn create_disk_index_searcher<Data>(
1295        params: CreateDiskIndexSearcherParams,
1296        storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1297    ) -> DiskIndexSearcher<
1298        Data,
1299        DiskVertexProviderFactory<Data, VirtualAlignedReaderFactory<OverlayFS>>,
1300    >
1301    where
1302        Data: GraphDataType<VectorIdType = u32>,
1303    {
1304        assert!(params.io_limit > 0);
1305
1306        let runtime = tokio::runtime::Builder::new_multi_thread()
1307            .worker_threads(params.max_thread_num)
1308            .build()
1309            .unwrap();
1310
1311        let disk_index_reader = DiskIndexReader::<Data::VectorDataType>::new(
1312            params.pq_pivot_file_path.to_string(),
1313            params.pq_compressed_file_path.to_string(),
1314            storage_provider.as_ref(),
1315        )
1316        .unwrap();
1317
1318        let aligned_reader_factory = VirtualAlignedReaderFactory::new(
1319            get_disk_index_file(params.index_path_prefix),
1320            Arc::clone(storage_provider),
1321        );
1322        let caching_strategy = CachingStrategy::None;
1323        let vertex_provider_factory =
1324            DiskVertexProviderFactory::<Data, _>::new(aligned_reader_factory, caching_strategy)
1325                .unwrap();
1326
1327        DiskIndexSearcher::<Data, DiskVertexProviderFactory<Data, _>>::new(
1328            params.max_thread_num,
1329            params.io_limit,
1330            &disk_index_reader,
1331            vertex_provider_factory,
1332            Metric::L2,
1333            Some(runtime),
1334        )
1335        .unwrap()
1336    }
1337
1338    fn load_query_result<StorageReader: StorageReadProvider>(
1339        storage_provider: &StorageReader,
1340        query_result_path: &str,
1341    ) -> Vec<u32> {
1342        let (result, _, _) =
1343            file_util::load_bin::<u32, StorageReader>(storage_provider, query_result_path, 0)
1344                .unwrap();
1345        result
1346    }
1347
1348    struct TestDiskSearchParams<'a, StorageType> {
1349        storage_provider: &'a StorageType,
1350        index_search_engine: &'a DiskIndexSearcher<
1351            GraphDataF32VectorUnitData,
1352            DiskVertexProviderFactory<
1353                GraphDataF32VectorUnitData,
1354                VirtualAlignedReaderFactory<OverlayFS>,
1355            >,
1356        >,
1357        thread_num: u64,
1358        query_file_path: &'a str,
1359        truth_result_file_path: &'a str,
1360        k: usize,
1361        l: usize,
1362        dim: usize,
1363    }
1364
1365    struct TestDiskSearchAssociateParams<'a, StorageType> {
1366        storage_provider: &'a StorageType,
1367        index_search_engine: &'a DiskIndexSearcher<
1368            GraphDataF32VectorU32Data,
1369            DiskVertexProviderFactory<
1370                GraphDataF32VectorU32Data,
1371                VirtualAlignedReaderFactory<OverlayFS>,
1372            >,
1373        >,
1374        thread_num: u64,
1375        query_file_path: &'a str,
1376        truth_result_file_path: &'a str,
1377        k: usize,
1378        l: usize,
1379        dim: usize,
1380    }
1381
1382    fn test_disk_search<StorageType: StorageReadProvider>(
1383        params: TestDiskSearchParams<StorageType>,
1384    ) {
1385        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1386            .unwrap()
1387            .0;
1388        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1389        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1390
1391        let queries = aligned_query
1392            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1393            .unwrap();
1394
1395        let truth_result =
1396            load_query_result(params.storage_provider, params.truth_result_file_path);
1397
1398        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1399        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1400        queries
1401            .par_iter()
1402            .enumerate()
1403            .for_each_in_pool(&pool, |(i, query)| {
1404                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1405                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1406                let mut temp = Vec::with_capacity(query.len() + 1);
1407                temp.push(0.0);
1408                temp.extend_from_slice(query);
1409                aligned_box.memcpy(temp.as_slice()).unwrap();
1410                let query = &aligned_box.as_slice()[1..];
1411
1412                let mut query_stats = QueryStatistics::default();
1413                let mut indices = vec![0u32; 10];
1414                let mut distances = vec![0f32; 10];
1415                let mut associated_data = vec![(); 10];
1416
1417                let result = params
1418                    .index_search_engine
1419                    //.search_with_associated_data(query, params.k as u32, params.l as u32)
1420                    .search_internal(
1421                        query,
1422                        params.k,
1423                        params.l as u32,
1424                        None, // beam_width
1425                        &mut query_stats,
1426                        &mut indices,
1427                        &mut distances,
1428                        &mut associated_data,
1429                        &(|_| true),
1430                        false,
1431                    );
1432
1433                // Calculate the range of the truth_result for this query
1434                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1435
1436                assert!(result.is_ok(), "Expected search to succeed");
1437
1438                let result_unwrapped = result.unwrap();
1439                assert!(
1440                    result_unwrapped.query_statistics.total_io_operations > 0,
1441                    "Expected IO operations to be greater than 0"
1442                );
1443                assert!(
1444                    result_unwrapped.query_statistics.total_vertices_loaded > 0,
1445                    "Expected vertices loaded to be greater than 0"
1446                );
1447
1448                // Compare res with truth_slice using assert_eq!
1449                assert_eq!(
1450                    indices, truth_slice,
1451                    "Results DO NOT match with the truth result for query {}",
1452                    i
1453                );
1454            });
1455    }
1456
1457    fn test_disk_search_with_associated<StorageType: StorageReadProvider>(
1458        params: TestDiskSearchAssociateParams<StorageType>,
1459        beam_width: Option<usize>,
1460    ) {
1461        let query_vector = load_aligned_bin(params.storage_provider, params.query_file_path)
1462            .unwrap()
1463            .0;
1464        let mut aligned_query = AlignedBoxWithSlice::<f32>::new(query_vector.len(), 32).unwrap();
1465        aligned_query.memcpy(query_vector.as_slice()).unwrap();
1466        let queries = aligned_query
1467            .split_into_nonoverlapping_mut_slices(0..aligned_query.len(), params.dim)
1468            .unwrap();
1469        let truth_result =
1470            load_query_result(params.storage_provider, params.truth_result_file_path);
1471        let pool = create_thread_pool(params.thread_num.into_usize()).unwrap();
1472        // Convert query_vector to number of Vertex with data type f32 and dimension equals to dim.
1473        queries
1474            .par_iter()
1475            .enumerate()
1476            .for_each_in_pool(&pool, |(i, query)| {
1477                // Test search_with_associated_data with an unaligned query. Some distance functions require aligned data.
1478                let mut aligned_box = AlignedBoxWithSlice::<f32>::new(query.len() + 1, 32).unwrap();
1479                let mut temp = Vec::with_capacity(query.len() + 1);
1480                temp.push(0.0);
1481                temp.extend_from_slice(query);
1482                aligned_box.memcpy(temp.as_slice()).unwrap();
1483                let query = &aligned_box.as_slice()[1..];
1484                let result = params
1485                    .index_search_engine
1486                    .search(query, params.k as u32, params.l as u32, beam_width, None, false)
1487                    .unwrap();
1488                let indices: Vec<u32> = result.results.iter().map(|item| item.vertex_id).collect();
1489                let associated_data: Vec<u32> =
1490                    result.results.iter().map(|item| item.data).collect();
1491                let truth_data = get_truth_associated_data(params.storage_provider);
1492                let associated_data_truth: Vec<u32> = indices
1493                    .iter()
1494                    .map(|&vid| truth_data[vid as usize])
1495                    .collect();
1496                assert_eq!(
1497                    associated_data, associated_data_truth,
1498                    "Associated data DO NOT match with the truth result for query {}, associated_data from search: {:?}, associated_data from truth result: {:?}",
1499                    i,associated_data, associated_data_truth
1500                );
1501                let truth_slice = &truth_result[i * params.k..(i + 1) * params.k];
1502                assert_eq!(
1503                    indices, truth_slice,
1504                    "Results DO NOT match with the truth result for query {}",
1505                    i
1506                );
1507            });
1508    }
1509
1510    #[test]
1511    fn test_disk_search_invalid_input() {
1512        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1513        let ctx = &DefaultContext;
1514
1515        let params = CreateDiskIndexSearcherParams {
1516            max_thread_num: 5,
1517            pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1518            pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1519            index_path: TEST_INDEX_128DIM,
1520            index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1521            ..Default::default()
1522        };
1523
1524        let paths = PQPathNames::for_disk_index(TEST_INDEX_PREFIX_128DIM);
1525        assert_eq!(
1526            paths.pivots, params.pq_pivot_file_path,
1527            "pq_pivot_file_path is not correct"
1528        );
1529        assert_eq!(
1530            paths.compressed_data, params.pq_compressed_file_path,
1531            "pq_compressed_file_path is not correct"
1532        );
1533        assert_eq!(
1534            params.index_path,
1535            format!("{}_disk.index", params.index_path_prefix),
1536            "index_path is not correct"
1537        );
1538
1539        // Test error case: l < k
1540        let res = Knn::new_default(20, 10);
1541        assert!(res.is_err());
1542        assert_eq!(
1543            <KnnSearchError as std::convert::Into<ANNError>>::into(res.unwrap_err()).kind(),
1544            ANNErrorKind::IndexError
1545        );
1546        // Test error case: beam_width = 0
1547        let res = Knn::new(10, 10, Some(0));
1548        assert!(res.is_err());
1549
1550        let search_engine =
1551            create_disk_index_searcher::<GraphDataF32VectorU32Data>(params, &storage_provider);
1552
1553        // minor validation tests to improve code coverage
1554        assert_eq!(
1555            search_engine
1556                .index
1557                .data_provider
1558                .to_external_id(ctx, 0)
1559                .unwrap(),
1560            0
1561        );
1562        assert_eq!(
1563            search_engine
1564                .index
1565                .data_provider
1566                .to_internal_id(ctx, &0)
1567                .unwrap(),
1568            0
1569        );
1570
1571        let provider_max_degree = search_engine
1572            .index
1573            .data_provider
1574            .graph_header
1575            .max_degree::<<GraphDataF32VectorU32Data as GraphDataType>::VectorDataType>()
1576            .unwrap();
1577        let index_max_degree = search_engine.index.config.pruned_degree().get();
1578        assert_eq!(provider_max_degree, index_max_degree);
1579
1580        let query = vec![0f32; 128];
1581        let mut query_stats = QueryStatistics::default();
1582        let mut indices = vec![0u32; 10];
1583        let mut distances = vec![0f32; 10];
1584        let mut associated_data = vec![0u32; 10];
1585
1586        // Set L: {} to a value of at least K:
1587        let result = search_engine.search_internal(
1588            &query,
1589            10,
1590            10 - 1,
1591            None,
1592            &mut query_stats,
1593            &mut indices,
1594            &mut distances,
1595            &mut associated_data,
1596            &|_| true,
1597            false,
1598        );
1599
1600        assert!(result.is_err());
1601        assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexError);
1602    }
1603
1604    #[test]
1605    fn test_disk_search_beam_search() {
1606        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1607
1608        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1609            CreateDiskIndexSearcherParams {
1610                max_thread_num: 1,
1611                pq_pivot_file_path: TEST_PQ_PIVOT,
1612                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1613                index_path: TEST_INDEX,
1614                index_path_prefix: TEST_INDEX_PREFIX,
1615                ..Default::default()
1616            },
1617            &storage_provider,
1618        );
1619
1620        let query_vector: [f32; 128] = [1f32; 128];
1621        let mut indices = vec![0u32; 10];
1622        let mut distances = vec![0f32; 10];
1623        let mut associated_data = vec![(); 10];
1624
1625        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1626            &mut indices,
1627            &mut distances,
1628            &mut associated_data,
1629        );
1630        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1631        let mut search_record = VisitedSearchRecord::new(0);
1632        let search_params = Knn::new(10, 10, Some(4)).unwrap();
1633        let recorded_search =
1634            diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
1635        search_engine
1636            .runtime
1637            .block_on(search_engine.index.search(
1638                recorded_search,
1639                &strategy,
1640                &DefaultContext,
1641                query_vector.as_slice(),
1642                &mut result_output_buffer,
1643            ))
1644            .unwrap();
1645
1646        let ids = search_record
1647            .visited
1648            .iter()
1649            .map(|n| n.id)
1650            .collect::<Vec<_>>();
1651
1652        const EXPECTED_NODES: [u32; 18] = [
1653            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141, 180,
1654        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
1655
1656        assert_eq!(ids, &EXPECTED_NODES);
1657
1658        let return_list_size = 10;
1659        let search_list_size = 10;
1660        let result = search_engine.search(
1661            &query_vector,
1662            return_list_size,
1663            search_list_size,
1664            Some(4),
1665            None,
1666            false,
1667        );
1668        assert!(result.is_ok(), "Expected search to succeed");
1669        let search_result = result.unwrap();
1670        assert_eq!(
1671            search_result.results.len() as u32,
1672            return_list_size,
1673            "Expected result count to match"
1674        );
1675        assert_eq!(
1676            indices,
1677            vec![152, 72, 170, 118, 87, 165, 79, 141, 108, 86],
1678            "Expected indices to match"
1679        );
1680    }
1681
1682    #[cfg(feature = "experimental_diversity_search")]
1683    #[test]
1684    fn test_disk_search_diversity_search() {
1685        use diskann::graph::DiverseSearchParams;
1686        use diskann::neighbor::AttributeValueProvider;
1687        use std::collections::HashMap;
1688
1689        // Simple test attribute provider
1690        #[derive(Debug, Clone)]
1691        struct TestAttributeProvider {
1692            attributes: HashMap<u32, u32>,
1693        }
1694        impl TestAttributeProvider {
1695            fn new() -> Self {
1696                Self {
1697                    attributes: HashMap::new(),
1698                }
1699            }
1700            fn insert(&mut self, id: u32, attribute: u32) {
1701                self.attributes.insert(id, attribute);
1702            }
1703        }
1704        impl diskann::provider::HasId for TestAttributeProvider {
1705            type Id = u32;
1706        }
1707
1708        impl AttributeValueProvider for TestAttributeProvider {
1709            type Value = u32;
1710
1711            fn get(&self, id: Self::Id) -> Option<Self::Value> {
1712                self.attributes.get(&id).copied()
1713            }
1714        }
1715
1716        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1717
1718        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1719            CreateDiskIndexSearcherParams {
1720                max_thread_num: 1,
1721                pq_pivot_file_path: TEST_PQ_PIVOT,
1722                pq_compressed_file_path: TEST_PQ_COMPRESSED,
1723                index_path: TEST_INDEX,
1724                index_path_prefix: TEST_INDEX_PREFIX,
1725                ..Default::default()
1726            },
1727            &storage_provider,
1728        );
1729
1730        let query_vector: [f32; 128] = [1f32; 128];
1731
1732        // Create attribute provider with random labels (1 to 3) for all vectors
1733        let mut attribute_provider = TestAttributeProvider::new();
1734        let num_vectors = 256; // Number of vectors in the test dataset
1735        for i in 0..num_vectors {
1736            // Assign labels 1-3 based on modulo to ensure distribution
1737            let label = (i % 15) + 1;
1738            attribute_provider.insert(i, label);
1739        }
1740        // Wrap in Arc once to avoid cloning the HashMap later
1741        let attribute_provider = std::sync::Arc::new(attribute_provider);
1742
1743        let mut indices = vec![0u32; 10];
1744        let mut distances = vec![0f32; 10];
1745        let mut associated_data = vec![(); 10];
1746
1747        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
1748            &mut indices,
1749            &mut distances,
1750            &mut associated_data,
1751        );
1752        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
1753
1754        // Create diverse search parameters with attribute provider
1755        let diverse_params = DiverseSearchParams::new(
1756            0, // diverse_attribute_id
1757            3, // diverse_results_k
1758            attribute_provider.clone(),
1759        );
1760
1761        let search_params = Knn::new(10, 20, None).unwrap();
1762
1763        let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params);
1764        let stats = search_engine
1765            .runtime
1766            .block_on(search_engine.index.search(
1767                diverse_search,
1768                &strategy,
1769                &DefaultContext,
1770                query_vector.as_slice(),
1771                &mut result_output_buffer,
1772            ))
1773            .unwrap();
1774
1775        // Verify that search was performed and returned some results
1776        assert!(
1777            stats.result_count > 0,
1778            "Expected to get some results during diversity search"
1779        );
1780
1781        let return_list_size = 10;
1782        let search_list_size = 20;
1783        let diverse_results_k = 1;
1784        let diverse_params = DiverseSearchParams::new(
1785            0, // diverse_attribute_id
1786            diverse_results_k,
1787            attribute_provider.clone(),
1788        );
1789
1790        // Test diverse search using the search API
1791        let mut indices2 = vec![0u32; return_list_size as usize];
1792        let mut distances2 = vec![0f32; return_list_size as usize];
1793        let mut associated_data2 = vec![(); return_list_size as usize];
1794        let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new(
1795            &mut indices2,
1796            &mut distances2,
1797            &mut associated_data2,
1798        );
1799        let strategy2 = search_engine.search_strategy(&query_vector, &|_| true);
1800        let search_params2 =
1801            Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap();
1802
1803        let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params);
1804        let stats = search_engine
1805            .runtime
1806            .block_on(search_engine.index.search(
1807                diverse_search2,
1808                &strategy2,
1809                &DefaultContext,
1810                query_vector.as_slice(),
1811                &mut result_output_buffer2,
1812            ))
1813            .unwrap();
1814
1815        // Verify results
1816        assert!(
1817            stats.result_count > 0,
1818            "Expected diversity search to return results"
1819        );
1820        assert!(
1821            stats.result_count <= return_list_size,
1822            "Expected result count to be <= {}",
1823            return_list_size
1824        );
1825
1826        // Verify that we got some results
1827        assert!(
1828            stats.result_count > 0,
1829            "Expected to get some search results"
1830        );
1831
1832        // Print search results with their attributes
1833        println!("\n=== Diversity Search Results ===");
1834        println!("Query: [1f32; 128]");
1835        println!("diverse_results_k: {}", diverse_results_k);
1836        println!("Total results: {}\n", stats.result_count);
1837        println!("{:<10} {:<15} {:<10}", "Vertex ID", "Distance", "Label");
1838        println!("{}", "-".repeat(35));
1839        for i in 0..stats.result_count as usize {
1840            let attribute_value = attribute_provider.get(indices2[i]).unwrap_or(0);
1841            println!(
1842                "{:<10} {:<15.2} {:<10}",
1843                indices2[i], distances2[i], attribute_value
1844            );
1845        }
1846
1847        // Verify that distances are non-negative and sorted
1848        for i in 0..(stats.result_count as usize).saturating_sub(1) {
1849            assert!(distances2[i] >= 0.0, "Expected non-negative distance");
1850            assert!(
1851                distances2[i] <= distances2[i + 1],
1852                "Expected distances to be sorted in ascending order"
1853            );
1854        }
1855
1856        // Verify diversity: Check that we have diverse attribute values in the results
1857        let mut attribute_counts = HashMap::new();
1858        for item in indices2.iter().take(stats.result_count as usize) {
1859            if let Some(attribute_value) = attribute_provider.get(*item) {
1860                *attribute_counts.entry(attribute_value).or_insert(0) += 1;
1861            }
1862        }
1863
1864        // Print attribute distribution
1865        println!("\n=== Attribute Distribution ===");
1866        let mut sorted_attrs: Vec<_> = attribute_counts.iter().collect();
1867        sorted_attrs.sort_by_key(|(k, _)| *k);
1868        for (attribute_value, count) in &sorted_attrs {
1869            println!(
1870                "Label {}: {} occurrences (max allowed: {})",
1871                attribute_value, count, diverse_results_k
1872            );
1873        }
1874        println!("Total unique labels: {}", attribute_counts.len());
1875        println!("================================\n");
1876
1877        // With diverse_results_k = 5, we expect at most 5 results per attribute value
1878        for (attribute_value, count) in &attribute_counts {
1879            println!(
1880                "Assert: Label {} has {} occurrences (max: {})",
1881                attribute_value, count, diverse_results_k
1882            );
1883            assert!(
1884                *count <= diverse_results_k,
1885                "Attribute value {} appears {} times, which exceeds diverse_results_k of {}",
1886                attribute_value,
1887                count,
1888                diverse_results_k
1889            );
1890        }
1891
1892        // Verify that we have multiple different attribute values (diversity)
1893        // With 3 possible labels and diverse_results_k=5, we should see at least 2 different labels
1894        println!(
1895            "Assert: Found {} unique labels (expected at least 2)",
1896            attribute_counts.len()
1897        );
1898        assert!(
1899            attribute_counts.len() >= 2,
1900            "Expected at least 2 different attribute values for diversity, got {}",
1901            attribute_counts.len()
1902        );
1903    }
1904
1905    #[rstest]
1906    // This case checks expected behavior of unfiltered search.
1907    #[case(
1908        |_id: &u32| true,
1909        false,
1910        10,
1911        vec![152, 118, 72, 170, 87, 141, 79, 207, 124, 86],
1912        vec![256101.7, 256675.3, 256709.69, 256712.5, 256760.08, 256958.5, 257006.1, 257025.7, 257105.67, 257107.67],
1913    )]
1914    // This case validates post-filtering using 2 ids which are not present in the unfiltered result set.
1915    // It is expected that the post-filtering will return an empty result
1916    #[case(
1917        |id: &u32| *id == 0 || *id == 1,
1918        false,
1919        0,
1920        vec![0; 10],
1921        vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1922    )]
1923    // This case validates pre-filtering using 2 ids which are not present in the unfiltered result set.
1924    // It is expected that the pre-filtering will do search over matching ids
1925    #[case(
1926        |id: &u32| *id == 0 || *id == 1,
1927        true,
1928        2,
1929        vec![1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
1930        vec![257247.28, 258179.28, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1931    )]
1932    // This case validates post-filtering using 3 ids from the unfiltered result set.
1933    // It is expected that the post-filtering will filter out non-matching ids
1934    #[case(
1935        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1936        false,
1937        3,
1938        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1939        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1940    )]
1941    // This case validates pre-filtering using 3 ids from the unfiltered result set.
1942    // It is expected that the pre-filtering will do search over matching ids
1943    #[case(
1944        |id: &u32| *id == 72 || *id == 87 || *id == 170,
1945        true,
1946        3,
1947        vec![72, 170, 87, 0, 0, 0, 0, 0, 0, 0],
1948        vec![256709.69, 256712.5, 256760.08, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1949    )]
1950    fn test_search_with_vector_filter(
1951        #[case] vector_filter: fn(&u32) -> bool,
1952        #[case] is_flat_search: bool,
1953        #[case] expected_result_count: u32,
1954        #[case] expected_indices: Vec<u32>,
1955        #[case] expected_distances: Vec<f32>,
1956    ) {
1957        // Exact distances can vary slightly depending on the architecture used
1958        // to compute distances due to different unrolling strategies and SIMD widthd.
1959        //
1960        // This parameter allows for a small margin when matching distances.
1961        let check_distances = |got: &[f32], expected: &[f32]| -> bool {
1962            const ABS_TOLERANCE: f32 = 0.02;
1963            assert_eq!(got.len(), expected.len());
1964            for (i, (g, e)) in std::iter::zip(got.iter(), expected.iter()).enumerate() {
1965                if (g - e).abs() > ABS_TOLERANCE {
1966                    panic!(
1967                        "distances differ at position {} by more than {}\n\n\
1968                         got: {:?}\nexpected: {:?}",
1969                        i, ABS_TOLERANCE, got, expected,
1970                    );
1971                }
1972            }
1973            true
1974        };
1975
1976        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
1977
1978        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
1979            CreateDiskIndexSearcherParams {
1980                max_thread_num: 5,
1981                pq_pivot_file_path: TEST_PQ_PIVOT_128DIM,
1982                pq_compressed_file_path: TEST_PQ_COMPRESSED_128DIM,
1983                index_path: TEST_INDEX_128DIM,
1984                index_path_prefix: TEST_INDEX_PREFIX_128DIM,
1985                ..Default::default()
1986            },
1987            &storage_provider,
1988        );
1989        let query = vec![0.1f32; 128];
1990        let mut query_stats = QueryStatistics::default();
1991        let mut indices = vec![0u32; 10];
1992        let mut distances = vec![0f32; 10];
1993        let mut associated_data = vec![(); 10];
1994
1995        let result = search_engine.search_internal(
1996            &query,
1997            10,
1998            10,
1999            None, // beam_width
2000            &mut query_stats,
2001            &mut indices,
2002            &mut distances,
2003            &mut associated_data,
2004            &vector_filter,
2005            is_flat_search,
2006        );
2007
2008        assert!(result.is_ok(), "Expected search to succeed");
2009        assert_eq!(
2010            result.unwrap().result_count,
2011            expected_result_count,
2012            "Expected result count to match"
2013        );
2014        assert_eq!(indices, expected_indices, "Expected indices to match");
2015        assert!(
2016            check_distances(&distances, &expected_distances),
2017            "Expected distances to match"
2018        );
2019
2020        let result_with_filter = search_engine.search(
2021            &query,
2022            10,
2023            10,
2024            None, // beam_width
2025            Some(Box::new(vector_filter)),
2026            is_flat_search,
2027        );
2028
2029        assert!(result_with_filter.is_ok(), "Expected search to succeed");
2030        let result_with_filter_unwrapped = result_with_filter.unwrap();
2031        assert_eq!(
2032            result_with_filter_unwrapped.stats.result_count, expected_result_count,
2033            "Expected result count to match"
2034        );
2035        let actual_indices = result_with_filter_unwrapped
2036            .results
2037            .iter()
2038            .map(|x| x.vertex_id)
2039            .collect::<Vec<_>>();
2040        assert_eq!(
2041            actual_indices, expected_indices,
2042            "Expected indices to match"
2043        );
2044        let actual_distances = result_with_filter_unwrapped
2045            .results
2046            .iter()
2047            .map(|x| x.distance)
2048            .collect::<Vec<_>>();
2049        assert!(
2050            check_distances(&actual_distances, &expected_distances),
2051            "Expected distances to match"
2052        );
2053    }
2054
2055    #[test]
2056    fn test_beam_search_respects_io_limit() {
2057        let io_limit = 11; // Set a small IO limit for testing
2058        let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root()));
2059
2060        let search_engine = create_disk_index_searcher::<GraphDataF32VectorUnitData>(
2061            CreateDiskIndexSearcherParams {
2062                max_thread_num: 1,
2063                pq_pivot_file_path: TEST_PQ_PIVOT,
2064                pq_compressed_file_path: TEST_PQ_COMPRESSED,
2065                index_path: TEST_INDEX,
2066                index_path_prefix: TEST_INDEX_PREFIX,
2067                io_limit,
2068            },
2069            &storage_provider,
2070        );
2071        let query_vector: [f32; 128] = [1f32; 128];
2072
2073        let mut indices = vec![0u32; 10];
2074        let mut distances = vec![0f32; 10];
2075        let mut associated_data = vec![(); 10];
2076
2077        let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new(
2078            &mut indices,
2079            &mut distances,
2080            &mut associated_data,
2081        );
2082
2083        let strategy = search_engine.search_strategy(&query_vector, &|_| true);
2084
2085        let mut search_record = VisitedSearchRecord::new(0);
2086        let search_params = Knn::new(10, 10, Some(4)).unwrap();
2087        let recorded_search =
2088            diskann::graph::search::RecordedKnn::new(search_params, &mut search_record);
2089        search_engine
2090            .runtime
2091            .block_on(search_engine.index.search(
2092                recorded_search,
2093                &strategy,
2094                &DefaultContext,
2095                query_vector.as_slice(),
2096                &mut result_output_buffer,
2097            ))
2098            .unwrap();
2099        let visited_ids = search_record
2100            .visited
2101            .iter()
2102            .map(|n| n.id)
2103            .collect::<Vec<_>>();
2104
2105        let query_stats = strategy.io_tracker;
2106        //Verify the IO limit was respected
2107        assert!(
2108            query_stats.io_count() <= io_limit,
2109            "Expected IO operations to be <= {}, but got {}",
2110            io_limit,
2111            query_stats.io_count()
2112        );
2113
2114        const EXPECTED_NODES: [u32; 17] = [
2115            72, 118, 108, 86, 84, 152, 170, 82, 114, 87, 207, 176, 79, 153, 67, 165, 141,
2116        ]; //Expected nodes for query = [1f32; 128] with beam_width=4
2117
2118        // Count matching results
2119        let mut matching_count = 0;
2120        for expected_node in EXPECTED_NODES.iter() {
2121            if visited_ids.contains(expected_node) {
2122                matching_count += 1;
2123            }
2124        }
2125
2126        // Calculate recall
2127        let recall = (matching_count as f32 / EXPECTED_NODES.len() as f32) * 100.0;
2128
2129        //Verify the recall is above 60%. The threshold her eis arbitrary, just to make sure when
2130        // search hits io_limit that it doesn't break and the recall degrades gracefully
2131        assert!(recall >= 60.0, "Match percentage is below 60%: {}", recall);
2132    }
2133}