1use std::marker::PhantomData;
7
8use diskann::{utils::VectorRepr, ANNError};
9use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
10use diskann_providers::{
11 forward_threadpool,
12 model::{
13 pq::{accum_row_inplace, generate_pq_pivots},
14 GeneratePivotArguments,
15 },
16 storage::PQStorage,
17 utils::{AsThreadPool, BridgeErr, Timer},
18};
19use diskann_quantization::{product::TransposedTable, CompressInto};
20use diskann_utils::views::MatrixBase;
21use diskann_vector::distance::Metric;
22use tracing::info;
23
24use crate::storage::quant::compressor::{CompressionStage, QuantCompressor};
25
26pub struct PQGenerationContext<'a, Storage, Pool>
27where
28 Storage: StorageReadProvider + StorageWriteProvider,
29 Pool: AsThreadPool,
30{
31 pub pq_storage: PQStorage,
32 pub num_chunks: usize,
33 pub seed: Option<u64>,
34 pub p_val: f64,
35 pub storage_provider: &'a Storage,
36 pub pool: Pool,
37 pub metric: Metric,
38 pub dim: usize,
39 pub max_kmeans_reps: usize,
40 pub num_centers: usize,
41}
42
43pub struct PQGeneration<'a, T, Storage, Pool>
44where
45 T: VectorRepr,
46 Storage: StorageReadProvider + StorageWriteProvider + 'a,
47 Pool: AsThreadPool,
48{
49 table: TransposedTable,
50 num_chunks: usize,
51 phantom_data: PhantomData<T>,
52 phantom_storage: PhantomData<&'a Storage>,
53 phantom_pool: PhantomData<Pool>,
54}
55
56impl<'a, T, Storage, Pool> QuantCompressor<T> for PQGeneration<'a, T, Storage, Pool>
57where
58 T: VectorRepr,
59 Storage: StorageReadProvider + StorageWriteProvider + 'a,
60 Pool: AsThreadPool,
61{
62 type CompressorContext = PQGenerationContext<'a, Storage, Pool>;
63
64 fn new_at_stage(
65 stage: CompressionStage,
66 context: &Self::CompressorContext,
67 ) -> diskann::ANNResult<Self> {
68 if context.num_chunks > context.dim {
70 return Err(ANNError::log_pq_error(
71 "Error: number of chunks more than dimension.",
72 ));
73 }
74
75 let pivots_exists = context
76 .pq_storage
77 .pivot_data_exist(context.storage_provider);
78
79 let pool = &context.pool;
80 forward_threadpool!(pool = pool: Pool);
81
82 if !pivots_exists {
83 if stage == CompressionStage::Resume {
84 return Err(ANNError::log_pq_error(
86 "Error: Pivot data does not exist when start_vertex_id is not 0.",
87 ));
88 }
89
90 let timer = Timer::new();
91
92 let rng =
93 diskann_providers::utils::create_rnd_provider_from_optional_seed(context.seed);
94 let (mut train_data, train_size, train_dim) = context
95 .pq_storage
96 .get_random_train_data_slice::<T, Storage>(
97 context.p_val,
98 context.storage_provider,
99 &mut rng.create_rnd(),
100 )?;
101
102 generate_pq_pivots(
103 GeneratePivotArguments::new(
104 train_size,
105 train_dim,
106 context.num_centers,
107 context.num_chunks,
108 context.max_kmeans_reps,
109 context.metric == Metric::L2,
110 )?,
111 &mut train_data,
112 &context.pq_storage,
113 context.storage_provider,
114 rng,
115 pool,
116 )?;
117
118 info!(
119 "PQ pivot generation took {} seconds",
120 timer.elapsed().as_secs_f64()
121 );
122 }
123
124 let (_, full_dim) = context
125 .pq_storage
126 .read_existing_pivot_metadata(context.storage_provider)?;
127
128 let num_chunks = context.num_chunks;
130 let (mut full_pivot_data, centroid, chunk_offsets, _) =
131 context.pq_storage.load_existing_pivot_data(
132 &num_chunks,
133 &context.num_centers,
134 &full_dim,
135 context.storage_provider,
136 false,
137 )?;
138
139 let mut full_pivot_data_mat = diskann_utils::views::MutMatrixView::try_from(
140 full_pivot_data.as_mut_slice(),
141 context.num_centers,
142 full_dim,
143 )
144 .bridge_err()?;
145
146 accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
147
148 let table = TransposedTable::from_parts(
149 full_pivot_data_mat.as_view(),
150 diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
151 .bridge_err()?
152 .to_owned(),
153 )
154 .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
155
156 Ok(Self {
157 table,
158 num_chunks,
159 phantom_data: PhantomData,
160 phantom_pool: PhantomData,
161 phantom_storage: PhantomData,
162 })
163 }
164
165 fn compress(
166 &self,
167 vector: MatrixBase<&[f32]>,
168 output: MatrixBase<&mut [u8]>,
169 ) -> Result<(), diskann::ANNError> {
170 self.table
171 .compress_into(vector, output)
172 .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
173 }
174
175 fn compressed_bytes(&self) -> usize {
176 self.num_chunks
177 }
178}
179
180#[cfg(test)]
185mod pq_generation_tests {
186 use diskann::ANNError;
187 use diskann_providers::model::pq::generate_pq_pivots;
188 use diskann_providers::model::GeneratePivotArguments;
189 use diskann_providers::storage::{PQStorage, StorageWriteProvider, VirtualStorageProvider};
190 use diskann_providers::utils::{
191 create_thread_pool_for_test, file_util::load_bin, save_bin_f32, AsThreadPool,
192 };
193 use diskann_utils::test_data_root;
194 use diskann_utils::views::{MatrixView, MutMatrixView};
195 use diskann_vector::distance::Metric;
196 use rstest::rstest;
197 use vfs::FileSystem;
198
199 use super::{CompressionStage, PQGeneration, PQGenerationContext};
200 use crate::storage::quant::compressor::QuantCompressor;
201
202 const TEST_PQ_DATA_PATH: &str = "/sift/siftsmall_learn.bin";
203 const TEST_PQ_PIVOTS_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
204 const TEST_PQ_COMPRESSED_PATH: &str = "/sift/siftsmall_learn_pq_compressed.bin";
205 const VALIDATION_DATA: [f32; 40] = [
206 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
208 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
209 2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 100.0f32,
210 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
211 ];
212 #[allow(clippy::too_many_arguments)]
213 fn create_new_compressor<'a, R: AsThreadPool, F: vfs::FileSystem>(
214 stage: CompressionStage,
215 provider: &'a VirtualStorageProvider<F>,
216 dim: usize,
217 num_chunks: usize,
218 max_kmeans_reps: usize,
219 num_centers: usize,
220 p_val: f64,
221 pool: R,
222 pivots_path: String,
223 compressed_path: String,
224 data_path: Option<&str>,
225 ) -> Result<PQGeneration<'a, f32, VirtualStorageProvider<F>, R>, ANNError> {
226 let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path);
227 let context = PQGenerationContext::<'_, _, _> {
228 pq_storage,
229 num_chunks,
230 num_centers,
231 seed: Some(42),
232 p_val,
233 max_kmeans_reps,
234 storage_provider: provider,
235 pool,
236 metric: Metric::L2,
237 dim,
238 };
239 PQGeneration::<_, _, _>::new_at_stage(stage, &context)
240 }
241
242 #[rstest]
243 fn test_create_and_load_pivots_file() {
244 let storage_provider = VirtualStorageProvider::new_memory();
245 storage_provider
246 .filesystem()
247 .create_dir("/pq_generation_tests")
248 .expect("Could not create test directory");
249
250 let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
251 let pivot_file_name_compressor = "/pq_generation_tests/compressor_pivots_test.bin";
252 let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
253 let data_path = "/pq_generation_tests/data_path.bin";
254 let pq_storage: PQStorage =
255 PQStorage::new(pivot_file_name, compressed_file_name, Some(data_path));
256
257 let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
258 let mut train_data: Vec<f32> = VALIDATION_DATA.to_vec();
259
260 let _ = save_bin_f32(
261 &mut storage_provider.create_for_write(data_path).unwrap(),
262 &train_data,
263 ndata,
264 dim,
265 0,
266 );
267
268 let pool = create_thread_pool_for_test();
269 generate_pq_pivots(
270 GeneratePivotArguments::new(
271 ndata,
272 dim,
273 num_centers,
274 num_chunks,
275 max_k_means_reps,
276 true,
277 )
278 .unwrap(),
279 &mut train_data,
280 &pq_storage,
281 &storage_provider,
282 diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42),
283 &pool,
284 )
285 .unwrap();
286
287 let compressor = create_new_compressor(
288 CompressionStage::Start,
289 &storage_provider,
290 dim,
291 num_chunks,
292 max_k_means_reps,
293 num_centers,
294 1.0, &pool,
296 pivot_file_name_compressor.to_string(),
297 compressed_file_name.to_string(),
298 Some(data_path),
299 );
300
301 assert!(compressor.is_ok());
302
303 let compressor = compressor.unwrap();
304 assert_eq!(compressor.num_chunks, num_chunks);
305 assert_eq!(compressor.compressed_bytes(), num_chunks);
306
307 assert_eq!(compressor.table.dim(), dim);
308 assert_eq!(compressor.table.ncenters(), num_centers);
309 assert_eq!(compressor.table.nchunks(), num_chunks);
310
311 assert!(&storage_provider.exists(pivot_file_name_compressor));
312 let (compressor_pivots, cn, cd) =
313 load_bin::<u8, _>(&storage_provider, pivot_file_name_compressor, 0).unwrap();
314 let (true_pivots, n, d) = load_bin::<u8, _>(&storage_provider, pivot_file_name, 0).unwrap();
315
316 assert_eq!(cn, n);
317 assert_eq!(cd, d);
318 assert_eq!(compressor_pivots, true_pivots);
319 }
320
321 #[rstest]
322 fn throw_error_for_resume_and_no_existing_file() {
323 let storage_provider = VirtualStorageProvider::new_memory();
324 storage_provider
325 .filesystem()
326 .create_dir("/pq_generation_tests")
327 .expect("Could not create test directory");
328
329 let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
330 let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
331 let data_path = "/pq_generation_tests/data_path.bin";
332
333 let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
334
335 let _ = save_bin_f32(
336 &mut storage_provider.create_for_write(data_path).unwrap(),
337 &VALIDATION_DATA,
338 ndata,
339 dim,
340 0,
341 );
342 let pool = create_thread_pool_for_test();
343
344 let compressor = create_new_compressor(
345 CompressionStage::Resume,
346 &storage_provider,
347 dim,
348 num_chunks,
349 max_k_means_reps,
350 num_centers,
351 1.0,
352 &pool,
353 pivot_file_name.to_string(),
354 compressed_file_name.to_string(),
355 Some(data_path),
356 );
357
358 assert!(compressor.is_err());
359 }
360
361 #[rstest]
362 fn test_pq_end_to_end_with_codebook() {
363 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
364
365 let pool = create_thread_pool_for_test();
366 let dim = 128;
367 let num_chunks = 1;
368 let max_k_means_reps = 10;
369
370 let compressor = create_new_compressor(
371 CompressionStage::Resume,
372 &storage_provider,
373 dim,
374 num_chunks,
375 max_k_means_reps,
376 256,
377 1.0,
378 &pool,
379 TEST_PQ_PIVOTS_PATH.to_string(),
380 "".to_string(),
381 None,
382 );
383
384 if let Err(x) = compressor.as_ref() {
385 println!("Error creating compressor: {x}");
386 };
387
388 assert!(compressor.is_ok());
389
390 let (data, npts, dim) =
391 load_bin::<f32, _>(&storage_provider, TEST_PQ_DATA_PATH, 0).unwrap();
392 let mut compressed_mat = vec![0_u8; num_chunks * npts];
393 let result = compressor.unwrap().compress(
394 MatrixView::try_from(&data, npts, dim).unwrap(),
395 MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(),
396 );
397 assert!(result.is_ok());
398
399 let (compressed_gt, _, _) =
400 load_bin::<u8, _>(&storage_provider, TEST_PQ_COMPRESSED_PATH, 0).unwrap();
401 assert_eq!(compressed_gt, compressed_mat);
402 }
403
404 #[rstest]
405 #[case(129, 128, 256)] #[case(128, 0, 256)] #[case(128, 128, 0)] fn test_parameter_error_cases(
409 #[case] dim: usize,
410 #[case] num_chunks: usize,
411 #[case] centers: usize,
412 ) {
413 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
415 let pool = create_thread_pool_for_test();
416 let max_k_means_reps = 10;
417 let compressor = create_new_compressor(
418 CompressionStage::Start,
419 &storage_provider,
420 dim,
421 num_chunks,
422 max_k_means_reps,
423 centers,
424 1.0,
425 &pool,
426 TEST_PQ_PIVOTS_PATH.to_string(),
427 "".to_string(),
428 None,
429 );
430 assert!(compressor.is_err());
431 }
432}