1use std::marker::PhantomData;
7
8use diskann::{utils::VectorRepr, ANNError};
9use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
10use diskann_providers::{
11 forward_threadpool,
12 model::{
13 pq::{accum_row_inplace, generate_pq_pivots},
14 GeneratePivotArguments,
15 },
16 storage::PQStorage,
17 utils::{AsThreadPool, BridgeErr, Timer},
18};
19use diskann_quantization::{product::TransposedTable, CompressInto};
20use diskann_utils::views::MatrixBase;
21use diskann_vector::distance::Metric;
22use tracing::info;
23
24use crate::storage::quant::compressor::{CompressionStage, QuantCompressor};
25
26pub struct PQGenerationContext<'a, Storage, Pool>
27where
28 Storage: StorageReadProvider + StorageWriteProvider,
29 Pool: AsThreadPool,
30{
31 pub pq_storage: PQStorage,
32 pub num_chunks: usize,
33 pub seed: Option<u64>,
34 pub p_val: f64,
35 pub storage_provider: &'a Storage,
36 pub pool: Pool,
37 pub metric: Metric,
38 pub dim: usize,
39 pub max_kmeans_reps: usize,
40 pub num_centers: usize,
41}
42
43pub struct PQGeneration<'a, T, Storage, Pool>
44where
45 T: VectorRepr,
46 Storage: StorageReadProvider + StorageWriteProvider + 'a,
47 Pool: AsThreadPool,
48{
49 table: TransposedTable,
50 num_chunks: usize,
51 phantom_data: PhantomData<T>,
52 phantom_storage: PhantomData<&'a Storage>,
53 phantom_pool: PhantomData<Pool>,
54}
55
56impl<'a, T, Storage, Pool> QuantCompressor<T> for PQGeneration<'a, T, Storage, Pool>
57where
58 T: VectorRepr,
59 Storage: StorageReadProvider + StorageWriteProvider + 'a,
60 Pool: AsThreadPool,
61{
62 type CompressorContext = PQGenerationContext<'a, Storage, Pool>;
63
64 fn new_at_stage(
65 stage: CompressionStage,
66 context: &Self::CompressorContext,
67 ) -> diskann::ANNResult<Self> {
68 if context.num_chunks > context.dim {
70 return Err(ANNError::log_pq_error(
71 "Error: number of chunks more than dimension.",
72 ));
73 }
74
75 let pivots_exists = context
76 .pq_storage
77 .pivot_data_exist(context.storage_provider);
78
79 let pool = &context.pool;
80 forward_threadpool!(pool = pool: Pool);
81
82 if !pivots_exists {
83 if stage == CompressionStage::Resume {
84 return Err(ANNError::log_pq_error(
86 "Error: Pivot data does not exist when start_vertex_id is not 0.",
87 ));
88 }
89
90 let timer = Timer::new();
91
92 let rng =
93 diskann_providers::utils::create_rnd_provider_from_optional_seed(context.seed);
94 let (mut train_data, train_size, train_dim) = context
95 .pq_storage
96 .get_random_train_data_slice::<T, Storage>(
97 context.p_val,
98 context.storage_provider,
99 &mut rng.create_rnd(),
100 )?;
101
102 generate_pq_pivots(
103 GeneratePivotArguments::new(
104 train_size,
105 train_dim,
106 context.num_centers,
107 context.num_chunks,
108 context.max_kmeans_reps,
109 context.metric == Metric::L2,
110 )?,
111 &mut train_data,
112 &context.pq_storage,
113 context.storage_provider,
114 rng,
115 pool,
116 )?;
117
118 info!(
119 "PQ pivot generation took {} seconds",
120 timer.elapsed().as_secs_f64()
121 );
122 }
123
124 let (_, full_dim) = context
125 .pq_storage
126 .read_existing_pivot_metadata(context.storage_provider)?;
127
128 let num_chunks = context.num_chunks;
130 let (mut full_pivot_data, centroid, chunk_offsets, _) =
131 context.pq_storage.load_existing_pivot_data(
132 &num_chunks,
133 &context.num_centers,
134 &full_dim,
135 context.storage_provider,
136 false,
137 )?;
138
139 let mut full_pivot_data_mat = diskann_utils::views::MutMatrixView::try_from(
140 full_pivot_data.as_mut_slice(),
141 context.num_centers,
142 full_dim,
143 )
144 .bridge_err()?;
145
146 accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
147
148 let table = TransposedTable::from_parts(
149 full_pivot_data_mat.as_view(),
150 diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
151 .bridge_err()?
152 .to_owned(),
153 )
154 .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
155
156 Ok(Self {
157 table,
158 num_chunks,
159 phantom_data: PhantomData,
160 phantom_pool: PhantomData,
161 phantom_storage: PhantomData,
162 })
163 }
164
165 fn compress(
166 &self,
167 vector: MatrixBase<&[f32]>,
168 output: MatrixBase<&mut [u8]>,
169 ) -> Result<(), diskann::ANNError> {
170 self.table
171 .compress_into(vector, output)
172 .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
173 }
174
175 fn compressed_bytes(&self) -> usize {
176 self.num_chunks
177 }
178}
179
180#[cfg(test)]
185mod pq_generation_tests {
186 use diskann::ANNError;
187 use diskann_providers::model::pq::generate_pq_pivots;
188 use diskann_providers::model::GeneratePivotArguments;
189 use diskann_providers::storage::{PQStorage, StorageWriteProvider, VirtualStorageProvider};
190 use diskann_providers::utils::{
191 create_thread_pool_for_test, file_util::load_bin, save_bin_f32, AsThreadPool,
192 };
193 use diskann_utils::test_data_root;
194 use diskann_utils::views::{MatrixView, MutMatrixView};
195 use diskann_vector::distance::Metric;
196 use rstest::rstest;
197 use vfs::{FileSystem, MemoryFS, OverlayFS};
198
199 use super::{CompressionStage, PQGeneration, PQGenerationContext};
200 use crate::storage::quant::compressor::QuantCompressor;
201
202 const TEST_PQ_DATA_PATH: &str = "/sift/siftsmall_learn.bin";
203 const TEST_PQ_PIVOTS_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
204 const TEST_PQ_COMPRESSED_PATH: &str = "/sift/siftsmall_learn_pq_compressed.bin";
205 const VALIDATION_DATA: [f32; 40] = [
206 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
208 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
209 2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 100.0f32,
210 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
211 ];
212 #[allow(clippy::too_many_arguments)]
213 fn create_new_compressor<'a, R: AsThreadPool>(
214 stage: CompressionStage,
215 provider: &'a VirtualStorageProvider<OverlayFS>,
216 dim: usize,
217 num_chunks: usize,
218 max_kmeans_reps: usize,
219 num_centers: usize,
220 p_val: f64,
221 pool: R,
222 pivots_path: String,
223 compressed_path: String,
224 data_path: Option<&str>,
225 ) -> Result<PQGeneration<'a, f32, VirtualStorageProvider<OverlayFS>, R>, ANNError> {
226 let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path);
227 let context = PQGenerationContext::<'_, _, _> {
228 pq_storage,
229 num_chunks,
230 num_centers,
231 seed: Some(42),
232 p_val,
233 max_kmeans_reps,
234 storage_provider: provider,
235 pool,
236 metric: Metric::L2,
237 dim,
238 };
239 PQGeneration::<_, _, _>::new_at_stage(stage, &context)
240 }
241
242 #[rstest]
243 fn test_create_and_load_pivots_file() {
244 let fs = OverlayFS::new(&[MemoryFS::default().into()]);
245 fs.create_dir("/pq_generation_tests")
246 .expect("Could not create test directory");
247 let storage_provider = VirtualStorageProvider::new(fs);
248
249 let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
250 let pivot_file_name_compressor = "/pq_generation_tests/compressor_pivots_test.bin";
251 let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
252 let data_path = "/pq_generation_tests/data_path.bin";
253 let pq_storage: PQStorage =
254 PQStorage::new(pivot_file_name, compressed_file_name, Some(data_path));
255
256 let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
257 let mut train_data: Vec<f32> = VALIDATION_DATA.to_vec();
258
259 let _ = save_bin_f32(
260 &mut storage_provider.create_for_write(data_path).unwrap(),
261 &train_data,
262 ndata,
263 dim,
264 0,
265 );
266
267 let pool = create_thread_pool_for_test();
268 generate_pq_pivots(
269 GeneratePivotArguments::new(
270 ndata,
271 dim,
272 num_centers,
273 num_chunks,
274 max_k_means_reps,
275 true,
276 )
277 .unwrap(),
278 &mut train_data,
279 &pq_storage,
280 &storage_provider,
281 diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42),
282 &pool,
283 )
284 .unwrap();
285
286 let compressor = create_new_compressor(
287 CompressionStage::Start,
288 &storage_provider,
289 dim,
290 num_chunks,
291 max_k_means_reps,
292 num_centers,
293 1.0, &pool,
295 pivot_file_name_compressor.to_string(),
296 compressed_file_name.to_string(),
297 Some(data_path),
298 );
299
300 assert!(compressor.is_ok());
301
302 let compressor = compressor.unwrap();
303 assert_eq!(compressor.num_chunks, num_chunks);
304 assert_eq!(compressor.compressed_bytes(), num_chunks);
305
306 assert_eq!(compressor.table.dim(), dim);
307 assert_eq!(compressor.table.ncenters(), num_centers);
308 assert_eq!(compressor.table.nchunks(), num_chunks);
309
310 assert!(&storage_provider.exists(pivot_file_name_compressor));
311 let (compressor_pivots, cn, cd) =
312 load_bin::<u8, _>(&storage_provider, pivot_file_name_compressor, 0).unwrap();
313 let (true_pivots, n, d) = load_bin::<u8, _>(&storage_provider, pivot_file_name, 0).unwrap();
314
315 assert_eq!(cn, n);
316 assert_eq!(cd, d);
317 assert_eq!(compressor_pivots, true_pivots);
318 }
319
320 #[rstest]
321 fn throw_error_for_resume_and_no_existing_file() {
322 let fs = OverlayFS::new(&[
323 MemoryFS::default().into(),
324 ]);
326 fs.create_dir("/pq_generation_tests")
327 .expect("Could not create test directory");
328 let storage_provider = VirtualStorageProvider::new(fs);
329
330 let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
331 let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
332 let data_path = "/pq_generation_tests/data_path.bin";
333
334 let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
335
336 let _ = save_bin_f32(
337 &mut storage_provider.create_for_write(data_path).unwrap(),
338 &VALIDATION_DATA,
339 ndata,
340 dim,
341 0,
342 );
343 let pool = create_thread_pool_for_test();
344
345 let compressor = create_new_compressor(
346 CompressionStage::Resume,
347 &storage_provider,
348 dim,
349 num_chunks,
350 max_k_means_reps,
351 num_centers,
352 1.0,
353 &pool,
354 pivot_file_name.to_string(),
355 compressed_file_name.to_string(),
356 Some(data_path),
357 );
358
359 assert!(compressor.is_err());
360 }
361
362 #[rstest]
363 fn test_pq_end_to_end_with_codebook() {
364 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
365
366 let pool = create_thread_pool_for_test();
367 let dim = 128;
368 let num_chunks = 1;
369 let max_k_means_reps = 10;
370
371 let compressor = create_new_compressor(
372 CompressionStage::Resume,
373 &storage_provider,
374 dim,
375 num_chunks,
376 max_k_means_reps,
377 256,
378 1.0,
379 &pool,
380 TEST_PQ_PIVOTS_PATH.to_string(),
381 "".to_string(),
382 None,
383 );
384
385 if let Err(x) = compressor.as_ref() {
386 println!("Error creating compressor: {x}");
387 };
388
389 assert!(compressor.is_ok());
390
391 let (data, npts, dim) =
392 load_bin::<f32, _>(&storage_provider, TEST_PQ_DATA_PATH, 0).unwrap();
393 let mut compressed_mat = vec![0_u8; num_chunks * npts];
394 let result = compressor.unwrap().compress(
395 MatrixView::try_from(&data, npts, dim).unwrap(),
396 MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(),
397 );
398 assert!(result.is_ok());
399
400 let (compressed_gt, _, _) =
401 load_bin::<u8, _>(&storage_provider, TEST_PQ_COMPRESSED_PATH, 0).unwrap();
402 assert_eq!(compressed_gt, compressed_mat);
403 }
404
405 #[rstest]
406 #[case(129, 128, 256)] #[case(128, 0, 256)] #[case(128, 128, 0)] fn test_parameter_error_cases(
410 #[case] dim: usize,
411 #[case] num_chunks: usize,
412 #[case] centers: usize,
413 ) {
414 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
416 let pool = create_thread_pool_for_test();
417 let max_k_means_reps = 10;
418 let compressor = create_new_compressor(
419 CompressionStage::Start,
420 &storage_provider,
421 dim,
422 num_chunks,
423 max_k_means_reps,
424 centers,
425 1.0,
426 &pool,
427 TEST_PQ_PIVOTS_PATH.to_string(),
428 "".to_string(),
429 None,
430 );
431 assert!(compressor.is_err());
432 }
433}