Skip to main content

diskann_disk/build/builder/
core.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::mem::{self, size_of};
6
7use diskann::ANNResult;
8use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
9use diskann_providers::{
10    model::{
11        graph::traits::GraphDataType, IndexConfiguration, GRAPH_SLACK_FACTOR,
12        MAX_PQ_TRAINING_SET_SIZE,
13    },
14    storage::PQStorage,
15    utils::{
16        load_metadata_from_file, RayonThreadPool, SampleVectorReader, SamplingDensity,
17        READ_WRITE_BLOCK_SIZE,
18    },
19};
20use diskann_utils::io::read_bin;
21use rand::{seq::SliceRandom, Rng};
22use tracing::info;
23
24use crate::{
25    build::chunking::{
26        checkpoint::{
27            CheckpointContext, CheckpointManager, CheckpointManagerExt, Progress, WorkStage,
28        },
29        continuation::ChunkingConfig,
30    },
31    disk_index_build_parameter::BYTES_IN_GB,
32    storage::{CachedReader, CachedWriter, DiskIndexWriter},
33    utils::partition_with_ram_budget,
34    DiskIndexBuildParameters, QuantizationType,
35};
36
37/// Overhead factor for RAM estimation during index build (10% buffer).
38const OVERHEAD_FACTOR: f64 = 1.1f64;
39
40/// Estimate RAM usage in bytes for building an index.
41#[inline]
42fn estimate_build_index_ram_usage(
43    num_points: u64,
44    dim: u64,
45    datasize: u64,
46    graph_degree: u64,
47    build_quantization_type: &QuantizationType,
48) -> f64 {
49    let graph_size =
50        (num_points * graph_degree * mem::size_of::<u32>() as u64) as f64 * GRAPH_SLACK_FACTOR;
51
52    let single_vec_size = match *build_quantization_type {
53        QuantizationType::FP => dim.next_multiple_of(8u64) * datasize,
54        // We can skip PQ pivots data as it is very small(~3MB) for even large datasets like OAI-3072.
55        QuantizationType::PQ { num_chunks } => num_chunks as u64,
56        // `+ std::mem::size_of::<f32>()` for f32 compensation metadata for the scalar quantizer.
57        QuantizationType::SQ { nbits, .. } => {
58            (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
59        }
60    };
61
62    OVERHEAD_FACTOR * (graph_size + (single_vec_size * num_points) as f64)
63}
64
65/// Core shared functionality between sync and async disk index builders.
66/// Contains only fields and methods that are truly needed by both builder types.
67pub struct DiskIndexBuilderCore<'a, Data, StorageProvider>
68where
69    Data: GraphDataType<VectorIdType = u32>,
70    StorageProvider: StorageReadProvider + StorageWriteProvider,
71{
72    pub index_writer: DiskIndexWriter,
73
74    pub pq_storage: PQStorage,
75
76    pub disk_build_param: DiskIndexBuildParameters,
77
78    pub index_configuration: IndexConfiguration,
79
80    pub chunking_config: ChunkingConfig,
81
82    pub checkpoint_record_manager: Box<dyn CheckpointManager>,
83
84    pub storage_provider: &'a StorageProvider,
85
86    pub _phantom: std::marker::PhantomData<Data>,
87}
88
89impl<'a, Data, StorageProvider> DiskIndexBuilderCore<'a, Data, StorageProvider>
90where
91    Data: GraphDataType<VectorIdType = u32>,
92    StorageProvider: StorageReadProvider + StorageWriteProvider,
93{
94    pub(crate) fn create_disk_layout(&mut self) -> ANNResult<()> {
95        self.checkpoint_record_manager.execute_stage(
96            WorkStage::WriteDiskLayout,
97            WorkStage::End,
98            || {
99                self.index_writer
100                    .create_disk_layout::<Data, StorageProvider>(self.storage_provider)?;
101                Ok(())
102            },
103            || Ok(()),
104        )?;
105
106        self.index_writer
107            .index_build_cleanup(self.storage_provider)?;
108
109        Ok(())
110    }
111
112    pub(crate) fn create_shard_index_config(
113        &self,
114        shard_base_file: &str,
115    ) -> ANNResult<IndexConfiguration> {
116        let base_config = &self.index_configuration;
117        let storage_provider = self.storage_provider;
118
119        let search_list_size = base_config.config.l_build().get();
120        let pruned_degree = base_config.config.pruned_degree().get();
121
122        let low_degree_params = diskann::graph::config::Builder::new(
123            2 * pruned_degree / 3,
124            diskann::graph::config::MaxDegree::default_slack(),
125            search_list_size,
126            base_config.dist_metric.into(),
127        )
128        .build()?;
129
130        let metadata = load_metadata_from_file(storage_provider, shard_base_file)?;
131
132        let mut index_config = base_config.clone();
133        index_config.max_points = metadata.npoints();
134        index_config.config = low_degree_params;
135
136        Ok(index_config)
137    }
138
139    pub(crate) fn retrieve_shard_data_from_ids<T>(
140        &self,
141        dataset_file: &str,
142        shard_ids_file: &str,
143        shard_base_file: &str,
144    ) -> ANNResult<()>
145    where
146        T: Default + bytemuck::Pod,
147    {
148        let storage_provider = self.storage_provider;
149        let shard_ids = read_bin::<u32>(&mut storage_provider.open_reader(shard_ids_file)?)?;
150        let shard_size = shard_ids.nrows();
151        info!("Loaded {} shard ids from {}", shard_size, shard_ids_file);
152        let max_id = shard_ids.as_slice().iter().max().copied().unwrap_or(0);
153        let sampling_rate = shard_ids.as_slice().len() as f64 / (max_id + 1) as f64;
154
155        let mut dataset_reader: SampleVectorReader<T, _> = SampleVectorReader::new(
156            dataset_file,
157            SamplingDensity::from_sample_rate(sampling_rate),
158            storage_provider,
159        )?;
160
161        let (_npts, dim) = dataset_reader.get_dataset_headers();
162
163        let mut shard_base_cached_writer = CachedWriter::<StorageProvider>::new(
164            shard_base_file,
165            READ_WRITE_BLOCK_SIZE,
166            storage_provider.create_for_write(shard_base_file)?,
167        )?;
168
169        let dummy_size: u32 = 0;
170        shard_base_cached_writer.write(&dummy_size.to_le_bytes())?;
171        shard_base_cached_writer.write(&dim.to_le_bytes())?;
172
173        let mut num_written: u32 = 0;
174        dataset_reader.read_vectors(shard_ids.as_slice().iter().copied(), |vector_t| {
175            // Casting Pod type to bytes always succeeds (u8 has alignment of 1)
176            let vector_bytes: &[u8] = bytemuck::must_cast_slice(vector_t);
177            shard_base_cached_writer.write(vector_bytes)?;
178            num_written += 1;
179            Ok(())
180        })?;
181
182        info!(
183            "Written file: {} with {} points",
184            shard_base_file, num_written
185        );
186
187        shard_base_cached_writer.flush()?;
188        shard_base_cached_writer.reset()?;
189        shard_base_cached_writer.write(&num_written.to_le_bytes())?;
190
191        Ok(())
192    }
193
194    #[allow(clippy::too_many_arguments)]
195    fn merge_shards(
196        &self,
197        merged_index_prefix: &str,
198        num_parts: usize,
199        max_degree: u32,
200        output_vamana: String,
201        rng: &mut impl Rng,
202    ) -> ANNResult<()> {
203        // Read ID maps
204        let mut vamana_names = vec![String::new(); num_parts];
205        let mut id_maps: Vec<Vec<u32>> = vec![Vec::new(); num_parts];
206        for shard in 0..num_parts {
207            vamana_names[shard] = DiskIndexWriter::get_merged_index_subshard_mem_index_file(
208                merged_index_prefix,
209                shard,
210            );
211
212            let id_maps_file =
213                DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard);
214            id_maps[shard] = self.read_idmap(id_maps_file)?;
215        }
216
217        // find max node id
218        let num_nodes: u32 = *id_maps.iter().flatten().max().unwrap_or(&0) + 1;
219        let num_elements: u32 = id_maps.iter().map(|idmap| idmap.len() as u32).sum();
220        info!("# nodes: {}, max degree: {}", num_nodes, max_degree);
221
222        // compute inverse map: node -> shards
223        let mut node_shard: Vec<(u32, u32)> = Vec::with_capacity(num_elements as usize);
224        for (shard, id_map) in id_maps.iter().enumerate() {
225            info!("Creating inverse map -- shard #{}", shard);
226            node_shard.extend(id_map.iter().map(|node_id| (*node_id, shard as u32)));
227        }
228        node_shard.sort_unstable_by(|left, right| {
229            left.0.cmp(&right.0).then_with(|| left.1.cmp(&right.1))
230        });
231
232        info!("Finished computing node -> shards map");
233
234        // create cached vamana readers
235        let mut vamana_readers = Vec::new();
236        for name in &vamana_names {
237            let reader = CachedReader::<StorageProvider>::new(
238                name,
239                READ_WRITE_BLOCK_SIZE,
240                self.storage_provider,
241            )?;
242            vamana_readers.push(reader);
243        }
244
245        // create cached vamana writers
246        let mut merged_vamana_cached_writer = CachedWriter::<StorageProvider>::new(
247            &output_vamana,
248            READ_WRITE_BLOCK_SIZE,
249            self.storage_provider.create_for_write(&output_vamana)?,
250        )?;
251
252        // expected file size + max degree + medoid_id + frozen_point info
253        let vamana_metadata_size =
254            size_of::<u64>() + size_of::<u32>() + size_of::<u32>() + size_of::<u64>();
255
256        // we initialize the size of the merged index to the metadata size
257        // we will overwrite the index size at the end
258        let mut merged_index_size: u64 = vamana_metadata_size as u64;
259        merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
260
261        let mut read_buf_8_bytes = [0u8; 8];
262
263        // get max input width
264        let mut max_input_width = 0;
265        // read width from each vamana to advance buffer by sizeof(uint32_t) bytes
266        for reader in &mut vamana_readers {
267            reader.read(&mut read_buf_8_bytes)?;
268            let _expected_file_size: u64 = u64::from_le_bytes(read_buf_8_bytes);
269            let input_width = reader.read_u32()?;
270            max_input_width = input_width.max(max_input_width);
271        }
272
273        // write max_degree to merged_vamana_index
274        let output_width: u32 = max_degree;
275        info!(
276            "Max input width: {}, output width: {}",
277            max_input_width, output_width
278        );
279
280        merged_vamana_cached_writer.write(&output_width.to_le_bytes())?;
281
282        // write medoid to merged_vamana_index
283        for shard in 0..num_parts {
284            // read medoid
285            let mut medoid: u32 = vamana_readers[shard].read_u32()?;
286            vamana_readers[shard].read(&mut read_buf_8_bytes)?;
287            let vamana_index_frozen: u64 = u64::from_le_bytes(read_buf_8_bytes);
288            debug_assert_eq!(vamana_index_frozen, 0);
289
290            // rename medoid
291            medoid = id_maps[shard][medoid as usize];
292
293            // write renamed medoid
294            if shard == (num_parts - 1) {
295                // uncomment if running hierarchical
296                merged_vamana_cached_writer.write(&medoid.to_le_bytes())?;
297            }
298        }
299
300        let vamana_index_frozen: u64 = 0; // as of now the functionality to merge many overlapping vamana
301                                          // indices is supported only for bulk indices without frozen point.
302                                          // Hence the final index will also not have any frozen points.
303        merged_vamana_cached_writer.write(&vamana_index_frozen.to_le_bytes())?;
304
305        info!("Starting merge");
306
307        let mut nbr_set = vec![false; num_nodes as usize];
308        let mut final_nbrs: Vec<u32> = Vec::new();
309        let mut cur_id = 0;
310        for pair in &node_shard {
311            let (node_id, shard_id) = *pair;
312            if cur_id < node_id {
313                final_nbrs.shuffle(rng);
314
315                let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
316                merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
317
318                let bytes = final_nbrs
319                    .iter()
320                    .take(nnbrs as usize)
321                    .flat_map(|x| x.to_le_bytes())
322                    .collect::<Vec<u8>>();
323                merged_vamana_cached_writer.write(&bytes)?;
324
325                merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
326                if cur_id % 499999 == 1 {
327                    print!(".");
328                }
329                cur_id = node_id;
330
331                final_nbrs.iter().for_each(|p| nbr_set[*p as usize] = false);
332                final_nbrs.clear();
333            }
334
335            // read num of neighbors from vamana index
336            let num_nbrs = vamana_readers[shard_id as usize].read_u32()?;
337
338            if num_nbrs == 0 {
339                info!(
340                    "WARNING: shard #{}, node_id {} has 0 nbrs",
341                    shard_id, node_id
342                );
343            } else {
344                let mut nbrs_bytes = vec![0u8; num_nbrs as usize * mem::size_of::<u32>()];
345                vamana_readers[shard_id as usize].read(&mut nbrs_bytes)?;
346                let nbrs: &[u32] = bytemuck::cast_slice(&nbrs_bytes);
347
348                // rename nodes
349                for j in 0..num_nbrs {
350                    let nbr = nbrs[j as usize];
351                    let renamed_node = id_maps[shard_id as usize][nbr as usize];
352                    if !nbr_set[renamed_node as usize] {
353                        nbr_set[renamed_node as usize] = true;
354                        final_nbrs.push(renamed_node);
355                    }
356                }
357            }
358        }
359
360        // write the last node, to be refactored...
361        final_nbrs.shuffle(rng);
362
363        let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
364        merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
365
366        let bytes = final_nbrs
367            .iter()
368            .take(nnbrs as usize)
369            .flat_map(|x| x.to_le_bytes())
370            .collect::<Vec<u8>>();
371        merged_vamana_cached_writer.write(&bytes)?;
372
373        merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
374
375        nbr_set.clear();
376        final_nbrs.clear();
377
378        info!("Expected size: {}", merged_index_size);
379        merged_vamana_cached_writer.reset()?;
380        merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
381
382        info!("Finished merge");
383        Ok(())
384    }
385
386    fn read_idmap(&self, idmaps_path: String) -> Result<Vec<u32>, diskann_utils::io::ReadBinError> {
387        let data = read_bin::<u32>(&mut self.storage_provider.open_reader(&idmaps_path)?)?;
388        Ok(data.into_inner().into_vec())
389    }
390
391    fn merge_shards_and_cleanup(
392        &self,
393        merged_index_prefix: &str,
394        num_parts: usize,
395        max_degree: u32,
396        rng: &mut impl Rng,
397    ) -> ANNResult<()> {
398        // merge all in-memory indices into one
399        self.merge_shards(
400            merged_index_prefix,
401            num_parts,
402            max_degree,
403            self.index_writer.get_mem_index_file(),
404            rng,
405        )?;
406
407        // delete tempFiles
408        for p in 0..num_parts {
409            let shard_base_file =
410                DiskIndexWriter::get_merged_index_subshard_data_file(merged_index_prefix, p);
411            let shard_ids_file =
412                DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, p);
413            let shard_index_file =
414                DiskIndexWriter::get_merged_index_subshard_mem_index_file(merged_index_prefix, p);
415            let shard_index_file_data =
416                DiskIndexWriter::get_merged_index_subshard_mem_dataset_file(&shard_index_file);
417
418            self.storage_provider.delete(&shard_base_file)?;
419            self.storage_provider.delete(&shard_ids_file)?;
420            self.storage_provider.delete(&shard_index_file)?;
421            // Check if shard dataset file exists before deleting it.
422            // Async build path doesn't always create this file.
423            if self.storage_provider.exists(&shard_index_file_data) {
424                self.storage_provider.delete(&shard_index_file_data)?;
425            }
426        }
427
428        Ok(())
429    }
430}
431
432pub(crate) enum IndexBuildStrategy {
433    OneShot,
434    Merged,
435}
436
437pub(crate) fn determine_build_strategy<Data: GraphDataType>(
438    index_configuration: &IndexConfiguration,
439    index_build_ram_limit_in_bytes: f64,
440    build_quantization_type: &QuantizationType,
441) -> IndexBuildStrategy {
442    let estimated_index_ram_in_bytes = estimate_build_index_ram_usage(
443        index_configuration.max_points as u64,
444        index_configuration.dim as u64,
445        mem::size_of::<Data::VectorDataType>() as u64,
446        index_configuration.config.max_degree().get() as u64,
447        build_quantization_type,
448    );
449
450    info!(
451        "Estimated index RAM usage: {} GB, index_build_ram_limit={} GB",
452        estimated_index_ram_in_bytes / BYTES_IN_GB,
453        index_build_ram_limit_in_bytes / BYTES_IN_GB
454    );
455
456    if estimated_index_ram_in_bytes >= index_build_ram_limit_in_bytes {
457        info!(
458            "Insufficient memory budget for index build in one shot, index_build_ram_limit={} GB estimated_index_ram={} GB",
459            index_build_ram_limit_in_bytes / BYTES_IN_GB,
460            estimated_index_ram_in_bytes / BYTES_IN_GB,
461        );
462        IndexBuildStrategy::Merged
463    } else {
464        info!(
465            "Full index fits in RAM budget, should consume at most {} GBs, so building in one shot",
466            estimated_index_ram_in_bytes / BYTES_IN_GB
467        );
468        IndexBuildStrategy::OneShot
469    }
470}
471
472pub(crate) struct MergedVamanaIndexWorkflow<'a> {
473    pool: &'a RayonThreadPool,
474    rng: diskann_providers::utils::StandardRng,
475    dataset_file: String,
476    max_degree: u32,
477    pub merged_index_prefix: String,
478}
479
480impl<'a> MergedVamanaIndexWorkflow<'a> {
481    pub(crate) fn new<Data, StorageProvider>(
482        builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
483        pool: &'a RayonThreadPool,
484    ) -> Self
485    where
486        Data: GraphDataType<VectorIdType = u32>,
487        StorageProvider: StorageReadProvider + StorageWriteProvider,
488    {
489        let rng = diskann_providers::utils::create_rnd_from_optional_seed(
490            builder.index_configuration.random_seed,
491        );
492        let dataset_file = builder.index_writer.get_dataset_file();
493        let merged_index_prefix = builder.index_writer.get_merged_index_prefix();
494        let max_degree = builder.index_configuration.config.pruned_degree_u32().get();
495
496        Self {
497            pool,
498            rng,
499            dataset_file,
500            merged_index_prefix,
501            max_degree,
502        }
503    }
504
505    pub(crate) fn partition_data<Data, StorageProvider>(
506        &mut self,
507        builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
508    ) -> ANNResult<usize>
509    where
510        Data: GraphDataType<VectorIdType = u32>,
511        StorageProvider: StorageReadProvider + StorageWriteProvider,
512    {
513        // Advance to PartitionData stage if current stage is InMemIndexBuild
514        builder.checkpoint_record_manager.execute_stage(
515            WorkStage::InMemIndexBuild,
516            WorkStage::PartitionData,
517            || Ok(()),
518            || Ok(()),
519        )?;
520
521        // Partition data stage
522        builder.checkpoint_record_manager.execute_stage(
523            WorkStage::PartitionData,
524            WorkStage::BuildIndicesOnShards(0),
525            || {
526                let num_points = builder.index_configuration.max_points;
527                let sampling_rate = MAX_PQ_TRAINING_SET_SIZE / num_points as f64;
528
529                let ram_budget_in_bytes =
530                    builder.disk_build_param.build_memory_limit().in_bytes() as f64;
531                // calculate how many partitions we need, in order to fit in RAM budget
532                // save id_map for each partition to disk
533                partition_with_ram_budget::<Data::VectorDataType, _, _, _>(
534                    &self.dataset_file,
535                    builder.index_configuration.dim,
536                    sampling_rate,
537                    ram_budget_in_bytes,
538                    2, // k_base
539                    &self.merged_index_prefix,
540                    builder.storage_provider,
541                    &mut self.rng,
542                    self.pool,
543                    |num_points, dim| {
544                        let datasize = std::mem::size_of::<Data::VectorDataType>() as u64;
545                        let graph_degree = 2 * self.max_degree / 3;
546                        estimate_build_index_ram_usage(
547                            num_points,
548                            dim,
549                            datasize,
550                            graph_degree as u64,
551                            builder.disk_build_param.build_quantization(),
552                        )
553                    },
554                )
555            },
556            || {
557                // load num_parts based on file names
558                let mut p = 0;
559                while builder.storage_provider.exists(
560                    &DiskIndexWriter::get_merged_index_subshard_id_map_file(
561                        &self.merged_index_prefix,
562                        p,
563                    ),
564                ) {
565                    p += 1;
566                }
567                info!("Found {} existing partitions from previous run", p);
568                Ok(p)
569            },
570        )
571    }
572
573    pub(crate) fn merge_and_cleanup<Data, StorageProvider>(
574        &mut self,
575        builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
576        num_parts: usize,
577    ) -> ANNResult<()>
578    where
579        Data: GraphDataType<VectorIdType = u32>,
580        StorageProvider: StorageReadProvider + StorageWriteProvider,
581    {
582        if builder
583            .checkpoint_record_manager
584            .get_resumption_point(WorkStage::MergeIndices)?
585            .is_some()
586        {
587            builder.merge_shards_and_cleanup(
588                &self.merged_index_prefix,
589                num_parts,
590                self.max_degree,
591                &mut self.rng,
592            )?;
593            builder
594                .checkpoint_record_manager
595                .update(Progress::Completed, WorkStage::WriteDiskLayout)?;
596        }
597
598        Ok(())
599    }
600
601    pub(crate) fn get_shard_context<'b, Data, StorageProvider>(
602        &self,
603        builder: &'b DiskIndexBuilderCore<'_, Data, StorageProvider>,
604        p: usize,
605        num_parts: usize,
606    ) -> CheckpointContext<'b>
607    where
608        Data: GraphDataType<VectorIdType = u32>,
609        StorageProvider: StorageReadProvider + StorageWriteProvider,
610    {
611        let current_stage = WorkStage::BuildIndicesOnShards(p);
612        let next_stage = if p == num_parts - 1 {
613            // If this is the last shard, next stage is MergeIndices
614            WorkStage::MergeIndices
615        } else {
616            // Otherwise, continue with the next shard
617            WorkStage::BuildIndicesOnShards(p + 1)
618        };
619        CheckpointContext::new(
620            builder.checkpoint_record_manager.as_ref(),
621            current_stage,
622            next_stage,
623        )
624    }
625}
626
627#[cfg(test)]
628pub(crate) mod disk_index_builder_tests {
629    use std::{io::Read, sync::Arc};
630
631    use diskann::{
632        graph::config,
633        utils::{IntoUsize, VectorRepr, ONE},
634        ANNResult,
635    };
636    use diskann_providers::storage::VirtualStorageProvider;
637    use diskann_providers::{
638        common::AlignedBoxWithSlice,
639        storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file},
640        test_utils::graph_data_type_utils::{
641            GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
642        },
643        utils::Timer,
644    };
645    use diskann_utils::test_data_root;
646    use diskann_vector::{
647        distance::Metric::{self, L2},
648        DistanceFunction,
649    };
650    use rstest::rstest;
651    use vfs::OverlayFS;
652
653    use super::*;
654    use crate::{
655        build::builder::build::DiskIndexBuilder,
656        data_model::{CachingStrategy, GraphHeader},
657        disk_index_build_parameter::{DiskIndexBuildParameters, MemoryBudget, NumPQChunks},
658        search::provider::{
659            disk_provider::DiskIndexSearcher,
660            disk_vertex_provider_factory::DiskVertexProviderFactory,
661        },
662        storage::disk_index_reader::DiskIndexReader,
663        utils::{QueryStatistics, VirtualAlignedReaderFactory},
664    };
665    const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
666    pub const TEST_DATA_FILE: &str = "/sift/siftsmall_learn_256pts.fbin";
667    /// We can use the same index prefix for all tests since we use virtual storage provider
668    const INDEX_PATH_PREFIX: &str = "/disk_index_build/sift_learn_test_disk_index_build";
669    const TRUTH_INDEX_PATH_PREFIX_R4_L50: &str = "/disk_index_build/truth_sift_learn_R4_L50";
670
671    pub struct CheckpointParams {
672        pub chunking_config: ChunkingConfig,
673        pub checkpoint_record_manager: Box<dyn CheckpointManager>,
674    }
675
676    pub struct TestParams {
677        pub dim: usize,
678        pub full_dim: usize,
679        pub max_degree: u32,
680        pub num_pq_chunks: usize,
681        pub build_quantization_type: QuantizationType,
682        pub l_build: u32,
683        pub data_path: String,
684        pub index_path_prefix: String,
685        pub associated_data_path: Option<String>,
686        pub index_build_ram_gb: f64,
687        pub checkpoint_params: Option<CheckpointParams>,
688        pub num_threads: usize,
689        pub metric: Metric,
690    }
691
692    impl Default for TestParams {
693        fn default() -> Self {
694            Self {
695                dim: 128, // D
696                full_dim: 128,
697                max_degree: 4, // R
698                num_pq_chunks: 128,
699                build_quantization_type: QuantizationType::FP, // No quantization, i.e. QuantizationType::FP
700                l_build: 50,
701                data_path: TEST_DATA_FILE.to_string(),
702                index_path_prefix: INDEX_PATH_PREFIX.to_string(),
703                associated_data_path: None,
704                index_build_ram_gb: 1.0,
705                checkpoint_params: None,
706                num_threads: 1,
707                metric: L2,
708            }
709        }
710    }
711
712    impl TestParams {
713        /// Returns the appropriate truth index path prefix for build comparison.
714        fn truth_index_path_prefix(&self) -> &str {
715            match (self.max_degree, self.l_build, self.index_build_ram_gb) {
716                (4, 50, 1.0) => TRUTH_INDEX_PATH_PREFIX_R4_L50,
717                (max_degree, l_build, index_build_ram_gb) => panic!(
718                    "Truth index path not found for max_degree={}, l_build={}, index_build_ram_gb={}",
719                    max_degree, l_build, index_build_ram_gb
720                ),
721            }
722        }
723        pub fn truth_pq_compressed_path(&self) -> String {
724            let prefix = match self.num_pq_chunks {
725                128 => TRUTH_INDEX_PATH_PREFIX_R4_L50,
726                num_pq_chunks => panic!(
727                    "Truth pq compressed path not found for num_pq_chunks={}",
728                    num_pq_chunks,
729                ),
730            };
731            get_compressed_pq_file(prefix)
732        }
733
734        pub fn pq_compressed_path(&self) -> String {
735            get_compressed_pq_file(&self.index_path_prefix)
736        }
737    }
738
739    pub fn new_vfs() -> VirtualStorageProvider<OverlayFS> {
740        VirtualStorageProvider::new_overlay(test_data_root())
741    }
742
743    pub struct IndexBuildFixture<StorageProvider: StorageReadProvider + StorageWriteProvider> {
744        pub storage_provider: Arc<StorageProvider>,
745        pub params: TestParams,
746    }
747
748    impl<StorageProvider: StorageReadProvider + StorageWriteProvider + 'static>
749        IndexBuildFixture<StorageProvider>
750    {
751        pub fn new(storage_provider: StorageProvider, params: TestParams) -> ANNResult<Self> {
752            Ok(Self {
753                storage_provider: Arc::new(storage_provider),
754                params,
755            })
756        }
757
758        pub fn build<T>(&self) -> ANNResult<()>
759        where
760            T: GraphDataType<VectorIdType = u32>,
761            StorageProvider::Reader: std::marker::Send + Read,
762        {
763            // Create disk index build parameters
764            let disk_index_build_parameters = DiskIndexBuildParameters::new(
765                MemoryBudget::try_from_gb(self.params.index_build_ram_gb)?,
766                self.params.build_quantization_type,
767                NumPQChunks::new_with(self.params.num_pq_chunks, self.params.full_dim)?,
768            );
769
770            let config = config::Builder::new_with(
771                self.params.max_degree.into_usize(),
772                config::MaxDegree::default_slack(),
773                self.params.l_build.into_usize(),
774                self.params.metric.into(),
775                |b| {
776                    b.saturate_after_prune(true);
777                },
778            )
779            .build()?;
780
781            let metadata =
782                load_metadata_from_file(self.storage_provider.as_ref(), &self.params.data_path)
783                    .unwrap();
784
785            assert_eq!(
786                self.params.dim,
787                metadata.ndims(),
788                "Parameters dimension {} and data dimension {} are not equal",
789                self.params.dim,
790                metadata.ndims(),
791            );
792
793            let config = IndexConfiguration::new(
794                self.params.metric,
795                self.params.dim,
796                metadata.npoints(),
797                ONE,
798                self.params.num_threads,
799                config,
800            )
801            .with_pseudo_rng_from_seed(100);
802
803            let disk_index_writer = DiskIndexWriter::new(
804                self.params.data_path.clone(),
805                self.params.index_path_prefix.clone(),
806                self.params.associated_data_path.clone(),
807                DEFAULT_DISK_SECTOR_LEN,
808            )?;
809
810            let mut disk_index = match self.params.checkpoint_params {
811                Some(ref checkpoint_params) => {
812                    let checkpoint_record_manager =
813                        checkpoint_params.checkpoint_record_manager.clone_box();
814                    let chunking_config = checkpoint_params.chunking_config.clone();
815                    DiskIndexBuilder::<T, _>::new_with_chunking_config(
816                        self.storage_provider.as_ref(),
817                        disk_index_build_parameters,
818                        config,
819                        disk_index_writer,
820                        chunking_config,
821                        checkpoint_record_manager,
822                    )
823                }
824                None => DiskIndexBuilder::<T, _>::new(
825                    self.storage_provider.as_ref(),
826                    disk_index_build_parameters,
827                    config,
828                    disk_index_writer,
829                ),
830            }?;
831
832            let timer = Timer::new();
833            disk_index.build()?;
834            println!("Indexing time: {} seconds", timer.elapsed().as_secs_f64());
835
836            Ok(())
837        }
838
839        pub fn compare_pq_compressed_files(&self) {
840            self.compare_files(
841                &self.params.pq_compressed_path(),
842                &self.params.truth_pq_compressed_path(),
843            );
844        }
845
846        pub fn assert_index_max_degree<T: GraphDataType>(&self) -> ANNResult<()> {
847            let index_file_path = get_disk_index_file(&self.params.index_path_prefix);
848            let file_data = load_file_to_vec(self.storage_provider.as_ref(), &index_file_path);
849            let graph_header = GraphHeader::try_from(&file_data[8..])?;
850            let max_degree = graph_header.max_degree::<T::VectorDataType>()?;
851            assert_eq!(
852                max_degree, self.params.max_degree as usize,
853                "Max degree mismatch: expected {}, got {}",
854                self.params.max_degree, max_degree
855            );
856
857            Ok(())
858        }
859
860        fn compare_disk_index_with_associated_data(
861            &self,
862            pivot_file_prefix_test: &str,
863            pivot_file_prefix_expected: &str,
864            index_file_suffix: &str,
865        ) {
866            let pq_pivot_path = pivot_file_prefix_test.to_string() + index_file_suffix;
867            let pq_pivot_path_truth = pivot_file_prefix_expected.to_string() + index_file_suffix;
868            let file1 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path);
869            let file2 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path_truth);
870            compare_disk_index_graphs(&file1, &file2)
871        }
872
873        pub fn compare_files(&self, file_path1: &str, file_path2: &str) {
874            let file1 = load_file_to_vec(self.storage_provider.as_ref(), file_path1);
875            let file2 = load_file_to_vec(self.storage_provider.as_ref(), file_path2);
876
877            assert_eq!(file1.len(), file2.len());
878            assert_eq!(file1, file2)
879        }
880    }
881
882    /// Common helper function for one-shot async index build tests
883    fn run_one_shot_test<F>(index_path_prefix: String, params_customizer: F)
884    where
885        F: FnOnce(TestParams) -> TestParams,
886    {
887        let l_build = 64;
888        let max_degree = 16;
889        let top_k = 10;
890        let search_l = 32;
891
892        let base_params = TestParams {
893            l_build,
894            max_degree,
895            index_path_prefix,
896            ..TestParams::default()
897        };
898
899        let params = params_customizer(base_params);
900
901        let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
902        fixture.build::<GraphDataF32VectorUnitData>().unwrap();
903
904        // Validate search recall against ground truth for async tests
905        verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
906            &fixture.params,
907            top_k,
908            search_l,
909            &fixture.storage_provider,
910        )
911        .unwrap();
912
913        fixture
914            .assert_index_max_degree::<GraphDataF32VectorUnitData>()
915            .unwrap();
916
917        // Assert that all data was kept in memory and no files were written to the disk.
918        let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
919        assert!(!fixture.storage_provider.exists(&mem_index_file_path));
920    }
921
922    #[rstest]
923    fn test_build_from_iter_one_shot_with_metric(
924        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
925    ) {
926        let index_path_prefix = format!("{}_metric_{:?}", INDEX_PATH_PREFIX, metric);
927
928        run_one_shot_test(index_path_prefix, |params| TestParams { metric, ..params });
929    }
930
931    #[test]
932    fn test_build_from_iter_one_shot_with_associated_data() {
933        // Set up test data
934        let params = TestParams {
935            associated_data_path: Some(
936                "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
937            ),
938            ..TestParams::default()
939        };
940
941        // Create fixture with virtual storage provider
942        let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
943
944        // Build the index with the associated data
945        fixture.build::<GraphDataF32VectorU32Data>().unwrap();
946
947        // Assert that all data was kept in memory and no files were written to the disk.
948        let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
949        let mem_index_associated_data_path = format!(
950            "{}_mem.index.associated_data",
951            fixture.params.index_path_prefix
952        );
953        assert!(!fixture.storage_provider.exists(&mem_index_file_path));
954        assert!(!fixture
955            .storage_provider
956            .exists(&mem_index_associated_data_path));
957
958        // assert index files are expected.
959        fixture.compare_disk_index_with_associated_data(
960            &fixture.params.index_path_prefix,
961            fixture.params.truth_index_path_prefix(),
962            "_disk.index",
963        );
964    }
965
966    #[test]
967    fn test_build_from_iter_merged_index() {
968        // Use the same parameters from [test_sift_build_and_search] in diskann_index
969        let l_build = 64;
970        let max_degree = 16;
971        let top_k = 10;
972        let search_l = 32;
973
974        let index_path_prefix =
975            "/disk_index_build/disk_index_sift_learn_test_disk_index_build_merged".to_string();
976        let params = TestParams {
977            l_build,
978            max_degree,
979            index_path_prefix,
980            index_build_ram_gb: 0.0001, // small enough to trigger merged index build
981            ..TestParams::default()
982        };
983
984        let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
985
986        fixture.build::<GraphDataF32VectorUnitData>().unwrap();
987
988        verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
989            &fixture.params,
990            top_k,
991            search_l,
992            &fixture.storage_provider,
993        )
994        .unwrap();
995
996        fixture
997            .assert_index_max_degree::<GraphDataF32VectorUnitData>()
998            .unwrap();
999    }
1000
1001    #[rstest]
1002    #[case(QuantizationType::SQ { nbits: 2, standard_deviation: None }, "SQ quantization is only supported for 1 bit")]
1003    fn test_build_quantization_type_failure_cases(
1004        #[case] build_quantization_type: QuantizationType,
1005        #[case] error_message: &str,
1006    ) {
1007        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1008        let disk_index_builder = create_disk_index_builder(
1009            1000, // num_points
1010            128,  // dim
1011            128,  // num_pq_chunks
1012            &storage_provider,
1013            build_quantization_type,
1014        );
1015
1016        let err = disk_index_builder.err().unwrap();
1017        assert!(err.to_string().contains(error_message));
1018    }
1019
1020    fn load_file_to_vec<StorageType: StorageReadProvider>(
1021        storage_provider: &StorageType,
1022        file_path: &str,
1023    ) -> Vec<u8> {
1024        let mut file = storage_provider.open_reader(file_path).unwrap();
1025        let mut buffer = vec![];
1026        file.read_to_end(&mut buffer).unwrap();
1027        buffer
1028    }
1029
1030    /// Verifies that search results exactly match the ground truth of nearest neighbors
1031    ///
1032    /// This function performs validation of search results by:
1033    /// 1. Running searches on the index using actual data points from the dataset as queries
1034    /// 2. Computing the exact ground truth results using direct distance calculations
1035    /// 3. Verifying that the search engine returns precisely the same results as the ground truth
1036    pub(crate) fn verify_search_result_with_ground_truth<
1037        G: GraphDataType<VectorIdType = u32, AssociatedDataType = ()>,
1038    >(
1039        params: &TestParams,
1040        top_k: usize,
1041        search_l: u32,
1042        storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1043    ) -> ANNResult<()> {
1044        let pq_pivot_path = get_pq_pivot_file(&params.index_path_prefix);
1045        let pq_compressed_path = get_compressed_pq_file(&params.index_path_prefix);
1046        let index_file_path = get_disk_index_file(&params.index_path_prefix);
1047
1048        let index_reader = DiskIndexReader::<G::VectorDataType>::new(
1049            pq_pivot_path,
1050            pq_compressed_path,
1051            storage_provider.as_ref(),
1052        )?;
1053
1054        let vertex_provider_factory = DiskVertexProviderFactory::new(
1055            VirtualAlignedReaderFactory::new(index_file_path, Arc::clone(storage_provider)),
1056            CachingStrategy::None,
1057        )?;
1058
1059        let search_engine = DiskIndexSearcher::<G, DiskVertexProviderFactory<G, _>>::new(
1060            1,
1061            u32::MAX as usize,
1062            &index_reader,
1063            vertex_provider_factory,
1064            params.metric,
1065            None,
1066        )?;
1067
1068        let data =
1069            read_bin::<G::VectorDataType>(&mut storage_provider.open_reader(&params.data_path)?)?;
1070        let dim = data.ncols();
1071        let distance = <G::VectorDataType>::distance(params.metric, Some(dim));
1072
1073        // Here, we use elements of the dataset to search the dataset itself.
1074        //
1075        // We do this for each query, computing the expected ground truth and verifying
1076        // that our simple graph search matches.
1077        //
1078        // Because this dataset is small, we can expect exact equality.
1079        for (q, query_data) in data.row_iter().enumerate() {
1080            let gt =
1081                diskann_providers::test_utils::groundtruth(data.as_view(), query_data, |a, b| {
1082                    distance.evaluate_similarity(a, b)
1083                });
1084
1085            let mut query: AlignedBoxWithSlice<G::VectorDataType> =
1086                AlignedBoxWithSlice::<G::VectorDataType>::new(dim, 8)?;
1087            query.memcpy(query_data)?;
1088
1089            let mut query_stats = QueryStatistics::default();
1090
1091            let mut indices = vec![0u32; top_k];
1092            let mut distances = vec![0f32; top_k];
1093            let mut associated_data = vec![(); top_k];
1094
1095            _ = search_engine.search_internal(
1096                &query,
1097                top_k,
1098                search_l,
1099                None, // beam_width
1100                &mut query_stats,
1101                &mut indices,
1102                &mut distances,
1103                &mut associated_data,
1104                &|_| true,
1105                false,
1106            );
1107
1108            diskann_providers::test_utils::assert_top_k_exactly_match(
1109                q, &gt, &indices, &distances, top_k,
1110            );
1111        }
1112
1113        Ok(())
1114    }
1115
1116    // Compare that the index built in test is the same as the truth index. The truth index doesn't have associated data, we are only comparing the vector and neighbor data.
1117    pub fn compare_disk_index_graphs(graph_data: &[u8], truth_graph_data: &[u8]) {
1118        let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
1119        let truth_graph_header = GraphHeader::try_from(&truth_graph_data[8..]).unwrap();
1120
1121        let test_node_per_block = graph_header.metadata().num_nodes_per_block;
1122        let test_max_node_length = graph_header.metadata().node_len;
1123
1124        let truth_node_per_block = truth_graph_header.metadata().num_nodes_per_block;
1125        let truth_max_node_length = truth_graph_header.metadata().node_len;
1126
1127        assert_eq!(
1128            graph_header.metadata().num_pts,
1129            truth_graph_header.metadata().num_pts
1130        );
1131
1132        assert_eq!(
1133            graph_header.metadata().dims,
1134            truth_graph_header.metadata().dims
1135        );
1136
1137        let num_pts = graph_header.metadata().num_pts as usize;
1138        let dim = graph_header.metadata().dims;
1139
1140        for idx in 0..num_pts {
1141            let test_node_id_offset = node_data_offset(
1142                idx,
1143                test_max_node_length as usize,
1144                test_node_per_block as usize,
1145                DEFAULT_DISK_SECTOR_LEN,
1146            );
1147
1148            let truth_node_id_offset = node_data_offset(
1149                idx,
1150                truth_max_node_length as usize,
1151                truth_node_per_block as usize,
1152                DEFAULT_DISK_SECTOR_LEN,
1153            );
1154
1155            // Assert that the vector data is the same between the test and truth graphs for this node.
1156            assert_eq!(
1157                &graph_data
1158                    [test_node_id_offset..test_node_id_offset + dim * std::mem::size_of::<f32>()],
1159                &truth_graph_data
1160                    [truth_node_id_offset..truth_node_id_offset + dim * std::mem::size_of::<f32>()]
1161            );
1162
1163            // Assert that the neighbor count is the same between the test and truth graphs for this node.
1164            let test_nbr_cnt_offset = test_node_id_offset + dim * std::mem::size_of::<f32>();
1165            let truth_nbr_cnt_offset = truth_node_id_offset + dim * std::mem::size_of::<f32>();
1166
1167            let test_nbr_count = u32::from_le_bytes([
1168                graph_data[test_nbr_cnt_offset],
1169                graph_data[test_nbr_cnt_offset + 1],
1170                graph_data[test_nbr_cnt_offset + 2],
1171                graph_data[test_nbr_cnt_offset + 3],
1172            ]);
1173
1174            let truth_nbr_count = u32::from_le_bytes([
1175                truth_graph_data[truth_nbr_cnt_offset],
1176                truth_graph_data[truth_nbr_cnt_offset + 1],
1177                truth_graph_data[truth_nbr_cnt_offset + 2],
1178                truth_graph_data[truth_nbr_cnt_offset + 3],
1179            ]);
1180
1181            assert_eq!(test_nbr_count, truth_nbr_count);
1182
1183            // Assert the neighbors (u32) are the same between the test and truth graphs for this node.
1184            let test_nbr_offset = test_nbr_cnt_offset + 4;
1185            let truth_nbr_offset = truth_nbr_cnt_offset + 4;
1186            assert_eq!(
1187                graph_data[test_nbr_offset..test_nbr_offset + test_nbr_count as usize * 4],
1188                truth_graph_data[truth_nbr_offset..truth_nbr_offset + truth_nbr_count as usize * 4]
1189            );
1190        }
1191    }
1192
1193    pub fn node_data_offset(
1194        node_id: usize,
1195        node_length: usize,
1196        nodes_per_block: usize,
1197        block_size: usize,
1198    ) -> usize {
1199        let block_id = node_id / nodes_per_block;
1200        let node_id_in_block = node_id % nodes_per_block;
1201        let offset = block_id * block_size + node_id_in_block * node_length;
1202        offset + block_size
1203    }
1204
1205    fn create_disk_index_builder(
1206        num_points: usize,
1207        dim: usize,
1208        num_of_pq_chunks: usize,
1209        storage_provider: &VirtualStorageProvider<OverlayFS>,
1210        build_quantization_type: QuantizationType,
1211    ) -> ANNResult<
1212        DiskIndexBuilder<'_, GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>,
1213    > {
1214        let memory_budget = MemoryBudget::try_from_gb(1.0)?;
1215        let num_pq_chunks = NumPQChunks::new_with(num_of_pq_chunks, dim)?;
1216
1217        let build_parameters =
1218            DiskIndexBuildParameters::new(memory_budget, build_quantization_type, num_pq_chunks);
1219
1220        let index_configuration = IndexConfiguration::new(
1221            L2,
1222            dim,
1223            num_points,
1224            ONE,
1225            1,
1226            config::Builder::new_with(4, config::MaxDegree::default_slack(), 50, L2.into(), |b| {
1227                b.saturate_after_prune(true);
1228            })
1229            .build()?,
1230        );
1231
1232        let disk_index_writer = DiskIndexWriter::new(
1233            "data_path".to_string(),
1234            "index_path_prefix".to_string(),
1235            None,
1236            DEFAULT_DISK_SECTOR_LEN,
1237        )?;
1238
1239        DiskIndexBuilder::<GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>::new(
1240            storage_provider,
1241            build_parameters,
1242            index_configuration,
1243            disk_index_writer,
1244        )
1245    }
1246}
1247
1248#[cfg(test)]
1249mod ram_estimation_tests {
1250    use rstest::rstest;
1251
1252    use super::*;
1253    use crate::QuantizationType;
1254
1255    #[rstest]
1256    #[case(QuantizationType::FP)]
1257    #[case(QuantizationType::PQ { num_chunks: 15 })]
1258    #[case(QuantizationType::SQ { nbits: 1, standard_deviation: None })]
1259    fn test_estimate_build_index_ram_usage(#[case] build_quantization_type: QuantizationType) {
1260        let num_points = 1000;
1261        let dim = 128;
1262        let size_of_t = std::mem::size_of::<f32>() as u64;
1263        let graph_degree = 50;
1264
1265        let single_vec_size = match build_quantization_type {
1266            QuantizationType::FP => dim * size_of_t,
1267            QuantizationType::PQ { num_chunks } => num_chunks as u64,
1268            QuantizationType::SQ { nbits, .. } => {
1269                (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
1270            }
1271        };
1272        let mut expected_ram_usage = (num_points as f64)
1273            * (graph_degree as f64)
1274            * (std::mem::size_of::<u32>() as f64)
1275            * GRAPH_SLACK_FACTOR
1276            + (num_points * single_vec_size) as f64;
1277        expected_ram_usage *= OVERHEAD_FACTOR;
1278
1279        let actual_ram_usage = estimate_build_index_ram_usage(
1280            num_points,
1281            dim,
1282            size_of_t,
1283            graph_degree,
1284            &build_quantization_type,
1285        );
1286
1287        assert_eq!(actual_ram_usage, expected_ram_usage);
1288    }
1289}