Skip to main content

diskann_benchmark_core/search/
api.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{any::Any, future::Future, num::NonZeroUsize, pin::Pin, sync::Arc};
7
8use diskann::{ANNResult, graph, utils::async_tools};
9use diskann_benchmark_runner::utils::MicroSeconds;
10use diskann_utils::{
11    future::{AsyncFriendly, boxit},
12    views::{self, Matrix},
13};
14
15use crate::{
16    internal,
17    search::{
18        ResultIds,
19        ids::{Bounded, IdAggregator, ResultIdsInner},
20    },
21};
22
23/// Necessary behavior for Id aggregation. Used by [`Search::Id`].
24///
25/// This trait has a blanket implementation and thus needs not be implemented manually.
26pub trait Id: Default + Clone + Send + Sync + 'static {}
27
28impl<T> Id for T where T: Default + Clone + Send + Sync + 'static {}
29
30/// Indicate whether the number of items returned from search are bounded by a fixed amount
31/// or can grow to an unknown size.
32#[derive(Debug, Clone, Copy)]
33pub enum IdCount {
34    /// The number of ids returned from search are known to be bounded.
35    Fixed(NonZeroUsize),
36
37    /// The number of ids returned from search is unknown or unbounded. A size hint can
38    /// be provided that can potentially improve performance.
39    Dynamic(Option<NonZeroUsize>),
40}
41
42/// The core search API for approximate nearest neighbor searches.
43///
44/// This uses a model where queries are stored internally and identified by their
45/// index. Queries are numbered from `0` to `N-1` where `N = Search::num_queries()`
46/// is the total number of queries.
47///
48/// This trait is used in conjunction with [`search`] and [`search_all`]. See the
49/// documentation of those methods for more details.
50pub trait Search: AsyncFriendly {
51    /// The identifier for the type returned by search. These are canonically the
52    /// unique IDs associated with indexed vectors.
53    type Id: Id;
54
55    /// Custom input search parameters.
56    type Parameters: Clone + AsyncFriendly;
57
58    /// Custom output parameters. This augments the standard metrics collected by
59    /// [`search`] and allows implementation-specific data to be returned.
60    type Output: AsyncFriendly;
61
62    /// The number of queries that can be searched. The machinery in [`search`] and
63    /// [`search_all`] will invoke [`Search::search`] for each index in `0..N` where
64    /// `N` is the returned value of this method.
65    fn num_queries(&self) -> usize;
66
67    /// Provide a hint for the number of IDs returned for each query. This is used to
68    /// optimize internal buffer allocations.
69    fn id_count(&self, parameters: &Self::Parameters) -> IdCount;
70
71    /// Perform a search for the query identified by `index` using `parameters`. The
72    /// results must be written into `buffer`. Customized output is returned.
73    fn search<O>(
74        &self,
75        parameters: &Self::Parameters,
76        buffer: &mut O,
77        index: usize,
78    ) -> impl Future<Output = ANNResult<Self::Output>> + Send
79    where
80        O: graph::SearchOutputBuffer<Self::Id> + Send;
81}
82
83/// Aggregated results for a single invocation of [`search`]. This corresponds to a
84/// potentially parallelized batch of queries.
85///
86/// # Note
87///
88/// In the documentation of the member functions, the term "querywise" describes that the
89/// returned collection has an ordered correspondence with the original queries.
90///
91/// If the [`Search`] object that generated these results as `N` queries (as returned by
92/// [`Search::num_queries`]), then for these returned container, entry `i` will correspond
93/// to the `i`th query for `i` in `0..N`.
94#[derive(Debug)]
95pub struct SearchResults<I, T> {
96    ids: ResultIds<I>,
97    latencies: Vec<MicroSeconds>,
98    output: Vec<T>,
99    end_to_end_latency: MicroSeconds,
100}
101
102impl<I, T> SearchResults<I, T> {
103    /// Return the number of queries in the batch.
104    pub fn len(&self) -> usize {
105        self.latencies.len()
106    }
107
108    /// Return `true` only if `self.len() == 0`.
109    pub fn is_empty(&self) -> bool {
110        self.len() == 0
111    }
112
113    /// Return the wall clock time taken to process all queries in the batch.
114    pub fn end_to_end_latency(&self) -> MicroSeconds {
115        self.end_to_end_latency
116    }
117
118    /// Return the querywise computed IDs from search.
119    pub fn ids(&self) -> &ResultIds<I> {
120        &self.ids
121    }
122
123    /// Return the querywise latencies for each search. If [`Self::latencies_mut`] has been
124    /// called, the return slice loses its querywise guarantee.
125    pub fn latencies(&self) -> &[MicroSeconds] {
126        &self.latencies
127    }
128
129    /// Return the querywise latencies for each search by mutable reference. This is for
130    /// efficient use of [`diskann_benchmark_runner::utils::percentiles::compute_percentiles`].
131    ///
132    /// Modifying the underlying slice invalidates the querywise guarantee.
133    pub fn latencies_mut(&mut self) -> &mut [MicroSeconds] {
134        &mut self.latencies
135    }
136
137    /// Return the querywise customized outputs from search.
138    pub fn output(&self) -> &[T] {
139        &self.output
140    }
141
142    /// Consume `self`, returning the querywise customized outputs from search by value.
143    pub fn take_output(self) -> Vec<T> {
144        self.output
145    }
146}
147
148impl<I, T> SearchResults<I, T>
149where
150    I: Clone + Default,
151    T: Any,
152{
153    fn new(batch: BatchResultsInner<I>) -> Self
154    where
155        I: Clone + Default,
156        T: Any,
157    {
158        // The idea here is that we use `Collector` and dynamic dispatch for the output
159        // aggregation to avoid monomorphising the collection algorithm for all output
160        // types `T`.
161        let mut output = Vec::<T>::new();
162        let mut f = |any: Box<dyn Any>| match any.downcast::<Vec<T>>() {
163            Ok(outputs) => output.extend(*outputs),
164            Err(_) => panic!("Bad `Any` cast during aggregation"),
165        };
166
167        let Collector {
168            ids,
169            latencies,
170            end_to_end_latency,
171        } = Collector::collect(batch, &mut f);
172
173        Self {
174            ids,
175            latencies,
176            output,
177            end_to_end_latency,
178        }
179    }
180}
181
182#[derive(Debug)]
183struct Collector<I> {
184    ids: ResultIds<I>,
185    latencies: Vec<MicroSeconds>,
186    end_to_end_latency: MicroSeconds,
187}
188
189impl<I> Collector<I>
190where
191    I: Clone + Default,
192{
193    fn collect(batch: BatchResultsInner<I>, collect_any: &mut dyn FnMut(Box<dyn Any>)) -> Self {
194        let mut aggregator = IdAggregator::new();
195        let mut latencies = Vec::new();
196
197        batch.task_results.into_iter().for_each(|results| {
198            aggregator.push(results.ids);
199            latencies.extend_from_slice(&results.latencies);
200            (collect_any)(results.outputs);
201        });
202
203        Self {
204            ids: aggregator.finish(),
205            latencies,
206            end_to_end_latency: batch.end_to_end_latency,
207        }
208    }
209}
210
211/// Perform a search using the provided [`Search`] object. Argument `parameters` will be
212/// provided to each invocation of [`Search::search`]. The search will be parallelized into
213/// `ntasks` tasks using the provided `runtime`.
214///
215/// The returned results will have querywise correspondence with the original queries as
216/// described in the documentation of [`SearchResults`].
217///
218/// # See Also
219///
220/// [`search_all`], [`search_all_with`].
221pub fn search<S>(
222    search: Arc<S>,
223    parameters: S::Parameters,
224    ntasks: NonZeroUsize,
225    runtime: &tokio::runtime::Runtime,
226) -> anyhow::Result<SearchResults<S::Id, S::Output>>
227where
228    S: Search,
229{
230    let results = runtime.block_on(search_inner::<S::Id>(search, Arc::new(parameters), ntasks))?;
231    Ok(SearchResults::new(results))
232}
233
234/// An extension of [`search`] that allows multiple runs with different parameters with
235/// automatic result aggregation.
236///
237/// The elements of `parameters` will be executed sequentially. The element yielded from `parameters`
238/// is of type [`Run`], which encapsulates both the search parameters and setup information
239/// such as the number of tasks and repetitions. The returned vector will have the same length as
240/// the `parameters` iterator, with each entry corresponding to the aggregated results
241/// for the respective run.
242///
243/// The aggregation behavior is defined by `aggregator` using the [`Aggregate`] trait.
244/// [`Aggregate::aggregate`] will be provided with the raw results of all repetitions of
245/// a single result from `parameters`.
246///
247/// # Notes on Repetitions
248///
249/// Each run will be repeated `R` times where `R` is defined by [`Run::setup`]. Callers are
250/// encouraged to use multiple repetitions to obtain more stable performance metrics. Result
251/// aggregation can summarize the results across a repetition group to reduce memory consumption.
252///
253/// # See Also
254///
255/// [`search`], [`search_all_with`].
256pub fn search_all<S, Itr, A>(
257    object: Arc<S>,
258    parameters: Itr,
259    aggregator: A,
260) -> anyhow::Result<Vec<A::Output>>
261where
262    S: Search,
263    Itr: IntoIterator<Item = Run<S::Parameters>>,
264    A: Aggregate<S::Parameters, S::Id, S::Output>,
265{
266    search_all_with(
267        object,
268        parameters,
269        aggregator,
270        |_: &mut tokio::runtime::Builder| {},
271    )
272}
273
274/// An extension of [`search`] that allows multiple runs with different parameters with
275/// automatic result aggregation.
276///
277/// The elements of `parameters` will be executed sequentially. The element yielded from `parameters`
278/// is of type [`Run`], which encapsulates both the search parameters and setup information
279/// such as the number of tasks and repetitions. The returned vector will have the same length as
280/// the `parameters` iterator, with each entry corresponding to the aggregated results
281/// for the respective run.
282///
283/// The aggregation behavior is defined by `aggregator` using the [`Aggregate`] trait.
284/// [`Aggregate::aggregate`] will be provided with the raw results of all repetitions of
285/// a single result from `parameters`.
286///
287/// When new [`tokio::runtime::Builder`]s are created, they will be passed to the `on_builder`
288/// callback for customization. Note that these builders will already be initialized with the
289/// number of threads specified by the corresponding [`Run`].
290///
291/// # Notes on Repetitions
292///
293/// Each run will be repeated `R` times where `R` is defined by [`Run::setup`]. Callers are
294/// encouraged to use multiple repetitions to obtain more stable performance metrics. Result
295/// aggregation can summarize the results across a repetition group to reduce memory consumption.
296///
297/// # See Also
298///
299/// [`search_all`], [`search`].
300pub fn search_all_with<S, Itr, A>(
301    object: Arc<S>,
302    parameters: Itr,
303    mut aggregator: A,
304    mut on_builder: impl FnMut(&mut tokio::runtime::Builder),
305) -> anyhow::Result<Vec<A::Output>>
306where
307    S: Search,
308    Itr: IntoIterator<Item = Run<S::Parameters>>,
309    A: Aggregate<S::Parameters, S::Id, S::Output>,
310{
311    let mut output = Vec::new();
312    for run in parameters {
313        let runtime = crate::tokio::runtime_with(run.setup().threads.into(), &mut on_builder)?;
314
315        let reps: usize = run.setup().reps.into();
316        let raw = (0..reps)
317            .map(|_| -> anyhow::Result<_> {
318                search(
319                    object.clone(),
320                    run.parameters().clone(),
321                    run.setup().tasks,
322                    &runtime,
323                )
324            })
325            .collect::<anyhow::Result<Vec<_>>>()?;
326
327        output.push(aggregator.aggregate(run, raw)?);
328    }
329
330    Ok(output)
331}
332
333/// High level parameters for configuring a search run using [`search_all`].
334#[derive(Debug, Clone, PartialEq)]
335pub struct Setup {
336    /// The number of threads to spawn in the [`tokio::runtime::Runtime`].
337    pub threads: NonZeroUsize,
338
339    /// The number of search tasks into which the search will be parallelized.
340    /// This is intentionally decoupled from `threads` to allow for oversubscription
341    /// of truly asynchronous providers.
342    pub tasks: NonZeroUsize,
343
344    /// The number of repetitions of the search to perform.
345    pub reps: NonZeroUsize,
346}
347
348/// A single run of search containing a [`Setup`] and [`Search::Parameters`].
349#[derive(Debug)]
350pub struct Run<P> {
351    parameters: P,
352    setup: Setup,
353}
354
355impl<P> Run<P> {
356    /// Construct a new [`Run`] around the search parameters and setup.
357    pub fn new(parameters: P, setup: Setup) -> Self {
358        Self { parameters, setup }
359    }
360
361    /// Return a reference to the contained search parameters.
362    pub fn parameters(&self) -> &P {
363        &self.parameters
364    }
365
366    /// Return a reference to the contained setup.
367    pub fn setup(&self) -> &Setup {
368        &self.setup
369    }
370}
371
372/// Aggregate search results from multiple repetitions of a single run in [`search_all`].
373///
374/// # Type Parameters
375/// - `P`: The type of [`Search::Parameters`].
376/// - `I`: The type of [`Search::Id`].
377/// - `O`: The type of [`Search::Output`].
378pub trait Aggregate<P, I, O> {
379    /// The type of the aggregated result.
380    type Output;
381
382    /// Aggregate the `results` for all repetitions of `run`.
383    ///
384    /// The length of `results` is guaranteed to be equal to [`Run::setup().reps`](Setup::reps).
385    fn aggregate(
386        &mut self,
387        run: Run<P>,
388        results: Vec<SearchResults<I, O>>,
389    ) -> anyhow::Result<Self::Output>;
390}
391
392///////////
393// Inner //
394///////////
395
396/// The inner search method is only parameterized by the ID type to minimize monomorphization.
397///
398/// The dynamic type of `parameters` must be the same as `Search::Parameters` for the
399/// concrete type of `search`.
400fn search_inner<I>(
401    search: Arc<dyn SearchInner<Id = I>>,
402    parameters: Arc<dyn Any + Send + Sync>,
403    ntasks: NonZeroUsize,
404) -> impl Future<Output = anyhow::Result<BatchResultsInner<I>>> + Send
405where
406    I: Id,
407{
408    let fut = async move {
409        let start = std::time::Instant::now();
410        let handles: Vec<_> = async_tools::PartitionIter::new(search.num_queries(), ntasks)
411            .map(|range| {
412                let search_clone = search.clone();
413                let parameters_clone = parameters.clone();
414                tokio::spawn(
415                    async move { search_clone.search_batch(&*parameters_clone, range).await },
416                )
417            })
418            .collect();
419
420        let mut task_results = Vec::with_capacity(ntasks.into());
421        for h in handles {
422            task_results.push(h.await??);
423        }
424
425        let end_to_end_latency: MicroSeconds = start.elapsed().into();
426
427        Ok(BatchResultsInner {
428            end_to_end_latency,
429            task_results,
430        })
431    };
432
433    boxit(fut)
434}
435
436#[derive(Debug)]
437struct BatchResultsInner<I> {
438    end_to_end_latency: MicroSeconds,
439    task_results: Vec<SearchResultsInner<I>>,
440}
441
442/// Note: Maintain the invariant that the number of entries in all fields is the same. That
443/// is, this is something approximating an array of structs with special handling for the
444/// result ids.
445#[derive(Debug)]
446struct SearchResultsInner<I> {
447    ids: ResultIdsInner<I>,
448    latencies: Vec<MicroSeconds>,
449
450    // Result belonging strictly to the device under test. The concrete type is guaranteed
451    // to be `Vec<Search::Output>`.
452    outputs: Box<dyn Any + Send>,
453}
454
455impl<I> SearchResultsInner<I> {
456    /// A custom constructor for `SearchResultsInner` that ensures the dynamic type of the outputs.
457    fn new<T>(ids: ResultIdsInner<I>, latencies: Vec<MicroSeconds>, outputs: Vec<T::Output>) -> Self
458    where
459        T: Search<Id = I>,
460    {
461        Self {
462            ids,
463            latencies,
464            outputs: Box::new(outputs),
465        }
466    }
467}
468
469// General boxed futures need to be Pinned to be pollable.
470type Pinned<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
471
472trait SearchInner: AsyncFriendly {
473    type Id: Id;
474
475    fn num_queries(&self) -> usize;
476
477    fn search_batch<'a>(
478        &'a self,
479        parameters: &'a dyn Any,
480        range: std::ops::Range<usize>,
481    ) -> Pinned<'a, ANNResult<SearchResultsInner<Self::Id>>>;
482}
483
484impl<T> SearchInner for T
485where
486    T: Search,
487{
488    type Id = <T as Search>::Id;
489
490    fn num_queries(&self) -> usize {
491        <T as Search>::num_queries(self)
492    }
493
494    fn search_batch<'a>(
495        &'a self,
496        parameters: &'a dyn Any,
497        range: std::ops::Range<usize>,
498    ) -> Pinned<'a, ANNResult<SearchResultsInner<Self::Id>>> {
499        let parameters = parameters
500            .downcast_ref::<T::Parameters>()
501            .expect("the internal search API should always pass the correct dynamic type");
502
503        match self.id_count(parameters) {
504            IdCount::Fixed(num_ids) => boxit(search_batch_fixed(self, range, parameters, num_ids)),
505            IdCount::Dynamic(hint) => boxit(search_batch_dynamic(self, range, parameters, hint)),
506        }
507    }
508}
509
510async fn search_batch_fixed<T>(
511    search: &T,
512    range: std::ops::Range<usize>,
513    parameters: &T::Parameters,
514    num_ids: NonZeroUsize,
515) -> ANNResult<SearchResultsInner<T::Id>>
516where
517    T: Search,
518{
519    let mut lengths = Vec::with_capacity(range.len());
520    let mut ids = Matrix::new(views::Init(T::Id::default), range.len(), num_ids.into());
521
522    let mut latencies = Vec::<MicroSeconds>::with_capacity(range.len());
523    let mut outputs = Vec::<T::Output>::with_capacity(range.len());
524
525    for (ids, index) in std::iter::zip(ids.row_iter_mut(), range) {
526        let mut buffer = internal::buffer::Buffer::slice(ids);
527
528        let start = std::time::Instant::now();
529        let output = search.search(parameters, &mut buffer, index).await?;
530        lengths.push(buffer.current_len());
531
532        latencies.push(start.elapsed().into());
533        outputs.push(output);
534    }
535
536    Ok(SearchResultsInner::new::<T>(
537        ResultIdsInner::Fixed(Bounded::new(ids, lengths)),
538        latencies,
539        outputs,
540    ))
541}
542
543async fn search_batch_dynamic<T>(
544    search: &T,
545    range: std::ops::Range<usize>,
546    parameters: &T::Parameters,
547    hint: Option<NonZeroUsize>,
548) -> ANNResult<SearchResultsInner<T::Id>>
549where
550    T: Search,
551{
552    let mut ids = Vec::with_capacity(range.len());
553    let mut latencies = Vec::<MicroSeconds>::with_capacity(range.len());
554    let mut outputs = Vec::<T::Output>::with_capacity(range.len());
555
556    let hint = hint.map(|i| i.into()).unwrap_or(0);
557
558    for index in range {
559        let mut these_ids = Vec::with_capacity(hint);
560        let mut buffer = internal::buffer::Buffer::vector(&mut these_ids);
561
562        let start = std::time::Instant::now();
563        let output = search.search(parameters, &mut buffer, index).await?;
564        latencies.push(start.elapsed().into());
565
566        ids.push(these_ids);
567        outputs.push(output);
568    }
569
570    Ok(SearchResultsInner::new::<T>(
571        ResultIdsInner::Dynamic(ids),
572        latencies,
573        outputs,
574    ))
575}
576
577///////////
578// Tests //
579///////////
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    use std::hash::{self, Hash, Hasher};
586
587    // We intentionally do not derive `Clone` to ensure that it is not needed
588    // in the implementations.
589    #[derive(Debug)]
590    struct TestSearch {
591        queries: usize,
592        // A hash function to determine the number and value of returned IDs.
593        hasher: fn(usize, usize) -> usize,
594    }
595
596    impl TestSearch {
597        fn count(&self, index: usize, id_count: &IdCount) -> usize {
598            match id_count {
599                IdCount::Fixed(n) => (self.hasher)(index, index) % n.get(),
600                IdCount::Dynamic(_) => (self.hasher)(index, index) % DYNAMIC_MAX,
601            }
602        }
603
604        fn format(&self, index: usize, position: usize) -> String {
605            (self.hasher)(index, position).to_string()
606        }
607
608        fn check(&self, id_count: &IdCount, mut results: SearchResults<String, usize>) {
609            let num_queries = self.queries;
610
611            // End-to-end latency should not be zero.
612            assert_ne!(
613                results.end_to_end_latency().as_seconds(),
614                0.0,
615                "end to end latency should be non-zero"
616            );
617
618            assert_eq!(results.latencies().len(), num_queries);
619            assert_eq!(results.latencies_mut().len(), num_queries);
620
621            let rows = results.ids().as_rows();
622            assert_eq!(rows.nrows(), num_queries);
623            for i in 0..num_queries {
624                let row = rows.row(i);
625                assert_eq!(
626                    row.len(),
627                    self.count(i, id_count),
628                    "incorrect length for output row {}",
629                    i
630                );
631
632                for (j, id) in row.iter().enumerate() {
633                    assert_eq!(
634                        id,
635                        &self.format(i, j),
636                        "mismatch for query {} at position {}",
637                        i,
638                        j
639                    );
640                }
641            }
642
643            let expected_output: Vec<_> =
644                (0..num_queries).map(|i| self.count(i, id_count)).collect();
645
646            assert_eq!(results.output(), &expected_output);
647
648            let output = results.take_output();
649            assert_eq!(output, expected_output);
650        }
651    }
652
653    const DYNAMIC_MAX: usize = 5;
654
655    impl Search for TestSearch {
656        type Id = String;
657        type Parameters = IdCount;
658        type Output = usize;
659
660        fn num_queries(&self) -> usize {
661            self.queries
662        }
663
664        fn id_count(&self, parameters: &IdCount) -> IdCount {
665            *parameters
666        }
667
668        async fn search<O>(
669            &self,
670            params: &IdCount,
671            buffer: &mut O,
672            index: usize,
673        ) -> ANNResult<Self::Output>
674        where
675            O: graph::SearchOutputBuffer<Self::Id> + Send,
676        {
677            let count = self.count(index, params);
678            let set = buffer.extend((0..count).map(|i| (self.format(index, i), i as f32)));
679            assert_eq!(set, count);
680            Ok(count)
681        }
682    }
683
684    fn hash(a: usize, b: usize) -> usize {
685        let mut hasher = hash::DefaultHasher::new();
686        a.hash(&mut hasher);
687        b.hash(&mut hasher);
688        hasher.finish() as usize
689    }
690
691    // This test sweeps across a wide variety of threads, tasks, and behavior.
692    //
693    // We use hashing to generate deterministic but non-uniform results.
694    #[test]
695    fn test_search() {
696        for num_queries in [3, 4, 5] {
697            let searcher = Arc::new(TestSearch {
698                queries: num_queries,
699                hasher: hash,
700            });
701
702            for num_threads in 1..6 {
703                let runtime = crate::tokio::runtime(num_threads).unwrap();
704
705                for num_tasks in 1..6 {
706                    let num_tasks = NonZeroUsize::new(num_tasks).unwrap();
707                    for id_count in [
708                        IdCount::Fixed(NonZeroUsize::new(3).unwrap()),
709                        IdCount::Dynamic(Some(NonZeroUsize::new(4).unwrap())),
710                        IdCount::Dynamic(None),
711                    ] {
712                        let results =
713                            search(searcher.clone(), id_count, num_tasks, &runtime).unwrap();
714
715                        searcher.check(&id_count, results);
716                    }
717                }
718            }
719        }
720    }
721
722    /// An aggregator for testing [`search_all`]. This simply invokes [`TestSearch::check`]
723    /// on the inner results, verifies the number of results, and
724    struct Aggregator<'a> {
725        /// The searcher provided to [`search_all`].
726        searcher: Arc<TestSearch>,
727
728        /// A seed for randomizing the return values.
729        seed: usize,
730
731        /// A count for the number of times `aggregate` was called.
732        called: &'a mut usize,
733    }
734
735    impl Aggregate<IdCount, String, usize> for Aggregator<'_> {
736        type Output = usize;
737
738        fn aggregate(
739            &mut self,
740            run: Run<IdCount>,
741            results: Vec<SearchResults<String, usize>>,
742        ) -> anyhow::Result<Self::Output> {
743            assert_eq!(
744                results.len(),
745                run.setup().reps.get(),
746                "the incorrect number of results was returned",
747            );
748
749            for result in results {
750                self.searcher.check(run.parameters(), result);
751            }
752
753            let count = *self.called;
754            *self.called += 1;
755            Ok(hash(self.seed, count))
756        }
757    }
758
759    #[test]
760    fn test_search_all() {
761        let counts = [
762            IdCount::Fixed(NonZeroUsize::new(3).unwrap()),
763            IdCount::Dynamic(Some(NonZeroUsize::new(4).unwrap())),
764            IdCount::Dynamic(None),
765        ];
766
767        let seed = 0x2f1b462446d1f225;
768
769        for num_queries in [3, 4, 5] {
770            let searcher = Arc::new(TestSearch {
771                queries: num_queries,
772                hasher: hash,
773            });
774
775            let iter = itertools::iproduct!((1..6), (1..6), (2..3), counts,).map(
776                |(threads, tasks, reps, parameters)| {
777                    Run::new(
778                        parameters,
779                        Setup {
780                            threads: NonZeroUsize::new(threads).unwrap(),
781                            tasks: NonZeroUsize::new(tasks).unwrap(),
782                            reps: NonZeroUsize::new(reps).unwrap(),
783                        },
784                    )
785                },
786            );
787
788            // `search_all`
789            {
790                let mut called = 0usize;
791                let aggregator = Aggregator {
792                    searcher: searcher.clone(),
793                    seed,
794                    called: &mut called,
795                };
796
797                let len = iter.size_hint().0;
798                let results = search_all(searcher.clone(), iter.clone(), aggregator).unwrap();
799
800                assert_eq!(results.len(), len);
801                assert_eq!(called, len);
802
803                for (i, r) in results.into_iter().enumerate() {
804                    assert_eq!(r, hash(seed, i), "mismatch for result {}", i);
805                }
806            }
807
808            // `search_all_with`
809            {
810                let mut called = 0usize;
811                let aggregator = Aggregator {
812                    searcher: searcher.clone(),
813                    seed,
814                    called: &mut called,
815                };
816
817                let len = iter.size_hint().0;
818                let mut builder_calls = 0usize;
819                let results = search_all_with(searcher, iter, aggregator, |_| {
820                    builder_calls += 1;
821                })
822                .unwrap();
823
824                assert_eq!(results.len(), len);
825                assert_eq!(called, len);
826                assert_eq!(builder_calls, len);
827
828                for (i, r) in results.into_iter().enumerate() {
829                    assert_eq!(r, hash(seed, i), "mismatch for result {}", i);
830                }
831            }
832        }
833    }
834}