Skip to main content

diskann_disk/storage/quant/pq/
pq_generation.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::marker::PhantomData;
7
8use diskann::{utils::VectorRepr, ANNError};
9use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
10use diskann_providers::{
11    forward_threadpool,
12    model::{
13        pq::{accum_row_inplace, generate_pq_pivots},
14        GeneratePivotArguments,
15    },
16    storage::PQStorage,
17    utils::{AsThreadPool, BridgeErr, Timer},
18};
19use diskann_quantization::{product::TransposedTable, CompressInto};
20use diskann_utils::views::MatrixBase;
21use diskann_vector::distance::Metric;
22use tracing::info;
23
24use crate::storage::quant::compressor::{CompressionStage, QuantCompressor};
25
26pub struct PQGenerationContext<'a, Storage, Pool>
27where
28    Storage: StorageReadProvider + StorageWriteProvider,
29    Pool: AsThreadPool,
30{
31    pub pq_storage: PQStorage,
32    pub num_chunks: usize,
33    pub seed: Option<u64>,
34    pub p_val: f64,
35    pub storage_provider: &'a Storage,
36    pub pool: Pool,
37    pub metric: Metric,
38    pub dim: usize,
39    pub max_kmeans_reps: usize,
40    pub num_centers: usize,
41}
42
43pub struct PQGeneration<'a, T, Storage, Pool>
44where
45    T: VectorRepr,
46    Storage: StorageReadProvider + StorageWriteProvider + 'a,
47    Pool: AsThreadPool,
48{
49    table: TransposedTable,
50    num_chunks: usize,
51    phantom_data: PhantomData<T>,
52    phantom_storage: PhantomData<&'a Storage>,
53    phantom_pool: PhantomData<Pool>,
54}
55
56impl<'a, T, Storage, Pool> QuantCompressor<T> for PQGeneration<'a, T, Storage, Pool>
57where
58    T: VectorRepr,
59    Storage: StorageReadProvider + StorageWriteProvider + 'a,
60    Pool: AsThreadPool,
61{
62    type CompressorContext = PQGenerationContext<'a, Storage, Pool>;
63
64    fn new_at_stage(
65        stage: CompressionStage,
66        context: &Self::CompressorContext,
67    ) -> diskann::ANNResult<Self> {
68        // validate that the number of chunks is correct.
69        if context.num_chunks > context.dim {
70            return Err(ANNError::log_pq_error(
71                "Error: number of chunks more than dimension.",
72            ));
73        }
74
75        let pivots_exists = context
76            .pq_storage
77            .pivot_data_exist(context.storage_provider);
78
79        let pool = &context.pool;
80        forward_threadpool!(pool = pool: Pool);
81
82        if !pivots_exists {
83            if stage == CompressionStage::Resume {
84                //checks for error case when stage is Resume and pivot data doesn't exist.
85                return Err(ANNError::log_pq_error(
86                    "Error: Pivot data does not exist when start_vertex_id is not 0.",
87                ));
88            }
89
90            let timer = Timer::new();
91
92            let rng =
93                diskann_providers::utils::create_rnd_provider_from_optional_seed(context.seed);
94            let (mut train_data, train_size, train_dim) = context
95                .pq_storage
96                .get_random_train_data_slice::<T, Storage>(
97                    context.p_val,
98                    context.storage_provider,
99                    &mut rng.create_rnd(),
100                )?;
101
102            generate_pq_pivots(
103                GeneratePivotArguments::new(
104                    train_size,
105                    train_dim,
106                    context.num_centers,
107                    context.num_chunks,
108                    context.max_kmeans_reps,
109                    context.metric == Metric::L2,
110                )?,
111                &mut train_data,
112                &context.pq_storage,
113                context.storage_provider,
114                rng,
115                pool,
116            )?;
117
118            info!(
119                "PQ pivot generation took {} seconds",
120                timer.elapsed().as_secs_f64()
121            );
122        }
123
124        let (_, full_dim) = context
125            .pq_storage
126            .read_existing_pivot_metadata(context.storage_provider)?;
127
128        //Load the pivots
129        let num_chunks = context.num_chunks;
130        let (mut full_pivot_data, centroid, chunk_offsets, _) =
131            context.pq_storage.load_existing_pivot_data(
132                &num_chunks,
133                &context.num_centers,
134                &full_dim,
135                context.storage_provider,
136                false,
137            )?;
138
139        let mut full_pivot_data_mat = diskann_utils::views::MutMatrixView::try_from(
140            full_pivot_data.as_mut_slice(),
141            context.num_centers,
142            full_dim,
143        )
144        .bridge_err()?;
145
146        accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
147
148        let table = TransposedTable::from_parts(
149            full_pivot_data_mat.as_view(),
150            diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
151                .bridge_err()?
152                .to_owned(),
153        )
154        .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
155
156        Ok(Self {
157            table,
158            num_chunks,
159            phantom_data: PhantomData,
160            phantom_pool: PhantomData,
161            phantom_storage: PhantomData,
162        })
163    }
164
165    fn compress(
166        &self,
167        vector: MatrixBase<&[f32]>,
168        output: MatrixBase<&mut [u8]>,
169    ) -> Result<(), diskann::ANNError> {
170        self.table
171            .compress_into(vector, output)
172            .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
173    }
174
175    fn compressed_bytes(&self) -> usize {
176        self.num_chunks
177    }
178}
179
180//////////////////
181///// Tests /////
182/////////////////
183
184#[cfg(test)]
185mod pq_generation_tests {
186    use diskann::ANNError;
187    use diskann_providers::model::pq::generate_pq_pivots;
188    use diskann_providers::model::GeneratePivotArguments;
189    use diskann_providers::storage::{PQStorage, StorageWriteProvider, VirtualStorageProvider};
190    use diskann_providers::utils::{
191        create_thread_pool_for_test, file_util::load_bin, save_bin_f32, AsThreadPool,
192    };
193    use diskann_utils::test_data_root;
194    use diskann_utils::views::{MatrixView, MutMatrixView};
195    use diskann_vector::distance::Metric;
196    use rstest::rstest;
197    use vfs::{FileSystem, MemoryFS, OverlayFS};
198
199    use super::{CompressionStage, PQGeneration, PQGenerationContext};
200    use crate::storage::quant::compressor::QuantCompressor;
201
202    const TEST_PQ_DATA_PATH: &str = "/sift/siftsmall_learn.bin";
203    const TEST_PQ_PIVOTS_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
204    const TEST_PQ_COMPRESSED_PATH: &str = "/sift/siftsmall_learn_pq_compressed.bin";
205    const VALIDATION_DATA: [f32; 40] = [
206        //sample validation data: npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8]
207        1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
208        2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
209        2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 100.0f32,
210        100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
211    ];
212    #[allow(clippy::too_many_arguments)]
213    fn create_new_compressor<'a, R: AsThreadPool>(
214        stage: CompressionStage,
215        provider: &'a VirtualStorageProvider<OverlayFS>,
216        dim: usize,
217        num_chunks: usize,
218        max_kmeans_reps: usize,
219        num_centers: usize,
220        p_val: f64,
221        pool: R,
222        pivots_path: String,
223        compressed_path: String,
224        data_path: Option<&str>,
225    ) -> Result<PQGeneration<'a, f32, VirtualStorageProvider<OverlayFS>, R>, ANNError> {
226        let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path);
227        let context = PQGenerationContext::<'_, _, _> {
228            pq_storage,
229            num_chunks,
230            num_centers,
231            seed: Some(42),
232            p_val,
233            max_kmeans_reps,
234            storage_provider: provider,
235            pool,
236            metric: Metric::L2,
237            dim,
238        };
239        PQGeneration::<_, _, _>::new_at_stage(stage, &context)
240    }
241
242    #[rstest]
243    fn test_create_and_load_pivots_file() {
244        let fs = OverlayFS::new(&[MemoryFS::default().into()]);
245        fs.create_dir("/pq_generation_tests")
246            .expect("Could not create test directory");
247        let storage_provider = VirtualStorageProvider::new(fs);
248
249        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
250        let pivot_file_name_compressor = "/pq_generation_tests/compressor_pivots_test.bin";
251        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
252        let data_path = "/pq_generation_tests/data_path.bin";
253        let pq_storage: PQStorage =
254            PQStorage::new(pivot_file_name, compressed_file_name, Some(data_path));
255
256        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
257        let mut train_data: Vec<f32> = VALIDATION_DATA.to_vec();
258
259        let _ = save_bin_f32(
260            &mut storage_provider.create_for_write(data_path).unwrap(),
261            &train_data,
262            ndata,
263            dim,
264            0,
265        );
266
267        let pool = create_thread_pool_for_test();
268        generate_pq_pivots(
269            GeneratePivotArguments::new(
270                ndata,
271                dim,
272                num_centers,
273                num_chunks,
274                max_k_means_reps,
275                true,
276            )
277            .unwrap(),
278            &mut train_data,
279            &pq_storage,
280            &storage_provider,
281            diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42),
282            &pool,
283        )
284        .unwrap();
285
286        let compressor = create_new_compressor(
287            CompressionStage::Start,
288            &storage_provider,
289            dim,
290            num_chunks,
291            max_k_means_reps,
292            num_centers,
293            1.0, //take all the data to compute codebook
294            &pool,
295            pivot_file_name_compressor.to_string(),
296            compressed_file_name.to_string(),
297            Some(data_path),
298        );
299
300        assert!(compressor.is_ok());
301
302        let compressor = compressor.unwrap();
303        assert_eq!(compressor.num_chunks, num_chunks);
304        assert_eq!(compressor.compressed_bytes(), num_chunks);
305
306        assert_eq!(compressor.table.dim(), dim);
307        assert_eq!(compressor.table.ncenters(), num_centers);
308        assert_eq!(compressor.table.nchunks(), num_chunks);
309
310        assert!(&storage_provider.exists(pivot_file_name_compressor));
311        let (compressor_pivots, cn, cd) =
312            load_bin::<u8, _>(&storage_provider, pivot_file_name_compressor, 0).unwrap();
313        let (true_pivots, n, d) = load_bin::<u8, _>(&storage_provider, pivot_file_name, 0).unwrap();
314
315        assert_eq!(cn, n);
316        assert_eq!(cd, d);
317        assert_eq!(compressor_pivots, true_pivots);
318    }
319
320    #[rstest]
321    fn throw_error_for_resume_and_no_existing_file() {
322        let fs = OverlayFS::new(&[
323            MemoryFS::default().into(),
324            // PhysicalFS::new("tests/data/").into(),
325        ]);
326        fs.create_dir("/pq_generation_tests")
327            .expect("Could not create test directory");
328        let storage_provider = VirtualStorageProvider::new(fs);
329
330        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
331        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
332        let data_path = "/pq_generation_tests/data_path.bin";
333
334        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
335
336        let _ = save_bin_f32(
337            &mut storage_provider.create_for_write(data_path).unwrap(),
338            &VALIDATION_DATA,
339            ndata,
340            dim,
341            0,
342        );
343        let pool = create_thread_pool_for_test();
344
345        let compressor = create_new_compressor(
346            CompressionStage::Resume,
347            &storage_provider,
348            dim,
349            num_chunks,
350            max_k_means_reps,
351            num_centers,
352            1.0,
353            &pool,
354            pivot_file_name.to_string(),
355            compressed_file_name.to_string(),
356            Some(data_path),
357        );
358
359        assert!(compressor.is_err());
360    }
361
362    #[rstest]
363    fn test_pq_end_to_end_with_codebook() {
364        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
365
366        let pool = create_thread_pool_for_test();
367        let dim = 128;
368        let num_chunks = 1;
369        let max_k_means_reps = 10;
370
371        let compressor = create_new_compressor(
372            CompressionStage::Resume,
373            &storage_provider,
374            dim,
375            num_chunks,
376            max_k_means_reps,
377            256,
378            1.0,
379            &pool,
380            TEST_PQ_PIVOTS_PATH.to_string(),
381            "".to_string(),
382            None,
383        );
384
385        if let Err(x) = compressor.as_ref() {
386            println!("Error creating compressor: {x}");
387        };
388
389        assert!(compressor.is_ok());
390
391        let (data, npts, dim) =
392            load_bin::<f32, _>(&storage_provider, TEST_PQ_DATA_PATH, 0).unwrap();
393        let mut compressed_mat = vec![0_u8; num_chunks * npts];
394        let result = compressor.unwrap().compress(
395            MatrixView::try_from(&data, npts, dim).unwrap(),
396            MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(),
397        );
398        assert!(result.is_ok());
399
400        let (compressed_gt, _, _) =
401            load_bin::<u8, _>(&storage_provider, TEST_PQ_COMPRESSED_PATH, 0).unwrap();
402        assert_eq!(compressed_gt, compressed_mat);
403    }
404
405    #[rstest]
406    #[case(129, 128, 256)] // num_chunks > dim
407    #[case(128, 0, 256)] // num_chunks == 0
408    #[case(128, 128, 0)] // num_centers == 0
409    fn test_parameter_error_cases(
410        #[case] dim: usize,
411        #[case] num_chunks: usize,
412        #[case] centers: usize,
413    ) {
414        //test the error cases for parameters: num_chunks > dim, num_chunks == 0, num_centers == 0
415        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
416        let pool = create_thread_pool_for_test();
417        let max_k_means_reps = 10;
418        let compressor = create_new_compressor(
419            CompressionStage::Start,
420            &storage_provider,
421            dim,
422            num_chunks,
423            max_k_means_reps,
424            centers,
425            1.0,
426            &pool,
427            TEST_PQ_PIVOTS_PATH.to_string(),
428            "".to_string(),
429            None,
430        );
431        assert!(compressor.is_err());
432    }
433}