1use std::marker::PhantomData;
7
8use diskann::{utils::VectorRepr, ANNError};
9use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
10use diskann_providers::{
11 forward_threadpool,
12 model::{
13 pq::{accum_row_inplace, generate_pq_pivots},
14 GeneratePivotArguments,
15 },
16 storage::PQStorage,
17 utils::{AsThreadPool, BridgeErr, Timer},
18};
19use diskann_quantization::{product::TransposedTable, CompressInto};
20use diskann_utils::views::MatrixBase;
21use diskann_vector::distance::Metric;
22use tracing::info;
23
24use crate::storage::quant::compressor::{CompressionStage, QuantCompressor};
25
26pub struct PQGenerationContext<'a, Storage, Pool>
27where
28 Storage: StorageReadProvider + StorageWriteProvider,
29 Pool: AsThreadPool,
30{
31 pub pq_storage: PQStorage,
32 pub num_chunks: usize,
33 pub seed: Option<u64>,
34 pub p_val: f64,
35 pub storage_provider: &'a Storage,
36 pub pool: Pool,
37 pub metric: Metric,
38 pub dim: usize,
39 pub max_kmeans_reps: usize,
40 pub num_centers: usize,
41}
42
43pub struct PQGeneration<'a, T, Storage, Pool>
44where
45 T: VectorRepr,
46 Storage: StorageReadProvider + StorageWriteProvider + 'a,
47 Pool: AsThreadPool,
48{
49 table: TransposedTable,
50 num_chunks: usize,
51 phantom_data: PhantomData<T>,
52 phantom_storage: PhantomData<&'a Storage>,
53 phantom_pool: PhantomData<Pool>,
54}
55
56impl<'a, T, Storage, Pool> QuantCompressor<T> for PQGeneration<'a, T, Storage, Pool>
57where
58 T: VectorRepr,
59 Storage: StorageReadProvider + StorageWriteProvider + 'a,
60 Pool: AsThreadPool,
61{
62 type CompressorContext = PQGenerationContext<'a, Storage, Pool>;
63
64 fn new_at_stage(
65 stage: CompressionStage,
66 context: &Self::CompressorContext,
67 ) -> diskann::ANNResult<Self> {
68 if context.num_chunks > context.dim {
70 return Err(ANNError::log_pq_error(
71 "Error: number of chunks more than dimension.",
72 ));
73 }
74
75 let pivots_exists = context
76 .pq_storage
77 .pivot_data_exist(context.storage_provider);
78
79 let pool = &context.pool;
80 forward_threadpool!(pool = pool: Pool);
81
82 if !pivots_exists {
83 if stage == CompressionStage::Resume {
84 return Err(ANNError::log_pq_error(
86 "Error: Pivot data does not exist when start_vertex_id is not 0.",
87 ));
88 }
89
90 let timer = Timer::new();
91
92 let rng =
93 diskann_providers::utils::create_rnd_provider_from_optional_seed(context.seed);
94 let (mut train_data, train_size, train_dim) = context
95 .pq_storage
96 .get_random_train_data_slice::<T, Storage>(
97 context.p_val,
98 context.storage_provider,
99 &mut rng.create_rnd(),
100 )?;
101
102 generate_pq_pivots(
103 GeneratePivotArguments::new(
104 train_size,
105 train_dim,
106 context.num_centers,
107 context.num_chunks,
108 context.max_kmeans_reps,
109 context.metric == Metric::L2,
110 )?,
111 &mut train_data,
112 &context.pq_storage,
113 context.storage_provider,
114 rng,
115 pool,
116 )?;
117
118 info!(
119 "PQ pivot generation took {} seconds",
120 timer.elapsed().as_secs_f64()
121 );
122 }
123
124 let (_, full_dim) = context
125 .pq_storage
126 .read_existing_pivot_metadata(context.storage_provider)?;
127
128 let num_chunks = context.num_chunks;
130 let (mut full_pivot_data, centroid, chunk_offsets, _) =
131 context.pq_storage.load_existing_pivot_data(
132 &num_chunks,
133 &context.num_centers,
134 &full_dim,
135 context.storage_provider,
136 false,
137 )?;
138
139 let mut full_pivot_data_mat = diskann_utils::views::MutMatrixView::try_from(
140 full_pivot_data.as_mut_slice(),
141 context.num_centers,
142 full_dim,
143 )
144 .bridge_err()?;
145
146 accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
147
148 let table = TransposedTable::from_parts(
149 full_pivot_data_mat.as_view(),
150 diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
151 .bridge_err()?
152 .to_owned(),
153 )
154 .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
155
156 Ok(Self {
157 table,
158 num_chunks,
159 phantom_data: PhantomData,
160 phantom_pool: PhantomData,
161 phantom_storage: PhantomData,
162 })
163 }
164
165 fn compress(
166 &self,
167 vector: MatrixBase<&[f32]>,
168 output: MatrixBase<&mut [u8]>,
169 ) -> Result<(), diskann::ANNError> {
170 self.table
171 .compress_into(vector, output)
172 .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
173 }
174
175 fn compressed_bytes(&self) -> usize {
176 self.num_chunks
177 }
178}
179
180#[cfg(test)]
185mod pq_generation_tests {
186 use diskann::ANNError;
187 use diskann_providers::model::pq::generate_pq_pivots;
188 use diskann_providers::model::GeneratePivotArguments;
189 use diskann_providers::storage::{
190 PQStorage, StorageReadProvider, StorageWriteProvider, VirtualStorageProvider,
191 };
192 use diskann_providers::utils::{create_thread_pool_for_test, AsThreadPool};
193 use diskann_utils::{
194 io::{read_bin, write_bin},
195 test_data_root,
196 views::{MatrixView, MutMatrixView},
197 };
198 use diskann_vector::distance::Metric;
199 use rstest::rstest;
200 use vfs::FileSystem;
201
202 use super::{CompressionStage, PQGeneration, PQGenerationContext};
203 use crate::storage::quant::compressor::QuantCompressor;
204
205 const TEST_PQ_DATA_PATH: &str = "/sift/siftsmall_learn.bin";
206 const TEST_PQ_PIVOTS_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
207 const TEST_PQ_COMPRESSED_PATH: &str = "/sift/siftsmall_learn_pq_compressed.bin";
208 const VALIDATION_DATA: [f32; 40] = [
209 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
211 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
212 2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 100.0f32,
213 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
214 ];
215 #[allow(clippy::too_many_arguments)]
216 fn create_new_compressor<'a, R: AsThreadPool, F: vfs::FileSystem>(
217 stage: CompressionStage,
218 provider: &'a VirtualStorageProvider<F>,
219 dim: usize,
220 num_chunks: usize,
221 max_kmeans_reps: usize,
222 num_centers: usize,
223 p_val: f64,
224 pool: R,
225 pivots_path: String,
226 compressed_path: String,
227 data_path: Option<&str>,
228 ) -> Result<PQGeneration<'a, f32, VirtualStorageProvider<F>, R>, ANNError> {
229 let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path);
230 let context = PQGenerationContext::<'_, _, _> {
231 pq_storage,
232 num_chunks,
233 num_centers,
234 seed: Some(42),
235 p_val,
236 max_kmeans_reps,
237 storage_provider: provider,
238 pool,
239 metric: Metric::L2,
240 dim,
241 };
242 PQGeneration::<_, _, _>::new_at_stage(stage, &context)
243 }
244
245 #[rstest]
246 fn test_create_and_load_pivots_file() {
247 let storage_provider = VirtualStorageProvider::new_memory();
248 storage_provider
249 .filesystem()
250 .create_dir("/pq_generation_tests")
251 .expect("Could not create test directory");
252
253 let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
254 let pivot_file_name_compressor = "/pq_generation_tests/compressor_pivots_test.bin";
255 let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
256 let data_path = "/pq_generation_tests/data_path.bin";
257 let pq_storage: PQStorage =
258 PQStorage::new(pivot_file_name, compressed_file_name, Some(data_path));
259
260 let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
261 let mut train_data: Vec<f32> = VALIDATION_DATA.to_vec();
262
263 write_bin(
264 MatrixView::try_from(train_data.as_slice(), ndata, dim).unwrap(),
265 &mut storage_provider.create_for_write(data_path).unwrap(),
266 )
267 .unwrap();
268
269 let pool = create_thread_pool_for_test();
270 generate_pq_pivots(
271 GeneratePivotArguments::new(
272 ndata,
273 dim,
274 num_centers,
275 num_chunks,
276 max_k_means_reps,
277 true,
278 )
279 .unwrap(),
280 &mut train_data,
281 &pq_storage,
282 &storage_provider,
283 diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42),
284 &pool,
285 )
286 .unwrap();
287
288 let compressor = create_new_compressor(
289 CompressionStage::Start,
290 &storage_provider,
291 dim,
292 num_chunks,
293 max_k_means_reps,
294 num_centers,
295 1.0, &pool,
297 pivot_file_name_compressor.to_string(),
298 compressed_file_name.to_string(),
299 Some(data_path),
300 );
301
302 assert!(compressor.is_ok());
303
304 let compressor = compressor.unwrap();
305 assert_eq!(compressor.num_chunks, num_chunks);
306 assert_eq!(compressor.compressed_bytes(), num_chunks);
307
308 assert_eq!(compressor.table.dim(), dim);
309 assert_eq!(compressor.table.ncenters(), num_centers);
310 assert_eq!(compressor.table.nchunks(), num_chunks);
311
312 assert!(&storage_provider.exists(pivot_file_name_compressor));
313 let compressor_pivots = read_bin::<u8>(
314 &mut storage_provider
315 .open_reader(pivot_file_name_compressor)
316 .unwrap(),
317 )
318 .unwrap();
319 let true_pivots =
320 read_bin::<u8>(&mut storage_provider.open_reader(pivot_file_name).unwrap()).unwrap();
321 assert_eq!(compressor_pivots, true_pivots);
322 }
323
324 #[rstest]
325 fn throw_error_for_resume_and_no_existing_file() {
326 let storage_provider = VirtualStorageProvider::new_memory();
327 storage_provider
328 .filesystem()
329 .create_dir("/pq_generation_tests")
330 .expect("Could not create test directory");
331
332 let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
333 let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
334 let data_path = "/pq_generation_tests/data_path.bin";
335
336 let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
337
338 write_bin(
339 MatrixView::try_from(VALIDATION_DATA.as_slice(), ndata, dim).unwrap(),
340 &mut storage_provider.create_for_write(data_path).unwrap(),
341 )
342 .unwrap();
343 let pool = create_thread_pool_for_test();
344
345 let compressor = create_new_compressor(
346 CompressionStage::Resume,
347 &storage_provider,
348 dim,
349 num_chunks,
350 max_k_means_reps,
351 num_centers,
352 1.0,
353 &pool,
354 pivot_file_name.to_string(),
355 compressed_file_name.to_string(),
356 Some(data_path),
357 );
358
359 assert!(compressor.is_err());
360 }
361
362 #[rstest]
363 fn test_pq_end_to_end_with_codebook() {
364 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
365
366 let pool = create_thread_pool_for_test();
367 let dim = 128;
368 let num_chunks = 1;
369 let max_k_means_reps = 10;
370
371 let compressor = create_new_compressor(
372 CompressionStage::Resume,
373 &storage_provider,
374 dim,
375 num_chunks,
376 max_k_means_reps,
377 256,
378 1.0,
379 &pool,
380 TEST_PQ_PIVOTS_PATH.to_string(),
381 "".to_string(),
382 None,
383 );
384
385 if let Err(x) = compressor.as_ref() {
386 println!("Error creating compressor: {x}");
387 };
388
389 assert!(compressor.is_ok());
390
391 let data_matrix =
392 read_bin::<f32>(&mut storage_provider.open_reader(TEST_PQ_DATA_PATH).unwrap()).unwrap();
393 let npts = data_matrix.nrows();
394 let mut compressed_mat = vec![0_u8; num_chunks * npts];
395 let result = compressor.unwrap().compress(
396 data_matrix.as_view(),
397 MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(),
398 );
399 assert!(result.is_ok());
400
401 let compressed_gt = read_bin::<u8>(
402 &mut storage_provider
403 .open_reader(TEST_PQ_COMPRESSED_PATH)
404 .unwrap(),
405 )
406 .unwrap();
407 assert_eq!(compressed_gt.as_slice(), &compressed_mat);
408 }
409
410 #[rstest]
411 #[case(129, 128, 256)] #[case(128, 0, 256)] #[case(128, 128, 0)] fn test_parameter_error_cases(
415 #[case] dim: usize,
416 #[case] num_chunks: usize,
417 #[case] centers: usize,
418 ) {
419 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
421 let pool = create_thread_pool_for_test();
422 let max_k_means_reps = 10;
423 let compressor = create_new_compressor(
424 CompressionStage::Start,
425 &storage_provider,
426 dim,
427 num_chunks,
428 max_k_means_reps,
429 centers,
430 1.0,
431 &pool,
432 TEST_PQ_PIVOTS_PATH.to_string(),
433 "".to_string(),
434 None,
435 );
436 assert!(compressor.is_err());
437 }
438}