1use std::mem::{self, size_of};
6
7use diskann::ANNResult;
8use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
9use diskann_providers::{
10 model::{
11 graph::traits::GraphDataType, IndexConfiguration, GRAPH_SLACK_FACTOR,
12 MAX_PQ_TRAINING_SET_SIZE,
13 },
14 storage::PQStorage,
15 utils::{
16 load_metadata_from_file, RayonThreadPool, SampleVectorReader, SamplingDensity,
17 READ_WRITE_BLOCK_SIZE,
18 },
19};
20use diskann_utils::io::read_bin;
21use rand::{seq::SliceRandom, Rng};
22use tracing::info;
23
24use crate::{
25 build::chunking::{
26 checkpoint::{
27 CheckpointContext, CheckpointManager, CheckpointManagerExt, Progress, WorkStage,
28 },
29 continuation::ChunkingConfig,
30 },
31 disk_index_build_parameter::BYTES_IN_GB,
32 storage::{CachedReader, CachedWriter, DiskIndexWriter},
33 utils::partition_with_ram_budget,
34 DiskIndexBuildParameters, QuantizationType,
35};
36
37const OVERHEAD_FACTOR: f64 = 1.1f64;
39
40#[inline]
42fn estimate_build_index_ram_usage(
43 num_points: u64,
44 dim: u64,
45 datasize: u64,
46 graph_degree: u64,
47 build_quantization_type: &QuantizationType,
48) -> f64 {
49 let graph_size =
50 (num_points * graph_degree * mem::size_of::<u32>() as u64) as f64 * GRAPH_SLACK_FACTOR;
51
52 let single_vec_size = match *build_quantization_type {
53 QuantizationType::FP => dim.next_multiple_of(8u64) * datasize,
54 QuantizationType::PQ { num_chunks } => num_chunks as u64,
56 QuantizationType::SQ { nbits, .. } => {
58 (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
59 }
60 };
61
62 OVERHEAD_FACTOR * (graph_size + (single_vec_size * num_points) as f64)
63}
64
65pub struct DiskIndexBuilderCore<'a, Data, StorageProvider>
68where
69 Data: GraphDataType<VectorIdType = u32>,
70 StorageProvider: StorageReadProvider + StorageWriteProvider,
71{
72 pub index_writer: DiskIndexWriter,
73
74 pub pq_storage: PQStorage,
75
76 pub disk_build_param: DiskIndexBuildParameters,
77
78 pub index_configuration: IndexConfiguration,
79
80 pub chunking_config: ChunkingConfig,
81
82 pub checkpoint_record_manager: Box<dyn CheckpointManager>,
83
84 pub storage_provider: &'a StorageProvider,
85
86 pub _phantom: std::marker::PhantomData<Data>,
87}
88
89impl<'a, Data, StorageProvider> DiskIndexBuilderCore<'a, Data, StorageProvider>
90where
91 Data: GraphDataType<VectorIdType = u32>,
92 StorageProvider: StorageReadProvider + StorageWriteProvider,
93{
94 pub(crate) fn create_disk_layout(&mut self) -> ANNResult<()> {
95 self.checkpoint_record_manager.execute_stage(
96 WorkStage::WriteDiskLayout,
97 WorkStage::End,
98 || {
99 self.index_writer
100 .create_disk_layout::<Data, StorageProvider>(self.storage_provider)?;
101 Ok(())
102 },
103 || Ok(()),
104 )?;
105
106 self.index_writer
107 .index_build_cleanup(self.storage_provider)?;
108
109 Ok(())
110 }
111
112 pub(crate) fn create_shard_index_config(
113 &self,
114 shard_base_file: &str,
115 ) -> ANNResult<IndexConfiguration> {
116 let base_config = &self.index_configuration;
117 let storage_provider = self.storage_provider;
118
119 let search_list_size = base_config.config.l_build().get();
120 let pruned_degree = base_config.config.pruned_degree().get();
121
122 let low_degree_params = diskann::graph::config::Builder::new(
123 2 * pruned_degree / 3,
124 diskann::graph::config::MaxDegree::default_slack(),
125 search_list_size,
126 base_config.dist_metric.into(),
127 )
128 .build()?;
129
130 let metadata = load_metadata_from_file(storage_provider, shard_base_file)?;
131
132 let mut index_config = base_config.clone();
133 index_config.max_points = metadata.npoints();
134 index_config.config = low_degree_params;
135
136 Ok(index_config)
137 }
138
139 pub(crate) fn retrieve_shard_data_from_ids<T>(
140 &self,
141 dataset_file: &str,
142 shard_ids_file: &str,
143 shard_base_file: &str,
144 ) -> ANNResult<()>
145 where
146 T: Default + bytemuck::Pod,
147 {
148 let storage_provider = self.storage_provider;
149 let shard_ids = read_bin::<u32>(&mut storage_provider.open_reader(shard_ids_file)?)?;
150 let shard_size = shard_ids.nrows();
151 info!("Loaded {} shard ids from {}", shard_size, shard_ids_file);
152 let max_id = shard_ids.as_slice().iter().max().copied().unwrap_or(0);
153 let sampling_rate = shard_ids.as_slice().len() as f64 / (max_id + 1) as f64;
154
155 let mut dataset_reader: SampleVectorReader<T, _> = SampleVectorReader::new(
156 dataset_file,
157 SamplingDensity::from_sample_rate(sampling_rate),
158 storage_provider,
159 )?;
160
161 let (_npts, dim) = dataset_reader.get_dataset_headers();
162
163 let mut shard_base_cached_writer = CachedWriter::<StorageProvider>::new(
164 shard_base_file,
165 READ_WRITE_BLOCK_SIZE,
166 storage_provider.create_for_write(shard_base_file)?,
167 )?;
168
169 let dummy_size: u32 = 0;
170 shard_base_cached_writer.write(&dummy_size.to_le_bytes())?;
171 shard_base_cached_writer.write(&dim.to_le_bytes())?;
172
173 let mut num_written: u32 = 0;
174 dataset_reader.read_vectors(shard_ids.as_slice().iter().copied(), |vector_t| {
175 let vector_bytes: &[u8] = bytemuck::must_cast_slice(vector_t);
177 shard_base_cached_writer.write(vector_bytes)?;
178 num_written += 1;
179 Ok(())
180 })?;
181
182 info!(
183 "Written file: {} with {} points",
184 shard_base_file, num_written
185 );
186
187 shard_base_cached_writer.flush()?;
188 shard_base_cached_writer.reset()?;
189 shard_base_cached_writer.write(&num_written.to_le_bytes())?;
190
191 Ok(())
192 }
193
194 #[allow(clippy::too_many_arguments)]
195 fn merge_shards(
196 &self,
197 merged_index_prefix: &str,
198 num_parts: usize,
199 max_degree: u32,
200 output_vamana: String,
201 rng: &mut impl Rng,
202 ) -> ANNResult<()> {
203 let mut vamana_names = vec![String::new(); num_parts];
205 let mut id_maps: Vec<Vec<u32>> = vec![Vec::new(); num_parts];
206 for shard in 0..num_parts {
207 vamana_names[shard] = DiskIndexWriter::get_merged_index_subshard_mem_index_file(
208 merged_index_prefix,
209 shard,
210 );
211
212 let id_maps_file =
213 DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard);
214 id_maps[shard] = self.read_idmap(id_maps_file)?;
215 }
216
217 let num_nodes: u32 = *id_maps.iter().flatten().max().unwrap_or(&0) + 1;
219 let num_elements: u32 = id_maps.iter().map(|idmap| idmap.len() as u32).sum();
220 info!("# nodes: {}, max degree: {}", num_nodes, max_degree);
221
222 let mut node_shard: Vec<(u32, u32)> = Vec::with_capacity(num_elements as usize);
224 for (shard, id_map) in id_maps.iter().enumerate() {
225 info!("Creating inverse map -- shard #{}", shard);
226 node_shard.extend(id_map.iter().map(|node_id| (*node_id, shard as u32)));
227 }
228 node_shard.sort_unstable_by(|left, right| {
229 left.0.cmp(&right.0).then_with(|| left.1.cmp(&right.1))
230 });
231
232 info!("Finished computing node -> shards map");
233
234 let mut vamana_readers = Vec::new();
236 for name in &vamana_names {
237 let reader = CachedReader::<StorageProvider>::new(
238 name,
239 READ_WRITE_BLOCK_SIZE,
240 self.storage_provider,
241 )?;
242 vamana_readers.push(reader);
243 }
244
245 let mut merged_vamana_cached_writer = CachedWriter::<StorageProvider>::new(
247 &output_vamana,
248 READ_WRITE_BLOCK_SIZE,
249 self.storage_provider.create_for_write(&output_vamana)?,
250 )?;
251
252 let vamana_metadata_size =
254 size_of::<u64>() + size_of::<u32>() + size_of::<u32>() + size_of::<u64>();
255
256 let mut merged_index_size: u64 = vamana_metadata_size as u64;
259 merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
260
261 let mut read_buf_8_bytes = [0u8; 8];
262
263 let mut max_input_width = 0;
265 for reader in &mut vamana_readers {
267 reader.read(&mut read_buf_8_bytes)?;
268 let _expected_file_size: u64 = u64::from_le_bytes(read_buf_8_bytes);
269 let input_width = reader.read_u32()?;
270 max_input_width = input_width.max(max_input_width);
271 }
272
273 let output_width: u32 = max_degree;
275 info!(
276 "Max input width: {}, output width: {}",
277 max_input_width, output_width
278 );
279
280 merged_vamana_cached_writer.write(&output_width.to_le_bytes())?;
281
282 for shard in 0..num_parts {
284 let mut medoid: u32 = vamana_readers[shard].read_u32()?;
286 vamana_readers[shard].read(&mut read_buf_8_bytes)?;
287 let vamana_index_frozen: u64 = u64::from_le_bytes(read_buf_8_bytes);
288 debug_assert_eq!(vamana_index_frozen, 0);
289
290 medoid = id_maps[shard][medoid as usize];
292
293 if shard == (num_parts - 1) {
295 merged_vamana_cached_writer.write(&medoid.to_le_bytes())?;
297 }
298 }
299
300 let vamana_index_frozen: u64 = 0; merged_vamana_cached_writer.write(&vamana_index_frozen.to_le_bytes())?;
304
305 info!("Starting merge");
306
307 let mut nbr_set = vec![false; num_nodes as usize];
308 let mut final_nbrs: Vec<u32> = Vec::new();
309 let mut cur_id = 0;
310 for pair in &node_shard {
311 let (node_id, shard_id) = *pair;
312 if cur_id < node_id {
313 final_nbrs.shuffle(rng);
314
315 let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
316 merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
317
318 let bytes = final_nbrs
319 .iter()
320 .take(nnbrs as usize)
321 .flat_map(|x| x.to_le_bytes())
322 .collect::<Vec<u8>>();
323 merged_vamana_cached_writer.write(&bytes)?;
324
325 merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
326 if cur_id % 499999 == 1 {
327 print!(".");
328 }
329 cur_id = node_id;
330
331 final_nbrs.iter().for_each(|p| nbr_set[*p as usize] = false);
332 final_nbrs.clear();
333 }
334
335 let num_nbrs = vamana_readers[shard_id as usize].read_u32()?;
337
338 if num_nbrs == 0 {
339 info!(
340 "WARNING: shard #{}, node_id {} has 0 nbrs",
341 shard_id, node_id
342 );
343 } else {
344 let mut nbrs_bytes = vec![0u8; num_nbrs as usize * mem::size_of::<u32>()];
345 vamana_readers[shard_id as usize].read(&mut nbrs_bytes)?;
346 let nbrs: &[u32] = bytemuck::cast_slice(&nbrs_bytes);
347
348 for j in 0..num_nbrs {
350 let nbr = nbrs[j as usize];
351 let renamed_node = id_maps[shard_id as usize][nbr as usize];
352 if !nbr_set[renamed_node as usize] {
353 nbr_set[renamed_node as usize] = true;
354 final_nbrs.push(renamed_node);
355 }
356 }
357 }
358 }
359
360 final_nbrs.shuffle(rng);
362
363 let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
364 merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
365
366 let bytes = final_nbrs
367 .iter()
368 .take(nnbrs as usize)
369 .flat_map(|x| x.to_le_bytes())
370 .collect::<Vec<u8>>();
371 merged_vamana_cached_writer.write(&bytes)?;
372
373 merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
374
375 nbr_set.clear();
376 final_nbrs.clear();
377
378 info!("Expected size: {}", merged_index_size);
379 merged_vamana_cached_writer.reset()?;
380 merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
381
382 info!("Finished merge");
383 Ok(())
384 }
385
386 fn read_idmap(&self, idmaps_path: String) -> Result<Vec<u32>, diskann_utils::io::ReadBinError> {
387 let data = read_bin::<u32>(&mut self.storage_provider.open_reader(&idmaps_path)?)?;
388 Ok(data.into_inner().into_vec())
389 }
390
391 fn merge_shards_and_cleanup(
392 &self,
393 merged_index_prefix: &str,
394 num_parts: usize,
395 max_degree: u32,
396 rng: &mut impl Rng,
397 ) -> ANNResult<()> {
398 self.merge_shards(
400 merged_index_prefix,
401 num_parts,
402 max_degree,
403 self.index_writer.get_mem_index_file(),
404 rng,
405 )?;
406
407 for p in 0..num_parts {
409 let shard_base_file =
410 DiskIndexWriter::get_merged_index_subshard_data_file(merged_index_prefix, p);
411 let shard_ids_file =
412 DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, p);
413 let shard_index_file =
414 DiskIndexWriter::get_merged_index_subshard_mem_index_file(merged_index_prefix, p);
415 let shard_index_file_data =
416 DiskIndexWriter::get_merged_index_subshard_mem_dataset_file(&shard_index_file);
417
418 self.storage_provider.delete(&shard_base_file)?;
419 self.storage_provider.delete(&shard_ids_file)?;
420 self.storage_provider.delete(&shard_index_file)?;
421 if self.storage_provider.exists(&shard_index_file_data) {
424 self.storage_provider.delete(&shard_index_file_data)?;
425 }
426 }
427
428 Ok(())
429 }
430}
431
432pub(crate) enum IndexBuildStrategy {
433 OneShot,
434 Merged,
435}
436
437pub(crate) fn determine_build_strategy<Data: GraphDataType>(
438 index_configuration: &IndexConfiguration,
439 index_build_ram_limit_in_bytes: f64,
440 build_quantization_type: &QuantizationType,
441) -> IndexBuildStrategy {
442 let estimated_index_ram_in_bytes = estimate_build_index_ram_usage(
443 index_configuration.max_points as u64,
444 index_configuration.dim as u64,
445 mem::size_of::<Data::VectorDataType>() as u64,
446 index_configuration.config.max_degree().get() as u64,
447 build_quantization_type,
448 );
449
450 info!(
451 "Estimated index RAM usage: {} GB, index_build_ram_limit={} GB",
452 estimated_index_ram_in_bytes / BYTES_IN_GB,
453 index_build_ram_limit_in_bytes / BYTES_IN_GB
454 );
455
456 if estimated_index_ram_in_bytes >= index_build_ram_limit_in_bytes {
457 info!(
458 "Insufficient memory budget for index build in one shot, index_build_ram_limit={} GB estimated_index_ram={} GB",
459 index_build_ram_limit_in_bytes / BYTES_IN_GB,
460 estimated_index_ram_in_bytes / BYTES_IN_GB,
461 );
462 IndexBuildStrategy::Merged
463 } else {
464 info!(
465 "Full index fits in RAM budget, should consume at most {} GBs, so building in one shot",
466 estimated_index_ram_in_bytes / BYTES_IN_GB
467 );
468 IndexBuildStrategy::OneShot
469 }
470}
471
472pub(crate) struct MergedVamanaIndexWorkflow<'a> {
473 pool: &'a RayonThreadPool,
474 rng: diskann_providers::utils::StandardRng,
475 dataset_file: String,
476 max_degree: u32,
477 pub merged_index_prefix: String,
478}
479
480impl<'a> MergedVamanaIndexWorkflow<'a> {
481 pub(crate) fn new<Data, StorageProvider>(
482 builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
483 pool: &'a RayonThreadPool,
484 ) -> Self
485 where
486 Data: GraphDataType<VectorIdType = u32>,
487 StorageProvider: StorageReadProvider + StorageWriteProvider,
488 {
489 let rng = diskann_providers::utils::create_rnd_from_optional_seed(
490 builder.index_configuration.random_seed,
491 );
492 let dataset_file = builder.index_writer.get_dataset_file();
493 let merged_index_prefix = builder.index_writer.get_merged_index_prefix();
494 let max_degree = builder.index_configuration.config.pruned_degree_u32().get();
495
496 Self {
497 pool,
498 rng,
499 dataset_file,
500 merged_index_prefix,
501 max_degree,
502 }
503 }
504
505 pub(crate) fn partition_data<Data, StorageProvider>(
506 &mut self,
507 builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
508 ) -> ANNResult<usize>
509 where
510 Data: GraphDataType<VectorIdType = u32>,
511 StorageProvider: StorageReadProvider + StorageWriteProvider,
512 {
513 builder.checkpoint_record_manager.execute_stage(
515 WorkStage::InMemIndexBuild,
516 WorkStage::PartitionData,
517 || Ok(()),
518 || Ok(()),
519 )?;
520
521 builder.checkpoint_record_manager.execute_stage(
523 WorkStage::PartitionData,
524 WorkStage::BuildIndicesOnShards(0),
525 || {
526 let num_points = builder.index_configuration.max_points;
527 let sampling_rate = MAX_PQ_TRAINING_SET_SIZE / num_points as f64;
528
529 let ram_budget_in_bytes =
530 builder.disk_build_param.build_memory_limit().in_bytes() as f64;
531 partition_with_ram_budget::<Data::VectorDataType, _, _, _>(
534 &self.dataset_file,
535 builder.index_configuration.dim,
536 sampling_rate,
537 ram_budget_in_bytes,
538 2, &self.merged_index_prefix,
540 builder.storage_provider,
541 &mut self.rng,
542 self.pool,
543 |num_points, dim| {
544 let datasize = std::mem::size_of::<Data::VectorDataType>() as u64;
545 let graph_degree = 2 * self.max_degree / 3;
546 estimate_build_index_ram_usage(
547 num_points,
548 dim,
549 datasize,
550 graph_degree as u64,
551 builder.disk_build_param.build_quantization(),
552 )
553 },
554 )
555 },
556 || {
557 let mut p = 0;
559 while builder.storage_provider.exists(
560 &DiskIndexWriter::get_merged_index_subshard_id_map_file(
561 &self.merged_index_prefix,
562 p,
563 ),
564 ) {
565 p += 1;
566 }
567 info!("Found {} existing partitions from previous run", p);
568 Ok(p)
569 },
570 )
571 }
572
573 pub(crate) fn merge_and_cleanup<Data, StorageProvider>(
574 &mut self,
575 builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
576 num_parts: usize,
577 ) -> ANNResult<()>
578 where
579 Data: GraphDataType<VectorIdType = u32>,
580 StorageProvider: StorageReadProvider + StorageWriteProvider,
581 {
582 if builder
583 .checkpoint_record_manager
584 .get_resumption_point(WorkStage::MergeIndices)?
585 .is_some()
586 {
587 builder.merge_shards_and_cleanup(
588 &self.merged_index_prefix,
589 num_parts,
590 self.max_degree,
591 &mut self.rng,
592 )?;
593 builder
594 .checkpoint_record_manager
595 .update(Progress::Completed, WorkStage::WriteDiskLayout)?;
596 }
597
598 Ok(())
599 }
600
601 pub(crate) fn get_shard_context<'b, Data, StorageProvider>(
602 &self,
603 builder: &'b DiskIndexBuilderCore<'_, Data, StorageProvider>,
604 p: usize,
605 num_parts: usize,
606 ) -> CheckpointContext<'b>
607 where
608 Data: GraphDataType<VectorIdType = u32>,
609 StorageProvider: StorageReadProvider + StorageWriteProvider,
610 {
611 let current_stage = WorkStage::BuildIndicesOnShards(p);
612 let next_stage = if p == num_parts - 1 {
613 WorkStage::MergeIndices
615 } else {
616 WorkStage::BuildIndicesOnShards(p + 1)
618 };
619 CheckpointContext::new(
620 builder.checkpoint_record_manager.as_ref(),
621 current_stage,
622 next_stage,
623 )
624 }
625}
626
627#[cfg(test)]
628pub(crate) mod disk_index_builder_tests {
629 use std::{io::Read, sync::Arc};
630
631 use diskann::{
632 graph::config,
633 utils::{IntoUsize, VectorRepr, ONE},
634 ANNResult,
635 };
636 use diskann_providers::storage::VirtualStorageProvider;
637 use diskann_providers::{
638 common::AlignedBoxWithSlice,
639 storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file},
640 test_utils::graph_data_type_utils::{
641 GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
642 },
643 utils::Timer,
644 };
645 use diskann_utils::test_data_root;
646 use diskann_vector::{
647 distance::Metric::{self, L2},
648 DistanceFunction,
649 };
650 use rstest::rstest;
651 use vfs::OverlayFS;
652
653 use super::*;
654 use crate::{
655 build::builder::build::DiskIndexBuilder,
656 data_model::{CachingStrategy, GraphHeader},
657 disk_index_build_parameter::{DiskIndexBuildParameters, MemoryBudget, NumPQChunks},
658 search::provider::{
659 disk_provider::DiskIndexSearcher,
660 disk_vertex_provider_factory::DiskVertexProviderFactory,
661 },
662 storage::disk_index_reader::DiskIndexReader,
663 utils::{QueryStatistics, VirtualAlignedReaderFactory},
664 };
665 const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
666 pub const TEST_DATA_FILE: &str = "/sift/siftsmall_learn_256pts.fbin";
667 const INDEX_PATH_PREFIX: &str = "/disk_index_build/sift_learn_test_disk_index_build";
669 const TRUTH_INDEX_PATH_PREFIX_R4_L50: &str = "/disk_index_build/truth_sift_learn_R4_L50";
670
671 pub struct CheckpointParams {
672 pub chunking_config: ChunkingConfig,
673 pub checkpoint_record_manager: Box<dyn CheckpointManager>,
674 }
675
676 pub struct TestParams {
677 pub dim: usize,
678 pub full_dim: usize,
679 pub max_degree: u32,
680 pub num_pq_chunks: usize,
681 pub build_quantization_type: QuantizationType,
682 pub l_build: u32,
683 pub data_path: String,
684 pub index_path_prefix: String,
685 pub associated_data_path: Option<String>,
686 pub index_build_ram_gb: f64,
687 pub checkpoint_params: Option<CheckpointParams>,
688 pub num_threads: usize,
689 pub metric: Metric,
690 }
691
692 impl Default for TestParams {
693 fn default() -> Self {
694 Self {
695 dim: 128, full_dim: 128,
697 max_degree: 4, num_pq_chunks: 128,
699 build_quantization_type: QuantizationType::FP, l_build: 50,
701 data_path: TEST_DATA_FILE.to_string(),
702 index_path_prefix: INDEX_PATH_PREFIX.to_string(),
703 associated_data_path: None,
704 index_build_ram_gb: 1.0,
705 checkpoint_params: None,
706 num_threads: 1,
707 metric: L2,
708 }
709 }
710 }
711
712 impl TestParams {
713 fn truth_index_path_prefix(&self) -> &str {
715 match (self.max_degree, self.l_build, self.index_build_ram_gb) {
716 (4, 50, 1.0) => TRUTH_INDEX_PATH_PREFIX_R4_L50,
717 (max_degree, l_build, index_build_ram_gb) => panic!(
718 "Truth index path not found for max_degree={}, l_build={}, index_build_ram_gb={}",
719 max_degree, l_build, index_build_ram_gb
720 ),
721 }
722 }
723 pub fn truth_pq_compressed_path(&self) -> String {
724 let prefix = match self.num_pq_chunks {
725 128 => TRUTH_INDEX_PATH_PREFIX_R4_L50,
726 num_pq_chunks => panic!(
727 "Truth pq compressed path not found for num_pq_chunks={}",
728 num_pq_chunks,
729 ),
730 };
731 get_compressed_pq_file(prefix)
732 }
733
734 pub fn pq_compressed_path(&self) -> String {
735 get_compressed_pq_file(&self.index_path_prefix)
736 }
737 }
738
739 pub fn new_vfs() -> VirtualStorageProvider<OverlayFS> {
740 VirtualStorageProvider::new_overlay(test_data_root())
741 }
742
743 pub struct IndexBuildFixture<StorageProvider: StorageReadProvider + StorageWriteProvider> {
744 pub storage_provider: Arc<StorageProvider>,
745 pub params: TestParams,
746 }
747
748 impl<StorageProvider: StorageReadProvider + StorageWriteProvider + 'static>
749 IndexBuildFixture<StorageProvider>
750 {
751 pub fn new(storage_provider: StorageProvider, params: TestParams) -> ANNResult<Self> {
752 Ok(Self {
753 storage_provider: Arc::new(storage_provider),
754 params,
755 })
756 }
757
758 pub fn build<T>(&self) -> ANNResult<()>
759 where
760 T: GraphDataType<VectorIdType = u32>,
761 StorageProvider::Reader: std::marker::Send + Read,
762 {
763 let disk_index_build_parameters = DiskIndexBuildParameters::new(
765 MemoryBudget::try_from_gb(self.params.index_build_ram_gb)?,
766 self.params.build_quantization_type,
767 NumPQChunks::new_with(self.params.num_pq_chunks, self.params.full_dim)?,
768 );
769
770 let config = config::Builder::new_with(
771 self.params.max_degree.into_usize(),
772 config::MaxDegree::default_slack(),
773 self.params.l_build.into_usize(),
774 self.params.metric.into(),
775 |b| {
776 b.saturate_after_prune(true);
777 },
778 )
779 .build()?;
780
781 let metadata =
782 load_metadata_from_file(self.storage_provider.as_ref(), &self.params.data_path)
783 .unwrap();
784
785 assert_eq!(
786 self.params.dim,
787 metadata.ndims(),
788 "Parameters dimension {} and data dimension {} are not equal",
789 self.params.dim,
790 metadata.ndims(),
791 );
792
793 let config = IndexConfiguration::new(
794 self.params.metric,
795 self.params.dim,
796 metadata.npoints(),
797 ONE,
798 self.params.num_threads,
799 config,
800 )
801 .with_pseudo_rng_from_seed(100);
802
803 let disk_index_writer = DiskIndexWriter::new(
804 self.params.data_path.clone(),
805 self.params.index_path_prefix.clone(),
806 self.params.associated_data_path.clone(),
807 DEFAULT_DISK_SECTOR_LEN,
808 )?;
809
810 let mut disk_index = match self.params.checkpoint_params {
811 Some(ref checkpoint_params) => {
812 let checkpoint_record_manager =
813 checkpoint_params.checkpoint_record_manager.clone_box();
814 let chunking_config = checkpoint_params.chunking_config.clone();
815 DiskIndexBuilder::<T, _>::new_with_chunking_config(
816 self.storage_provider.as_ref(),
817 disk_index_build_parameters,
818 config,
819 disk_index_writer,
820 chunking_config,
821 checkpoint_record_manager,
822 )
823 }
824 None => DiskIndexBuilder::<T, _>::new(
825 self.storage_provider.as_ref(),
826 disk_index_build_parameters,
827 config,
828 disk_index_writer,
829 ),
830 }?;
831
832 let timer = Timer::new();
833 disk_index.build()?;
834 println!("Indexing time: {} seconds", timer.elapsed().as_secs_f64());
835
836 Ok(())
837 }
838
839 pub fn compare_pq_compressed_files(&self) {
840 self.compare_files(
841 &self.params.pq_compressed_path(),
842 &self.params.truth_pq_compressed_path(),
843 );
844 }
845
846 pub fn assert_index_max_degree<T: GraphDataType>(&self) -> ANNResult<()> {
847 let index_file_path = get_disk_index_file(&self.params.index_path_prefix);
848 let file_data = load_file_to_vec(self.storage_provider.as_ref(), &index_file_path);
849 let graph_header = GraphHeader::try_from(&file_data[8..])?;
850 let max_degree = graph_header.max_degree::<T::VectorDataType>()?;
851 assert_eq!(
852 max_degree, self.params.max_degree as usize,
853 "Max degree mismatch: expected {}, got {}",
854 self.params.max_degree, max_degree
855 );
856
857 Ok(())
858 }
859
860 fn compare_disk_index_with_associated_data(
861 &self,
862 pivot_file_prefix_test: &str,
863 pivot_file_prefix_expected: &str,
864 index_file_suffix: &str,
865 ) {
866 let pq_pivot_path = pivot_file_prefix_test.to_string() + index_file_suffix;
867 let pq_pivot_path_truth = pivot_file_prefix_expected.to_string() + index_file_suffix;
868 let file1 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path);
869 let file2 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path_truth);
870 compare_disk_index_graphs(&file1, &file2)
871 }
872
873 pub fn compare_files(&self, file_path1: &str, file_path2: &str) {
874 let file1 = load_file_to_vec(self.storage_provider.as_ref(), file_path1);
875 let file2 = load_file_to_vec(self.storage_provider.as_ref(), file_path2);
876
877 assert_eq!(file1.len(), file2.len());
878 assert_eq!(file1, file2)
879 }
880 }
881
882 fn run_one_shot_test<F>(index_path_prefix: String, params_customizer: F)
884 where
885 F: FnOnce(TestParams) -> TestParams,
886 {
887 let l_build = 64;
888 let max_degree = 16;
889 let top_k = 10;
890 let search_l = 32;
891
892 let base_params = TestParams {
893 l_build,
894 max_degree,
895 index_path_prefix,
896 ..TestParams::default()
897 };
898
899 let params = params_customizer(base_params);
900
901 let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
902 fixture.build::<GraphDataF32VectorUnitData>().unwrap();
903
904 verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
906 &fixture.params,
907 top_k,
908 search_l,
909 &fixture.storage_provider,
910 )
911 .unwrap();
912
913 fixture
914 .assert_index_max_degree::<GraphDataF32VectorUnitData>()
915 .unwrap();
916
917 let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
919 assert!(!fixture.storage_provider.exists(&mem_index_file_path));
920 }
921
922 #[rstest]
923 fn test_build_from_iter_one_shot_with_metric(
924 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
925 ) {
926 let index_path_prefix = format!("{}_metric_{:?}", INDEX_PATH_PREFIX, metric);
927
928 run_one_shot_test(index_path_prefix, |params| TestParams { metric, ..params });
929 }
930
931 #[test]
932 fn test_build_from_iter_one_shot_with_associated_data() {
933 let params = TestParams {
935 associated_data_path: Some(
936 "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
937 ),
938 ..TestParams::default()
939 };
940
941 let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
943
944 fixture.build::<GraphDataF32VectorU32Data>().unwrap();
946
947 let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
949 let mem_index_associated_data_path = format!(
950 "{}_mem.index.associated_data",
951 fixture.params.index_path_prefix
952 );
953 assert!(!fixture.storage_provider.exists(&mem_index_file_path));
954 assert!(!fixture
955 .storage_provider
956 .exists(&mem_index_associated_data_path));
957
958 fixture.compare_disk_index_with_associated_data(
960 &fixture.params.index_path_prefix,
961 fixture.params.truth_index_path_prefix(),
962 "_disk.index",
963 );
964 }
965
966 #[test]
967 fn test_build_from_iter_merged_index() {
968 let l_build = 64;
970 let max_degree = 16;
971 let top_k = 10;
972 let search_l = 32;
973
974 let index_path_prefix =
975 "/disk_index_build/disk_index_sift_learn_test_disk_index_build_merged".to_string();
976 let params = TestParams {
977 l_build,
978 max_degree,
979 index_path_prefix,
980 index_build_ram_gb: 0.0001, ..TestParams::default()
982 };
983
984 let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
985
986 fixture.build::<GraphDataF32VectorUnitData>().unwrap();
987
988 verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
989 &fixture.params,
990 top_k,
991 search_l,
992 &fixture.storage_provider,
993 )
994 .unwrap();
995
996 fixture
997 .assert_index_max_degree::<GraphDataF32VectorUnitData>()
998 .unwrap();
999 }
1000
1001 #[rstest]
1002 #[case(QuantizationType::SQ { nbits: 2, standard_deviation: None }, "SQ quantization is only supported for 1 bit")]
1003 fn test_build_quantization_type_failure_cases(
1004 #[case] build_quantization_type: QuantizationType,
1005 #[case] error_message: &str,
1006 ) {
1007 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1008 let disk_index_builder = create_disk_index_builder(
1009 1000, 128, 128, &storage_provider,
1013 build_quantization_type,
1014 );
1015
1016 let err = disk_index_builder.err().unwrap();
1017 assert!(err.to_string().contains(error_message));
1018 }
1019
1020 fn load_file_to_vec<StorageType: StorageReadProvider>(
1021 storage_provider: &StorageType,
1022 file_path: &str,
1023 ) -> Vec<u8> {
1024 let mut file = storage_provider.open_reader(file_path).unwrap();
1025 let mut buffer = vec![];
1026 file.read_to_end(&mut buffer).unwrap();
1027 buffer
1028 }
1029
1030 pub(crate) fn verify_search_result_with_ground_truth<
1037 G: GraphDataType<VectorIdType = u32, AssociatedDataType = ()>,
1038 >(
1039 params: &TestParams,
1040 top_k: usize,
1041 search_l: u32,
1042 storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1043 ) -> ANNResult<()> {
1044 let pq_pivot_path = get_pq_pivot_file(¶ms.index_path_prefix);
1045 let pq_compressed_path = get_compressed_pq_file(¶ms.index_path_prefix);
1046 let index_file_path = get_disk_index_file(¶ms.index_path_prefix);
1047
1048 let index_reader = DiskIndexReader::<G::VectorDataType>::new(
1049 pq_pivot_path,
1050 pq_compressed_path,
1051 storage_provider.as_ref(),
1052 )?;
1053
1054 let vertex_provider_factory = DiskVertexProviderFactory::new(
1055 VirtualAlignedReaderFactory::new(index_file_path, Arc::clone(storage_provider)),
1056 CachingStrategy::None,
1057 )?;
1058
1059 let search_engine = DiskIndexSearcher::<G, DiskVertexProviderFactory<G, _>>::new(
1060 1,
1061 u32::MAX as usize,
1062 &index_reader,
1063 vertex_provider_factory,
1064 params.metric,
1065 None,
1066 )?;
1067
1068 let data =
1069 read_bin::<G::VectorDataType>(&mut storage_provider.open_reader(¶ms.data_path)?)?;
1070 let dim = data.ncols();
1071 let distance = <G::VectorDataType>::distance(params.metric, Some(dim));
1072
1073 for (q, query_data) in data.row_iter().enumerate() {
1080 let gt =
1081 diskann_providers::test_utils::groundtruth(data.as_view(), query_data, |a, b| {
1082 distance.evaluate_similarity(a, b)
1083 });
1084
1085 let mut query: AlignedBoxWithSlice<G::VectorDataType> =
1086 AlignedBoxWithSlice::<G::VectorDataType>::new(dim, 8)?;
1087 query.memcpy(query_data)?;
1088
1089 let mut query_stats = QueryStatistics::default();
1090
1091 let mut indices = vec![0u32; top_k];
1092 let mut distances = vec![0f32; top_k];
1093 let mut associated_data = vec![(); top_k];
1094
1095 _ = search_engine.search_internal(
1096 &query,
1097 top_k,
1098 search_l,
1099 None, &mut query_stats,
1101 &mut indices,
1102 &mut distances,
1103 &mut associated_data,
1104 &|_| true,
1105 false,
1106 );
1107
1108 diskann_providers::test_utils::assert_top_k_exactly_match(
1109 q, >, &indices, &distances, top_k,
1110 );
1111 }
1112
1113 Ok(())
1114 }
1115
1116 pub fn compare_disk_index_graphs(graph_data: &[u8], truth_graph_data: &[u8]) {
1118 let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
1119 let truth_graph_header = GraphHeader::try_from(&truth_graph_data[8..]).unwrap();
1120
1121 let test_node_per_block = graph_header.metadata().num_nodes_per_block;
1122 let test_max_node_length = graph_header.metadata().node_len;
1123
1124 let truth_node_per_block = truth_graph_header.metadata().num_nodes_per_block;
1125 let truth_max_node_length = truth_graph_header.metadata().node_len;
1126
1127 assert_eq!(
1128 graph_header.metadata().num_pts,
1129 truth_graph_header.metadata().num_pts
1130 );
1131
1132 assert_eq!(
1133 graph_header.metadata().dims,
1134 truth_graph_header.metadata().dims
1135 );
1136
1137 let num_pts = graph_header.metadata().num_pts as usize;
1138 let dim = graph_header.metadata().dims;
1139
1140 for idx in 0..num_pts {
1141 let test_node_id_offset = node_data_offset(
1142 idx,
1143 test_max_node_length as usize,
1144 test_node_per_block as usize,
1145 DEFAULT_DISK_SECTOR_LEN,
1146 );
1147
1148 let truth_node_id_offset = node_data_offset(
1149 idx,
1150 truth_max_node_length as usize,
1151 truth_node_per_block as usize,
1152 DEFAULT_DISK_SECTOR_LEN,
1153 );
1154
1155 assert_eq!(
1157 &graph_data
1158 [test_node_id_offset..test_node_id_offset + dim * std::mem::size_of::<f32>()],
1159 &truth_graph_data
1160 [truth_node_id_offset..truth_node_id_offset + dim * std::mem::size_of::<f32>()]
1161 );
1162
1163 let test_nbr_cnt_offset = test_node_id_offset + dim * std::mem::size_of::<f32>();
1165 let truth_nbr_cnt_offset = truth_node_id_offset + dim * std::mem::size_of::<f32>();
1166
1167 let test_nbr_count = u32::from_le_bytes([
1168 graph_data[test_nbr_cnt_offset],
1169 graph_data[test_nbr_cnt_offset + 1],
1170 graph_data[test_nbr_cnt_offset + 2],
1171 graph_data[test_nbr_cnt_offset + 3],
1172 ]);
1173
1174 let truth_nbr_count = u32::from_le_bytes([
1175 truth_graph_data[truth_nbr_cnt_offset],
1176 truth_graph_data[truth_nbr_cnt_offset + 1],
1177 truth_graph_data[truth_nbr_cnt_offset + 2],
1178 truth_graph_data[truth_nbr_cnt_offset + 3],
1179 ]);
1180
1181 assert_eq!(test_nbr_count, truth_nbr_count);
1182
1183 let test_nbr_offset = test_nbr_cnt_offset + 4;
1185 let truth_nbr_offset = truth_nbr_cnt_offset + 4;
1186 assert_eq!(
1187 graph_data[test_nbr_offset..test_nbr_offset + test_nbr_count as usize * 4],
1188 truth_graph_data[truth_nbr_offset..truth_nbr_offset + truth_nbr_count as usize * 4]
1189 );
1190 }
1191 }
1192
1193 pub fn node_data_offset(
1194 node_id: usize,
1195 node_length: usize,
1196 nodes_per_block: usize,
1197 block_size: usize,
1198 ) -> usize {
1199 let block_id = node_id / nodes_per_block;
1200 let node_id_in_block = node_id % nodes_per_block;
1201 let offset = block_id * block_size + node_id_in_block * node_length;
1202 offset + block_size
1203 }
1204
1205 fn create_disk_index_builder(
1206 num_points: usize,
1207 dim: usize,
1208 num_of_pq_chunks: usize,
1209 storage_provider: &VirtualStorageProvider<OverlayFS>,
1210 build_quantization_type: QuantizationType,
1211 ) -> ANNResult<
1212 DiskIndexBuilder<'_, GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>,
1213 > {
1214 let memory_budget = MemoryBudget::try_from_gb(1.0)?;
1215 let num_pq_chunks = NumPQChunks::new_with(num_of_pq_chunks, dim)?;
1216
1217 let build_parameters =
1218 DiskIndexBuildParameters::new(memory_budget, build_quantization_type, num_pq_chunks);
1219
1220 let index_configuration = IndexConfiguration::new(
1221 L2,
1222 dim,
1223 num_points,
1224 ONE,
1225 1,
1226 config::Builder::new_with(4, config::MaxDegree::default_slack(), 50, L2.into(), |b| {
1227 b.saturate_after_prune(true);
1228 })
1229 .build()?,
1230 );
1231
1232 let disk_index_writer = DiskIndexWriter::new(
1233 "data_path".to_string(),
1234 "index_path_prefix".to_string(),
1235 None,
1236 DEFAULT_DISK_SECTOR_LEN,
1237 )?;
1238
1239 DiskIndexBuilder::<GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>::new(
1240 storage_provider,
1241 build_parameters,
1242 index_configuration,
1243 disk_index_writer,
1244 )
1245 }
1246}
1247
1248#[cfg(test)]
1249mod ram_estimation_tests {
1250 use rstest::rstest;
1251
1252 use super::*;
1253 use crate::QuantizationType;
1254
1255 #[rstest]
1256 #[case(QuantizationType::FP)]
1257 #[case(QuantizationType::PQ { num_chunks: 15 })]
1258 #[case(QuantizationType::SQ { nbits: 1, standard_deviation: None })]
1259 fn test_estimate_build_index_ram_usage(#[case] build_quantization_type: QuantizationType) {
1260 let num_points = 1000;
1261 let dim = 128;
1262 let size_of_t = std::mem::size_of::<f32>() as u64;
1263 let graph_degree = 50;
1264
1265 let single_vec_size = match build_quantization_type {
1266 QuantizationType::FP => dim * size_of_t,
1267 QuantizationType::PQ { num_chunks } => num_chunks as u64,
1268 QuantizationType::SQ { nbits, .. } => {
1269 (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
1270 }
1271 };
1272 let mut expected_ram_usage = (num_points as f64)
1273 * (graph_degree as f64)
1274 * (std::mem::size_of::<u32>() as f64)
1275 * GRAPH_SLACK_FACTOR
1276 + (num_points * single_vec_size) as f64;
1277 expected_ram_usage *= OVERHEAD_FACTOR;
1278
1279 let actual_ram_usage = estimate_build_index_ram_usage(
1280 num_points,
1281 dim,
1282 size_of_t,
1283 graph_degree,
1284 &build_quantization_type,
1285 );
1286
1287 assert_eq!(actual_ram_usage, expected_ram_usage);
1288 }
1289}