1use std::mem::{self, size_of};
6
7use diskann::ANNResult;
8use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
9use diskann_providers::{
10 model::{
11 graph::traits::GraphDataType, IndexConfiguration, GRAPH_SLACK_FACTOR,
12 MAX_PQ_TRAINING_SET_SIZE,
13 },
14 storage::PQStorage,
15 utils::{
16 load_bin, load_metadata_from_file, RayonThreadPool, SampleVectorReader, SamplingDensity,
17 READ_WRITE_BLOCK_SIZE,
18 },
19};
20use rand::{seq::SliceRandom, Rng};
21use tracing::info;
22
23use crate::{
24 build::chunking::{
25 checkpoint::{
26 CheckpointContext, CheckpointManager, CheckpointManagerExt, Progress, WorkStage,
27 },
28 continuation::ChunkingConfig,
29 },
30 disk_index_build_parameter::BYTES_IN_GB,
31 storage::{CachedReader, CachedWriter, DiskIndexWriter},
32 utils::partition_with_ram_budget,
33 DiskIndexBuildParameters, QuantizationType,
34};
35
36const OVERHEAD_FACTOR: f64 = 1.1f64;
38
39#[inline]
41fn estimate_build_index_ram_usage(
42 num_points: u64,
43 dim: u64,
44 datasize: u64,
45 graph_degree: u64,
46 build_quantization_type: &QuantizationType,
47) -> f64 {
48 let graph_size =
49 (num_points * graph_degree * mem::size_of::<u32>() as u64) as f64 * GRAPH_SLACK_FACTOR;
50
51 let single_vec_size = match *build_quantization_type {
52 QuantizationType::FP => dim.next_multiple_of(8u64) * datasize,
53 QuantizationType::PQ { num_chunks } => num_chunks as u64,
55 QuantizationType::SQ { nbits, .. } => {
57 (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
58 }
59 };
60
61 OVERHEAD_FACTOR * (graph_size + (single_vec_size * num_points) as f64)
62}
63
64pub struct DiskIndexBuilderCore<'a, Data, StorageProvider>
67where
68 Data: GraphDataType<VectorIdType = u32>,
69 StorageProvider: StorageReadProvider + StorageWriteProvider,
70{
71 pub index_writer: DiskIndexWriter,
72
73 pub pq_storage: PQStorage,
74
75 pub disk_build_param: DiskIndexBuildParameters,
76
77 pub index_configuration: IndexConfiguration,
78
79 pub chunking_config: ChunkingConfig,
80
81 pub checkpoint_record_manager: Box<dyn CheckpointManager>,
82
83 pub storage_provider: &'a StorageProvider,
84
85 pub _phantom: std::marker::PhantomData<Data>,
86}
87
88impl<'a, Data, StorageProvider> DiskIndexBuilderCore<'a, Data, StorageProvider>
89where
90 Data: GraphDataType<VectorIdType = u32>,
91 StorageProvider: StorageReadProvider + StorageWriteProvider,
92{
93 pub(crate) fn create_disk_layout(&mut self) -> ANNResult<()> {
94 self.checkpoint_record_manager.execute_stage(
95 WorkStage::WriteDiskLayout,
96 WorkStage::End,
97 || {
98 self.index_writer
99 .create_disk_layout::<Data, StorageProvider>(self.storage_provider)?;
100 Ok(())
101 },
102 || Ok(()),
103 )?;
104
105 self.index_writer
106 .index_build_cleanup(self.storage_provider)?;
107
108 Ok(())
109 }
110
111 pub(crate) fn create_shard_index_config(
112 &self,
113 shard_base_file: &str,
114 ) -> ANNResult<IndexConfiguration> {
115 let base_config = &self.index_configuration;
116 let storage_provider = self.storage_provider;
117
118 let search_list_size = base_config.config.l_build().get();
119 let pruned_degree = base_config.config.pruned_degree().get();
120
121 let low_degree_params = diskann::graph::config::Builder::new(
122 2 * pruned_degree / 3,
123 diskann::graph::config::MaxDegree::default_slack(),
124 search_list_size,
125 base_config.dist_metric.into(),
126 )
127 .build()?;
128
129 let metadata = load_metadata_from_file(storage_provider, shard_base_file)?;
130
131 let mut index_config = base_config.clone();
132 index_config.max_points = metadata.npoints;
133 index_config.config = low_degree_params;
134
135 Ok(index_config)
136 }
137
138 pub(crate) fn retrieve_shard_data_from_ids<T>(
139 &self,
140 dataset_file: &str,
141 shard_ids_file: &str,
142 shard_base_file: &str,
143 ) -> ANNResult<()>
144 where
145 T: Default + bytemuck::Pod,
146 {
147 let storage_provider = self.storage_provider;
148 let (shard_ids, shard_size, _) = load_bin::<u32, StorageProvider::Reader>(
149 &mut storage_provider.open_reader(shard_ids_file)?,
150 0,
151 )?;
152 info!("Loaded {} shard ids from {}", shard_size, shard_ids_file);
153 let max_id = shard_ids.iter().max().copied().unwrap_or(0);
154 let sampling_rate = shard_ids.len() as f64 / (max_id + 1) as f64;
155
156 let mut dataset_reader: SampleVectorReader<T, _> = SampleVectorReader::new(
157 dataset_file,
158 SamplingDensity::from_sample_rate(sampling_rate),
159 storage_provider,
160 )?;
161
162 let (_npts, dim) = dataset_reader.get_dataset_headers();
163
164 let mut shard_base_cached_writer = CachedWriter::<StorageProvider>::new(
165 shard_base_file,
166 READ_WRITE_BLOCK_SIZE,
167 storage_provider.create_for_write(shard_base_file)?,
168 )?;
169
170 let dummy_size: u32 = 0;
171 shard_base_cached_writer.write(&dummy_size.to_le_bytes())?;
172 shard_base_cached_writer.write(&dim.to_le_bytes())?;
173
174 let mut num_written: u32 = 0;
175 dataset_reader.read_vectors(shard_ids.iter().copied(), |vector_t| {
176 let vector_bytes: &[u8] = bytemuck::must_cast_slice(vector_t);
178 shard_base_cached_writer.write(vector_bytes)?;
179 num_written += 1;
180 Ok(())
181 })?;
182
183 info!(
184 "Written file: {} with {} points",
185 shard_base_file, num_written
186 );
187
188 shard_base_cached_writer.flush()?;
189 shard_base_cached_writer.reset()?;
190 shard_base_cached_writer.write(&num_written.to_le_bytes())?;
191
192 Ok(())
193 }
194
195 #[allow(clippy::too_many_arguments)]
196 fn merge_shards(
197 &self,
198 merged_index_prefix: &str,
199 num_parts: usize,
200 max_degree: u32,
201 output_vamana: String,
202 rng: &mut impl Rng,
203 ) -> ANNResult<()> {
204 let mut vamana_names = vec![String::new(); num_parts];
206 let mut id_maps: Vec<Vec<u32>> = vec![Vec::new(); num_parts];
207 for shard in 0..num_parts {
208 vamana_names[shard] = DiskIndexWriter::get_merged_index_subshard_mem_index_file(
209 merged_index_prefix,
210 shard,
211 );
212
213 let id_maps_file =
214 DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, shard);
215 id_maps[shard] = self.read_idmap(id_maps_file)?;
216 }
217
218 let num_nodes: u32 = *id_maps.iter().flatten().max().unwrap_or(&0) + 1;
220 let num_elements: u32 = id_maps.iter().map(|idmap| idmap.len() as u32).sum();
221 info!("# nodes: {}, max degree: {}", num_nodes, max_degree);
222
223 let mut node_shard: Vec<(u32, u32)> = Vec::with_capacity(num_elements as usize);
225 for (shard, id_map) in id_maps.iter().enumerate() {
226 info!("Creating inverse map -- shard #{}", shard);
227 node_shard.extend(id_map.iter().map(|node_id| (*node_id, shard as u32)));
228 }
229 node_shard.sort_unstable_by(|left, right| {
230 left.0.cmp(&right.0).then_with(|| left.1.cmp(&right.1))
231 });
232
233 info!("Finished computing node -> shards map");
234
235 let mut vamana_readers = Vec::new();
237 for name in &vamana_names {
238 let reader = CachedReader::<StorageProvider>::new(
239 name,
240 READ_WRITE_BLOCK_SIZE,
241 self.storage_provider,
242 )?;
243 vamana_readers.push(reader);
244 }
245
246 let mut merged_vamana_cached_writer = CachedWriter::<StorageProvider>::new(
248 &output_vamana,
249 READ_WRITE_BLOCK_SIZE,
250 self.storage_provider.create_for_write(&output_vamana)?,
251 )?;
252
253 let vamana_metadata_size =
255 size_of::<u64>() + size_of::<u32>() + size_of::<u32>() + size_of::<u64>();
256
257 let mut merged_index_size: u64 = vamana_metadata_size as u64;
260 merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
261
262 let mut read_buf_8_bytes = [0u8; 8];
263
264 let mut max_input_width = 0;
266 for reader in &mut vamana_readers {
268 reader.read(&mut read_buf_8_bytes)?;
269 let _expected_file_size: u64 = u64::from_le_bytes(read_buf_8_bytes);
270 let input_width = reader.read_u32()?;
271 max_input_width = input_width.max(max_input_width);
272 }
273
274 let output_width: u32 = max_degree;
276 info!(
277 "Max input width: {}, output width: {}",
278 max_input_width, output_width
279 );
280
281 merged_vamana_cached_writer.write(&output_width.to_le_bytes())?;
282
283 for shard in 0..num_parts {
285 let mut medoid: u32 = vamana_readers[shard].read_u32()?;
287 vamana_readers[shard].read(&mut read_buf_8_bytes)?;
288 let vamana_index_frozen: u64 = u64::from_le_bytes(read_buf_8_bytes);
289 debug_assert_eq!(vamana_index_frozen, 0);
290
291 medoid = id_maps[shard][medoid as usize];
293
294 if shard == (num_parts - 1) {
296 merged_vamana_cached_writer.write(&medoid.to_le_bytes())?;
298 }
299 }
300
301 let vamana_index_frozen: u64 = 0; merged_vamana_cached_writer.write(&vamana_index_frozen.to_le_bytes())?;
305
306 info!("Starting merge");
307
308 let mut nbr_set = vec![false; num_nodes as usize];
309 let mut final_nbrs: Vec<u32> = Vec::new();
310 let mut cur_id = 0;
311 for pair in &node_shard {
312 let (node_id, shard_id) = *pair;
313 if cur_id < node_id {
314 final_nbrs.shuffle(rng);
315
316 let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
317 merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
318
319 let bytes = final_nbrs
320 .iter()
321 .take(nnbrs as usize)
322 .flat_map(|x| x.to_le_bytes())
323 .collect::<Vec<u8>>();
324 merged_vamana_cached_writer.write(&bytes)?;
325
326 merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
327 if cur_id % 499999 == 1 {
328 print!(".");
329 }
330 cur_id = node_id;
331
332 final_nbrs.iter().for_each(|p| nbr_set[*p as usize] = false);
333 final_nbrs.clear();
334 }
335
336 let num_nbrs = vamana_readers[shard_id as usize].read_u32()?;
338
339 if num_nbrs == 0 {
340 info!(
341 "WARNING: shard #{}, node_id {} has 0 nbrs",
342 shard_id, node_id
343 );
344 } else {
345 let mut nbrs_bytes = vec![0u8; num_nbrs as usize * mem::size_of::<u32>()];
346 vamana_readers[shard_id as usize].read(&mut nbrs_bytes)?;
347 let nbrs: &[u32] = bytemuck::cast_slice(&nbrs_bytes);
348
349 for j in 0..num_nbrs {
351 let nbr = nbrs[j as usize];
352 let renamed_node = id_maps[shard_id as usize][nbr as usize];
353 if !nbr_set[renamed_node as usize] {
354 nbr_set[renamed_node as usize] = true;
355 final_nbrs.push(renamed_node);
356 }
357 }
358 }
359 }
360
361 final_nbrs.shuffle(rng);
363
364 let nnbrs: u32 = std::cmp::min(final_nbrs.len() as u32, max_degree);
365 merged_vamana_cached_writer.write(&nnbrs.to_le_bytes())?;
366
367 let bytes = final_nbrs
368 .iter()
369 .take(nnbrs as usize)
370 .flat_map(|x| x.to_le_bytes())
371 .collect::<Vec<u8>>();
372 merged_vamana_cached_writer.write(&bytes)?;
373
374 merged_index_size += (size_of::<u32>() + nnbrs as usize * size_of::<u32>()) as u64;
375
376 nbr_set.clear();
377 final_nbrs.clear();
378
379 info!("Expected size: {}", merged_index_size);
380 merged_vamana_cached_writer.reset()?;
381 merged_vamana_cached_writer.write(&merged_index_size.to_le_bytes())?;
382
383 info!("Finished merge");
384 Ok(())
385 }
386
387 fn read_idmap(&self, idmaps_path: String) -> std::io::Result<Vec<u32>> {
388 let (data, _npts, _dim) = load_bin::<u32, StorageProvider::Reader>(
389 &mut self.storage_provider.open_reader(&idmaps_path)?,
390 0,
391 )?;
392 Ok(data)
393 }
394
395 fn merge_shards_and_cleanup(
396 &self,
397 merged_index_prefix: &str,
398 num_parts: usize,
399 max_degree: u32,
400 rng: &mut impl Rng,
401 ) -> ANNResult<()> {
402 self.merge_shards(
404 merged_index_prefix,
405 num_parts,
406 max_degree,
407 self.index_writer.get_mem_index_file(),
408 rng,
409 )?;
410
411 for p in 0..num_parts {
413 let shard_base_file =
414 DiskIndexWriter::get_merged_index_subshard_data_file(merged_index_prefix, p);
415 let shard_ids_file =
416 DiskIndexWriter::get_merged_index_subshard_id_map_file(merged_index_prefix, p);
417 let shard_index_file =
418 DiskIndexWriter::get_merged_index_subshard_mem_index_file(merged_index_prefix, p);
419 let shard_index_file_data =
420 DiskIndexWriter::get_merged_index_subshard_mem_dataset_file(&shard_index_file);
421
422 self.storage_provider.delete(&shard_base_file)?;
423 self.storage_provider.delete(&shard_ids_file)?;
424 self.storage_provider.delete(&shard_index_file)?;
425 if self.storage_provider.exists(&shard_index_file_data) {
428 self.storage_provider.delete(&shard_index_file_data)?;
429 }
430 }
431
432 Ok(())
433 }
434}
435
436pub(crate) enum IndexBuildStrategy {
437 OneShot,
438 Merged,
439}
440
441pub(crate) fn determine_build_strategy<Data: GraphDataType>(
442 index_configuration: &IndexConfiguration,
443 index_build_ram_limit_in_bytes: f64,
444 build_quantization_type: &QuantizationType,
445) -> IndexBuildStrategy {
446 let estimated_index_ram_in_bytes = estimate_build_index_ram_usage(
447 index_configuration.max_points as u64,
448 index_configuration.dim as u64,
449 mem::size_of::<Data::VectorDataType>() as u64,
450 index_configuration.config.max_degree().get() as u64,
451 build_quantization_type,
452 );
453
454 info!(
455 "Estimated index RAM usage: {} GB, index_build_ram_limit={} GB",
456 estimated_index_ram_in_bytes / BYTES_IN_GB,
457 index_build_ram_limit_in_bytes / BYTES_IN_GB
458 );
459
460 if estimated_index_ram_in_bytes >= index_build_ram_limit_in_bytes {
461 info!(
462 "Insufficient memory budget for index build in one shot, index_build_ram_limit={} GB estimated_index_ram={} GB",
463 index_build_ram_limit_in_bytes / BYTES_IN_GB,
464 estimated_index_ram_in_bytes / BYTES_IN_GB,
465 );
466 IndexBuildStrategy::Merged
467 } else {
468 info!(
469 "Full index fits in RAM budget, should consume at most {} GBs, so building in one shot",
470 estimated_index_ram_in_bytes / BYTES_IN_GB
471 );
472 IndexBuildStrategy::OneShot
473 }
474}
475
476pub(crate) struct MergedVamanaIndexWorkflow<'a> {
477 pool: &'a RayonThreadPool,
478 rng: diskann_providers::utils::StandardRng,
479 dataset_file: String,
480 max_degree: u32,
481 pub merged_index_prefix: String,
482}
483
484impl<'a> MergedVamanaIndexWorkflow<'a> {
485 pub(crate) fn new<Data, StorageProvider>(
486 builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
487 pool: &'a RayonThreadPool,
488 ) -> Self
489 where
490 Data: GraphDataType<VectorIdType = u32>,
491 StorageProvider: StorageReadProvider + StorageWriteProvider,
492 {
493 let rng = diskann_providers::utils::create_rnd_from_optional_seed(
494 builder.index_configuration.random_seed,
495 );
496 let dataset_file = builder.index_writer.get_dataset_file();
497 let merged_index_prefix = builder.index_writer.get_merged_index_prefix();
498 let max_degree = builder.index_configuration.config.pruned_degree_u32().get();
499
500 Self {
501 pool,
502 rng,
503 dataset_file,
504 merged_index_prefix,
505 max_degree,
506 }
507 }
508
509 pub(crate) fn partition_data<Data, StorageProvider>(
510 &mut self,
511 builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
512 ) -> ANNResult<usize>
513 where
514 Data: GraphDataType<VectorIdType = u32>,
515 StorageProvider: StorageReadProvider + StorageWriteProvider,
516 {
517 builder.checkpoint_record_manager.execute_stage(
519 WorkStage::InMemIndexBuild,
520 WorkStage::PartitionData,
521 || Ok(()),
522 || Ok(()),
523 )?;
524
525 builder.checkpoint_record_manager.execute_stage(
527 WorkStage::PartitionData,
528 WorkStage::BuildIndicesOnShards(0),
529 || {
530 let num_points = builder.index_configuration.max_points;
531 let sampling_rate = MAX_PQ_TRAINING_SET_SIZE / num_points as f64;
532
533 let ram_budget_in_bytes =
534 builder.disk_build_param.build_memory_limit().in_bytes() as f64;
535 partition_with_ram_budget::<Data::VectorDataType, _, _, _>(
538 &self.dataset_file,
539 builder.index_configuration.dim,
540 sampling_rate,
541 ram_budget_in_bytes,
542 2, &self.merged_index_prefix,
544 builder.storage_provider,
545 &mut self.rng,
546 self.pool,
547 |num_points, dim| {
548 let datasize = std::mem::size_of::<Data::VectorDataType>() as u64;
549 let graph_degree = 2 * self.max_degree / 3;
550 estimate_build_index_ram_usage(
551 num_points,
552 dim,
553 datasize,
554 graph_degree as u64,
555 builder.disk_build_param.build_quantization(),
556 )
557 },
558 )
559 },
560 || {
561 let mut p = 0;
563 while builder.storage_provider.exists(
564 &DiskIndexWriter::get_merged_index_subshard_id_map_file(
565 &self.merged_index_prefix,
566 p,
567 ),
568 ) {
569 p += 1;
570 }
571 info!("Found {} existing partitions from previous run", p);
572 Ok(p)
573 },
574 )
575 }
576
577 pub(crate) fn merge_and_cleanup<Data, StorageProvider>(
578 &mut self,
579 builder: &mut DiskIndexBuilderCore<'_, Data, StorageProvider>,
580 num_parts: usize,
581 ) -> ANNResult<()>
582 where
583 Data: GraphDataType<VectorIdType = u32>,
584 StorageProvider: StorageReadProvider + StorageWriteProvider,
585 {
586 if builder
587 .checkpoint_record_manager
588 .get_resumption_point(WorkStage::MergeIndices)?
589 .is_some()
590 {
591 builder.merge_shards_and_cleanup(
592 &self.merged_index_prefix,
593 num_parts,
594 self.max_degree,
595 &mut self.rng,
596 )?;
597 builder
598 .checkpoint_record_manager
599 .update(Progress::Completed, WorkStage::WriteDiskLayout)?;
600 }
601
602 Ok(())
603 }
604
605 pub(crate) fn get_shard_context<'b, Data, StorageProvider>(
606 &self,
607 builder: &'b DiskIndexBuilderCore<'_, Data, StorageProvider>,
608 p: usize,
609 num_parts: usize,
610 ) -> CheckpointContext<'b>
611 where
612 Data: GraphDataType<VectorIdType = u32>,
613 StorageProvider: StorageReadProvider + StorageWriteProvider,
614 {
615 let current_stage = WorkStage::BuildIndicesOnShards(p);
616 let next_stage = if p == num_parts - 1 {
617 WorkStage::MergeIndices
619 } else {
620 WorkStage::BuildIndicesOnShards(p + 1)
622 };
623 CheckpointContext::new(
624 builder.checkpoint_record_manager.as_ref(),
625 current_stage,
626 next_stage,
627 )
628 }
629}
630
631#[cfg(test)]
632pub(crate) mod disk_index_builder_tests {
633 use std::{io::Read, sync::Arc};
634
635 use diskann::{
636 graph::config,
637 utils::{IntoUsize, VectorRepr, ONE},
638 ANNResult,
639 };
640 use diskann_providers::storage::VirtualStorageProvider;
641 use diskann_providers::{
642 common::AlignedBoxWithSlice,
643 storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file},
644 test_utils::graph_data_type_utils::{
645 GraphDataF32VectorU32Data, GraphDataF32VectorUnitData,
646 },
647 utils::{file_util, BridgeErr, Timer},
648 };
649 use diskann_utils::test_data_root;
650 use diskann_vector::{
651 distance::Metric::{self, L2},
652 DistanceFunction,
653 };
654 use rstest::rstest;
655 use vfs::OverlayFS;
656
657 use super::*;
658 use crate::{
659 build::builder::build::DiskIndexBuilder,
660 data_model::{CachingStrategy, GraphHeader},
661 disk_index_build_parameter::{DiskIndexBuildParameters, MemoryBudget, NumPQChunks},
662 search::provider::{
663 disk_provider::DiskIndexSearcher,
664 disk_vertex_provider_factory::DiskVertexProviderFactory,
665 },
666 storage::disk_index_reader::DiskIndexReader,
667 utils::{QueryStatistics, VirtualAlignedReaderFactory},
668 };
669 const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
670 pub const TEST_DATA_FILE: &str = "/sift/siftsmall_learn_256pts.fbin";
671 const INDEX_PATH_PREFIX: &str = "/disk_index_build/sift_learn_test_disk_index_build";
673 const TRUTH_INDEX_PATH_PREFIX_R4_L50: &str = "/disk_index_build/truth_sift_learn_R4_L50";
674
675 pub struct CheckpointParams {
676 pub chunking_config: ChunkingConfig,
677 pub checkpoint_record_manager: Box<dyn CheckpointManager>,
678 }
679
680 pub struct TestParams {
681 pub dim: usize,
682 pub full_dim: usize,
683 pub max_degree: u32,
684 pub num_pq_chunks: usize,
685 pub build_quantization_type: QuantizationType,
686 pub l_build: u32,
687 pub data_path: String,
688 pub index_path_prefix: String,
689 pub associated_data_path: Option<String>,
690 pub index_build_ram_gb: f64,
691 pub checkpoint_params: Option<CheckpointParams>,
692 pub num_threads: usize,
693 pub metric: Metric,
694 }
695
696 impl Default for TestParams {
697 fn default() -> Self {
698 Self {
699 dim: 128, full_dim: 128,
701 max_degree: 4, num_pq_chunks: 128,
703 build_quantization_type: QuantizationType::FP, l_build: 50,
705 data_path: TEST_DATA_FILE.to_string(),
706 index_path_prefix: INDEX_PATH_PREFIX.to_string(),
707 associated_data_path: None,
708 index_build_ram_gb: 1.0,
709 checkpoint_params: None,
710 num_threads: 1,
711 metric: L2,
712 }
713 }
714 }
715
716 impl TestParams {
717 fn truth_index_path_prefix(&self) -> &str {
719 match (self.max_degree, self.l_build, self.index_build_ram_gb) {
720 (4, 50, 1.0) => TRUTH_INDEX_PATH_PREFIX_R4_L50,
721 (max_degree, l_build, index_build_ram_gb) => panic!(
722 "Truth index path not found for max_degree={}, l_build={}, index_build_ram_gb={}",
723 max_degree, l_build, index_build_ram_gb
724 ),
725 }
726 }
727 pub fn truth_pq_compressed_path(&self) -> String {
728 let prefix = match self.num_pq_chunks {
729 128 => TRUTH_INDEX_PATH_PREFIX_R4_L50,
730 num_pq_chunks => panic!(
731 "Truth pq compressed path not found for num_pq_chunks={}",
732 num_pq_chunks,
733 ),
734 };
735 get_compressed_pq_file(prefix)
736 }
737
738 pub fn pq_compressed_path(&self) -> String {
739 get_compressed_pq_file(&self.index_path_prefix)
740 }
741 }
742
743 pub fn new_vfs() -> VirtualStorageProvider<OverlayFS> {
744 VirtualStorageProvider::new_overlay(test_data_root())
745 }
746
747 pub struct IndexBuildFixture<StorageProvider: StorageReadProvider + StorageWriteProvider> {
748 pub storage_provider: Arc<StorageProvider>,
749 pub params: TestParams,
750 }
751
752 impl<StorageProvider: StorageReadProvider + StorageWriteProvider + 'static>
753 IndexBuildFixture<StorageProvider>
754 {
755 pub fn new(storage_provider: StorageProvider, params: TestParams) -> ANNResult<Self> {
756 Ok(Self {
757 storage_provider: Arc::new(storage_provider),
758 params,
759 })
760 }
761
762 pub fn build<T>(&self) -> ANNResult<()>
763 where
764 T: GraphDataType<VectorIdType = u32>,
765 StorageProvider::Reader: std::marker::Send + Read,
766 {
767 let disk_index_build_parameters = DiskIndexBuildParameters::new(
769 MemoryBudget::try_from_gb(self.params.index_build_ram_gb)?,
770 self.params.build_quantization_type,
771 NumPQChunks::new_with(self.params.num_pq_chunks, self.params.full_dim)?,
772 );
773
774 let config = config::Builder::new_with(
775 self.params.max_degree.into_usize(),
776 config::MaxDegree::default_slack(),
777 self.params.l_build.into_usize(),
778 self.params.metric.into(),
779 |b| {
780 b.saturate_after_prune(true);
781 },
782 )
783 .build()?;
784
785 let metadata =
786 load_metadata_from_file(self.storage_provider.as_ref(), &self.params.data_path)
787 .unwrap();
788
789 assert_eq!(
790 self.params.dim, metadata.ndims,
791 "Parameters dimension {} and data dimension {} are not equal",
792 self.params.dim, metadata.ndims
793 );
794
795 let config = IndexConfiguration::new(
796 self.params.metric,
797 self.params.dim,
798 metadata.npoints,
799 ONE,
800 self.params.num_threads,
801 config,
802 )
803 .with_pseudo_rng_from_seed(100);
804
805 let disk_index_writer = DiskIndexWriter::new(
806 self.params.data_path.clone(),
807 self.params.index_path_prefix.clone(),
808 self.params.associated_data_path.clone(),
809 DEFAULT_DISK_SECTOR_LEN,
810 )?;
811
812 let mut disk_index = match self.params.checkpoint_params {
813 Some(ref checkpoint_params) => {
814 let checkpoint_record_manager =
815 checkpoint_params.checkpoint_record_manager.clone_box();
816 let chunking_config = checkpoint_params.chunking_config.clone();
817 DiskIndexBuilder::<T, _>::new_with_chunking_config(
818 self.storage_provider.as_ref(),
819 disk_index_build_parameters,
820 config,
821 disk_index_writer,
822 chunking_config,
823 checkpoint_record_manager,
824 )
825 }
826 None => DiskIndexBuilder::<T, _>::new(
827 self.storage_provider.as_ref(),
828 disk_index_build_parameters,
829 config,
830 disk_index_writer,
831 ),
832 }?;
833
834 let timer = Timer::new();
835 disk_index.build()?;
836 println!("Indexing time: {} seconds", timer.elapsed().as_secs_f64());
837
838 Ok(())
839 }
840
841 pub fn compare_pq_compressed_files(&self) {
842 self.compare_files(
843 &self.params.pq_compressed_path(),
844 &self.params.truth_pq_compressed_path(),
845 );
846 }
847
848 pub fn assert_index_max_degree<T: GraphDataType>(&self) -> ANNResult<()> {
849 let index_file_path = get_disk_index_file(&self.params.index_path_prefix);
850 let file_data = load_file_to_vec(self.storage_provider.as_ref(), &index_file_path);
851 let graph_header = GraphHeader::try_from(&file_data[8..])?;
852 let max_degree = graph_header.max_degree::<T::VectorDataType>()?;
853 assert_eq!(
854 max_degree, self.params.max_degree as usize,
855 "Max degree mismatch: expected {}, got {}",
856 self.params.max_degree, max_degree
857 );
858
859 Ok(())
860 }
861
862 fn compare_disk_index_with_associated_data(
863 &self,
864 pivot_file_prefix_test: &str,
865 pivot_file_prefix_expected: &str,
866 index_file_suffix: &str,
867 ) {
868 let pq_pivot_path = pivot_file_prefix_test.to_string() + index_file_suffix;
869 let pq_pivot_path_truth = pivot_file_prefix_expected.to_string() + index_file_suffix;
870 let file1 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path);
871 let file2 = load_file_to_vec(self.storage_provider.as_ref(), &pq_pivot_path_truth);
872 compare_disk_index_graphs(&file1, &file2)
873 }
874
875 pub fn compare_files(&self, file_path1: &str, file_path2: &str) {
876 let file1 = load_file_to_vec(self.storage_provider.as_ref(), file_path1);
877 let file2 = load_file_to_vec(self.storage_provider.as_ref(), file_path2);
878
879 assert_eq!(file1.len(), file2.len());
880 assert_eq!(file1, file2)
881 }
882 }
883
884 fn run_one_shot_test<F>(index_path_prefix: String, params_customizer: F)
886 where
887 F: FnOnce(TestParams) -> TestParams,
888 {
889 let l_build = 64;
890 let max_degree = 16;
891 let top_k = 10;
892 let search_l = 32;
893
894 let base_params = TestParams {
895 l_build,
896 max_degree,
897 index_path_prefix,
898 ..TestParams::default()
899 };
900
901 let params = params_customizer(base_params);
902
903 let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
904 fixture.build::<GraphDataF32VectorUnitData>().unwrap();
905
906 verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
908 &fixture.params,
909 top_k,
910 search_l,
911 &fixture.storage_provider,
912 )
913 .unwrap();
914
915 fixture
916 .assert_index_max_degree::<GraphDataF32VectorUnitData>()
917 .unwrap();
918
919 let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
921 assert!(!fixture.storage_provider.exists(&mem_index_file_path));
922 }
923
924 #[rstest]
925 fn test_build_from_iter_one_shot_with_metric(
926 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
927 ) {
928 let index_path_prefix = format!("{}_metric_{:?}", INDEX_PATH_PREFIX, metric);
929
930 run_one_shot_test(index_path_prefix, |params| TestParams { metric, ..params });
931 }
932
933 #[test]
934 fn test_build_from_iter_one_shot_with_associated_data() {
935 let params = TestParams {
937 associated_data_path: Some(
938 "/sift/siftsmall_learn_256pts_u32_associated_data.fbin".to_string(),
939 ),
940 ..TestParams::default()
941 };
942
943 let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
945
946 fixture.build::<GraphDataF32VectorU32Data>().unwrap();
948
949 let mem_index_file_path = format!("{}_mem.index.data", fixture.params.index_path_prefix);
951 let mem_index_associated_data_path = format!(
952 "{}_mem.index.associated_data",
953 fixture.params.index_path_prefix
954 );
955 assert!(!fixture.storage_provider.exists(&mem_index_file_path));
956 assert!(!fixture
957 .storage_provider
958 .exists(&mem_index_associated_data_path));
959
960 fixture.compare_disk_index_with_associated_data(
962 &fixture.params.index_path_prefix,
963 fixture.params.truth_index_path_prefix(),
964 "_disk.index",
965 );
966 }
967
968 #[test]
969 fn test_build_from_iter_merged_index() {
970 let l_build = 64;
972 let max_degree = 16;
973 let top_k = 10;
974 let search_l = 32;
975
976 let index_path_prefix =
977 "/disk_index_build/disk_index_sift_learn_test_disk_index_build_merged".to_string();
978 let params = TestParams {
979 l_build,
980 max_degree,
981 index_path_prefix,
982 index_build_ram_gb: 0.0001, ..TestParams::default()
984 };
985
986 let fixture = IndexBuildFixture::new(new_vfs(), params).unwrap();
987
988 fixture.build::<GraphDataF32VectorUnitData>().unwrap();
989
990 verify_search_result_with_ground_truth::<GraphDataF32VectorUnitData>(
991 &fixture.params,
992 top_k,
993 search_l,
994 &fixture.storage_provider,
995 )
996 .unwrap();
997
998 fixture
999 .assert_index_max_degree::<GraphDataF32VectorUnitData>()
1000 .unwrap();
1001 }
1002
1003 #[rstest]
1004 #[case(QuantizationType::SQ { nbits: 2, standard_deviation: None }, "SQ quantization is only supported for 1 bit")]
1005 fn test_build_quantization_type_failure_cases(
1006 #[case] build_quantization_type: QuantizationType,
1007 #[case] error_message: &str,
1008 ) {
1009 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
1010 let disk_index_builder = create_disk_index_builder(
1011 1000, 128, 128, &storage_provider,
1015 build_quantization_type,
1016 );
1017
1018 let err = disk_index_builder.err().unwrap();
1019 assert!(err.to_string().contains(error_message));
1020 }
1021
1022 fn load_file_to_vec<StorageType: StorageReadProvider>(
1023 storage_provider: &StorageType,
1024 file_path: &str,
1025 ) -> Vec<u8> {
1026 let mut file = storage_provider.open_reader(file_path).unwrap();
1027 let mut buffer = vec![];
1028 file.read_to_end(&mut buffer).unwrap();
1029 buffer
1030 }
1031
1032 pub(crate) fn verify_search_result_with_ground_truth<
1039 G: GraphDataType<VectorIdType = u32, AssociatedDataType = ()>,
1040 >(
1041 params: &TestParams,
1042 top_k: usize,
1043 search_l: u32,
1044 storage_provider: &Arc<VirtualStorageProvider<OverlayFS>>,
1045 ) -> ANNResult<()> {
1046 let pq_pivot_path = get_pq_pivot_file(¶ms.index_path_prefix);
1047 let pq_compressed_path = get_compressed_pq_file(¶ms.index_path_prefix);
1048 let index_file_path = get_disk_index_file(¶ms.index_path_prefix);
1049
1050 let index_reader = DiskIndexReader::<G::VectorDataType>::new(
1051 pq_pivot_path,
1052 pq_compressed_path,
1053 storage_provider.as_ref(),
1054 )?;
1055
1056 let vertex_provider_factory = DiskVertexProviderFactory::new(
1057 VirtualAlignedReaderFactory::new(index_file_path, Arc::clone(storage_provider)),
1058 CachingStrategy::None,
1059 )?;
1060
1061 let search_engine = DiskIndexSearcher::<G, DiskVertexProviderFactory<G, _>>::new(
1062 1,
1063 u32::MAX as usize,
1064 &index_reader,
1065 vertex_provider_factory,
1066 params.metric,
1067 None,
1068 )?;
1069
1070 let (data, npoints, dim) = file_util::load_bin::<G::VectorDataType, _>(
1071 storage_provider.as_ref(),
1072 ¶ms.data_path,
1073 0,
1074 )?;
1075 let data =
1076 diskann_utils::views::Matrix::try_from(data.into(), npoints, dim).bridge_err()?;
1077 let distance = <G::VectorDataType>::distance(params.metric, Some(dim));
1078
1079 for (q, query_data) in data.row_iter().enumerate() {
1086 let gt =
1087 diskann_providers::test_utils::groundtruth(data.as_view(), query_data, |a, b| {
1088 distance.evaluate_similarity(a, b)
1089 });
1090
1091 let mut query: AlignedBoxWithSlice<G::VectorDataType> =
1092 AlignedBoxWithSlice::<G::VectorDataType>::new(dim, 8)?;
1093 query.memcpy(query_data)?;
1094
1095 let mut query_stats = QueryStatistics::default();
1096
1097 let mut indices = vec![0u32; top_k];
1098 let mut distances = vec![0f32; top_k];
1099 let mut associated_data = vec![(); top_k];
1100
1101 _ = search_engine.search_internal(
1102 &query,
1103 top_k,
1104 search_l,
1105 None, &mut query_stats,
1107 &mut indices,
1108 &mut distances,
1109 &mut associated_data,
1110 &|_| true,
1111 false,
1112 );
1113
1114 diskann_providers::test_utils::assert_top_k_exactly_match(
1115 q, >, &indices, &distances, top_k,
1116 );
1117 }
1118
1119 Ok(())
1120 }
1121
1122 pub fn compare_disk_index_graphs(graph_data: &[u8], truth_graph_data: &[u8]) {
1124 let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
1125 let truth_graph_header = GraphHeader::try_from(&truth_graph_data[8..]).unwrap();
1126
1127 let test_node_per_block = graph_header.metadata().num_nodes_per_block;
1128 let test_max_node_length = graph_header.metadata().node_len;
1129
1130 let truth_node_per_block = truth_graph_header.metadata().num_nodes_per_block;
1131 let truth_max_node_length = truth_graph_header.metadata().node_len;
1132
1133 assert_eq!(
1134 graph_header.metadata().num_pts,
1135 truth_graph_header.metadata().num_pts
1136 );
1137
1138 assert_eq!(
1139 graph_header.metadata().dims,
1140 truth_graph_header.metadata().dims
1141 );
1142
1143 let num_pts = graph_header.metadata().num_pts as usize;
1144 let dim = graph_header.metadata().dims;
1145
1146 for idx in 0..num_pts {
1147 let test_node_id_offset = node_data_offset(
1148 idx,
1149 test_max_node_length as usize,
1150 test_node_per_block as usize,
1151 DEFAULT_DISK_SECTOR_LEN,
1152 );
1153
1154 let truth_node_id_offset = node_data_offset(
1155 idx,
1156 truth_max_node_length as usize,
1157 truth_node_per_block as usize,
1158 DEFAULT_DISK_SECTOR_LEN,
1159 );
1160
1161 assert_eq!(
1163 &graph_data
1164 [test_node_id_offset..test_node_id_offset + dim * std::mem::size_of::<f32>()],
1165 &truth_graph_data
1166 [truth_node_id_offset..truth_node_id_offset + dim * std::mem::size_of::<f32>()]
1167 );
1168
1169 let test_nbr_cnt_offset = test_node_id_offset + dim * std::mem::size_of::<f32>();
1171 let truth_nbr_cnt_offset = truth_node_id_offset + dim * std::mem::size_of::<f32>();
1172
1173 let test_nbr_count = u32::from_le_bytes([
1174 graph_data[test_nbr_cnt_offset],
1175 graph_data[test_nbr_cnt_offset + 1],
1176 graph_data[test_nbr_cnt_offset + 2],
1177 graph_data[test_nbr_cnt_offset + 3],
1178 ]);
1179
1180 let truth_nbr_count = u32::from_le_bytes([
1181 truth_graph_data[truth_nbr_cnt_offset],
1182 truth_graph_data[truth_nbr_cnt_offset + 1],
1183 truth_graph_data[truth_nbr_cnt_offset + 2],
1184 truth_graph_data[truth_nbr_cnt_offset + 3],
1185 ]);
1186
1187 assert_eq!(test_nbr_count, truth_nbr_count);
1188
1189 let test_nbr_offset = test_nbr_cnt_offset + 4;
1191 let truth_nbr_offset = truth_nbr_cnt_offset + 4;
1192 assert_eq!(
1193 graph_data[test_nbr_offset..test_nbr_offset + test_nbr_count as usize * 4],
1194 truth_graph_data[truth_nbr_offset..truth_nbr_offset + truth_nbr_count as usize * 4]
1195 );
1196 }
1197 }
1198
1199 pub fn node_data_offset(
1200 node_id: usize,
1201 node_length: usize,
1202 nodes_per_block: usize,
1203 block_size: usize,
1204 ) -> usize {
1205 let block_id = node_id / nodes_per_block;
1206 let node_id_in_block = node_id % nodes_per_block;
1207 let offset = block_id * block_size + node_id_in_block * node_length;
1208 offset + block_size
1209 }
1210
1211 fn create_disk_index_builder(
1212 num_points: usize,
1213 dim: usize,
1214 num_of_pq_chunks: usize,
1215 storage_provider: &VirtualStorageProvider<OverlayFS>,
1216 build_quantization_type: QuantizationType,
1217 ) -> ANNResult<
1218 DiskIndexBuilder<'_, GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>,
1219 > {
1220 let memory_budget = MemoryBudget::try_from_gb(1.0)?;
1221 let num_pq_chunks = NumPQChunks::new_with(num_of_pq_chunks, dim)?;
1222
1223 let build_parameters =
1224 DiskIndexBuildParameters::new(memory_budget, build_quantization_type, num_pq_chunks);
1225
1226 let index_configuration = IndexConfiguration::new(
1227 L2,
1228 dim,
1229 num_points,
1230 ONE,
1231 1,
1232 config::Builder::new_with(4, config::MaxDegree::default_slack(), 50, L2.into(), |b| {
1233 b.saturate_after_prune(true);
1234 })
1235 .build()?,
1236 );
1237
1238 let disk_index_writer = DiskIndexWriter::new(
1239 "data_path".to_string(),
1240 "index_path_prefix".to_string(),
1241 None,
1242 DEFAULT_DISK_SECTOR_LEN,
1243 )?;
1244
1245 DiskIndexBuilder::<GraphDataF32VectorUnitData, VirtualStorageProvider<OverlayFS>>::new(
1246 storage_provider,
1247 build_parameters,
1248 index_configuration,
1249 disk_index_writer,
1250 )
1251 }
1252}
1253
1254#[cfg(test)]
1255mod ram_estimation_tests {
1256 use rstest::rstest;
1257
1258 use super::*;
1259 use crate::QuantizationType;
1260
1261 #[rstest]
1262 #[case(QuantizationType::FP)]
1263 #[case(QuantizationType::PQ { num_chunks: 15 })]
1264 #[case(QuantizationType::SQ { nbits: 1, standard_deviation: None })]
1265 fn test_estimate_build_index_ram_usage(#[case] build_quantization_type: QuantizationType) {
1266 let num_points = 1000;
1267 let dim = 128;
1268 let size_of_t = std::mem::size_of::<f32>() as u64;
1269 let graph_degree = 50;
1270
1271 let single_vec_size = match build_quantization_type {
1272 QuantizationType::FP => dim * size_of_t,
1273 QuantizationType::PQ { num_chunks } => num_chunks as u64,
1274 QuantizationType::SQ { nbits, .. } => {
1275 (nbits as u64 * dim).div_ceil(8) + std::mem::size_of::<f32>() as u64
1276 }
1277 };
1278 let mut expected_ram_usage = (num_points as f64)
1279 * (graph_degree as f64)
1280 * (std::mem::size_of::<u32>() as f64)
1281 * GRAPH_SLACK_FACTOR
1282 + (num_points * single_vec_size) as f64;
1283 expected_ram_usage *= OVERHEAD_FACTOR;
1284
1285 let actual_ram_usage = estimate_build_index_ram_usage(
1286 num_points,
1287 dim,
1288 size_of_t,
1289 graph_degree,
1290 &build_quantization_type,
1291 );
1292
1293 assert_eq!(actual_ram_usage, expected_ram_usage);
1294 }
1295}