1use std::{any::Any, future::Future, num::NonZeroUsize, pin::Pin, sync::Arc};
7
8use diskann::{ANNResult, graph, utils::async_tools};
9use diskann_benchmark_runner::utils::MicroSeconds;
10use diskann_utils::{
11 future::{AsyncFriendly, boxit},
12 views::{self, Matrix},
13};
14
15use crate::{
16 internal,
17 search::{
18 ResultIds,
19 ids::{Bounded, IdAggregator, ResultIdsInner},
20 },
21};
22
23pub trait Id: Default + Clone + Send + Sync + 'static {}
27
28impl<T> Id for T where T: Default + Clone + Send + Sync + 'static {}
29
30#[derive(Debug, Clone, Copy)]
33pub enum IdCount {
34 Fixed(NonZeroUsize),
36
37 Dynamic(Option<NonZeroUsize>),
40}
41
42pub trait Search: AsyncFriendly {
51 type Id: Id;
54
55 type Parameters: Clone + AsyncFriendly;
57
58 type Output: AsyncFriendly;
61
62 fn num_queries(&self) -> usize;
66
67 fn id_count(&self, parameters: &Self::Parameters) -> IdCount;
70
71 fn search<O>(
74 &self,
75 parameters: &Self::Parameters,
76 buffer: &mut O,
77 index: usize,
78 ) -> impl Future<Output = ANNResult<Self::Output>> + Send
79 where
80 O: graph::SearchOutputBuffer<Self::Id> + Send;
81}
82
83#[derive(Debug)]
95pub struct SearchResults<I, T> {
96 ids: ResultIds<I>,
97 latencies: Vec<MicroSeconds>,
98 output: Vec<T>,
99 end_to_end_latency: MicroSeconds,
100}
101
102impl<I, T> SearchResults<I, T> {
103 pub fn len(&self) -> usize {
105 self.latencies.len()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.len() == 0
111 }
112
113 pub fn end_to_end_latency(&self) -> MicroSeconds {
115 self.end_to_end_latency
116 }
117
118 pub fn ids(&self) -> &ResultIds<I> {
120 &self.ids
121 }
122
123 pub fn latencies(&self) -> &[MicroSeconds] {
126 &self.latencies
127 }
128
129 pub fn latencies_mut(&mut self) -> &mut [MicroSeconds] {
134 &mut self.latencies
135 }
136
137 pub fn output(&self) -> &[T] {
139 &self.output
140 }
141
142 pub fn take_output(self) -> Vec<T> {
144 self.output
145 }
146}
147
148impl<I, T> SearchResults<I, T>
149where
150 I: Clone + Default,
151 T: Any,
152{
153 fn new(batch: BatchResultsInner<I>) -> Self
154 where
155 I: Clone + Default,
156 T: Any,
157 {
158 let mut output = Vec::<T>::new();
162 let mut f = |any: Box<dyn Any>| match any.downcast::<Vec<T>>() {
163 Ok(outputs) => output.extend(*outputs),
164 Err(_) => panic!("Bad `Any` cast during aggregation"),
165 };
166
167 let Collector {
168 ids,
169 latencies,
170 end_to_end_latency,
171 } = Collector::collect(batch, &mut f);
172
173 Self {
174 ids,
175 latencies,
176 output,
177 end_to_end_latency,
178 }
179 }
180}
181
182#[derive(Debug)]
183struct Collector<I> {
184 ids: ResultIds<I>,
185 latencies: Vec<MicroSeconds>,
186 end_to_end_latency: MicroSeconds,
187}
188
189impl<I> Collector<I>
190where
191 I: Clone + Default,
192{
193 fn collect(batch: BatchResultsInner<I>, collect_any: &mut dyn FnMut(Box<dyn Any>)) -> Self {
194 let mut aggregator = IdAggregator::new();
195 let mut latencies = Vec::new();
196
197 batch.task_results.into_iter().for_each(|results| {
198 aggregator.push(results.ids);
199 latencies.extend_from_slice(&results.latencies);
200 (collect_any)(results.outputs);
201 });
202
203 Self {
204 ids: aggregator.finish(),
205 latencies,
206 end_to_end_latency: batch.end_to_end_latency,
207 }
208 }
209}
210
211pub fn search<S>(
222 search: Arc<S>,
223 parameters: S::Parameters,
224 ntasks: NonZeroUsize,
225 runtime: &tokio::runtime::Runtime,
226) -> anyhow::Result<SearchResults<S::Id, S::Output>>
227where
228 S: Search,
229{
230 let results = runtime.block_on(search_inner::<S::Id>(search, Arc::new(parameters), ntasks))?;
231 Ok(SearchResults::new(results))
232}
233
234pub fn search_all<S, Itr, A>(
257 object: Arc<S>,
258 parameters: Itr,
259 aggregator: A,
260) -> anyhow::Result<Vec<A::Output>>
261where
262 S: Search,
263 Itr: IntoIterator<Item = Run<S::Parameters>>,
264 A: Aggregate<S::Parameters, S::Id, S::Output>,
265{
266 search_all_with(
267 object,
268 parameters,
269 aggregator,
270 |_: &mut tokio::runtime::Builder| {},
271 )
272}
273
274pub fn search_all_with<S, Itr, A>(
301 object: Arc<S>,
302 parameters: Itr,
303 mut aggregator: A,
304 mut on_builder: impl FnMut(&mut tokio::runtime::Builder),
305) -> anyhow::Result<Vec<A::Output>>
306where
307 S: Search,
308 Itr: IntoIterator<Item = Run<S::Parameters>>,
309 A: Aggregate<S::Parameters, S::Id, S::Output>,
310{
311 let mut output = Vec::new();
312 for run in parameters {
313 let runtime = crate::tokio::runtime_with(run.setup().threads.into(), &mut on_builder)?;
314
315 let reps: usize = run.setup().reps.into();
316 let raw = (0..reps)
317 .map(|_| -> anyhow::Result<_> {
318 search(
319 object.clone(),
320 run.parameters().clone(),
321 run.setup().tasks,
322 &runtime,
323 )
324 })
325 .collect::<anyhow::Result<Vec<_>>>()?;
326
327 output.push(aggregator.aggregate(run, raw)?);
328 }
329
330 Ok(output)
331}
332
333#[derive(Debug, Clone, PartialEq)]
335pub struct Setup {
336 pub threads: NonZeroUsize,
338
339 pub tasks: NonZeroUsize,
343
344 pub reps: NonZeroUsize,
346}
347
348#[derive(Debug)]
350pub struct Run<P> {
351 parameters: P,
352 setup: Setup,
353}
354
355impl<P> Run<P> {
356 pub fn new(parameters: P, setup: Setup) -> Self {
358 Self { parameters, setup }
359 }
360
361 pub fn parameters(&self) -> &P {
363 &self.parameters
364 }
365
366 pub fn setup(&self) -> &Setup {
368 &self.setup
369 }
370}
371
372pub trait Aggregate<P, I, O> {
379 type Output;
381
382 fn aggregate(
386 &mut self,
387 run: Run<P>,
388 results: Vec<SearchResults<I, O>>,
389 ) -> anyhow::Result<Self::Output>;
390}
391
392fn search_inner<I>(
401 search: Arc<dyn SearchInner<Id = I>>,
402 parameters: Arc<dyn Any + Send + Sync>,
403 ntasks: NonZeroUsize,
404) -> impl Future<Output = anyhow::Result<BatchResultsInner<I>>> + Send
405where
406 I: Id,
407{
408 let fut = async move {
409 let start = std::time::Instant::now();
410 let handles: Vec<_> = async_tools::PartitionIter::new(search.num_queries(), ntasks)
411 .map(|range| {
412 let search_clone = search.clone();
413 let parameters_clone = parameters.clone();
414 tokio::spawn(
415 async move { search_clone.search_batch(&*parameters_clone, range).await },
416 )
417 })
418 .collect();
419
420 let mut task_results = Vec::with_capacity(ntasks.into());
421 for h in handles {
422 task_results.push(h.await??);
423 }
424
425 let end_to_end_latency: MicroSeconds = start.elapsed().into();
426
427 Ok(BatchResultsInner {
428 end_to_end_latency,
429 task_results,
430 })
431 };
432
433 boxit(fut)
434}
435
436#[derive(Debug)]
437struct BatchResultsInner<I> {
438 end_to_end_latency: MicroSeconds,
439 task_results: Vec<SearchResultsInner<I>>,
440}
441
442#[derive(Debug)]
446struct SearchResultsInner<I> {
447 ids: ResultIdsInner<I>,
448 latencies: Vec<MicroSeconds>,
449
450 outputs: Box<dyn Any + Send>,
453}
454
455impl<I> SearchResultsInner<I> {
456 fn new<T>(ids: ResultIdsInner<I>, latencies: Vec<MicroSeconds>, outputs: Vec<T::Output>) -> Self
458 where
459 T: Search<Id = I>,
460 {
461 Self {
462 ids,
463 latencies,
464 outputs: Box::new(outputs),
465 }
466 }
467}
468
469type Pinned<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
471
472trait SearchInner: AsyncFriendly {
473 type Id: Id;
474
475 fn num_queries(&self) -> usize;
476
477 fn search_batch<'a>(
478 &'a self,
479 parameters: &'a dyn Any,
480 range: std::ops::Range<usize>,
481 ) -> Pinned<'a, ANNResult<SearchResultsInner<Self::Id>>>;
482}
483
484impl<T> SearchInner for T
485where
486 T: Search,
487{
488 type Id = <T as Search>::Id;
489
490 fn num_queries(&self) -> usize {
491 <T as Search>::num_queries(self)
492 }
493
494 fn search_batch<'a>(
495 &'a self,
496 parameters: &'a dyn Any,
497 range: std::ops::Range<usize>,
498 ) -> Pinned<'a, ANNResult<SearchResultsInner<Self::Id>>> {
499 let parameters = parameters
500 .downcast_ref::<T::Parameters>()
501 .expect("the internal search API should always pass the correct dynamic type");
502
503 match self.id_count(parameters) {
504 IdCount::Fixed(num_ids) => boxit(search_batch_fixed(self, range, parameters, num_ids)),
505 IdCount::Dynamic(hint) => boxit(search_batch_dynamic(self, range, parameters, hint)),
506 }
507 }
508}
509
510async fn search_batch_fixed<T>(
511 search: &T,
512 range: std::ops::Range<usize>,
513 parameters: &T::Parameters,
514 num_ids: NonZeroUsize,
515) -> ANNResult<SearchResultsInner<T::Id>>
516where
517 T: Search,
518{
519 let mut lengths = Vec::with_capacity(range.len());
520 let mut ids = Matrix::new(views::Init(T::Id::default), range.len(), num_ids.into());
521
522 let mut latencies = Vec::<MicroSeconds>::with_capacity(range.len());
523 let mut outputs = Vec::<T::Output>::with_capacity(range.len());
524
525 for (ids, index) in std::iter::zip(ids.row_iter_mut(), range) {
526 let mut buffer = internal::buffer::Buffer::slice(ids);
527
528 let start = std::time::Instant::now();
529 let output = search.search(parameters, &mut buffer, index).await?;
530 lengths.push(buffer.current_len());
531
532 latencies.push(start.elapsed().into());
533 outputs.push(output);
534 }
535
536 Ok(SearchResultsInner::new::<T>(
537 ResultIdsInner::Fixed(Bounded::new(ids, lengths)),
538 latencies,
539 outputs,
540 ))
541}
542
543async fn search_batch_dynamic<T>(
544 search: &T,
545 range: std::ops::Range<usize>,
546 parameters: &T::Parameters,
547 hint: Option<NonZeroUsize>,
548) -> ANNResult<SearchResultsInner<T::Id>>
549where
550 T: Search,
551{
552 let mut ids = Vec::with_capacity(range.len());
553 let mut latencies = Vec::<MicroSeconds>::with_capacity(range.len());
554 let mut outputs = Vec::<T::Output>::with_capacity(range.len());
555
556 let hint = hint.map(|i| i.into()).unwrap_or(0);
557
558 for index in range {
559 let mut these_ids = Vec::with_capacity(hint);
560 let mut buffer = internal::buffer::Buffer::vector(&mut these_ids);
561
562 let start = std::time::Instant::now();
563 let output = search.search(parameters, &mut buffer, index).await?;
564 latencies.push(start.elapsed().into());
565
566 ids.push(these_ids);
567 outputs.push(output);
568 }
569
570 Ok(SearchResultsInner::new::<T>(
571 ResultIdsInner::Dynamic(ids),
572 latencies,
573 outputs,
574 ))
575}
576
577#[cfg(test)]
582mod tests {
583 use super::*;
584
585 use std::hash::{self, Hash, Hasher};
586
587 #[derive(Debug)]
590 struct TestSearch {
591 queries: usize,
592 hasher: fn(usize, usize) -> usize,
594 }
595
596 impl TestSearch {
597 fn count(&self, index: usize, id_count: &IdCount) -> usize {
598 match id_count {
599 IdCount::Fixed(n) => (self.hasher)(index, index) % n.get(),
600 IdCount::Dynamic(_) => (self.hasher)(index, index) % DYNAMIC_MAX,
601 }
602 }
603
604 fn format(&self, index: usize, position: usize) -> String {
605 (self.hasher)(index, position).to_string()
606 }
607
608 fn check(&self, id_count: &IdCount, mut results: SearchResults<String, usize>) {
609 let num_queries = self.queries;
610
611 assert_ne!(
613 results.end_to_end_latency().as_seconds(),
614 0.0,
615 "end to end latency should be non-zero"
616 );
617
618 assert_eq!(results.latencies().len(), num_queries);
619 assert_eq!(results.latencies_mut().len(), num_queries);
620
621 let rows = results.ids().as_rows();
622 assert_eq!(rows.nrows(), num_queries);
623 for i in 0..num_queries {
624 let row = rows.row(i);
625 assert_eq!(
626 row.len(),
627 self.count(i, id_count),
628 "incorrect length for output row {}",
629 i
630 );
631
632 for (j, id) in row.iter().enumerate() {
633 assert_eq!(
634 id,
635 &self.format(i, j),
636 "mismatch for query {} at position {}",
637 i,
638 j
639 );
640 }
641 }
642
643 let expected_output: Vec<_> =
644 (0..num_queries).map(|i| self.count(i, id_count)).collect();
645
646 assert_eq!(results.output(), &expected_output);
647
648 let output = results.take_output();
649 assert_eq!(output, expected_output);
650 }
651 }
652
653 const DYNAMIC_MAX: usize = 5;
654
655 impl Search for TestSearch {
656 type Id = String;
657 type Parameters = IdCount;
658 type Output = usize;
659
660 fn num_queries(&self) -> usize {
661 self.queries
662 }
663
664 fn id_count(&self, parameters: &IdCount) -> IdCount {
665 *parameters
666 }
667
668 async fn search<O>(
669 &self,
670 params: &IdCount,
671 buffer: &mut O,
672 index: usize,
673 ) -> ANNResult<Self::Output>
674 where
675 O: graph::SearchOutputBuffer<Self::Id> + Send,
676 {
677 let count = self.count(index, params);
678 let set = buffer.extend((0..count).map(|i| (self.format(index, i), i as f32)));
679 assert_eq!(set, count);
680 Ok(count)
681 }
682 }
683
684 fn hash(a: usize, b: usize) -> usize {
685 let mut hasher = hash::DefaultHasher::new();
686 a.hash(&mut hasher);
687 b.hash(&mut hasher);
688 hasher.finish() as usize
689 }
690
691 #[test]
695 fn test_search() {
696 for num_queries in [3, 4, 5] {
697 let searcher = Arc::new(TestSearch {
698 queries: num_queries,
699 hasher: hash,
700 });
701
702 for num_threads in 1..6 {
703 let runtime = crate::tokio::runtime(num_threads).unwrap();
704
705 for num_tasks in 1..6 {
706 let num_tasks = NonZeroUsize::new(num_tasks).unwrap();
707 for id_count in [
708 IdCount::Fixed(NonZeroUsize::new(3).unwrap()),
709 IdCount::Dynamic(Some(NonZeroUsize::new(4).unwrap())),
710 IdCount::Dynamic(None),
711 ] {
712 let results =
713 search(searcher.clone(), id_count, num_tasks, &runtime).unwrap();
714
715 searcher.check(&id_count, results);
716 }
717 }
718 }
719 }
720 }
721
722 struct Aggregator<'a> {
725 searcher: Arc<TestSearch>,
727
728 seed: usize,
730
731 called: &'a mut usize,
733 }
734
735 impl Aggregate<IdCount, String, usize> for Aggregator<'_> {
736 type Output = usize;
737
738 fn aggregate(
739 &mut self,
740 run: Run<IdCount>,
741 results: Vec<SearchResults<String, usize>>,
742 ) -> anyhow::Result<Self::Output> {
743 assert_eq!(
744 results.len(),
745 run.setup().reps.get(),
746 "the incorrect number of results was returned",
747 );
748
749 for result in results {
750 self.searcher.check(run.parameters(), result);
751 }
752
753 let count = *self.called;
754 *self.called += 1;
755 Ok(hash(self.seed, count))
756 }
757 }
758
759 #[test]
760 fn test_search_all() {
761 let counts = [
762 IdCount::Fixed(NonZeroUsize::new(3).unwrap()),
763 IdCount::Dynamic(Some(NonZeroUsize::new(4).unwrap())),
764 IdCount::Dynamic(None),
765 ];
766
767 let seed = 0x2f1b462446d1f225;
768
769 for num_queries in [3, 4, 5] {
770 let searcher = Arc::new(TestSearch {
771 queries: num_queries,
772 hasher: hash,
773 });
774
775 let iter = itertools::iproduct!((1..6), (1..6), (2..3), counts,).map(
776 |(threads, tasks, reps, parameters)| {
777 Run::new(
778 parameters,
779 Setup {
780 threads: NonZeroUsize::new(threads).unwrap(),
781 tasks: NonZeroUsize::new(tasks).unwrap(),
782 reps: NonZeroUsize::new(reps).unwrap(),
783 },
784 )
785 },
786 );
787
788 {
790 let mut called = 0usize;
791 let aggregator = Aggregator {
792 searcher: searcher.clone(),
793 seed,
794 called: &mut called,
795 };
796
797 let len = iter.size_hint().0;
798 let results = search_all(searcher.clone(), iter.clone(), aggregator).unwrap();
799
800 assert_eq!(results.len(), len);
801 assert_eq!(called, len);
802
803 for (i, r) in results.into_iter().enumerate() {
804 assert_eq!(r, hash(seed, i), "mismatch for result {}", i);
805 }
806 }
807
808 {
810 let mut called = 0usize;
811 let aggregator = Aggregator {
812 searcher: searcher.clone(),
813 seed,
814 called: &mut called,
815 };
816
817 let len = iter.size_hint().0;
818 let mut builder_calls = 0usize;
819 let results = search_all_with(searcher, iter, aggregator, |_| {
820 builder_calls += 1;
821 })
822 .unwrap();
823
824 assert_eq!(results.len(), len);
825 assert_eq!(called, len);
826 assert_eq!(builder_calls, len);
827
828 for (i, r) in results.into_iter().enumerate() {
829 assert_eq!(r, hash(seed, i), "mismatch for result {}", i);
830 }
831 }
832 }
833 }
834}