Skip to main content

diskann_tools/utils/
build_disk_index.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann::{
7    graph::config,
8    utils::{IntoUsize, ONE},
9    ANNError, ANNResult,
10};
11use diskann_disk::{
12    build::{
13        builder::build::DiskIndexBuilder,
14        chunking::{checkpoint::CheckpointManager, continuation::ChunkingConfig},
15    },
16    disk_index_build_parameter::{
17        DiskIndexBuildParameters, MemoryBudget, NumPQChunks, DISK_SECTOR_LEN,
18    },
19    storage::DiskIndexWriter,
20    QuantizationType,
21};
22use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
23use diskann_providers::{
24    model::{graph::traits::GraphDataType, IndexConfiguration},
25    utils::{load_metadata_from_file, Timer},
26};
27use diskann_vector::distance::Metric;
28use opentelemetry::global::BoxedSpan;
29#[cfg(feature = "perf_test")]
30use opentelemetry::{
31    trace::{Span, Tracer},
32    KeyValue,
33};
34
35pub struct ChunkingParameters {
36    pub chunking_config: ChunkingConfig,
37    pub checkpoint_record_manager: Box<dyn CheckpointManager>,
38}
39
40/// A simple struct to contain the underlying dimension of the data and
41/// its full-precision vector dimension.
42///
43/// * `dim` is the length of the vector when represented with the underlying datatype
44/// * `full_dim` is the length of the vector when converted to a full-precision slice, i.e. [f32]
45///
46/// # Notes
47///
48/// These values are the same when using primitive data types to represent the vectors
49/// such as `half::f16` or `f32`, however, for quantized vectors used in place of
50/// full-preicision vectors such as [`common::MinMaxElement`] these might be different.
51#[derive(Clone, Copy, PartialEq, Debug)]
52pub struct DimensionValues {
53    dim: usize,
54    full_dim: usize,
55}
56impl DimensionValues {
57    pub fn new(dim: usize, full_dim: usize) -> Self {
58        Self { dim, full_dim }
59    }
60
61    pub fn dim(&self) -> usize {
62        self.dim
63    }
64
65    pub fn full_dim(&self) -> usize {
66        self.full_dim
67    }
68}
69
70pub struct BuildDiskIndexParameters<'a> {
71    pub metric: Metric,
72    pub data_path: &'a str,
73    pub r: u32,
74    pub l: u32,
75    pub index_path_prefix: &'a str,
76    pub num_threads: usize,
77    pub num_of_pq_chunks: usize,
78    pub index_build_ram_limit_gb: f64,
79    pub build_quantization_type: QuantizationType,
80    pub chunking_parameters: Option<ChunkingParameters>,
81    pub dim_values: DimensionValues,
82}
83
84/// The main function to build a disk index
85pub fn build_disk_index<Data, StorageProviderType>(
86    storage_provider: &StorageProviderType,
87    parameters: BuildDiskIndexParameters,
88) -> ANNResult<()>
89where
90    Data: GraphDataType<VectorIdType = u32>,
91    StorageProviderType: StorageReadProvider + StorageWriteProvider + 'static,
92    <StorageProviderType as StorageReadProvider>::Reader: std::marker::Send,
93{
94    let build_parameters = DiskIndexBuildParameters::new(
95        MemoryBudget::try_from_gb(parameters.index_build_ram_limit_gb)?,
96        parameters.build_quantization_type,
97        NumPQChunks::new_with(
98            parameters.num_of_pq_chunks,
99            parameters.dim_values.full_dim(),
100        )?,
101    );
102
103    let config = config::Builder::new_with(
104        parameters.r.into_usize(),
105        config::MaxDegree::default_slack(),
106        parameters.l.into_usize(),
107        parameters.metric.into(),
108        |b| {
109            b.saturate_after_prune(true);
110        },
111    )
112    .build()?;
113
114    let metadata = load_metadata_from_file(storage_provider, parameters.data_path)?;
115
116    if metadata.ndims() != parameters.dim_values.dim() {
117        return Err(ANNError::log_index_config_error(
118            format!("{:?}", parameters.dim_values),
119            format!("dim_values must match with data_dim {}", metadata.ndims()),
120        ));
121    }
122
123    let index_configuration = IndexConfiguration::new(
124        parameters.metric,
125        metadata.ndims(),
126        metadata.npoints(),
127        ONE,
128        parameters.num_threads,
129        config,
130    )
131    .with_pseudo_rng();
132
133    let disk_index_writer = DiskIndexWriter::new(
134        parameters.data_path.to_string(),
135        parameters.index_path_prefix.to_string(),
136        Option::None,
137        DISK_SECTOR_LEN,
138    )?;
139
140    let mut disk_index = match parameters.chunking_parameters {
141        Some(chunking_parameters) => {
142            let chunking_config = chunking_parameters.chunking_config;
143            let checkpoint_record_manager = chunking_parameters.checkpoint_record_manager;
144            DiskIndexBuilder::<Data, StorageProviderType>::new_with_chunking_config(
145                storage_provider,
146                build_parameters,
147                index_configuration,
148                disk_index_writer,
149                chunking_config,
150                checkpoint_record_manager,
151            )
152        }
153        None => DiskIndexBuilder::<Data, StorageProviderType>::new(
154            storage_provider,
155            build_parameters,
156            index_configuration,
157            disk_index_writer,
158        ),
159    }?;
160
161    let mut _span: BoxedSpan;
162    #[cfg(feature = "perf_test")]
163    {
164        let tracer = opentelemetry::global::tracer("");
165
166        // Start a span for the search iteration.
167        _span = tracer.start("index-build statistics".to_string());
168    }
169
170    let timer = Timer::new();
171    disk_index.build()?;
172
173    let diff = timer.elapsed();
174    println!("Indexing time: {} seconds", diff.as_secs_f64());
175
176    #[cfg(feature = "perf_test")]
177    {
178        _span.set_attribute(KeyValue::new("total_time", diff.as_secs_f64()));
179        _span.set_attribute(KeyValue::new("total_comparisons", 0i64));
180        _span.set_attribute(KeyValue::new("search_hops", 0i64));
181        _span.end();
182    }
183
184    Ok(())
185}
186
187#[cfg(test)]
188mod tests {
189    use diskann::ANNErrorKind;
190    use diskann_providers::storage::VirtualStorageProvider;
191    use vfs::MemoryFS;
192
193    use super::*;
194    use crate::utils::GraphDataInt8Vector;
195
196    #[test]
197    fn test_build_disk_index_with_num_of_pq_chunks() {
198        let storage_provider = VirtualStorageProvider::new_memory();
199        let parameters = BuildDiskIndexParameters {
200            metric: Metric::L2,
201            data_path: "test_data_path",
202            r: 10,
203            l: 20,
204            index_path_prefix: "test_index_path_prefix",
205            num_threads: 4,
206            num_of_pq_chunks: 8,
207            index_build_ram_limit_gb: 1.0,
208            build_quantization_type: QuantizationType::FP,
209            chunking_parameters: None,
210            dim_values: DimensionValues::new(128, 128),
211        };
212
213        let result = build_disk_index::<GraphDataInt8Vector, VirtualStorageProvider<MemoryFS>>(
214            &storage_provider,
215            parameters,
216        );
217        assert!(result.is_err());
218        assert_ne!(result.unwrap_err().kind(), ANNErrorKind::IndexConfigError);
219    }
220
221    #[test]
222    fn test_build_disk_index_with_zero_num_of_pq_chunks() {
223        let storage_provider = VirtualStorageProvider::new_memory();
224        let parameters = BuildDiskIndexParameters {
225            metric: Metric::L2,
226            data_path: "test_data_path",
227            r: 10,
228            l: 20,
229            index_path_prefix: "test_index_path_prefix",
230            num_threads: 4,
231            num_of_pq_chunks: 0,
232            index_build_ram_limit_gb: 1.0,
233            build_quantization_type: QuantizationType::FP,
234            chunking_parameters: None,
235            dim_values: DimensionValues::new(128, 128),
236        };
237
238        let result = build_disk_index::<GraphDataInt8Vector, VirtualStorageProvider<MemoryFS>>(
239            &storage_provider,
240            parameters,
241        );
242        assert!(result.is_err());
243        assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexConfigError);
244    }
245}