diskann_disk/build/builder/
core.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::mem::{self, size_of};
6
7use diskann::ANNResult;
8use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
9use diskann_providers::{
10    model::{
11        graph::traits::GraphDataType, IndexConfiguration, GRAPH_SLACK_FACTOR,
12        MAX_PQ_TRAINING_SET_SIZE,
13    },
14    storage::PQStorage,
15    utils::{
16        load_bin, load_metadata_from_file, RayonThreadPool, SampleVectorReader, SamplingDensity,
17        READ_WRITE_BLOCK_SIZE,
18    },
19};
20use rand::{seq::SliceRandom, Rng};
21use tracing::info;
22
23use crate::{
24    build::chunking::{
25        checkpoint::{
26            CheckpointContext, CheckpointManager, CheckpointManagerExt, Progress, WorkStage,
27        },
28        continuation::ChunkingConfig,
29    },
30    disk_index_build_parameter::BYTES_IN_GB,
31    storage::{CachedReader, CachedWriter, DiskIndexWriter},
32    utils::partition_with_ram_budget,
33    DiskIndexBuildParameters, QuantizationType,
34};
35
36/// Overhead factor for RAM estimation during index build (10% buffer).
37const OVERHEAD_FACTOR: f64 = 1.1f64;
38
39/// Estimate RAM usage in bytes for building an index.
40#[inline]
41fn estimate_build_index_ram_usage(
42    num_points: u64,
43    dim: u64,
44    datasize: u64,
45    graph_degree: u64,
46    build_quantization_type: &QuantizationType,
47) -> f64 {
48    let graph_size =
49        (num_points * graph_degree * mem::size_of::<u32>() as u64) as f64 * GRAPH_SLACK_FACTOR;
50
51    let single_vec_size = match *build_quantization_type {
52        QuantizationType::FP => dim.next_multiple_of(8u64) * datasize,
53        // We can skip PQ pivots data as it is very small(~3MB) for even large datasets like OAI-3072.
54        QuantizationType::PQ { num_chunks } => num_chunks as u64,
55        // `+ std::mem::size_of::<f32>()` for f32 compensation metadata for the scalar quantizer.
56        QuantizationType::SQ { nbits, .. } => {
57            (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
58        }
59    };
60
61    OVERHEAD_FACTOR * (graph_size + (single_vec_size * num_points) as f64)
62}
63
64/// Core shared functionality between sync and async disk index builders.
65/// Contains only fields and methods that are truly needed by both builder types.
66pub struct DiskIndexBuilderCore<'a, Data, StorageProvider>
67where
68    Data: GraphDataType<VectorIdType = u32>,
69    StorageProvider: StorageReadProvider + StorageWriteProvider,
70{
71    pub index_writer: DiskIndexWriter,
72
73    pub pq_storage: PQStorage,
74
75    pub disk_build_param: DiskIndexBuildParameters,
76
77    pub index_configuration: IndexConfiguration,
78
79    pub chunking_config: ChunkingConfig,
80
81    pub checkpoint_record_manager: Box<dyn CheckpointManager>,
82
83    pub storage_provider: &'a StorageProvider,
84
85    pub _phantom: std::marker::PhantomData<Data>,
86}
87
88impl<'a, Data, StorageProvider> DiskIndexBuilderCore<'a, Data, StorageProvider>
89where
90    Data: GraphDataType<VectorIdType = u32>,
91    StorageProvider: StorageReadProvider + StorageWriteProvider,
92{
93    pub(crate) fn create_disk_layout(&mut self) -> ANNResult<()> {
94        self.checkpoint_record_manager.execute_stage(
95            WorkStage::WriteDiskLayout,
96            WorkStage::End,
97            || {
98                self.index_writer
99                    .create_disk_layout::<Data, StorageProvider>(self.storage_provider)?;
100                Ok(())
101            },
102            || Ok(()),
103        )?;
104
105        self.index_writer
106            .index_build_cleanup(self.storage_provider)?;
107
108        Ok(())
109    }
110
111    pub(crate) fn create_shard_index_config(
112        &self,
113        shard_base_file: &str,
114    ) -> ANNResult<IndexConfiguration> {
115        let base_config = &self.index_configuration;
116        let storage_provider = self.storage_provider;
117
118        let search_list_size = base_config.config.l_build().get();
119        let pruned_degree = base_config.config.pruned_degree().get();
120
121        let low_degree_params = diskann::graph::config::Builder::new(
122            2 * pruned_degree / 3,
123            diskann::graph::config::MaxDegree::default_slack(),
124            search_list_size,
125            base_config.dist_metric.into(),
126        )
127        .build()?;
128
129        let metadata = load_metadata_from_file(storage_provider, shard_base_file)?;
130
131        let mut index_config = base_config.clone();
132        index_config.max_points = metadata.npoints;
133        index_config.config = low_degree_params;
134
135        Ok(index_config)
136    }
137
138    pub(crate) fn retrieve_shard_data_from_ids<T>(
139        &self,
140        dataset_file: &str,
141        shard_ids_file: &str,
142        shard_base_file: &str,
143    ) -> ANNResult<()>
144    where
145        T: Default + bytemuck::Pod,
146    {
147        let storage_provider = self.storage_provider;
148        let (shard_ids, shard_size, _) = load_bin::<u32, StorageProvider::Reader>(
149            &mut storage_provider.open_reader(shard_ids_file)?,
150            0,
151        )?;
152        info!("Loaded {} shard ids from {}", shard_size, shard_ids_file);
153        let max_id = shard_ids.iter().max().copied().unwrap_or(0);
154        let sampling_rate = shard_ids.len() as f64 / (max_id + 1) as f64;
155
156        let mut dataset_reader: SampleVectorReader<T, _> = SampleVectorReader::new(
157            dataset_file,
158            SamplingDensity::from_sample_rate(sampling_rate),
159            storage_provider,
160        )?;
161
162        let (_npts, dim) = dataset_reader.get_dataset_headers();
163
164        let mut shard_base_cached_writer = CachedWriter::<StorageProvider>::new(
165            shard_base_file,
166            READ_WRITE_BLOCK_SIZE,
167            storage_provider.create_for_write(shard_base_file)?,
168        )?;
169
170        let dummy_size: u32 = 0;
171        shard_base_cached_writer.write(&dummy_size.to_le_bytes())?;
172        shard_base_cached_writer.write(&dim.to_le_bytes())?;
173
174        let mut num_written: u32 = 0;
175        dataset_reader.read_vectors(shard_ids.iter().copied(), |vector_t| {
176            // Casting Pod type to bytes always succeeds (u8 has alignment of 1)
177            let vector_bytes: &[u8] = bytemuck::must_cast_slice(vector_t);
178            shard_base_cached_writer.write(vector_bytes)?;
179            num_written += 1;
180            Ok(())
181        })?;
182
183        info!(
184            "Written file: {} with {} points",
185            shard_base_file, num_written
186        );
187
188        shard_base_cached_writer.flush()?;
189        shard_base_cached_writer.reset()?;
190        shard_base_cached_writer.write(&num_written.to_le_bytes())?;
191
192        Ok(())
193    }
194
195    #[allow(clippy::too_many_arguments)]
196    fn merge_shards(
197        &self,
198        merged_index_prefix: &str,
199        num_parts: usize,
200        max_degree: u32,
201        output_vamana: String,
202        rng: &mut impl Rng,
203    ) -> ANNResult<()> {
204        // Read ID maps
205        let mut vamana_names = vec![String::new(); num_parts];
206        let mut id_maps: Vec<Vec<u32>> = vec![Vec::new(); num_parts];
207        for shard in 0..num_parts {
208            vamana_names[shard] = DiskIndexWriter::get_merged_index_subshard_mem_index_file(
209                merged_index_prefix,
210                shard,
211            );
212
213            let id_maps_file =
214                DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard);
215            id_maps[shard] = self.read_idmap(id_maps_file)?;
216        }
217
218        // find max node id
219        let num_nodes: u32 = *id_maps.iter().flatten().max().unwrap_or(&0) + 1;
220        let num_elements: u32 = id_maps.iter().map(|idmap| idmap.len() as u32).sum();
221        info!("# nodes: {}, max degree: {}", num_nodes, max_degree);
222
223        // compute inverse map: node -> shards
224        let mut node_shard: Vec<(u32, u32)> = Vec::with_capacity(num_elements as usize);
225        for (shard, id_map) in id_maps.iter().enumerate() {
226            info!("Creating inverse map -- shard #{}", shard);
227            node_shard.extend(id_map.iter().map(|node_id| (*node_id, shard as u32)));
228        }
229        node_shard.sort_unstable_by(|left, right| {
230            left.0.cmp(&right.0).then_with(|| left.1.cmp(&right.1))
231        });
232
233        info!("Finished computing node -> shards map");
234
235        // create cached vamana readers
236        let mut vamana_readers = Vec::new();
237        for name in &vamana_names {
238            let reader = CachedReader::<StorageProvider>::new(
239                name,
240                READ_WRITE_BLOCK_SIZE,
241                self.storage_provider,
242            )?;
243            vamana_readers.push(reader);
244        }
245
246        // create cached vamana writers
247        let mut merged_vamana_cached_writer = CachedWriter::<StorageProvider>::new(
248            &output_vamana,
249            READ_WRITE_BLOCK_SIZE,
250            self.storage_provider.create_for_write(&output_vamana)?,
251        )?;
252
253        // expected file size + max degree + medoid_id + frozen_point info
254        let vamana_metadata_size =
255            size_of::<u64>() + size_of::<u32>() + size_of::<u32>() + size_of::<u64>();
256
257        // we initialize the size of the merged index to the metadata size
258        // we will overwrite the index size at the end
259        let mut merged_index_size: u64 = vamana_metadata_size as u64;
260        merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
261
262        let mut read_buf_8_bytes = [0u8; 8];
263
264        // get max input width
265        let mut max_input_width = 0;
266        // read width from each vamana to advance buffer by sizeof(uint32_t) bytes
267        for reader in &mut vamana_readers {
268            reader.read(&mut read_buf_8_bytes)?;
269            let _expected_file_size: u64 = u64::from_le_bytes(read_buf_8_bytes);
270            let input_width = reader.read_u32()?;
271            max_input_width = input_width.max(max_input_width);
272        }
273
274        // write max_degree to merged_vamana_index
275        let output_width: u32 = max_degree;
276        info!(
277            "Max input width: {}, output width: {}",
278            max_input_width, output_width
279        );
280
281        merged_vamana_cached_writer.write(&output_width.to_le_bytes())?;
282
283        // write medoid to merged_vamana_index
284        for shard in 0..num_parts {
285            // read medoid
286            let mut medoid: u32 = vamana_readers[shard].read_u32()?;
287            vamana_readers[shard].read(&mut read_buf_8_bytes)?;
288            let vamana_index_frozen: u64 = u64::from_le_bytes(read_buf_8_bytes);
289            debug_assert_eq!(vamana_index_frozen, 0);
290
291            // rename medoid
292            medoid = id_maps[shard][medoid as usize];
293
294            // write renamed medoid
295            if shard == (num_parts - 1) {
296                // uncomment if running hierarchical
297                merged_vamana_cached_writer.write(&medoid.to_le_bytes())?;
298            }
299        }
300
301        let vamana_index_frozen: u64 = 0; // as of now the functionality to merge many overlapping vamana
302                                          // indices is supported only for bulk indices without frozen point.
303                                          // Hence the final index will also not have any frozen points.
304        merged_vamana_cached_writer.write(&vamana_index_frozen.to_le_bytes())?;
305
306        info!("Starting merge");
307
308        let mut nbr_set = vec![false; num_nodes as usize];
309        let mut final_nbrs: Vec<u32> = Vec::new();
310        let mut cur_id = 0;
311        for pair in &node_shard {
312            let (node_id, shard_id) = *pair;
313            if cur_id < node_id {
314                final_nbrs.shuffle(rng);
315
316                let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
317                merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
318
319                let bytes = final_nbrs
320                    .iter()
321                    .take(nnbrs as usize)
322                    .flat_map(|x| x.to_le_bytes())
323                    .collect::<Vec<u8>>();
324                merged_vamana_cached_writer.write(&bytes)?;
325
326                merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
327                if cur_id % 499999 == 1 {
328                    print!(".");
329                }
330                cur_id = node_id;
331
332                final_nbrs.iter().for_each(|p| nbr_set[*p as usize] = false);
333                final_nbrs.clear();
334            }
335
336            // read num of neighbors from vamana index
337            let num_nbrs = vamana_readers[shard_id as usize].read_u32()?;
338
339            if num_nbrs == 0 {
340                info!(
341                    "WARNING: shard #{}, node_id {} has 0 nbrs",
342                    shard_id, node_id
343                );
344            } else {
345                let mut nbrs_bytes = vec![0u8; num_nbrs as usize * mem::size_of::<u32>()];
346                vamana_readers[shard_id as usize].read(&mut nbrs_bytes)?;
347                let nbrs: &[u32] = bytemuck::cast_slice(&nbrs_bytes);
348
349                // rename nodes
350                for j in 0..num_nbrs {
351                    let nbr = nbrs[j as usize];
352                    let renamed_node = id_maps[shard_id as usize][nbr as usize];
353                    if !nbr_set[renamed_node as usize] {
354                        nbr_set[renamed_node as usize] = true;
355                        final_nbrs.push(renamed_node);
356                    }
357                }
358            }
359        }
360
361        // write the last node, to be refactored...
362        final_nbrs.shuffle(rng);
363
364        let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
365        merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
366
367        let bytes = final_nbrs
368            .iter()
369            .take(nnbrs as usize)
370            .flat_map(|x| x.to_le_bytes())
371            .collect::<Vec<u8>>();
372        merged_vamana_cached_writer.write(&bytes)?;
373
374        merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
375
376        nbr_set.clear();
377        final_nbrs.clear();
378
379        info!("Expected size: {}", merged_index_size);
380        merged_vamana_cached_writer.reset()?;
381        merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
382
383        info!("Finished merge");
384        Ok(())
385    }
386
387    fn read_idmap(&self, idmaps_path: String) -> std::io::Result<Vec<u32>> {
388        let (data, _npts, _dim) = load_bin::<u32, StorageProvider::Reader>(
389            &mut self.storage_provider.open_reader(&idmaps_path)?,
390            0,
391        )?;
392        Ok(data)
393    }
394
395    fn merge_shards_and_cleanup(
396        &self,
397        merged_index_prefix: &str,
398        num_parts: usize,
399        max_degree: u32,
400        rng: &mut impl Rng,
401    ) -> ANNResult<()> {
402        // merge all in-memory indices into one
403        self.merge_shards(
404            merged_index_prefix,
405            num_parts,
406            max_degree,
407            self.index_writer.get_mem_index_file(),
408            rng,
409        )?;
410
411        // delete tempFiles
412        for p in 0..num_parts {
413            let shard_base_file =
414                DiskIndexWriter::get_merged_index_subshard_data_file(merged_index_prefix, p);
415            let shard_ids_file =
416                DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, p);
417            let shard_index_file =
418                DiskIndexWriter::get_merged_index_subshard_mem_index_file(merged_index_prefix, p);
419            let shard_index_file_data =
420                DiskIndexWriter::get_merged_index_subshard_mem_dataset_file(&shard_index_file);
421
422            self.storage_provider.delete(&shard_base_file)?;
423            self.storage_provider.delete(&shard_ids_file)?;
424            self.storage_provider.delete(&shard_index_file)?;
425            // Check if shard dataset file exists before deleting it.
426            // Async build path doesn't always create this file.
427            if self.storage_provider.exists(&shard_index_file_data) {
428                self.storage_provider.delete(&shard_index_file_data)?;
429            }
430        }
431
432        Ok(())
433    }
434}
435
436pub(crate) enum IndexBuildStrategy {
437    OneShot,
438    Merged,
439}
440
441pub(crate) fn determine_build_strategy<Data: GraphDataType>(
442    index_configuration: &IndexConfiguration,
443    index_build_ram_limit_in_bytes: f64,
444    build_quantization_type: &QuantizationType,
445) -> IndexBuildStrategy {
446    let estimated_index_ram_in_bytes = estimate_build_index_ram_usage(
447        index_configuration.max_points as u64,
448        index_configuration.dim as u64,
449        mem::size_of::<Data::VectorDataType>() as u64,
450        index_configuration.config.max_degree().get() as u64,
451        build_quantization_type,
452    );
453
454    info!(
455        "Estimated index RAM usage: {} GB, index_build_ram_limit={} GB",
456        estimated_index_ram_in_bytes / BYTES_IN_GB,
457        index_build_ram_limit_in_bytes / BYTES_IN_GB
458    );
459
460    if estimated_index_ram_in_bytes >= index_build_ram_limit_in_bytes {
461        info!(
462            "Insufficient memory budget for index build in one shot, index_build_ram_limit={} GB estimated_index_ram={} GB",
463            index_build_ram_limit_in_bytes / BYTES_IN_GB,
464            estimated_index_ram_in_bytes / BYTES_IN_GB,
465        );
466        IndexBuildStrategy::Merged
467    } else {
468        info!(
469            "Full index fits in RAM budget, should consume at most {} GBs, so building in one shot",
470            estimated_index_ram_in_bytes / BYTES_IN_GB
471        );
472        IndexBuildStrategy::OneShot
473    }
474}
475
476pub(crate) struct MergedVamanaIndexWorkflow<'a> {
477    pool: &'a RayonThreadPool,
478    rng: diskann_providers::utils::StandardRng,
479    dataset_file: String,
480    max_degree: u32,
481    pub merged_index_prefix: String,
482}
483
484impl<'a> MergedVamanaIndexWorkflow<'a> {
485    pub(crate) fn new<Data, StorageProvider>(
486        builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
487        pool: &'a RayonThreadPool,
488    ) -> Self
489    where
490        Data: GraphDataType<VectorIdType = u32>,
491        StorageProvider: StorageReadProvider + StorageWriteProvider,
492    {
493        let rng = diskann_providers::utils::create_rnd_from_optional_seed(
494            builder.index_configuration.random_seed,
495        );
496        let dataset_file = builder.index_writer.get_dataset_file();
497        let merged_index_prefix = builder.index_writer.get_merged_index_prefix();
498        let max_degree = builder.index_configuration.config.pruned_degree_u32().get();
499
500        Self {
501            pool,
502            rng,
503            dataset_file,
504            merged_index_prefix,
505            max_degree,
506        }
507    }
508
509    pub(crate) fn partition_data<Data, StorageProvider>(
510        &mut self,
511        builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
512    ) -> ANNResult<usize>
513    where
514        Data: GraphDataType<VectorIdType = u32>,
515        StorageProvider: StorageReadProvider + StorageWriteProvider,
516    {
517        // Advance to PartitionData stage if current stage is InMemIndexBuild
518        builder.checkpoint_record_manager.execute_stage(
519            WorkStage::InMemIndexBuild,
520            WorkStage::PartitionData,
521            || Ok(()),
522            || Ok(()),
523        )?;
524
525        // Partition data stage
526        builder.checkpoint_record_manager.execute_stage(
527            WorkStage::PartitionData,
528            WorkStage::BuildIndicesOnShards(0),
529            || {
530                let num_points = builder.index_configuration.max_points;
531                let sampling_rate = MAX_PQ_TRAINING_SET_SIZE / num_points as f64;
532
533                let ram_budget_in_bytes =
534                    builder.disk_build_param.build_memory_limit().in_bytes() as f64;
535                // calculate how many partitions we need, in order to fit in RAM budget
536                // save id_map for each partition to disk
537                partition_with_ram_budget::<Data::VectorDataType, _, _, _>(
538                    &self.dataset_file,
539                    builder.index_configuration.dim,
540                    sampling_rate,
541                    ram_budget_in_bytes,
542                    2, // k_base
543                    &self.merged_index_prefix,
544                    builder.storage_provider,
545                    &mut self.rng,
546                    self.pool,
547                    |num_points, dim| {
548                        let datasize = std::mem::size_of::<Data::VectorDataType>() as u64;
549                        let graph_degree = 2 * self.max_degree / 3;
550                        estimate_build_index_ram_usage(
551                            num_points,
552                            dim,
553                            datasize,
554                            graph_degree as u64,
555                            builder.disk_build_param.build_quantization(),
556                        )
557                    },
558                )
559            },
560            || {
561                // load num_parts based on file names
562                let mut p = 0;
563                while builder.storage_provider.exists(
564                    &DiskIndexWriter::get_merged_index_subshard_id_map_file(
565                        &self.merged_index_prefix,
566                        p,
567                    ),
568                ) {
569                    p += 1;
570                }
571                info!("Found {} existing partitions from previous run", p);
572                Ok(p)
573            },
574        )
575    }
576
577    pub(crate) fn merge_and_cleanup<Data, StorageProvider>(
578        &mut self,
579        builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
580        num_parts: usize,
581    ) -> ANNResult<()>
582    where
583        Data: GraphDataType<VectorIdType = u32>,
584        StorageProvider: StorageReadProvider + StorageWriteProvider,
585    {
586        if builder
587            .checkpoint_record_manager
588            .get_resumption_point(WorkStage::MergeIndices)?
589            .is_some()
590        {
591            builder.merge_shards_and_cleanup(
592                &self.merged_index_prefix,
593                num_parts,
594                self.max_degree,
595                &mut self.rng,
596            )?;
597            builder
598                .checkpoint_record_manager
599                .update(Progress::Completed, WorkStage::WriteDiskLayout)?;
600        }
601
602        Ok(())
603    }
604
605    pub(crate) fn get_shard_context<'b, Data, StorageProvider>(
606        &self,
607        builder: &'b DiskIndexBuilderCore<'_, Data, StorageProvider>,
608        p: usize,
609        num_parts: usize,
610    ) -> CheckpointContext<'b>
611    where
612        Data: GraphDataType<VectorIdType = u32>,
613        StorageProvider: StorageReadProvider + StorageWriteProvider,
614    {
615        let current_stage = WorkStage::BuildIndicesOnShards(p);
616        let next_stage = if p == num_parts - 1 {
617            // If this is the last shard, next stage is MergeIndices
618            WorkStage::MergeIndices
619        } else {
620            // Otherwise, continue with the next shard
621            WorkStage::BuildIndicesOnShards(p + 1)
622        };
623        CheckpointContext::new(
624            builder.checkpoint_record_manager.as_ref(),
625            current_stage,
626            next_stage,
627        )
628    }
629}
630
631#[cfg(test)]
632pub(crate) mod disk_index_builder_tests {
633    use std::{io::Read, sync::Arc};
634
635    use diskann::{
636        graph::config,
637        utils::{IntoUsize, VectorRepr, ONE},
638        ANNResult,
639    };
640    use diskann_providers::storage::VirtualStorageProvider;
641    use diskann_providers::{
642        common::AlignedBoxWithSlice,
643        storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file},
644        test_utils::graph_data_type_utils::{
645            GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
646        },
647        utils::{file_util, BridgeErr, Timer},
648    };
649    use diskann_utils::test_data_root;
650    use diskann_vector::{
651        distance::Metric::{self, L2},
652        DistanceFunction,
653    };
654    use rstest::rstest;
655    use vfs::OverlayFS;
656
657    use super::*;
658    use crate::{
659        build::builder::build::DiskIndexBuilder,
660        data_model::{CachingStrategy, GraphHeader},
661        disk_index_build_parameter::{DiskIndexBuildParameters, MemoryBudget, NumPQChunks},
662        search::provider::{
663            disk_provider::DiskIndexSearcher,
664            disk_vertex_provider_factory::DiskVertexProviderFactory,
665        },
666        storage::disk_index_reader::DiskIndexReader,
667        utils::{QueryStatistics, VirtualAlignedReaderFactory},
668    };
669    const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
670    pub const TEST_DATA_FILE: &str = "/sift/siftsmall_learn_256pts.fbin";
671    /// We can use the same index prefix for all tests since we use virtual storage provider
672    const INDEX_PATH_PREFIX: &str = "/disk_index_build/sift_learn_test_disk_index_build";
673    const TRUTH_INDEX_PATH_PREFIX_R4_L50: &str = "/disk_index_build/truth_sift_learn_R4_L50";
674
675    pub struct CheckpointParams {
676        pub chunking_config: ChunkingConfig,
677        pub checkpoint_record_manager: Box<dyn CheckpointManager>,
678    }
679
680    pub struct TestParams {
681        pub dim: usize,
682        pub full_dim: usize,
683        pub max_degree: u32,
684        pub num_pq_chunks: usize,
685        pub build_quantization_type: QuantizationType,
686        pub l_build: u32,
687        pub data_path: String,
688        pub index_path_prefix: String,
689        pub associated_data_path: Option<String>,
690        pub index_build_ram_gb: f64,
691        pub checkpoint_params: Option<CheckpointParams>,
692        pub num_threads: usize,
693        pub metric: Metric,
694    }
695
696    impl Default for TestParams {
697        fn default() -> Self {
698            Self {
699                dim: 128, // D
700                full_dim: 128,
701                max_degree: 4, // R
702                num_pq_chunks: 128,
703                build_quantization_type: QuantizationType::FP, // No quantization, i.e. QuantizationType::FP
704                l_build: 50,
705                data_path: TEST_DATA_FILE.to_string(),
706                index_path_prefix: INDEX_PATH_PREFIX.to_string(),
707                associated_data_path: None,
708                index_build_ram_gb: 1.0,
709                checkpoint_params: None,
710                num_threads: 1,
711                metric: L2,
712            }
713        }
714    }
715
716    impl TestParams {
717        /// Returns the appropriate truth index path prefix for build comparison.
718        fn truth_index_path_prefix(&self) -> &str {
719            match (self.max_degree, self.l_build, self.index_build_ram_gb) {
720                (4, 50, 1.0) => TRUTH_INDEX_PATH_PREFIX_R4_L50,
721                (max_degree, l_build, index_build_ram_gb) => panic!(
722                    "Truth index path not found for max_degree={}, l_build={}, index_build_ram_gb={}",
723                    max_degree, l_build, index_build_ram_gb
724                ),
725            }
726        }
727        pub fn truth_pq_compressed_path(&self) -> String {
728            let prefix = match self.num_pq_chunks {
729                128 => TRUTH_INDEX_PATH_PREFIX_R4_L50,
730                num_pq_chunks => panic!(
731                    "Truth pq compressed path not found for num_pq_chunks={}",
732                    num_pq_chunks,
733                ),
734            };
735            get_compressed_pq_file(prefix)
736        }
737
738        pub fn pq_compressed_path(&self) -> String {
739            get_compressed_pq_file(&self.index_path_prefix)
740        }
741    }
742
743    pub fn new_vfs() -> VirtualStorageProvider<OverlayFS> {
744        VirtualStorageProvider::new_overlay(test_data_root())
745    }
746
747    pub struct IndexBuildFixture<StorageProvider: StorageReadProvider + StorageWriteProvider> {
748        pub storage_provider: Arc<StorageProvider>,
749        pub params: TestParams,
750    }
751
752    impl<StorageProvider: StorageReadProvider + StorageWriteProvider + 'static>
753        IndexBuildFixture<StorageProvider>
754    {
755        pub fn new(storage_provider: StorageProvider, params: TestParams) -> ANNResult<Self> {
756            Ok(Self {
757                storage_provider: Arc::new(storage_provider),
758                params,
759            })
760        }
761
762        pub fn build<T>(&self) -> ANNResult<()>
763        where
764            T: GraphDataType<VectorIdType = u32>,
765            StorageProvider::Reader: std::marker::Send + Read,
766        {
767            // Create disk index build parameters
768            let disk_index_build_parameters = DiskIndexBuildParameters::new(
769                MemoryBudget::try_from_gb(self.params.index_build_ram_gb)?,
770                self.params.build_quantization_type,
771                NumPQChunks::new_with(self.params.num_pq_chunks, self.params.full_dim)?,
772            );
773
774            let config = config::Builder::new_with(
775                self.params.max_degree.into_usize(),
776                config::MaxDegree::default_slack(),
777                self.params.l_build.into_usize(),
778                self.params.metric.into(),
779                |b| {
780                    b.saturate_after_prune(true);
781                },
782            )
783            .build()?;
784
785            let metadata =
786                load_metadata_from_file(self.storage_provider.as_ref(), &self.params.data_path)
787                    .unwrap();
788
789            assert_eq!(
790                self.params.dim, metadata.ndims,
791                "Parameters dimension {} and data dimension {} are not equal",
792                self.params.dim, metadata.ndims
793            );
794
795            let config = IndexConfiguration::new(
796                self.params.metric,
797                self.params.dim,
798                metadata.npoints,
799                ONE,
800                self.params.num_threads,
801                config,
802            )
803            .with_pseudo_rng_from_seed(100);
804
805            let disk_index_writer = DiskIndexWriter::new(
806                self.params.data_path.clone(),
807                self.params.index_path_prefix.clone(),
808                self.params.associated_data_path.clone(),
809                DEFAULT_DISK_SECTOR_LEN,
810            )?;
811
812            let mut disk_index = match self.params.checkpoint_params {
813                Some(ref checkpoint_params) => {
814                    let checkpoint_record_manager =
815                        checkpoint_params.checkpoint_record_manager.clone_box();
816                    let chunking_config = checkpoint_params.chunking_config.clone();
817                    DiskIndexBuilder::<T, _>::new_with_chunking_config(
818                        self.storage_provider.as_ref(),
819                        disk_index_build_parameters,
820                        config,
821                        disk_index_writer,
822                        chunking_config,
823                        checkpoint_record_manager,
824                    )
825                }
826                None => DiskIndexBuilder::<T, _>::new(
827                    self.storage_provider.as_ref(),
828                    disk_index_build_parameters,
829                    config,
830                    disk_index_writer,
831                ),
832            }?;
833
834            let timer = Timer::new();
835            disk_index.build()?;
836            println!("Indexing time: {} seconds", timer.elapsed().as_secs_f64());
837
838            Ok(())
839        }
840
841        pub fn compare_pq_compressed_files(&self) {
842            self.compare_files(
843                &self.params.pq_compressed_path(),
844                &self.params.truth_pq_compressed_path(),
845            );
846        }
847
848        pub fn assert_index_max_degree<T: GraphDataType>(&self) -> ANNResult<()> {
849            let index_file_path = get_disk_index_file(&self.params.index_path_prefix);
850            let file_data = load_file_to_vec(self.storage_provider.as_ref(), &index_file_path);
851            let graph_header = GraphHeader::try_from(&file_data[8..])?;
852            let max_degree = graph_header.max_degree::<T::VectorDataType>()?;
853            assert_eq!(
854                max_degree, self.params.max_degree as usize,
855                "Max degree mismatch: expected {}, got {}",
856                self.params.max_degree, max_degree
857            );
858
859            Ok(())
860        }
861
862        fn compare_disk_index_with_associated_data(
863            &self,
864            pivot_file_prefix_test: &str,
865            pivot_file_prefix_expected: &str,
866            index_file_suffix: &str,
867        ) {
868            let pq_pivot_path = pivot_file_prefix_test.to_string() + index_file_suffix;
869            let pq_pivot_path_truth = pivot_file_prefix_expected.to_string() + index_file_suffix;
870            let file1 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path);
871            let file2 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path_truth);
872            compare_disk_index_graphs(&file1, &file2)
873        }
874
875        pub fn compare_files(&self, file_path1: &str, file_path2: &str) {
876            let file1 = load_file_to_vec(self.storage_provider.as_ref(), file_path1);
877            let file2 = load_file_to_vec(self.storage_provider.as_ref(), file_path2);
878
879            assert_eq!(file1.len(), file2.len());
880            assert_eq!(file1, file2)
881        }
882    }
883
884    /// Common helper function for one-shot async index build tests
885    fn run_one_shot_test<F>(index_path_prefix: String, params_customizer: F)
886    where
887        F: FnOnce(TestParams) -> TestParams,
888    {
889        let l_build = 64;
890        let max_degree = 16;
891        let top_k = 10;
892        let search_l = 32;
893
894        let base_params = TestParams {
895            l_build,
896            max_degree,
897            index_path_prefix,
898            ..TestParams::default()
899        };
900
901        let params = params_customizer(base_params);
902
903        let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
904        fixture.build::<GraphDataF32VectorUnitData>().unwrap();
905
906        // Validate search recall against ground truth for async tests
907        verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
908            &fixture.params,
909            top_k,
910            search_l,
911            &fixture.storage_provider,
912        )
913        .unwrap();
914
915        fixture
916            .assert_index_max_degree::<GraphDataF32VectorUnitData>()
917            .unwrap();
918
919        // Assert that all data was kept in memory and no files were written to the disk.
920        let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
921        assert!(!fixture.storage_provider.exists(&mem_index_file_path));
922    }
923
924    #[rstest]
925    fn test_build_from_iter_one_shot_with_metric(
926        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
927    ) {
928        let index_path_prefix = format!("{}_metric_{:?}", INDEX_PATH_PREFIX, metric);
929
930        run_one_shot_test(index_path_prefix, |params| TestParams { metric, ..params });
931    }
932
933    #[test]
934    fn test_build_from_iter_one_shot_with_associated_data() {
935        // Set up test data
936        let params = TestParams {
937            associated_data_path: Some(
938                "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
939            ),
940            ..TestParams::default()
941        };
942
943        // Create fixture with virtual storage provider
944        let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
945
946        // Build the index with the associated data
947        fixture.build::<GraphDataF32VectorU32Data>().unwrap();
948
949        // Assert that all data was kept in memory and no files were written to the disk.
950        let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
951        let mem_index_associated_data_path = format!(
952            "{}_mem.index.associated_data",
953            fixture.params.index_path_prefix
954        );
955        assert!(!fixture.storage_provider.exists(&mem_index_file_path));
956        assert!(!fixture
957            .storage_provider
958            .exists(&mem_index_associated_data_path));
959
960        // assert index files are expected.
961        fixture.compare_disk_index_with_associated_data(
962            &fixture.params.index_path_prefix,
963            fixture.params.truth_index_path_prefix(),
964            "_disk.index",
965        );
966    }
967
968    #[test]
969    fn test_build_from_iter_merged_index() {
970        // Use the same parameters from [test_sift_build_and_search] in diskann_index
971        let l_build = 64;
972        let max_degree = 16;
973        let top_k = 10;
974        let search_l = 32;
975
976        let index_path_prefix =
977            "/disk_index_build/disk_index_sift_learn_test_disk_index_build_merged".to_string();
978        let params = TestParams {
979            l_build,
980            max_degree,
981            index_path_prefix,
982            index_build_ram_gb: 0.0001, // small enough to trigger merged index build
983            ..TestParams::default()
984        };
985
986        let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
987
988        fixture.build::<GraphDataF32VectorUnitData>().unwrap();
989
990        verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
991            &fixture.params,
992            top_k,
993            search_l,
994            &fixture.storage_provider,
995        )
996        .unwrap();
997
998        fixture
999            .assert_index_max_degree::<GraphDataF32VectorUnitData>()
1000            .unwrap();
1001    }
1002
1003    #[rstest]
1004    #[case(QuantizationType::SQ { nbits: 2, standard_deviation: None }, "SQ quantization is only supported for 1 bit")]
1005    fn test_build_quantization_type_failure_cases(
1006        #[case] build_quantization_type: QuantizationType,
1007        #[case] error_message: &str,
1008    ) {
1009        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1010        let disk_index_builder = create_disk_index_builder(
1011            1000, // num_points
1012            128,  // dim
1013            128,  // num_pq_chunks
1014            &storage_provider,
1015            build_quantization_type,
1016        );
1017
1018        let err = disk_index_builder.err().unwrap();
1019        assert!(err.to_string().contains(error_message));
1020    }
1021
1022    fn load_file_to_vec<StorageType: StorageReadProvider>(
1023        storage_provider: &StorageType,
1024        file_path: &str,
1025    ) -> Vec<u8> {
1026        let mut file = storage_provider.open_reader(file_path).unwrap();
1027        let mut buffer = vec![];
1028        file.read_to_end(&mut buffer).unwrap();
1029        buffer
1030    }
1031
1032    /// Verifies that search results exactly match the ground truth of nearest neighbors
1033    ///
1034    /// This function performs validation of search results by:
1035    /// 1. Running searches on the index using actual data points from the dataset as queries
1036    /// 2. Computing the exact ground truth results using direct distance calculations
1037    /// 3. Verifying that the search engine returns precisely the same results as the ground truth
1038    pub(crate) fn verify_search_result_with_ground_truth<
1039        G: GraphDataType<VectorIdType = u32, AssociatedDataType = ()>,
1040    >(
1041        params: &TestParams,
1042        top_k: usize,
1043        search_l: u32,
1044        storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1045    ) -> ANNResult<()> {
1046        let pq_pivot_path = get_pq_pivot_file(&params.index_path_prefix);
1047        let pq_compressed_path = get_compressed_pq_file(&params.index_path_prefix);
1048        let index_file_path = get_disk_index_file(&params.index_path_prefix);
1049
1050        let index_reader = DiskIndexReader::<G::VectorDataType>::new(
1051            pq_pivot_path,
1052            pq_compressed_path,
1053            storage_provider.as_ref(),
1054        )?;
1055
1056        let vertex_provider_factory = DiskVertexProviderFactory::new(
1057            VirtualAlignedReaderFactory::new(index_file_path, Arc::clone(storage_provider)),
1058            CachingStrategy::None,
1059        )?;
1060
1061        let search_engine = DiskIndexSearcher::<G, DiskVertexProviderFactory<G, _>>::new(
1062            1,
1063            u32::MAX as usize,
1064            &index_reader,
1065            vertex_provider_factory,
1066            params.metric,
1067            None,
1068        )?;
1069
1070        let (data, npoints, dim) = file_util::load_bin::<G::VectorDataType, _>(
1071            storage_provider.as_ref(),
1072            &params.data_path,
1073            0,
1074        )?;
1075        let data =
1076            diskann_utils::views::Matrix::try_from(data.into(), npoints, dim).bridge_err()?;
1077        let distance = <G::VectorDataType>::distance(params.metric, Some(dim));
1078
1079        // Here, we use elements of the dataset to search the dataset itself.
1080        //
1081        // We do this for each query, computing the expected ground truth and verifying
1082        // that our simple graph search matches.
1083        //
1084        // Because this dataset is small, we can expect exact equality.
1085        for (q, query_data) in data.row_iter().enumerate() {
1086            let gt =
1087                diskann_providers::test_utils::groundtruth(data.as_view(), query_data, |a, b| {
1088                    distance.evaluate_similarity(a, b)
1089                });
1090
1091            let mut query: AlignedBoxWithSlice<G::VectorDataType> =
1092                AlignedBoxWithSlice::<G::VectorDataType>::new(dim, 8)?;
1093            query.memcpy(query_data)?;
1094
1095            let mut query_stats = QueryStatistics::default();
1096
1097            let mut indices = vec![0u32; top_k];
1098            let mut distances = vec![0f32; top_k];
1099            let mut associated_data = vec![(); top_k];
1100
1101            _ = search_engine.search_internal(
1102                &query,
1103                top_k,
1104                search_l,
1105                None, // beam_width
1106                &mut query_stats,
1107                &mut indices,
1108                &mut distances,
1109                &mut associated_data,
1110                &|_| true,
1111                false,
1112            );
1113
1114            diskann_providers::test_utils::assert_top_k_exactly_match(
1115                q, &gt, &indices, &distances, top_k,
1116            );
1117        }
1118
1119        Ok(())
1120    }
1121
1122    // Compare that the index built in test is the same as the truth index. The truth index doesn't have associated data, we are only comparing the vector and neighbor data.
1123    pub fn compare_disk_index_graphs(graph_data: &[u8], truth_graph_data: &[u8]) {
1124        let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
1125        let truth_graph_header = GraphHeader::try_from(&truth_graph_data[8..]).unwrap();
1126
1127        let test_node_per_block = graph_header.metadata().num_nodes_per_block;
1128        let test_max_node_length = graph_header.metadata().node_len;
1129
1130        let truth_node_per_block = truth_graph_header.metadata().num_nodes_per_block;
1131        let truth_max_node_length = truth_graph_header.metadata().node_len;
1132
1133        assert_eq!(
1134            graph_header.metadata().num_pts,
1135            truth_graph_header.metadata().num_pts
1136        );
1137
1138        assert_eq!(
1139            graph_header.metadata().dims,
1140            truth_graph_header.metadata().dims
1141        );
1142
1143        let num_pts = graph_header.metadata().num_pts as usize;
1144        let dim = graph_header.metadata().dims;
1145
1146        for idx in 0..num_pts {
1147            let test_node_id_offset = node_data_offset(
1148                idx,
1149                test_max_node_length as usize,
1150                test_node_per_block as usize,
1151                DEFAULT_DISK_SECTOR_LEN,
1152            );
1153
1154            let truth_node_id_offset = node_data_offset(
1155                idx,
1156                truth_max_node_length as usize,
1157                truth_node_per_block as usize,
1158                DEFAULT_DISK_SECTOR_LEN,
1159            );
1160
1161            // Assert that the vector data is the same between the test and truth graphs for this node.
1162            assert_eq!(
1163                &graph_data
1164                    [test_node_id_offset..test_node_id_offset + dim * std::mem::size_of::<f32>()],
1165                &truth_graph_data
1166                    [truth_node_id_offset..truth_node_id_offset + dim * std::mem::size_of::<f32>()]
1167            );
1168
1169            // Assert that the neighbor count is the same between the test and truth graphs for this node.
1170            let test_nbr_cnt_offset = test_node_id_offset + dim * std::mem::size_of::<f32>();
1171            let truth_nbr_cnt_offset = truth_node_id_offset + dim * std::mem::size_of::<f32>();
1172
1173            let test_nbr_count = u32::from_le_bytes([
1174                graph_data[test_nbr_cnt_offset],
1175                graph_data[test_nbr_cnt_offset + 1],
1176                graph_data[test_nbr_cnt_offset + 2],
1177                graph_data[test_nbr_cnt_offset + 3],
1178            ]);
1179
1180            let truth_nbr_count = u32::from_le_bytes([
1181                truth_graph_data[truth_nbr_cnt_offset],
1182                truth_graph_data[truth_nbr_cnt_offset + 1],
1183                truth_graph_data[truth_nbr_cnt_offset + 2],
1184                truth_graph_data[truth_nbr_cnt_offset + 3],
1185            ]);
1186
1187            assert_eq!(test_nbr_count, truth_nbr_count);
1188
1189            // Assert the neighbors (u32) are the same between the test and truth graphs for this node.
1190            let test_nbr_offset = test_nbr_cnt_offset + 4;
1191            let truth_nbr_offset = truth_nbr_cnt_offset + 4;
1192            assert_eq!(
1193                graph_data[test_nbr_offset..test_nbr_offset + test_nbr_count as usize * 4],
1194                truth_graph_data[truth_nbr_offset..truth_nbr_offset + truth_nbr_count as usize * 4]
1195            );
1196        }
1197    }
1198
1199    pub fn node_data_offset(
1200        node_id: usize,
1201        node_length: usize,
1202        nodes_per_block: usize,
1203        block_size: usize,
1204    ) -> usize {
1205        let block_id = node_id / nodes_per_block;
1206        let node_id_in_block = node_id % nodes_per_block;
1207        let offset = block_id * block_size + node_id_in_block * node_length;
1208        offset + block_size
1209    }
1210
1211    fn create_disk_index_builder(
1212        num_points: usize,
1213        dim: usize,
1214        num_of_pq_chunks: usize,
1215        storage_provider: &VirtualStorageProvider<OverlayFS>,
1216        build_quantization_type: QuantizationType,
1217    ) -> ANNResult<
1218        DiskIndexBuilder<'_, GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>,
1219    > {
1220        let memory_budget = MemoryBudget::try_from_gb(1.0)?;
1221        let num_pq_chunks = NumPQChunks::new_with(num_of_pq_chunks, dim)?;
1222
1223        let build_parameters =
1224            DiskIndexBuildParameters::new(memory_budget, build_quantization_type, num_pq_chunks);
1225
1226        let index_configuration = IndexConfiguration::new(
1227            L2,
1228            dim,
1229            num_points,
1230            ONE,
1231            1,
1232            config::Builder::new_with(4, config::MaxDegree::default_slack(), 50, L2.into(), |b| {
1233                b.saturate_after_prune(true);
1234            })
1235            .build()?,
1236        );
1237
1238        let disk_index_writer = DiskIndexWriter::new(
1239            "data_path".to_string(),
1240            "index_path_prefix".to_string(),
1241            None,
1242            DEFAULT_DISK_SECTOR_LEN,
1243        )?;
1244
1245        DiskIndexBuilder::<GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>::new(
1246            storage_provider,
1247            build_parameters,
1248            index_configuration,
1249            disk_index_writer,
1250        )
1251    }
1252}
1253
1254#[cfg(test)]
1255mod ram_estimation_tests {
1256    use rstest::rstest;
1257
1258    use super::*;
1259    use crate::QuantizationType;
1260
1261    #[rstest]
1262    #[case(QuantizationType::FP)]
1263    #[case(QuantizationType::PQ { num_chunks: 15 })]
1264    #[case(QuantizationType::SQ { nbits: 1, standard_deviation: None })]
1265    fn test_estimate_build_index_ram_usage(#[case] build_quantization_type: QuantizationType) {
1266        let num_points = 1000;
1267        let dim = 128;
1268        let size_of_t = std::mem::size_of::<f32>() as u64;
1269        let graph_degree = 50;
1270
1271        let single_vec_size = match build_quantization_type {
1272            QuantizationType::FP => dim * size_of_t,
1273            QuantizationType::PQ { num_chunks } => num_chunks as u64,
1274            QuantizationType::SQ { nbits, .. } => {
1275                (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
1276            }
1277        };
1278        let mut expected_ram_usage = (num_points as f64)
1279            * (graph_degree as f64)
1280            * (std::mem::size_of::<u32>() as f64)
1281            * GRAPH_SLACK_FACTOR
1282            + (num_points * single_vec_size) as f64;
1283        expected_ram_usage *= OVERHEAD_FACTOR;
1284
1285        let actual_ram_usage = estimate_build_index_ram_usage(
1286            num_points,
1287            dim,
1288            size_of_t,
1289            graph_degree,
1290            &build_quantization_type,
1291        );
1292
1293        assert_eq!(actual_ram_usage, expected_ram_usage);
1294    }
1295}