1use std::{
7 any::Any,
8 future::Future,
9 num::NonZeroUsize,
10 ops::Range,
11 pin::Pin,
12 sync::{
13 Arc,
14 atomic::{AtomicUsize, Ordering},
15 },
16};
17
18use diskann::{ANNError, ANNResult};
19use diskann_benchmark_runner::utils::MicroSeconds;
20use diskann_utils::future::{AsyncFriendly, boxit};
21
22pub trait Build: AsyncFriendly {
31 type Output: AsyncFriendly;
34
35 fn num_data(&self) -> usize;
39
40 fn build(&self, range: Range<usize>) -> impl Future<Output = ANNResult<Self::Output>> + Send;
46}
47
48#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct BatchResult<T> {
56 pub taskid: usize,
59
60 pub batch: Range<usize>,
62
63 pub latency: MicroSeconds,
65
66 pub output: T,
68}
69
70impl<T> BatchResult<T> {
71 pub fn batchsize(&self) -> usize {
73 self.batch.len()
74 }
75}
76
77#[derive(Debug)]
81pub struct BuildResults<T> {
82 output: Vec<BatchResult<T>>,
83 end_to_end_latency: MicroSeconds,
84}
85
86impl<T> BuildResults<T> {
87 pub fn end_to_end_latency(&self) -> MicroSeconds {
89 self.end_to_end_latency
90 }
91
92 pub fn output(&self) -> &[BatchResult<T>] {
94 &self.output
95 }
96
97 pub fn take_output(self) -> Vec<BatchResult<T>> {
99 self.output
100 }
101}
102
103impl<T> BuildResults<T>
104where
105 T: Any,
106{
107 fn new(inner: BuildResultsInner) -> Self {
112 let BuildResultsInner {
113 end_to_end_latency,
114 task_results,
115 } = inner;
116 let mut output = Vec::with_capacity(task_results.iter().map(|t| t.len()).sum());
117
118 task_results
119 .into_iter()
120 .enumerate()
121 .for_each(|(taskid, results)| {
122 results.into_iter().for_each(|r| {
123 output.push(BatchResult {
124 taskid,
125 batch: r.batch,
126 latency: r.latency,
127 output: *r
128 .output
129 .downcast::<T>()
130 .expect("incorrect downcast applied"),
131 })
132 })
133 });
134
135 Self {
136 output,
137 end_to_end_latency,
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq)]
147#[non_exhaustive]
148pub enum Parallelism {
149 #[non_exhaustive]
156 Dynamic {
157 batchsize: NonZeroUsize,
158 ntasks: NonZeroUsize,
159 },
160
161 #[non_exhaustive]
165 Sequential { batchsize: NonZeroUsize },
166
167 #[non_exhaustive]
178 Fixed {
179 batchsize: Option<NonZeroUsize>,
180 ntasks: NonZeroUsize,
181 },
182}
183
184impl Parallelism {
185 pub fn dynamic(batchsize: NonZeroUsize, ntasks: NonZeroUsize) -> Self {
189 Self::Dynamic { batchsize, ntasks }
190 }
191
192 pub fn fixed(batchsize: Option<NonZeroUsize>, ntasks: NonZeroUsize) -> Self {
197 Self::Fixed { batchsize, ntasks }
198 }
199
200 pub fn sequential(batchsize: NonZeroUsize) -> Self {
204 Self::Sequential { batchsize }
205 }
206}
207
208pub trait AsProgress {
212 fn as_progress(&self, max: usize) -> Arc<dyn Progress>;
214}
215
216pub trait Progress: AsyncFriendly {
218 fn progress(&self, handled: usize);
220
221 fn finish(&self);
223}
224
225pub fn build<B>(
229 builder: Arc<B>,
230 parallelism: Parallelism,
231 runtime: &tokio::runtime::Runtime,
232) -> anyhow::Result<BuildResults<B::Output>>
233where
234 B: Build,
235{
236 build_tracked(builder, parallelism, runtime, None)
237}
238
239pub fn build_tracked<B>(
247 builder: Arc<B>,
248 parallelism: Parallelism,
249 runtime: &tokio::runtime::Runtime,
250 as_progress: Option<&dyn AsProgress>,
251) -> anyhow::Result<BuildResults<B::Output>>
252where
253 B: Build,
254{
255 let max = builder.num_data();
256 let results = runtime.block_on(build_inner(
257 builder,
258 parallelism,
259 as_progress.map(|p| p.as_progress(max)),
260 ))?;
261 Ok(BuildResults::new(results))
262}
263
264fn build_inner(
270 build: Arc<dyn BuildInner>,
271 parallelism: Parallelism,
272 progress: Option<Arc<dyn Progress>>,
273) -> impl Future<Output = anyhow::Result<BuildResultsInner>> + Send {
274 match parallelism {
275 Parallelism::Dynamic { batchsize, ntasks } => {
276 boxit(build_inner_dynamic(build, batchsize, ntasks, progress))
277 }
278 Parallelism::Sequential { batchsize } => {
279 boxit(build_inner_dynamic(
282 build,
283 batchsize,
284 diskann::utils::ONE,
285 progress,
286 ))
287 }
288 Parallelism::Fixed { batchsize, ntasks } => {
289 boxit(build_inner_fixed(build, batchsize, ntasks, progress))
290 }
291 }
292}
293
294type Pinned<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
295
296trait BuildInner: AsyncFriendly {
298 fn num_data(&self) -> usize;
299
300 fn build(&self, range: Range<usize>) -> Pinned<'_, ANNResult<Box<dyn Any + Send>>>;
301}
302
303impl<T> BuildInner for T
304where
305 T: Build,
306{
307 fn num_data(&self) -> usize {
308 <T as Build>::num_data(self)
309 }
310
311 fn build(&self, range: Range<usize>) -> Pinned<'_, ANNResult<Box<dyn Any + Send>>> {
312 use futures_util::TryFutureExt;
313
314 boxit(<T as Build>::build(self, range).map_ok(|r| -> Box<dyn Any + Send> { Box::new(r) }))
315 }
316}
317
318#[derive(Debug)]
320struct BuildResultsInner {
321 end_to_end_latency: MicroSeconds,
322
323 task_results: Vec<Vec<BatchResultsInner>>,
327}
328
329#[derive(Debug)]
330struct BatchResultsInner {
331 batch: Range<usize>,
332 latency: MicroSeconds,
333 output: Box<dyn Any + Send>,
335}
336
337async fn build_inner_dynamic(
343 build: Arc<dyn BuildInner>,
344 batchsize: NonZeroUsize,
345 ntasks: NonZeroUsize,
346 progress: Option<Arc<dyn Progress>>,
347) -> anyhow::Result<BuildResultsInner> {
348 let start = std::time::Instant::now();
349 let control = ControlBlock::new(build.num_data(), batchsize);
350 let handles: Vec<_> = (0..ntasks.get())
351 .map(|_| {
352 let build_clone = build.clone();
353 let control_clone = control.clone();
354 let progress_clone = progress.clone();
355 tokio::spawn(async move {
356 let mut results = Vec::new();
357 while let Some(batch) = control_clone.next() {
358 let start = std::time::Instant::now();
359 let output = build_clone.build(batch.clone()).await?;
360 let latency: MicroSeconds = start.elapsed().into();
361
362 if let Some(p) = progress_clone.as_deref() {
363 p.progress(batch.len());
364 }
365
366 results.push(BatchResultsInner {
367 batch,
368 latency,
369 output,
370 });
371 }
372 Ok::<_, ANNError>(results)
373 })
374 })
375 .collect();
376
377 let mut task_results = Vec::with_capacity(ntasks.into());
378 for h in handles {
379 task_results.push(h.await??);
380 }
381
382 let end_to_end_latency: MicroSeconds = start.elapsed().into();
383 if let Some(p) = progress.as_deref() {
384 p.finish();
385 }
386
387 Ok(BuildResultsInner {
388 end_to_end_latency,
389 task_results,
390 })
391}
392
393#[derive(Debug, Clone)]
394struct ControlBlock(Arc<ControlBlockInner>);
395
396impl ControlBlock {
397 fn new(max: usize, batchsize: NonZeroUsize) -> Self {
398 Self(Arc::new(ControlBlockInner::new(max, batchsize)))
399 }
400
401 fn next(&self) -> Option<Range<usize>> {
402 let mut start = self.0.head.load(Ordering::Relaxed);
407
408 loop {
409 let next = start.saturating_add(self.0.batchsize.get()).min(self.0.max);
410 if next == start {
411 return None;
412 }
413
414 match self
415 .0
416 .head
417 .compare_exchange(start, next, Ordering::Relaxed, Ordering::Relaxed)
418 {
419 Ok(_) => return Some(start..next),
420 Err(current) => {
421 start = current;
422 }
423 }
424 }
425 }
426}
427
428#[derive(Debug)]
429struct ControlBlockInner {
430 head: AtomicUsize,
431 max: usize,
432 batchsize: NonZeroUsize,
433}
434
435impl ControlBlockInner {
436 fn new(max: usize, batchsize: NonZeroUsize) -> Self {
437 Self {
438 head: AtomicUsize::new(0),
439 max,
440 batchsize,
441 }
442 }
443}
444
445async fn build_inner_fixed(
450 build: Arc<dyn BuildInner>,
451 batchsize: Option<NonZeroUsize>,
452 ntasks: NonZeroUsize,
453 progress: Option<Arc<dyn Progress>>,
454) -> anyhow::Result<BuildResultsInner> {
455 use diskann::utils::async_tools::PartitionIter;
456
457 let start = std::time::Instant::now();
458 let handles: Vec<_> = PartitionIter::new(build.num_data(), ntasks)
459 .map(|range| {
460 let build_clone = build.clone();
461 let progress_clone = progress.clone();
462 tokio::spawn(async move {
463 let mut results = Vec::new();
464 match batchsize {
465 Some(batchsize) => {
466 for batch in Chunks::new(range, batchsize) {
467 let start = std::time::Instant::now();
468 let output = build_clone.build(batch.clone()).await?;
469 let latency: MicroSeconds = start.elapsed().into();
470
471 if let Some(p) = progress_clone.as_deref() {
472 p.progress(batch.len());
473 }
474
475 results.push(BatchResultsInner {
476 batch,
477 latency,
478 output,
479 });
480 }
481 }
482 None => {
483 let start = std::time::Instant::now();
484 let output = build_clone.build(range.clone()).await?;
485 let latency: MicroSeconds = start.elapsed().into();
486
487 if let Some(p) = progress_clone.as_deref() {
488 p.progress(range.len());
489 }
490
491 results.push(BatchResultsInner {
492 batch: range,
493 latency,
494 output,
495 });
496 }
497 }
498 Ok::<_, ANNError>(results)
499 })
500 })
501 .collect();
502
503 let mut task_results = Vec::with_capacity(ntasks.into());
504 for h in handles {
505 task_results.push(h.await??);
506 }
507
508 let end_to_end_latency: MicroSeconds = start.elapsed().into();
509 if let Some(p) = progress.as_deref() {
510 p.finish();
511 }
512
513 Ok(BuildResultsInner {
514 end_to_end_latency,
515 task_results,
516 })
517}
518
519#[derive(Debug, Clone)]
521struct Chunks {
522 current: usize,
524 end: usize,
526 chunk_size: NonZeroUsize,
528}
529
530impl Chunks {
531 fn new(range: Range<usize>, chunk_size: NonZeroUsize) -> Self {
532 Self {
533 current: range.start,
534 end: range.end,
535 chunk_size,
536 }
537 }
538}
539
540impl Iterator for Chunks {
541 type Item = Range<usize>;
542
543 fn next(&mut self) -> Option<Self::Item> {
544 if self.current >= self.end {
545 return None;
546 }
547
548 let start = self.current;
549 let end = (start + self.chunk_size.get()).min(self.end);
550 self.current = end;
551
552 Some(start..end)
553 }
554
555 fn size_hint(&self) -> (usize, Option<usize>) {
556 if self.current >= self.end {
557 return (0, Some(0));
558 }
559
560 let remaining = self.end - self.current;
561 let count = remaining.div_ceil(self.chunk_size.get());
562 (count, Some(count))
563 }
564}
565
566impl ExactSizeIterator for Chunks {}
567
568#[cfg(test)]
573mod tests {
574 use super::*;
575
576 use std::sync::atomic::AtomicBool;
577
578 #[test]
583 fn test_batch_result_batchsize() {
584 let result = BatchResult {
585 taskid: 0,
586 batch: 10..25,
587 latency: MicroSeconds::new(1000),
588 output: "test",
589 };
590 assert_eq!(result.batchsize(), 15);
591
592 let empty_result = BatchResult {
593 taskid: 1,
594 batch: 5..5,
595 latency: MicroSeconds::new(0),
596 output: 42,
597 };
598 assert_eq!(empty_result.batchsize(), 0);
599 }
600
601 #[test]
602 fn test_build_results_accessors() {
603 let batch1 = BatchResult {
604 taskid: 0,
605 batch: 0..10,
606 latency: MicroSeconds::new(100),
607 output: "first",
608 };
609 let batch2 = BatchResult {
610 taskid: 1,
611 batch: 10..20,
612 latency: MicroSeconds::new(200),
613 output: "second",
614 };
615
616 let results = BuildResults {
617 output: vec![batch1, batch2],
618 end_to_end_latency: MicroSeconds::new(500),
619 };
620
621 assert_eq!(results.end_to_end_latency(), MicroSeconds::new(500));
622 assert_eq!(results.output().len(), 2);
623 assert_eq!(results.output()[0].output, "first");
624 assert_eq!(results.output()[1].output, "second");
625
626 let output = results.take_output();
627 assert_eq!(output.len(), 2);
628 assert_eq!(output[0].output, "first");
629 assert_eq!(output[1].output, "second");
630 }
631
632 fn sort_ranges(x: &Range<usize>, y: &Range<usize>) -> std::cmp::Ordering {
637 x.start.cmp(&y.start)
638 }
639
640 fn check_ranges(x: &mut [Range<usize>], total: usize) {
641 x.sort_by(sort_ranges);
642 let mut expected_start = 0;
643 for r in x {
644 assert_eq!(r.start, expected_start);
645 expected_start = r.end;
646 }
647 assert_eq!(expected_start, total);
648 }
649
650 fn collect_all_ranges(control: &ControlBlock) -> Vec<Range<usize>> {
652 let mut ranges = Vec::new();
653 while let Some(range) = control.next() {
654 ranges.push(range);
655 }
656 ranges
657 }
658
659 #[test]
660 fn test_control_block() {
661 let test_cases: &[(usize, usize, &str)] = &[
663 (10, 3, "not evenly divisible"),
664 (9, 3, "exact multiple of batchsize"),
665 (0, 5, "empty range"),
666 (1, 1, "single element"),
667 (3, 10, "batchsize larger than max"),
668 (5, 5, "batchsize equals max"),
669 (5, 1, "batchsize one (sequential)"),
670 (10000, 128, "larger range"),
671 (usize::MAX, usize::MAX / 2 - 1, "very large numbers"),
672 ];
673
674 for &(max, batchsize, desc) in test_cases {
675 let control = ControlBlock::new(max, NonZeroUsize::new(batchsize).unwrap());
676 let mut ranges = collect_all_ranges(&control);
677 let expected_num_ranges = max.div_ceil(batchsize);
678
679 assert_eq!(
680 ranges.len(),
681 expected_num_ranges,
682 "{desc}: max={max}, batchsize={batchsize}: expected {expected_num_ranges} ranges, got {}",
683 ranges.len()
684 );
685 check_ranges(&mut ranges, max);
686 for _ in 1..3 {
687 assert!(control.next().is_none(), "{desc}: expected no more ranges");
688 }
689 }
690 }
691
692 #[test]
693 fn concurrent_access_yields_disjoint_complete_ranges() {
694 let max = 10000;
695 let control = ControlBlock::new(max, NonZeroUsize::new(7).unwrap());
696 let num_threads = 4;
697
698 let barrier = std::sync::Barrier::new(num_threads);
699 let mut all_ranges = std::thread::scope(|s| {
700 let handles: Vec<_> = (0..num_threads)
701 .map(|_| {
702 s.spawn(|| {
703 barrier.wait();
704 collect_all_ranges(&control.clone())
705 })
706 })
707 .collect();
708
709 handles
710 .into_iter()
711 .flat_map(|h| h.join().unwrap())
712 .collect::<Vec<_>>()
713 });
714
715 check_ranges(&mut all_ranges, max);
716 }
717
718 #[test]
723 fn test_chunks_basic() {
724 #[expect(
726 clippy::single_range_in_vec_init,
727 reason = "these are test cases - sometimes we do need an array of a simgle element range"
728 )]
729 let test_cases: &[(_, _, &[_])] = &[
730 (0..9, 3, &[0..3, 3..6, 6..9]),
732 (0..10, 3, &[0..3, 3..6, 6..9, 9..10]),
734 (0..5, 5, &[0..5]),
736 (0..3, 10, &[0..3]),
738 (0..1, 1, &[0..1]),
740 (0..1, 5, &[0..1]),
742 (0..0, 3, &[]),
744 (5..15, 3, &[5..8, 8..11, 11..14, 14..15]),
746 (10..16, 2, &[10..12, 12..14, 14..16]),
748 ];
749
750 for (range, chunk_size, expected) in test_cases {
751 let chunks: Vec<_> = Chunks::new(range.clone(), nz(*chunk_size)).collect();
752 assert_eq!(
753 &chunks, expected,
754 "Chunks::new({:?}, {}) produced {:?}, expected {:?}",
755 range, chunk_size, chunks, expected
756 );
757 }
758 }
759
760 #[test]
761 fn test_chunks_size_hint() {
762 let mut chunks = Chunks::new(0..10, nz(3));
764
765 assert_eq!(chunks.size_hint(), (4, Some(4)));
766 assert_eq!(chunks.len(), 4);
767
768 chunks.next(); assert_eq!(chunks.size_hint(), (3, Some(3)));
770 assert_eq!(chunks.len(), 3);
771
772 chunks.next(); assert_eq!(chunks.size_hint(), (2, Some(2)));
774
775 chunks.next(); assert_eq!(chunks.size_hint(), (1, Some(1)));
777
778 chunks.next(); assert_eq!(chunks.size_hint(), (0, Some(0)));
780 assert_eq!(chunks.len(), 0);
781
782 assert!(chunks.next().is_none());
784 assert_eq!(chunks.size_hint(), (0, Some(0)));
785 }
786
787 #[test]
788 fn test_chunks_empty_range() {
789 let chunks: Vec<_> = Chunks::new(0..0, nz(5)).collect();
790 assert!(chunks.is_empty());
791
792 let chunks: Vec<_> = Chunks::new(10..10, nz(3)).collect();
793 assert!(chunks.is_empty());
794 }
795
796 #[test]
797 fn test_chunks_covers_entire_range() {
798 let test_cases: &[(Range<usize>, usize)] = &[
800 (0..100, 7),
801 (0..1000, 13),
802 (50..150, 11),
803 (0..1, 1),
804 (0..17, 17),
805 (0..17, 18),
806 ];
807
808 for (range, chunk_size) in test_cases {
809 let chunks: Vec<_> = Chunks::new(range.clone(), nz(*chunk_size)).collect();
810
811 let mut expected_start = range.start;
813 for chunk in &chunks {
814 assert_eq!(
815 chunk.start, expected_start,
816 "Gap detected at {} (expected {})",
817 chunk.start, expected_start
818 );
819 assert!(chunk.end > chunk.start, "Empty chunk detected: {:?}", chunk);
820 expected_start = chunk.end;
821 }
822 assert_eq!(expected_start, range.end, "Chunks don't cover entire range");
823
824 for (i, chunk) in chunks.iter().enumerate() {
826 if i < chunks.len() - 1 {
827 assert_eq!(chunk.len(), *chunk_size, "Non-final chunk has wrong size");
828 } else {
829 assert!(
830 chunk.len() <= *chunk_size,
831 "Final chunk is larger than chunk_size"
832 );
833 }
834 }
835 }
836 }
837
838 #[test]
839 fn test_chunks_large_range() {
840 let range = 0..1_000_000;
842 let chunk_size = 1000;
843 let chunks: Vec<_> = Chunks::new(range.clone(), nz(chunk_size)).collect();
844
845 assert_eq!(chunks.len(), 1000);
846 assert_eq!(chunks.first(), Some(&(0..1000)));
847 assert_eq!(chunks.last(), Some(&(999_000..1_000_000)));
848 }
849
850 fn nz(n: usize) -> NonZeroUsize {
856 NonZeroUsize::new(n).unwrap()
857 }
858
859 struct MockBuild {
861 num_data: usize,
862 }
863
864 impl MockBuild {
865 fn new(num_data: usize) -> Self {
866 Self { num_data }
867 }
868 }
869
870 impl Build for MockBuild {
871 type Output = Range<usize>;
872
873 fn num_data(&self) -> usize {
874 self.num_data
875 }
876
877 async fn build(&self, range: Range<usize>) -> ANNResult<Self::Output> {
878 Ok(range)
879 }
880 }
881
882 struct MockProgress {
884 total_handled: AtomicUsize,
885 finish_called: AtomicBool,
886 }
887
888 impl MockProgress {
889 fn new() -> Self {
890 Self {
891 total_handled: AtomicUsize::new(0),
892 finish_called: AtomicBool::new(false),
893 }
894 }
895
896 fn total_handled(&self) -> usize {
897 self.total_handled.load(Ordering::Relaxed)
898 }
899
900 fn was_finished(&self) -> bool {
901 self.finish_called.load(Ordering::Relaxed)
902 }
903 }
904
905 impl Progress for MockProgress {
906 fn progress(&self, handled: usize) {
907 self.total_handled.fetch_add(handled, Ordering::Relaxed);
908 }
909
910 fn finish(&self) {
911 self.finish_called.store(true, Ordering::Relaxed);
912 }
913 }
914
915 struct MockAsProgress {
917 progress: Arc<MockProgress>,
918 expected_max: AtomicUsize,
919 }
920
921 impl MockAsProgress {
922 fn new() -> Self {
923 Self {
924 progress: Arc::new(MockProgress::new()),
925 expected_max: AtomicUsize::new(0),
926 }
927 }
928
929 fn progress(&self) -> &Arc<MockProgress> {
930 &self.progress
931 }
932
933 fn received_max(&self) -> usize {
934 self.expected_max.load(Ordering::Relaxed)
935 }
936 }
937
938 impl AsProgress for MockAsProgress {
939 fn as_progress(&self, max: usize) -> Arc<dyn Progress> {
940 self.expected_max.store(max, Ordering::Relaxed);
941 self.progress.clone()
942 }
943 }
944
945 #[test]
946 fn test_build() {
947 let test_cases: &[(usize, usize, Parallelism, &str)] = &[
949 (
950 4,
951 100,
952 Parallelism::dynamic(nz(10), nz(4)),
953 "basic multi-task",
954 ),
955 (1, 50, Parallelism::dynamic(nz(10), nz(1)), "single task"),
956 (4, 0, Parallelism::dynamic(nz(10), nz(4)), "empty data"),
957 (
958 4,
959 5,
960 Parallelism::dynamic(nz(100), nz(4)),
961 "batchsize larger than data",
962 ),
963 (2, 20, Parallelism::dynamic(nz(5), nz(2)), "small dataset"),
964 (
965 8,
966 1000,
967 Parallelism::dynamic(nz(7), nz(8)),
968 "larger dataset with odd batchsize",
969 ),
970 (
971 4,
972 100,
973 Parallelism::dynamic(nz(10), nz(1)),
974 "multiple threads but single task",
975 ),
976 (
977 2,
978 50,
979 Parallelism::sequential(nz(10)),
980 "sequential execution",
981 ),
982 (
984 4,
985 100,
986 Parallelism::fixed(Some(nz(10)), nz(4)),
987 "fixed with batchsize",
988 ),
989 (
990 4,
991 100,
992 Parallelism::fixed(None, nz(4)),
993 "fixed without batchsize (whole partition per task)",
994 ),
995 (
996 2,
997 50,
998 Parallelism::fixed(Some(nz(5)), nz(2)),
999 "fixed with small batchsize",
1000 ),
1001 (
1002 8,
1003 1000,
1004 Parallelism::fixed(Some(nz(100)), nz(8)),
1005 "fixed larger dataset",
1006 ),
1007 (
1008 4,
1009 0,
1010 Parallelism::fixed(Some(nz(10)), nz(4)),
1011 "fixed empty data",
1012 ),
1013 (
1014 4,
1015 5,
1016 Parallelism::fixed(Some(nz(100)), nz(4)),
1017 "fixed batchsize larger than partition",
1018 ),
1019 (
1020 1,
1021 50,
1022 Parallelism::fixed(Some(nz(10)), nz(1)),
1023 "fixed single task with batchsize",
1024 ),
1025 (
1026 1,
1027 50,
1028 Parallelism::fixed(None, nz(1)),
1029 "fixed single task without batchsize",
1030 ),
1031 (
1032 4,
1033 7,
1034 Parallelism::fixed(Some(nz(2)), nz(4)),
1035 "fixed uneven partition with batchsize",
1036 ),
1037 ];
1038
1039 for (num_threads, num_data, parallelism, desc) in test_cases {
1040 let num_data = *num_data;
1041 let runtime = crate::tokio::runtime(*num_threads).unwrap();
1042
1043 let (ntasks, expected_batches) = match parallelism {
1044 Parallelism::Dynamic { batchsize, ntasks } => {
1045 let expected = num_data.div_ceil(batchsize.get());
1046 (*ntasks, expected)
1047 }
1048 Parallelism::Sequential { batchsize } => {
1049 let expected = num_data.div_ceil(batchsize.get());
1050 (nz(1), expected)
1051 }
1052 Parallelism::Fixed { batchsize, ntasks } => {
1053 use diskann::utils::async_tools::PartitionIter;
1056 let expected: usize = PartitionIter::new(num_data, *ntasks)
1057 .map(|partition| match batchsize {
1058 Some(bs) => partition.len().div_ceil(bs.get()),
1059 None => {
1060 if partition.is_empty() {
1061 0
1062 } else {
1063 1
1064 }
1065 }
1066 })
1067 .sum();
1068 (*ntasks, expected)
1069 }
1070 };
1071
1072 let builder = Arc::new(MockBuild::new(num_data));
1073 let mock_as_progress = MockAsProgress::new();
1074
1075 let check_results = |results: BuildResults<Range<usize>>| {
1076 if num_data == 0 {
1077 assert!(
1078 results.output().is_empty(),
1079 "{desc}: no batches for empty data"
1080 );
1081 return;
1082 }
1083
1084 for batch_result in results.output() {
1086 assert_eq!(
1087 batch_result.output, batch_result.batch,
1088 "{desc}: output range should match batch range"
1089 );
1090 assert!(
1091 batch_result.taskid < ntasks.get(),
1092 "{desc}: taskid {} should be less than ntasks {}",
1093 batch_result.taskid,
1094 ntasks.get()
1095 );
1096 }
1097
1098 assert_eq!(
1099 results.output().len(),
1100 expected_batches,
1101 "{desc}: expected {expected_batches} batches, got {}",
1102 results.output().len()
1103 );
1104
1105 let mut ranges: Vec<_> = results.output().iter().map(|r| r.batch.clone()).collect();
1107 check_ranges(&mut ranges, num_data);
1108 };
1109
1110 let results = build_tracked(
1112 builder.clone(),
1113 *parallelism,
1114 &runtime,
1115 Some(&mock_as_progress),
1116 )
1117 .unwrap_or_else(|_| panic!("{desc}: build_tracked should succeed"));
1118
1119 assert_eq!(
1121 mock_as_progress.received_max(),
1122 num_data,
1123 "{desc}: as_progress should receive num_data as max"
1124 );
1125 assert_eq!(
1126 mock_as_progress.progress().total_handled(),
1127 num_data,
1128 "{desc}: total progress should equal num_data"
1129 );
1130 assert!(
1131 mock_as_progress.progress().was_finished(),
1132 "{desc}: finish should be called"
1133 );
1134
1135 check_results(results);
1136
1137 let results = build(builder, *parallelism, &runtime)
1139 .unwrap_or_else(|_| panic!("{desc}: build should succeed"));
1140 check_results(results);
1141 }
1142 }
1143}