Skip to main content

diskann_benchmark_core/build/
api.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    any::Any,
8    future::Future,
9    num::NonZeroUsize,
10    ops::Range,
11    pin::Pin,
12    sync::{
13        Arc,
14        atomic::{AtomicUsize, Ordering},
15    },
16};
17
18use diskann::{ANNError, ANNResult};
19use diskann_benchmark_runner::utils::MicroSeconds;
20use diskann_utils::future::{AsyncFriendly, boxit};
21
22/// The core build API.
23///
24/// This uses a model where the data over which the index build is stored internally and
25/// identified by its index. Data is numbered from `0` to `N - 1` where `N = Build::num_data()`
26/// is the total number of data points.
27///
28/// The trait is used in conjunction with [`build`] and [`build_tracked`]. See the documentation
29/// of those methods for more details.
30pub trait Build: AsyncFriendly {
31    /// Custom output parameters. This augments the standard metrics collected by [`build`] and
32    /// allows implementation-specific data to be returned.
33    type Output: AsyncFriendly;
34
35    /// Return the number of data points to build the index over. The machinery in [`build`] and
36    /// [`build_tracked`] will partition the range `0..num_data()` into disjoint ranges and call
37    /// [`Build::build`] on each range in an unspecified order.
38    fn num_data(&self) -> usize;
39
40    /// Insert the data points specified by the range. Implementations may assume that the range is
41    /// non-empty, within `0..num_data()`, and disjoint from other ranges passed to concurrent calls
42    /// while in [`build`] or [`build_tracked`].
43    ///
44    /// Multiple calls may be made in parallel.
45    fn build(&self, range: Range<usize>) -> impl Future<Output = ANNResult<Self::Output>> + Send;
46}
47
48/// The results of processing a single batch during build.
49///
50/// This struct is marked as `#[non_exhaustive]` to allow for future extension.
51///
52/// See: [`BuildResults`], [`build`] and [`build_tracked`].
53#[derive(Debug, Clone)]
54#[non_exhaustive]
55pub struct BatchResult<T> {
56    /// The index of the task that executed this batch. This will be in the range `0..ntasks` where
57    /// `ntasks` is the number of tasks specified to [`build`].
58    pub taskid: usize,
59
60    /// The range of data points processed by this batch.
61    pub batch: Range<usize>,
62
63    /// The wall clock time taken to process this batch.
64    pub latency: MicroSeconds,
65
66    /// The customized [`Build::Output`] for this batch.
67    pub output: T,
68}
69
70impl<T> BatchResult<T> {
71    /// Return the number of points in the batch associated with this result.
72    pub fn batchsize(&self) -> usize {
73        self.batch.len()
74    }
75}
76
77/// Aggregated results for a build operation.
78///
79/// See: [`build`] and [`build_tracked`].
80#[derive(Debug)]
81pub struct BuildResults<T> {
82    output: Vec<BatchResult<T>>,
83    end_to_end_latency: MicroSeconds,
84}
85
86impl<T> BuildResults<T> {
87    /// Return the total wall-clock time for the entire build operation.
88    pub fn end_to_end_latency(&self) -> MicroSeconds {
89        self.end_to_end_latency
90    }
91
92    /// Return the per-batch results by reference.
93    pub fn output(&self) -> &[BatchResult<T>] {
94        &self.output
95    }
96
97    /// Consume `self` and return the per-batch results by value.
98    pub fn take_output(self) -> Vec<BatchResult<T>> {
99        self.output
100    }
101}
102
103impl<T> BuildResults<T>
104where
105    T: Any,
106{
107    /// This is a private inner constructor that converts the type-erased `BuildResultsInner` into
108    /// a fully typed container.
109    ///
110    /// This requires that the dynamic type of the boxed [`Any`] outputs in `inner` is `T`.
111    fn new(inner: BuildResultsInner) -> Self {
112        let BuildResultsInner {
113            end_to_end_latency,
114            task_results,
115        } = inner;
116        let mut output = Vec::with_capacity(task_results.iter().map(|t| t.len()).sum());
117
118        task_results
119            .into_iter()
120            .enumerate()
121            .for_each(|(taskid, results)| {
122                results.into_iter().for_each(|r| {
123                    output.push(BatchResult {
124                        taskid,
125                        batch: r.batch,
126                        latency: r.latency,
127                        output: *r
128                            .output
129                            .downcast::<T>()
130                            .expect("incorrect downcast applied"),
131                    })
132                })
133            });
134
135        Self {
136            output,
137            end_to_end_latency,
138        }
139    }
140}
141
142/// Control the parallel partitioning strategy for [`build`] and [`build_tracked`].
143///
144/// Many aspects of this enum are `#[non_exhaustive]` to allow for future extension.
145/// Users should use the associated constructors instead to create instances.
146#[derive(Debug, Clone, Copy, PartialEq)]
147#[non_exhaustive]
148pub enum Parallelism {
149    /// Use dynamic load balancing to partition the work into batches of at most `batchsize`.
150    /// When the batchsize is 1, the implementation guarantees sequential execution.
151    ///
152    /// The batches assigned to each task can be assumed to be monotonically increasing.
153    ///
154    /// See: [`Parallelism::dynamic`].
155    #[non_exhaustive]
156    Dynamic {
157        batchsize: NonZeroUsize,
158        ntasks: NonZeroUsize,
159    },
160
161    /// Run the build with just a single task. Input data is still batched.
162    ///
163    /// See: [`Parallelism::sequential`].
164    #[non_exhaustive]
165    Sequential { batchsize: NonZeroUsize },
166
167    /// Create a fixed parallelism strategy with `ntasks` executors. This strategy
168    /// partitions the problem space into roughly `ntasks` balanced contiguous chunks.
169    ///
170    /// If `batchsize` is `Some`, than each chunk will be further subdivided into at most
171    /// `batchsize` sized subchunks which are then provided to [`Build::build`].
172    ///
173    /// If `batchsize` is `None`, then the entire task partition is supplied in a single call
174    /// to [`Build::build`].
175    ///
176    /// See: [`Parallelism::fixed`].
177    #[non_exhaustive]
178    Fixed {
179        batchsize: Option<NonZeroUsize>,
180        ntasks: NonZeroUsize,
181    },
182}
183
184impl Parallelism {
185    /// Create a dynamic parallelism strategy with the specified `batchsize` and `ntasks`.
186    ///
187    /// Returns [`Self::Dynamic`].
188    pub fn dynamic(batchsize: NonZeroUsize, ntasks: NonZeroUsize) -> Self {
189        Self::Dynamic { batchsize, ntasks }
190    }
191
192    /// Create a fixed parallelism strategy with `ntasks` executors and possible
193    /// sub-partitioning into the specified `batchsize`.
194    ///
195    /// Returns [`Self::Fixed`].
196    pub fn fixed(batchsize: Option<NonZeroUsize>, ntasks: NonZeroUsize) -> Self {
197        Self::Fixed { batchsize, ntasks }
198    }
199
200    /// Create a sequential parallelism strategy with the specified `batchsize`.
201    ///
202    /// Returns [`Self::Sequential`].
203    pub fn sequential(batchsize: NonZeroUsize) -> Self {
204        Self::Sequential { batchsize }
205    }
206}
207
208/// Enable lazy creation of a progress reporter for the long running build operation.
209///
210/// See: [`Progress`].
211pub trait AsProgress {
212    /// Construct a progress reporter for an operation consisting of `max` points.
213    fn as_progress(&self, max: usize) -> Arc<dyn Progress>;
214}
215
216/// A simple progress reporter for long running operations.
217pub trait Progress: AsyncFriendly {
218    /// Indicate that `handled` points have been processed.
219    fn progress(&self, handled: usize);
220
221    /// Indicate that the operation has finished.
222    fn finish(&self);
223}
224
225/// Perform a build operation and return the results.
226///
227/// See [`build_tracked`] for more details.
228pub fn build<B>(
229    builder: Arc<B>,
230    parallelism: Parallelism,
231    runtime: &tokio::runtime::Runtime,
232) -> anyhow::Result<BuildResults<B::Output>>
233where
234    B: Build,
235{
236    build_tracked(builder, parallelism, runtime, None)
237}
238
239/// Perform a build operation.
240///
241/// Work will be performed by spawning `ntasks` concurrent tasks in the provided `runtime`.
242/// These tasks will partition the problem space `0..builder.num_data()` into batches according
243/// to the policy in `parallelism`.
244///
245/// If `as_progress` is provided, it will be used to create a progress reporter.
246pub fn build_tracked<B>(
247    builder: Arc<B>,
248    parallelism: Parallelism,
249    runtime: &tokio::runtime::Runtime,
250    as_progress: Option<&dyn AsProgress>,
251) -> anyhow::Result<BuildResults<B::Output>>
252where
253    B: Build,
254{
255    let max = builder.num_data();
256    let results = runtime.block_on(build_inner(
257        builder,
258        parallelism,
259        as_progress.map(|p| p.as_progress(max)),
260    ))?;
261    Ok(BuildResults::new(results))
262}
263
264///////////
265// Inner //
266///////////
267
268/// An inner build method with no generic parameters to reduce code-generation.
269fn build_inner(
270    build: Arc<dyn BuildInner>,
271    parallelism: Parallelism,
272    progress: Option<Arc<dyn Progress>>,
273) -> impl Future<Output = anyhow::Result<BuildResultsInner>> + Send {
274    match parallelism {
275        Parallelism::Dynamic { batchsize, ntasks } => {
276            boxit(build_inner_dynamic(build, batchsize, ntasks, progress))
277        }
278        Parallelism::Sequential { batchsize } => {
279            // Sequential is just dynamic with one task. The dynamic load balancer will ensure that batches
280            // are processed in order.
281            boxit(build_inner_dynamic(
282                build,
283                batchsize,
284                diskann::utils::ONE,
285                progress,
286            ))
287        }
288        Parallelism::Fixed { batchsize, ntasks } => {
289            boxit(build_inner_fixed(build, batchsize, ntasks, progress))
290        }
291    }
292}
293
294type Pinned<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
295
296/// A dyn-compatible version of [`Build`] to reduce monomorphization bloat.
297trait BuildInner: AsyncFriendly {
298    fn num_data(&self) -> usize;
299
300    fn build(&self, range: Range<usize>) -> Pinned<'_, ANNResult<Box<dyn Any + Send>>>;
301}
302
303impl<T> BuildInner for T
304where
305    T: Build,
306{
307    fn num_data(&self) -> usize {
308        <T as Build>::num_data(self)
309    }
310
311    fn build(&self, range: Range<usize>) -> Pinned<'_, ANNResult<Box<dyn Any + Send>>> {
312        use futures_util::TryFutureExt;
313
314        boxit(<T as Build>::build(self, range).map_ok(|r| -> Box<dyn Any + Send> { Box::new(r) }))
315    }
316}
317
318/// Type erased inner build results.
319#[derive(Debug)]
320struct BuildResultsInner {
321    end_to_end_latency: MicroSeconds,
322
323    /// This field has an implicit correspondence with the task-id.
324    ///
325    /// Index `0` corresponds to task `0`, index `1` to task `1` and so on.
326    task_results: Vec<Vec<BatchResultsInner>>,
327}
328
329#[derive(Debug)]
330struct BatchResultsInner {
331    batch: Range<usize>,
332    latency: MicroSeconds,
333    /// Note that this has dynamic type `Build::Output`.
334    output: Box<dyn Any + Send>,
335}
336
337//---------//
338// Dynamic //
339//---------//
340
341/// The inner implementation for [`Parallelism::Dynamic`].
342async fn build_inner_dynamic(
343    build: Arc<dyn BuildInner>,
344    batchsize: NonZeroUsize,
345    ntasks: NonZeroUsize,
346    progress: Option<Arc<dyn Progress>>,
347) -> anyhow::Result<BuildResultsInner> {
348    let start = std::time::Instant::now();
349    let control = ControlBlock::new(build.num_data(), batchsize);
350    let handles: Vec<_> = (0..ntasks.get())
351        .map(|_| {
352            let build_clone = build.clone();
353            let control_clone = control.clone();
354            let progress_clone = progress.clone();
355            tokio::spawn(async move {
356                let mut results = Vec::new();
357                while let Some(batch) = control_clone.next() {
358                    let start = std::time::Instant::now();
359                    let output = build_clone.build(batch.clone()).await?;
360                    let latency: MicroSeconds = start.elapsed().into();
361
362                    if let Some(p) = progress_clone.as_deref() {
363                        p.progress(batch.len());
364                    }
365
366                    results.push(BatchResultsInner {
367                        batch,
368                        latency,
369                        output,
370                    });
371                }
372                Ok::<_, ANNError>(results)
373            })
374        })
375        .collect();
376
377    let mut task_results = Vec::with_capacity(ntasks.into());
378    for h in handles {
379        task_results.push(h.await??);
380    }
381
382    let end_to_end_latency: MicroSeconds = start.elapsed().into();
383    if let Some(p) = progress.as_deref() {
384        p.finish();
385    }
386
387    Ok(BuildResultsInner {
388        end_to_end_latency,
389        task_results,
390    })
391}
392
393#[derive(Debug, Clone)]
394struct ControlBlock(Arc<ControlBlockInner>);
395
396impl ControlBlock {
397    fn new(max: usize, batchsize: NonZeroUsize) -> Self {
398        Self(Arc::new(ControlBlockInner::new(max, batchsize)))
399    }
400
401    fn next(&self) -> Option<Range<usize>> {
402        // We need to be careful about overflowing and the potential conflict with multiple
403        // threads working with changes.
404        //
405        // The solution, unfortunately, is to use a compare-exchange loop.
406        let mut start = self.0.head.load(Ordering::Relaxed);
407
408        loop {
409            let next = start.saturating_add(self.0.batchsize.get()).min(self.0.max);
410            if next == start {
411                return None;
412            }
413
414            match self
415                .0
416                .head
417                .compare_exchange(start, next, Ordering::Relaxed, Ordering::Relaxed)
418            {
419                Ok(_) => return Some(start..next),
420                Err(current) => {
421                    start = current;
422                }
423            }
424        }
425    }
426}
427
428#[derive(Debug)]
429struct ControlBlockInner {
430    head: AtomicUsize,
431    max: usize,
432    batchsize: NonZeroUsize,
433}
434
435impl ControlBlockInner {
436    fn new(max: usize, batchsize: NonZeroUsize) -> Self {
437        Self {
438            head: AtomicUsize::new(0),
439            max,
440            batchsize,
441        }
442    }
443}
444
445//-------//
446// Fixed //
447//-------//
448
449async fn build_inner_fixed(
450    build: Arc<dyn BuildInner>,
451    batchsize: Option<NonZeroUsize>,
452    ntasks: NonZeroUsize,
453    progress: Option<Arc<dyn Progress>>,
454) -> anyhow::Result<BuildResultsInner> {
455    use diskann::utils::async_tools::PartitionIter;
456
457    let start = std::time::Instant::now();
458    let handles: Vec<_> = PartitionIter::new(build.num_data(), ntasks)
459        .map(|range| {
460            let build_clone = build.clone();
461            let progress_clone = progress.clone();
462            tokio::spawn(async move {
463                let mut results = Vec::new();
464                match batchsize {
465                    Some(batchsize) => {
466                        for batch in Chunks::new(range, batchsize) {
467                            let start = std::time::Instant::now();
468                            let output = build_clone.build(batch.clone()).await?;
469                            let latency: MicroSeconds = start.elapsed().into();
470
471                            if let Some(p) = progress_clone.as_deref() {
472                                p.progress(batch.len());
473                            }
474
475                            results.push(BatchResultsInner {
476                                batch,
477                                latency,
478                                output,
479                            });
480                        }
481                    }
482                    None => {
483                        let start = std::time::Instant::now();
484                        let output = build_clone.build(range.clone()).await?;
485                        let latency: MicroSeconds = start.elapsed().into();
486
487                        if let Some(p) = progress_clone.as_deref() {
488                            p.progress(range.len());
489                        }
490
491                        results.push(BatchResultsInner {
492                            batch: range,
493                            latency,
494                            output,
495                        });
496                    }
497                }
498                Ok::<_, ANNError>(results)
499            })
500        })
501        .collect();
502
503    let mut task_results = Vec::with_capacity(ntasks.into());
504    for h in handles {
505        task_results.push(h.await??);
506    }
507
508    let end_to_end_latency: MicroSeconds = start.elapsed().into();
509    if let Some(p) = progress.as_deref() {
510        p.finish();
511    }
512
513    Ok(BuildResultsInner {
514        end_to_end_latency,
515        task_results,
516    })
517}
518
519/// An iterator that partitions a [`Range<usize>`] into equal-sized sub-ranges.
520#[derive(Debug, Clone)]
521struct Chunks {
522    /// The current position in the range.
523    current: usize,
524    /// The end of the range.
525    end: usize,
526    /// The size of each chunk (except possibly the last).
527    chunk_size: NonZeroUsize,
528}
529
530impl Chunks {
531    fn new(range: Range<usize>, chunk_size: NonZeroUsize) -> Self {
532        Self {
533            current: range.start,
534            end: range.end,
535            chunk_size,
536        }
537    }
538}
539
540impl Iterator for Chunks {
541    type Item = Range<usize>;
542
543    fn next(&mut self) -> Option<Self::Item> {
544        if self.current >= self.end {
545            return None;
546        }
547
548        let start = self.current;
549        let end = (start + self.chunk_size.get()).min(self.end);
550        self.current = end;
551
552        Some(start..end)
553    }
554
555    fn size_hint(&self) -> (usize, Option<usize>) {
556        if self.current >= self.end {
557            return (0, Some(0));
558        }
559
560        let remaining = self.end - self.current;
561        let count = remaining.div_ceil(self.chunk_size.get());
562        (count, Some(count))
563    }
564}
565
566impl ExactSizeIterator for Chunks {}
567
568///////////
569// Tests //
570///////////
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575
576    use std::sync::atomic::AtomicBool;
577
578    /////////////////////////////////
579    // BatchResult / BuildResults //
580    /////////////////////////////////
581
582    #[test]
583    fn test_batch_result_batchsize() {
584        let result = BatchResult {
585            taskid: 0,
586            batch: 10..25,
587            latency: MicroSeconds::new(1000),
588            output: "test",
589        };
590        assert_eq!(result.batchsize(), 15);
591
592        let empty_result = BatchResult {
593            taskid: 1,
594            batch: 5..5,
595            latency: MicroSeconds::new(0),
596            output: 42,
597        };
598        assert_eq!(empty_result.batchsize(), 0);
599    }
600
601    #[test]
602    fn test_build_results_accessors() {
603        let batch1 = BatchResult {
604            taskid: 0,
605            batch: 0..10,
606            latency: MicroSeconds::new(100),
607            output: "first",
608        };
609        let batch2 = BatchResult {
610            taskid: 1,
611            batch: 10..20,
612            latency: MicroSeconds::new(200),
613            output: "second",
614        };
615
616        let results = BuildResults {
617            output: vec![batch1, batch2],
618            end_to_end_latency: MicroSeconds::new(500),
619        };
620
621        assert_eq!(results.end_to_end_latency(), MicroSeconds::new(500));
622        assert_eq!(results.output().len(), 2);
623        assert_eq!(results.output()[0].output, "first");
624        assert_eq!(results.output()[1].output, "second");
625
626        let output = results.take_output();
627        assert_eq!(output.len(), 2);
628        assert_eq!(output[0].output, "first");
629        assert_eq!(output[1].output, "second");
630    }
631
632    ///////////////////
633    // Control Block //
634    ///////////////////
635
636    fn sort_ranges(x: &Range<usize>, y: &Range<usize>) -> std::cmp::Ordering {
637        x.start.cmp(&y.start)
638    }
639
640    fn check_ranges(x: &mut [Range<usize>], total: usize) {
641        x.sort_by(sort_ranges);
642        let mut expected_start = 0;
643        for r in x {
644            assert_eq!(r.start, expected_start);
645            expected_start = r.end;
646        }
647        assert_eq!(expected_start, total);
648    }
649
650    /// Helper to collect all ranges from a ControlBlock.
651    fn collect_all_ranges(control: &ControlBlock) -> Vec<Range<usize>> {
652        let mut ranges = Vec::new();
653        while let Some(range) = control.next() {
654            ranges.push(range);
655        }
656        ranges
657    }
658
659    #[test]
660    fn test_control_block() {
661        // (max, batchsize, description)
662        let test_cases: &[(usize, usize, &str)] = &[
663            (10, 3, "not evenly divisible"),
664            (9, 3, "exact multiple of batchsize"),
665            (0, 5, "empty range"),
666            (1, 1, "single element"),
667            (3, 10, "batchsize larger than max"),
668            (5, 5, "batchsize equals max"),
669            (5, 1, "batchsize one (sequential)"),
670            (10000, 128, "larger range"),
671            (usize::MAX, usize::MAX / 2 - 1, "very large numbers"),
672        ];
673
674        for &(max, batchsize, desc) in test_cases {
675            let control = ControlBlock::new(max, NonZeroUsize::new(batchsize).unwrap());
676            let mut ranges = collect_all_ranges(&control);
677            let expected_num_ranges = max.div_ceil(batchsize);
678
679            assert_eq!(
680                ranges.len(),
681                expected_num_ranges,
682                "{desc}: max={max}, batchsize={batchsize}: expected {expected_num_ranges} ranges, got {}",
683                ranges.len()
684            );
685            check_ranges(&mut ranges, max);
686            for _ in 1..3 {
687                assert!(control.next().is_none(), "{desc}: expected no more ranges");
688            }
689        }
690    }
691
692    #[test]
693    fn concurrent_access_yields_disjoint_complete_ranges() {
694        let max = 10000;
695        let control = ControlBlock::new(max, NonZeroUsize::new(7).unwrap());
696        let num_threads = 4;
697
698        let barrier = std::sync::Barrier::new(num_threads);
699        let mut all_ranges = std::thread::scope(|s| {
700            let handles: Vec<_> = (0..num_threads)
701                .map(|_| {
702                    s.spawn(|| {
703                        barrier.wait();
704                        collect_all_ranges(&control.clone())
705                    })
706                })
707                .collect();
708
709            handles
710                .into_iter()
711                .flat_map(|h| h.join().unwrap())
712                .collect::<Vec<_>>()
713        });
714
715        check_ranges(&mut all_ranges, max);
716    }
717
718    ////////////
719    // Chunks //
720    ////////////
721
722    #[test]
723    fn test_chunks_basic() {
724        // Basic cases: (range, chunk_size, expected_chunks)
725        #[expect(
726            clippy::single_range_in_vec_init,
727            reason = "these are test cases - sometimes we do need an array of a simgle element range"
728        )]
729        let test_cases: &[(_, _, &[_])] = &[
730            // Evenly divisible
731            (0..9, 3, &[0..3, 3..6, 6..9]),
732            // Not evenly divisible - last chunk is smaller
733            (0..10, 3, &[0..3, 3..6, 6..9, 9..10]),
734            // Chunk size equals range length
735            (0..5, 5, &[0..5]),
736            // Chunk size larger than range length
737            (0..3, 10, &[0..3]),
738            // Single element
739            (0..1, 1, &[0..1]),
740            // Single element with larger chunk size
741            (0..1, 5, &[0..1]),
742            // Empty range
743            (0..0, 3, &[]),
744            // Non-zero start
745            (5..15, 3, &[5..8, 8..11, 11..14, 14..15]),
746            // Non-zero start, evenly divisible
747            (10..16, 2, &[10..12, 12..14, 14..16]),
748        ];
749
750        for (range, chunk_size, expected) in test_cases {
751            let chunks: Vec<_> = Chunks::new(range.clone(), nz(*chunk_size)).collect();
752            assert_eq!(
753                &chunks, expected,
754                "Chunks::new({:?}, {}) produced {:?}, expected {:?}",
755                range, chunk_size, chunks, expected
756            );
757        }
758    }
759
760    #[test]
761    fn test_chunks_size_hint() {
762        // Test that size_hint is accurate
763        let mut chunks = Chunks::new(0..10, nz(3));
764
765        assert_eq!(chunks.size_hint(), (4, Some(4)));
766        assert_eq!(chunks.len(), 4);
767
768        chunks.next(); // consume 0..3
769        assert_eq!(chunks.size_hint(), (3, Some(3)));
770        assert_eq!(chunks.len(), 3);
771
772        chunks.next(); // consume 3..6
773        assert_eq!(chunks.size_hint(), (2, Some(2)));
774
775        chunks.next(); // consume 6..9
776        assert_eq!(chunks.size_hint(), (1, Some(1)));
777
778        chunks.next(); // consume 9..10
779        assert_eq!(chunks.size_hint(), (0, Some(0)));
780        assert_eq!(chunks.len(), 0);
781
782        // After exhaustion
783        assert!(chunks.next().is_none());
784        assert_eq!(chunks.size_hint(), (0, Some(0)));
785    }
786
787    #[test]
788    fn test_chunks_empty_range() {
789        let chunks: Vec<_> = Chunks::new(0..0, nz(5)).collect();
790        assert!(chunks.is_empty());
791
792        let chunks: Vec<_> = Chunks::new(10..10, nz(3)).collect();
793        assert!(chunks.is_empty());
794    }
795
796    #[test]
797    fn test_chunks_covers_entire_range() {
798        // Verify that chunks cover the entire range without gaps or overlaps
799        let test_cases: &[(Range<usize>, usize)] = &[
800            (0..100, 7),
801            (0..1000, 13),
802            (50..150, 11),
803            (0..1, 1),
804            (0..17, 17),
805            (0..17, 18),
806        ];
807
808        for (range, chunk_size) in test_cases {
809            let chunks: Vec<_> = Chunks::new(range.clone(), nz(*chunk_size)).collect();
810
811            // Verify no gaps and no overlaps
812            let mut expected_start = range.start;
813            for chunk in &chunks {
814                assert_eq!(
815                    chunk.start, expected_start,
816                    "Gap detected at {} (expected {})",
817                    chunk.start, expected_start
818                );
819                assert!(chunk.end > chunk.start, "Empty chunk detected: {:?}", chunk);
820                expected_start = chunk.end;
821            }
822            assert_eq!(expected_start, range.end, "Chunks don't cover entire range");
823
824            // Verify chunk sizes
825            for (i, chunk) in chunks.iter().enumerate() {
826                if i < chunks.len() - 1 {
827                    assert_eq!(chunk.len(), *chunk_size, "Non-final chunk has wrong size");
828                } else {
829                    assert!(
830                        chunk.len() <= *chunk_size,
831                        "Final chunk is larger than chunk_size"
832                    );
833                }
834            }
835        }
836    }
837
838    #[test]
839    fn test_chunks_large_range() {
840        // Test with a large range to ensure no overflow issues
841        let range = 0..1_000_000;
842        let chunk_size = 1000;
843        let chunks: Vec<_> = Chunks::new(range.clone(), nz(chunk_size)).collect();
844
845        assert_eq!(chunks.len(), 1000);
846        assert_eq!(chunks.first(), Some(&(0..1000)));
847        assert_eq!(chunks.last(), Some(&(999_000..1_000_000)));
848    }
849
850    ///////////////////////////
851    // Build / Build Tracked //
852    ///////////////////////////
853
854    /// Helper to construct a `NonZeroUsize` from a `usize` in tests.
855    fn nz(n: usize) -> NonZeroUsize {
856        NonZeroUsize::new(n).unwrap()
857    }
858
859    /// A mock implementation of [`Build`] that returns the range it was called with.
860    struct MockBuild {
861        num_data: usize,
862    }
863
864    impl MockBuild {
865        fn new(num_data: usize) -> Self {
866            Self { num_data }
867        }
868    }
869
870    impl Build for MockBuild {
871        type Output = Range<usize>;
872
873        fn num_data(&self) -> usize {
874            self.num_data
875        }
876
877        async fn build(&self, range: Range<usize>) -> ANNResult<Self::Output> {
878            Ok(range)
879        }
880    }
881
882    /// A mock implementation of [`Progress`] that tracks calls.
883    struct MockProgress {
884        total_handled: AtomicUsize,
885        finish_called: AtomicBool,
886    }
887
888    impl MockProgress {
889        fn new() -> Self {
890            Self {
891                total_handled: AtomicUsize::new(0),
892                finish_called: AtomicBool::new(false),
893            }
894        }
895
896        fn total_handled(&self) -> usize {
897            self.total_handled.load(Ordering::Relaxed)
898        }
899
900        fn was_finished(&self) -> bool {
901            self.finish_called.load(Ordering::Relaxed)
902        }
903    }
904
905    impl Progress for MockProgress {
906        fn progress(&self, handled: usize) {
907            self.total_handled.fetch_add(handled, Ordering::Relaxed);
908        }
909
910        fn finish(&self) {
911            self.finish_called.store(true, Ordering::Relaxed);
912        }
913    }
914
915    /// A mock implementation of [`AsProgress`] that creates a [`MockProgress`].
916    struct MockAsProgress {
917        progress: Arc<MockProgress>,
918        expected_max: AtomicUsize,
919    }
920
921    impl MockAsProgress {
922        fn new() -> Self {
923            Self {
924                progress: Arc::new(MockProgress::new()),
925                expected_max: AtomicUsize::new(0),
926            }
927        }
928
929        fn progress(&self) -> &Arc<MockProgress> {
930            &self.progress
931        }
932
933        fn received_max(&self) -> usize {
934            self.expected_max.load(Ordering::Relaxed)
935        }
936    }
937
938    impl AsProgress for MockAsProgress {
939        fn as_progress(&self, max: usize) -> Arc<dyn Progress> {
940            self.expected_max.store(max, Ordering::Relaxed);
941            self.progress.clone()
942        }
943    }
944
945    #[test]
946    fn test_build() {
947        // (num_threads, num_data, parallelism, description)
948        let test_cases: &[(usize, usize, Parallelism, &str)] = &[
949            (
950                4,
951                100,
952                Parallelism::dynamic(nz(10), nz(4)),
953                "basic multi-task",
954            ),
955            (1, 50, Parallelism::dynamic(nz(10), nz(1)), "single task"),
956            (4, 0, Parallelism::dynamic(nz(10), nz(4)), "empty data"),
957            (
958                4,
959                5,
960                Parallelism::dynamic(nz(100), nz(4)),
961                "batchsize larger than data",
962            ),
963            (2, 20, Parallelism::dynamic(nz(5), nz(2)), "small dataset"),
964            (
965                8,
966                1000,
967                Parallelism::dynamic(nz(7), nz(8)),
968                "larger dataset with odd batchsize",
969            ),
970            (
971                4,
972                100,
973                Parallelism::dynamic(nz(10), nz(1)),
974                "multiple threads but single task",
975            ),
976            (
977                2,
978                50,
979                Parallelism::sequential(nz(10)),
980                "sequential execution",
981            ),
982            // Fixed parallelism test cases
983            (
984                4,
985                100,
986                Parallelism::fixed(Some(nz(10)), nz(4)),
987                "fixed with batchsize",
988            ),
989            (
990                4,
991                100,
992                Parallelism::fixed(None, nz(4)),
993                "fixed without batchsize (whole partition per task)",
994            ),
995            (
996                2,
997                50,
998                Parallelism::fixed(Some(nz(5)), nz(2)),
999                "fixed with small batchsize",
1000            ),
1001            (
1002                8,
1003                1000,
1004                Parallelism::fixed(Some(nz(100)), nz(8)),
1005                "fixed larger dataset",
1006            ),
1007            (
1008                4,
1009                0,
1010                Parallelism::fixed(Some(nz(10)), nz(4)),
1011                "fixed empty data",
1012            ),
1013            (
1014                4,
1015                5,
1016                Parallelism::fixed(Some(nz(100)), nz(4)),
1017                "fixed batchsize larger than partition",
1018            ),
1019            (
1020                1,
1021                50,
1022                Parallelism::fixed(Some(nz(10)), nz(1)),
1023                "fixed single task with batchsize",
1024            ),
1025            (
1026                1,
1027                50,
1028                Parallelism::fixed(None, nz(1)),
1029                "fixed single task without batchsize",
1030            ),
1031            (
1032                4,
1033                7,
1034                Parallelism::fixed(Some(nz(2)), nz(4)),
1035                "fixed uneven partition with batchsize",
1036            ),
1037        ];
1038
1039        for (num_threads, num_data, parallelism, desc) in test_cases {
1040            let num_data = *num_data;
1041            let runtime = crate::tokio::runtime(*num_threads).unwrap();
1042
1043            let (ntasks, expected_batches) = match parallelism {
1044                Parallelism::Dynamic { batchsize, ntasks } => {
1045                    let expected = num_data.div_ceil(batchsize.get());
1046                    (*ntasks, expected)
1047                }
1048                Parallelism::Sequential { batchsize } => {
1049                    let expected = num_data.div_ceil(batchsize.get());
1050                    (nz(1), expected)
1051                }
1052                Parallelism::Fixed { batchsize, ntasks } => {
1053                    // For Fixed, data is first partitioned among tasks, then each partition is batched.
1054                    // We need to calculate how many batches each task produces.
1055                    use diskann::utils::async_tools::PartitionIter;
1056                    let expected: usize = PartitionIter::new(num_data, *ntasks)
1057                        .map(|partition| match batchsize {
1058                            Some(bs) => partition.len().div_ceil(bs.get()),
1059                            None => {
1060                                if partition.is_empty() {
1061                                    0
1062                                } else {
1063                                    1
1064                                }
1065                            }
1066                        })
1067                        .sum();
1068                    (*ntasks, expected)
1069                }
1070            };
1071
1072            let builder = Arc::new(MockBuild::new(num_data));
1073            let mock_as_progress = MockAsProgress::new();
1074
1075            let check_results = |results: BuildResults<Range<usize>>| {
1076                if num_data == 0 {
1077                    assert!(
1078                        results.output().is_empty(),
1079                        "{desc}: no batches for empty data"
1080                    );
1081                    return;
1082                }
1083
1084                // Verify that each BatchResult's output matches its batch range.
1085                for batch_result in results.output() {
1086                    assert_eq!(
1087                        batch_result.output, batch_result.batch,
1088                        "{desc}: output range should match batch range"
1089                    );
1090                    assert!(
1091                        batch_result.taskid < ntasks.get(),
1092                        "{desc}: taskid {} should be less than ntasks {}",
1093                        batch_result.taskid,
1094                        ntasks.get()
1095                    );
1096                }
1097
1098                assert_eq!(
1099                    results.output().len(),
1100                    expected_batches,
1101                    "{desc}: expected {expected_batches} batches, got {}",
1102                    results.output().len()
1103                );
1104
1105                // Verify all data points are covered exactly once.
1106                let mut ranges: Vec<_> = results.output().iter().map(|r| r.batch.clone()).collect();
1107                check_ranges(&mut ranges, num_data);
1108            };
1109
1110            // Tracked build
1111            let results = build_tracked(
1112                builder.clone(),
1113                *parallelism,
1114                &runtime,
1115                Some(&mock_as_progress),
1116            )
1117            .unwrap_or_else(|_| panic!("{desc}: build_tracked should succeed"));
1118
1119            // Verify progress tracking.
1120            assert_eq!(
1121                mock_as_progress.received_max(),
1122                num_data,
1123                "{desc}: as_progress should receive num_data as max"
1124            );
1125            assert_eq!(
1126                mock_as_progress.progress().total_handled(),
1127                num_data,
1128                "{desc}: total progress should equal num_data"
1129            );
1130            assert!(
1131                mock_as_progress.progress().was_finished(),
1132                "{desc}: finish should be called"
1133            );
1134
1135            check_results(results);
1136
1137            // Untracked Build
1138            let results = build(builder, *parallelism, &runtime)
1139                .unwrap_or_else(|_| panic!("{desc}: build should succeed"));
1140            check_results(results);
1141        }
1142    }
1143}