Skip to main content

diskann_tools/utils/
build_disk_index.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann::{
7    graph::config,
8    utils::{IntoUsize, ONE},
9    ANNError, ANNResult,
10};
11use diskann_disk::{
12    build::{
13        builder::build::DiskIndexBuilder,
14        chunking::{checkpoint::CheckpointManager, continuation::ChunkingConfig},
15    },
16    data_model::GraphDataType,
17    disk_index_build_parameter::{
18        DiskIndexBuildParameters, MemoryBudget, NumPQChunks, DISK_SECTOR_LEN,
19    },
20    storage::DiskIndexWriter,
21    QuantizationType,
22};
23use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
24use diskann_providers::{
25    model::IndexConfiguration,
26    utils::{load_metadata_from_file, Timer},
27};
28use diskann_vector::distance::Metric;
29use opentelemetry::global::BoxedSpan;
30#[cfg(feature = "perf_test")]
31use opentelemetry::{
32    trace::{Span, Tracer},
33    KeyValue,
34};
35
36pub struct ChunkingParameters {
37    pub chunking_config: ChunkingConfig,
38    pub checkpoint_record_manager: Box<dyn CheckpointManager>,
39}
40
41/// A simple struct to contain the underlying dimension of the data and
42/// its full-precision vector dimension.
43///
44/// * `dim` is the length of the vector when represented with the underlying datatype
45/// * `full_dim` is the length of the vector when converted to a full-precision slice, i.e. [f32]
46///
47/// # Notes
48///
49/// These values are the same when using primitive data types to represent the vectors
50/// such as `half::f16` or `f32`, however, for quantized vectors used in place of
51/// full-preicision vectors such as [`common::MinMaxElement`] these might be different.
52#[derive(Clone, Copy, PartialEq, Debug)]
53pub struct DimensionValues {
54    dim: usize,
55    full_dim: usize,
56}
57impl DimensionValues {
58    pub fn new(dim: usize, full_dim: usize) -> Self {
59        Self { dim, full_dim }
60    }
61
62    pub fn dim(&self) -> usize {
63        self.dim
64    }
65
66    pub fn full_dim(&self) -> usize {
67        self.full_dim
68    }
69}
70
71pub struct BuildDiskIndexParameters<'a> {
72    pub metric: Metric,
73    pub data_path: &'a str,
74    pub r: u32,
75    pub l: u32,
76    pub index_path_prefix: &'a str,
77    pub num_threads: usize,
78    pub num_of_pq_chunks: usize,
79    pub index_build_ram_limit_gb: f64,
80    pub build_quantization_type: QuantizationType,
81    pub chunking_parameters: Option<ChunkingParameters>,
82    pub dim_values: DimensionValues,
83}
84
85/// The main function to build a disk index
86pub fn build_disk_index<Data, StorageProviderType>(
87    storage_provider: &StorageProviderType,
88    parameters: BuildDiskIndexParameters,
89) -> ANNResult<()>
90where
91    Data: GraphDataType<VectorIdType = u32>,
92    StorageProviderType: StorageReadProvider + StorageWriteProvider + 'static,
93    <StorageProviderType as StorageReadProvider>::Reader: std::marker::Send,
94{
95    let build_parameters = DiskIndexBuildParameters::new(
96        MemoryBudget::try_from_gb(parameters.index_build_ram_limit_gb)?,
97        parameters.build_quantization_type,
98        NumPQChunks::new_with(
99            parameters.num_of_pq_chunks,
100            parameters.dim_values.full_dim(),
101        )?,
102    );
103
104    let config = config::Builder::new_with(
105        parameters.r.into_usize(),
106        config::MaxDegree::default_slack(),
107        parameters.l.into_usize(),
108        parameters.metric.into(),
109        |b| {
110            b.saturate_after_prune(true);
111        },
112    )
113    .build()?;
114
115    let metadata = load_metadata_from_file(storage_provider, parameters.data_path)?;
116
117    if metadata.ndims() != parameters.dim_values.dim() {
118        return Err(ANNError::log_index_config_error(
119            format!("{:?}", parameters.dim_values),
120            format!("dim_values must match with data_dim {}", metadata.ndims()),
121        ));
122    }
123
124    let index_configuration = IndexConfiguration::new(
125        parameters.metric,
126        metadata.ndims(),
127        metadata.npoints(),
128        ONE,
129        parameters.num_threads,
130        config,
131    )
132    .with_pseudo_rng();
133
134    let disk_index_writer = DiskIndexWriter::new(
135        parameters.data_path.to_string(),
136        parameters.index_path_prefix.to_string(),
137        Option::None,
138        DISK_SECTOR_LEN,
139    )?;
140
141    let mut disk_index = match parameters.chunking_parameters {
142        Some(chunking_parameters) => {
143            let chunking_config = chunking_parameters.chunking_config;
144            let checkpoint_record_manager = chunking_parameters.checkpoint_record_manager;
145            DiskIndexBuilder::<Data, StorageProviderType>::new_with_chunking_config(
146                storage_provider,
147                build_parameters,
148                index_configuration,
149                disk_index_writer,
150                chunking_config,
151                checkpoint_record_manager,
152            )
153        }
154        None => DiskIndexBuilder::<Data, StorageProviderType>::new(
155            storage_provider,
156            build_parameters,
157            index_configuration,
158            disk_index_writer,
159        ),
160    }?;
161
162    let mut _span: BoxedSpan;
163    #[cfg(feature = "perf_test")]
164    {
165        let tracer = opentelemetry::global::tracer("");
166
167        // Start a span for the search iteration.
168        _span = tracer.start("index-build statistics".to_string());
169    }
170
171    let timer = Timer::new();
172    disk_index.build()?;
173
174    let diff = timer.elapsed();
175    println!("Indexing time: {} seconds", diff.as_secs_f64());
176
177    #[cfg(feature = "perf_test")]
178    {
179        _span.set_attribute(KeyValue::new("total_time", diff.as_secs_f64()));
180        _span.set_attribute(KeyValue::new("total_comparisons", 0i64));
181        _span.set_attribute(KeyValue::new("search_hops", 0i64));
182        _span.end();
183    }
184
185    Ok(())
186}
187
188#[cfg(test)]
189mod tests {
190    use diskann::ANNErrorKind;
191    use diskann_providers::storage::VirtualStorageProvider;
192    use vfs::MemoryFS;
193
194    use super::*;
195    use crate::utils::GraphDataInt8Vector;
196
197    #[test]
198    fn test_build_disk_index_with_num_of_pq_chunks() {
199        let storage_provider = VirtualStorageProvider::new_memory();
200        let parameters = BuildDiskIndexParameters {
201            metric: Metric::L2,
202            data_path: "test_data_path",
203            r: 10,
204            l: 20,
205            index_path_prefix: "test_index_path_prefix",
206            num_threads: 4,
207            num_of_pq_chunks: 8,
208            index_build_ram_limit_gb: 1.0,
209            build_quantization_type: QuantizationType::FP,
210            chunking_parameters: None,
211            dim_values: DimensionValues::new(128, 128),
212        };
213
214        let result = build_disk_index::<GraphDataInt8Vector, VirtualStorageProvider<MemoryFS>>(
215            &storage_provider,
216            parameters,
217        );
218        assert!(result.is_err());
219        assert_ne!(result.unwrap_err().kind(), ANNErrorKind::IndexConfigError);
220    }
221
222    #[test]
223    fn test_build_disk_index_with_zero_num_of_pq_chunks() {
224        let storage_provider = VirtualStorageProvider::new_memory();
225        let parameters = BuildDiskIndexParameters {
226            metric: Metric::L2,
227            data_path: "test_data_path",
228            r: 10,
229            l: 20,
230            index_path_prefix: "test_index_path_prefix",
231            num_threads: 4,
232            num_of_pq_chunks: 0,
233            index_build_ram_limit_gb: 1.0,
234            build_quantization_type: QuantizationType::FP,
235            chunking_parameters: None,
236            dim_values: DimensionValues::new(128, 128),
237        };
238
239        let result = build_disk_index::<GraphDataInt8Vector, VirtualStorageProvider<MemoryFS>>(
240            &storage_provider,
241            parameters,
242        );
243        assert!(result.is_err());
244        assert_eq!(result.unwrap_err().kind(), ANNErrorKind::IndexConfigError);
245    }
246}