Skip to main content

diskann_disk/build/builder/
quantizer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5//! Disk index quantizer implementation.
6use diskann::{ANNError, ANNResult};
7use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
8use diskann_providers::{
9    index::diskann_async::train_pq,
10    model::{
11        graph::{
12            provider::async_::{common::NoStore, inmem::WithBits},
13            traits::GraphDataType,
14        },
15        FixedChunkPQTable, IndexConfiguration, MAX_PQ_TRAINING_SET_SIZE,
16    },
17    storage::{PQStorage, SQStorage},
18    utils::{BridgeErr, PQPathNames},
19};
20use diskann_quantization::scalar::train::ScalarQuantizationParameters;
21use diskann_utils::views::MatrixView;
22use tracing::info;
23
24use crate::QuantizationType;
25
26/// Quantizer types used specifically for async disk index building.
27#[derive(Clone)]
28pub enum BuildQuantizer {
29    NoQuant(NoStore),
30    Scalar1Bit(WithBits<1>),
31    PQ(FixedChunkPQTable),
32}
33
34impl BuildQuantizer {
35    /// Train a new quantizer from scratch.
36    pub fn train<Data, StorageProvider>(
37        build_quantization_type: &QuantizationType,
38        index_path_prefix: &str,
39        index_configuration: &IndexConfiguration,
40        pq_storage: &PQStorage,
41        storage_provider: &StorageProvider,
42    ) -> ANNResult<Self>
43    where
44        Data: GraphDataType<VectorIdType = u32>,
45        StorageProvider: StorageReadProvider + StorageWriteProvider,
46    {
47        let num_points = index_configuration.max_points;
48        let p_val = MAX_PQ_TRAINING_SET_SIZE / (num_points as f64);
49        match *build_quantization_type {
50            QuantizationType::FP => Ok(Self::NoQuant(NoStore)),
51            QuantizationType::PQ { num_chunks } => {
52                let table = {
53                    //generate pq pivots.
54                    let seed = index_configuration.random_seed;
55                    let mut rnd =
56                        diskann_providers::utils::create_rnd_provider_from_optional_seed(seed)
57                            .create_rnd();
58                    let (train_data, train_size, train_dim) = pq_storage
59                        .get_random_train_data_slice::<Data::VectorDataType, _>(
60                            p_val,
61                            storage_provider,
62                            &mut rnd,
63                        )?;
64                    train_pq(
65                        MatrixView::try_from(&train_data, train_size, train_dim).bridge_err()?,
66                        num_chunks,
67                        &mut rnd,
68                        index_configuration.num_threads,
69                    )?
70                };
71                // Save at checkpoint. Note the the compressed data path and pivots path here
72                // are different than the ones used in quant vector generation.
73                let pq_paths = PQPathNames::new(index_path_prefix);
74                let pq_build_storage =
75                    PQStorage::new(&pq_paths.pivots, &pq_paths.compressed_data, None);
76                pq_build_storage.write_pivot_data(
77                    table.get_pq_table(),
78                    table.get_centroids(),
79                    table.get_chunk_offsets(),
80                    table.get_num_centers(),
81                    table.get_dim(),
82                    storage_provider,
83                )?;
84                Ok(Self::PQ(table))
85            }
86            QuantizationType::SQ {
87                nbits,
88                standard_deviation,
89            } => {
90                if nbits != 1 {
91                    return Err(ANNError::log_index_config_error(
92                        "build_quantization_type".to_string(),
93                        "SQ quantization is only supported for 1 bit".to_string(),
94                    ));
95                }
96                let rng = diskann_providers::utils::create_rnd_provider_from_optional_seed(
97                    index_configuration.random_seed,
98                );
99                let (train_data_vector, train_size, train_dim) = pq_storage
100                    .get_random_train_data_slice::<Data::VectorDataType, _>(
101                        p_val,
102                        storage_provider,
103                        &mut rng.create_rnd(),
104                    )?;
105
106                let quantizer_params = if let Some(std_dev) = standard_deviation {
107                    ScalarQuantizationParameters::new(std_dev)
108                } else {
109                    ScalarQuantizationParameters::default()
110                };
111
112                let quantizer = quantizer_params.train(
113                    MatrixView::try_from(&train_data_vector, train_size, train_dim).bridge_err()?,
114                );
115
116                info!("Now quantizer is trained and saving to file");
117                let sq_storage = SQStorage::new(index_path_prefix);
118                sq_storage.save_quantizer(&quantizer, storage_provider)?;
119
120                Ok(Self::Scalar1Bit(WithBits::<1>::new(quantizer)))
121            }
122        }
123    }
124
125    /// Load a previously trained quantizer from storage.
126    pub fn load<StorageProvider>(
127        build_quantization_type: &QuantizationType,
128        index_path_prefix: &str,
129        storage_provider: &StorageProvider,
130    ) -> ANNResult<Self>
131    where
132        StorageProvider: StorageReadProvider,
133    {
134        match build_quantization_type {
135            QuantizationType::FP => Ok(Self::NoQuant(NoStore)),
136            QuantizationType::PQ { num_chunks } => {
137                let pq_pivots_paths = PQPathNames::new(index_path_prefix);
138                let pq_build_storage = PQStorage::new(
139                    &pq_pivots_paths.pivots,
140                    &pq_pivots_paths.compressed_data,
141                    None,
142                );
143                let table = pq_build_storage.load_pq_pivots_bin::<StorageProvider>(
144                    &pq_pivots_paths.pivots,
145                    *num_chunks,
146                    storage_provider,
147                )?;
148                Ok(Self::PQ(table))
149            }
150            QuantizationType::SQ { .. } => {
151                let sq_storage = SQStorage::new(index_path_prefix);
152                let sq_quantizer = sq_storage.load_quantizer(storage_provider)?;
153                Ok(Self::Scalar1Bit(WithBits::<1>::new(sq_quantizer)))
154            }
155        }
156    }
157}