diskann_disk/build/builder/
quantizer.rs1use diskann::{ANNError, ANNResult};
7use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
8use diskann_providers::{
9 index::diskann_async::train_pq,
10 model::{
11 graph::{
12 provider::async_::{common::NoStore, inmem::WithBits},
13 traits::GraphDataType,
14 },
15 FixedChunkPQTable, IndexConfiguration, MAX_PQ_TRAINING_SET_SIZE,
16 },
17 storage::{PQStorage, SQStorage},
18 utils::{BridgeErr, PQPathNames},
19};
20use diskann_quantization::scalar::train::ScalarQuantizationParameters;
21use diskann_utils::views::MatrixView;
22use tracing::info;
23
24use crate::QuantizationType;
25
26#[derive(Clone)]
28pub enum BuildQuantizer {
29 NoQuant(NoStore),
30 Scalar1Bit(WithBits<1>),
31 PQ(FixedChunkPQTable),
32}
33
34impl BuildQuantizer {
35 pub fn train<Data, StorageProvider>(
37 build_quantization_type: &QuantizationType,
38 index_path_prefix: &str,
39 index_configuration: &IndexConfiguration,
40 pq_storage: &PQStorage,
41 storage_provider: &StorageProvider,
42 ) -> ANNResult<Self>
43 where
44 Data: GraphDataType<VectorIdType = u32>,
45 StorageProvider: StorageReadProvider + StorageWriteProvider,
46 {
47 let num_points = index_configuration.max_points;
48 let p_val = MAX_PQ_TRAINING_SET_SIZE / (num_points as f64);
49 match *build_quantization_type {
50 QuantizationType::FP => Ok(Self::NoQuant(NoStore)),
51 QuantizationType::PQ { num_chunks } => {
52 let table = {
53 let seed = index_configuration.random_seed;
55 let mut rnd =
56 diskann_providers::utils::create_rnd_provider_from_optional_seed(seed)
57 .create_rnd();
58 let (train_data, train_size, train_dim) = pq_storage
59 .get_random_train_data_slice::<Data::VectorDataType, _>(
60 p_val,
61 storage_provider,
62 &mut rnd,
63 )?;
64 train_pq(
65 MatrixView::try_from(&train_data, train_size, train_dim).bridge_err()?,
66 num_chunks,
67 &mut rnd,
68 index_configuration.num_threads,
69 )?
70 };
71 let pq_paths = PQPathNames::new(index_path_prefix);
74 let pq_build_storage =
75 PQStorage::new(&pq_paths.pivots, &pq_paths.compressed_data, None);
76 pq_build_storage.write_pivot_data(
77 table.get_pq_table(),
78 table.get_centroids(),
79 table.get_chunk_offsets(),
80 table.get_num_centers(),
81 table.get_dim(),
82 storage_provider,
83 )?;
84 Ok(Self::PQ(table))
85 }
86 QuantizationType::SQ {
87 nbits,
88 standard_deviation,
89 } => {
90 if nbits != 1 {
91 return Err(ANNError::log_index_config_error(
92 "build_quantization_type".to_string(),
93 "SQ quantization is only supported for 1 bit".to_string(),
94 ));
95 }
96 let rng = diskann_providers::utils::create_rnd_provider_from_optional_seed(
97 index_configuration.random_seed,
98 );
99 let (train_data_vector, train_size, train_dim) = pq_storage
100 .get_random_train_data_slice::<Data::VectorDataType, _>(
101 p_val,
102 storage_provider,
103 &mut rng.create_rnd(),
104 )?;
105
106 let quantizer_params = if let Some(std_dev) = standard_deviation {
107 ScalarQuantizationParameters::new(std_dev)
108 } else {
109 ScalarQuantizationParameters::default()
110 };
111
112 let quantizer = quantizer_params.train(
113 MatrixView::try_from(&train_data_vector, train_size, train_dim).bridge_err()?,
114 );
115
116 info!("Now quantizer is trained and saving to file");
117 let sq_storage = SQStorage::new(index_path_prefix);
118 sq_storage.save_quantizer(&quantizer, storage_provider)?;
119
120 Ok(Self::Scalar1Bit(WithBits::<1>::new(quantizer)))
121 }
122 }
123 }
124
125 pub fn load<StorageProvider>(
127 build_quantization_type: &QuantizationType,
128 index_path_prefix: &str,
129 storage_provider: &StorageProvider,
130 ) -> ANNResult<Self>
131 where
132 StorageProvider: StorageReadProvider,
133 {
134 match build_quantization_type {
135 QuantizationType::FP => Ok(Self::NoQuant(NoStore)),
136 QuantizationType::PQ { num_chunks } => {
137 let pq_pivots_paths = PQPathNames::new(index_path_prefix);
138 let pq_build_storage = PQStorage::new(
139 &pq_pivots_paths.pivots,
140 &pq_pivots_paths.compressed_data,
141 None,
142 );
143 let table = pq_build_storage.load_pq_pivots_bin::<StorageProvider>(
144 &pq_pivots_paths.pivots,
145 *num_chunks,
146 storage_provider,
147 )?;
148 Ok(Self::PQ(table))
149 }
150 QuantizationType::SQ { .. } => {
151 let sq_storage = SQStorage::new(index_path_prefix);
152 let sq_quantizer = sq_storage.load_quantizer(storage_provider)?;
153 Ok(Self::Scalar1Bit(WithBits::<1>::new(sq_quantizer)))
154 }
155 }
156 }
157}