Skip to main content

diskann_disk/storage/quant/pq/
pq_generation.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::marker::PhantomData;
7
8use diskann::{utils::VectorRepr, ANNError};
9use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
10use diskann_providers::{
11    forward_threadpool,
12    model::{
13        pq::{accum_row_inplace, generate_pq_pivots},
14        GeneratePivotArguments,
15    },
16    storage::PQStorage,
17    utils::{AsThreadPool, BridgeErr, Timer},
18};
19use diskann_quantization::{product::TransposedTable, CompressInto};
20use diskann_utils::views::MatrixBase;
21use diskann_vector::distance::Metric;
22use tracing::info;
23
24use crate::storage::quant::compressor::{CompressionStage, QuantCompressor};
25
26pub struct PQGenerationContext<'a, Storage, Pool>
27where
28    Storage: StorageReadProvider + StorageWriteProvider,
29    Pool: AsThreadPool,
30{
31    pub pq_storage: PQStorage,
32    pub num_chunks: usize,
33    pub seed: Option<u64>,
34    pub p_val: f64,
35    pub storage_provider: &'a Storage,
36    pub pool: Pool,
37    pub metric: Metric,
38    pub dim: usize,
39    pub max_kmeans_reps: usize,
40    pub num_centers: usize,
41}
42
43pub struct PQGeneration<'a, T, Storage, Pool>
44where
45    T: VectorRepr,
46    Storage: StorageReadProvider + StorageWriteProvider + 'a,
47    Pool: AsThreadPool,
48{
49    table: TransposedTable,
50    num_chunks: usize,
51    phantom_data: PhantomData<T>,
52    phantom_storage: PhantomData<&'a Storage>,
53    phantom_pool: PhantomData<Pool>,
54}
55
56impl<'a, T, Storage, Pool> QuantCompressor<T> for PQGeneration<'a, T, Storage, Pool>
57where
58    T: VectorRepr,
59    Storage: StorageReadProvider + StorageWriteProvider + 'a,
60    Pool: AsThreadPool,
61{
62    type CompressorContext = PQGenerationContext<'a, Storage, Pool>;
63
64    fn new_at_stage(
65        stage: CompressionStage,
66        context: &Self::CompressorContext,
67    ) -> diskann::ANNResult<Self> {
68        // validate that the number of chunks is correct.
69        if context.num_chunks > context.dim {
70            return Err(ANNError::log_pq_error(
71                "Error: number of chunks more than dimension.",
72            ));
73        }
74
75        let pivots_exists = context
76            .pq_storage
77            .pivot_data_exist(context.storage_provider);
78
79        let pool = &context.pool;
80        forward_threadpool!(pool = pool: Pool);
81
82        if !pivots_exists {
83            if stage == CompressionStage::Resume {
84                //checks for error case when stage is Resume and pivot data doesn't exist.
85                return Err(ANNError::log_pq_error(
86                    "Error: Pivot data does not exist when start_vertex_id is not 0.",
87                ));
88            }
89
90            let timer = Timer::new();
91
92            let rng =
93                diskann_providers::utils::create_rnd_provider_from_optional_seed(context.seed);
94            let (mut train_data, train_size, train_dim) = context
95                .pq_storage
96                .get_random_train_data_slice::<T, Storage>(
97                    context.p_val,
98                    context.storage_provider,
99                    &mut rng.create_rnd(),
100                )?;
101
102            generate_pq_pivots(
103                GeneratePivotArguments::new(
104                    train_size,
105                    train_dim,
106                    context.num_centers,
107                    context.num_chunks,
108                    context.max_kmeans_reps,
109                    context.metric == Metric::L2,
110                )?,
111                &mut train_data,
112                &context.pq_storage,
113                context.storage_provider,
114                rng,
115                pool,
116            )?;
117
118            info!(
119                "PQ pivot generation took {} seconds",
120                timer.elapsed().as_secs_f64()
121            );
122        }
123
124        let (_, full_dim) = context
125            .pq_storage
126            .read_existing_pivot_metadata(context.storage_provider)?;
127
128        //Load the pivots
129        let num_chunks = context.num_chunks;
130        let (mut full_pivot_data, centroid, chunk_offsets, _) =
131            context.pq_storage.load_existing_pivot_data(
132                &num_chunks,
133                &context.num_centers,
134                &full_dim,
135                context.storage_provider,
136                false,
137            )?;
138
139        let mut full_pivot_data_mat = diskann_utils::views::MutMatrixView::try_from(
140            full_pivot_data.as_mut_slice(),
141            context.num_centers,
142            full_dim,
143        )
144        .bridge_err()?;
145
146        accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
147
148        let table = TransposedTable::from_parts(
149            full_pivot_data_mat.as_view(),
150            diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
151                .bridge_err()?
152                .to_owned(),
153        )
154        .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
155
156        Ok(Self {
157            table,
158            num_chunks,
159            phantom_data: PhantomData,
160            phantom_pool: PhantomData,
161            phantom_storage: PhantomData,
162        })
163    }
164
165    fn compress(
166        &self,
167        vector: MatrixBase<&[f32]>,
168        output: MatrixBase<&mut [u8]>,
169    ) -> Result<(), diskann::ANNError> {
170        self.table
171            .compress_into(vector, output)
172            .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
173    }
174
175    fn compressed_bytes(&self) -> usize {
176        self.num_chunks
177    }
178}
179
180//////////////////
181///// Tests /////
182/////////////////
183
184#[cfg(test)]
185mod pq_generation_tests {
186    use diskann::ANNError;
187    use diskann_providers::model::pq::generate_pq_pivots;
188    use diskann_providers::model::GeneratePivotArguments;
189    use diskann_providers::storage::{PQStorage, StorageWriteProvider, VirtualStorageProvider};
190    use diskann_providers::utils::{
191        create_thread_pool_for_test, file_util::load_bin, save_bin_f32, AsThreadPool,
192    };
193    use diskann_utils::test_data_root;
194    use diskann_utils::views::{MatrixView, MutMatrixView};
195    use diskann_vector::distance::Metric;
196    use rstest::rstest;
197    use vfs::FileSystem;
198
199    use super::{CompressionStage, PQGeneration, PQGenerationContext};
200    use crate::storage::quant::compressor::QuantCompressor;
201
202    const TEST_PQ_DATA_PATH: &str = "/sift/siftsmall_learn.bin";
203    const TEST_PQ_PIVOTS_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
204    const TEST_PQ_COMPRESSED_PATH: &str = "/sift/siftsmall_learn_pq_compressed.bin";
205    const VALIDATION_DATA: [f32; 40] = [
206        //sample validation data: npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8]
207        1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
208        2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
209        2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 100.0f32,
210        100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
211    ];
212    #[allow(clippy::too_many_arguments)]
213    fn create_new_compressor<'a, R: AsThreadPool, F: vfs::FileSystem>(
214        stage: CompressionStage,
215        provider: &'a VirtualStorageProvider<F>,
216        dim: usize,
217        num_chunks: usize,
218        max_kmeans_reps: usize,
219        num_centers: usize,
220        p_val: f64,
221        pool: R,
222        pivots_path: String,
223        compressed_path: String,
224        data_path: Option<&str>,
225    ) -> Result<PQGeneration<'a, f32, VirtualStorageProvider<F>, R>, ANNError> {
226        let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path);
227        let context = PQGenerationContext::<'_, _, _> {
228            pq_storage,
229            num_chunks,
230            num_centers,
231            seed: Some(42),
232            p_val,
233            max_kmeans_reps,
234            storage_provider: provider,
235            pool,
236            metric: Metric::L2,
237            dim,
238        };
239        PQGeneration::<_, _, _>::new_at_stage(stage, &context)
240    }
241
242    #[rstest]
243    fn test_create_and_load_pivots_file() {
244        let storage_provider = VirtualStorageProvider::new_memory();
245        storage_provider
246            .filesystem()
247            .create_dir("/pq_generation_tests")
248            .expect("Could not create test directory");
249
250        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
251        let pivot_file_name_compressor = "/pq_generation_tests/compressor_pivots_test.bin";
252        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
253        let data_path = "/pq_generation_tests/data_path.bin";
254        let pq_storage: PQStorage =
255            PQStorage::new(pivot_file_name, compressed_file_name, Some(data_path));
256
257        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
258        let mut train_data: Vec<f32> = VALIDATION_DATA.to_vec();
259
260        let _ = save_bin_f32(
261            &mut storage_provider.create_for_write(data_path).unwrap(),
262            &train_data,
263            ndata,
264            dim,
265            0,
266        );
267
268        let pool = create_thread_pool_for_test();
269        generate_pq_pivots(
270            GeneratePivotArguments::new(
271                ndata,
272                dim,
273                num_centers,
274                num_chunks,
275                max_k_means_reps,
276                true,
277            )
278            .unwrap(),
279            &mut train_data,
280            &pq_storage,
281            &storage_provider,
282            diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42),
283            &pool,
284        )
285        .unwrap();
286
287        let compressor = create_new_compressor(
288            CompressionStage::Start,
289            &storage_provider,
290            dim,
291            num_chunks,
292            max_k_means_reps,
293            num_centers,
294            1.0, //take all the data to compute codebook
295            &pool,
296            pivot_file_name_compressor.to_string(),
297            compressed_file_name.to_string(),
298            Some(data_path),
299        );
300
301        assert!(compressor.is_ok());
302
303        let compressor = compressor.unwrap();
304        assert_eq!(compressor.num_chunks, num_chunks);
305        assert_eq!(compressor.compressed_bytes(), num_chunks);
306
307        assert_eq!(compressor.table.dim(), dim);
308        assert_eq!(compressor.table.ncenters(), num_centers);
309        assert_eq!(compressor.table.nchunks(), num_chunks);
310
311        assert!(&storage_provider.exists(pivot_file_name_compressor));
312        let (compressor_pivots, cn, cd) =
313            load_bin::<u8, _>(&storage_provider, pivot_file_name_compressor, 0).unwrap();
314        let (true_pivots, n, d) = load_bin::<u8, _>(&storage_provider, pivot_file_name, 0).unwrap();
315
316        assert_eq!(cn, n);
317        assert_eq!(cd, d);
318        assert_eq!(compressor_pivots, true_pivots);
319    }
320
321    #[rstest]
322    fn throw_error_for_resume_and_no_existing_file() {
323        let storage_provider = VirtualStorageProvider::new_memory();
324        storage_provider
325            .filesystem()
326            .create_dir("/pq_generation_tests")
327            .expect("Could not create test directory");
328
329        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
330        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
331        let data_path = "/pq_generation_tests/data_path.bin";
332
333        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
334
335        let _ = save_bin_f32(
336            &mut storage_provider.create_for_write(data_path).unwrap(),
337            &VALIDATION_DATA,
338            ndata,
339            dim,
340            0,
341        );
342        let pool = create_thread_pool_for_test();
343
344        let compressor = create_new_compressor(
345            CompressionStage::Resume,
346            &storage_provider,
347            dim,
348            num_chunks,
349            max_k_means_reps,
350            num_centers,
351            1.0,
352            &pool,
353            pivot_file_name.to_string(),
354            compressed_file_name.to_string(),
355            Some(data_path),
356        );
357
358        assert!(compressor.is_err());
359    }
360
361    #[rstest]
362    fn test_pq_end_to_end_with_codebook() {
363        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
364
365        let pool = create_thread_pool_for_test();
366        let dim = 128;
367        let num_chunks = 1;
368        let max_k_means_reps = 10;
369
370        let compressor = create_new_compressor(
371            CompressionStage::Resume,
372            &storage_provider,
373            dim,
374            num_chunks,
375            max_k_means_reps,
376            256,
377            1.0,
378            &pool,
379            TEST_PQ_PIVOTS_PATH.to_string(),
380            "".to_string(),
381            None,
382        );
383
384        if let Err(x) = compressor.as_ref() {
385            println!("Error creating compressor: {x}");
386        };
387
388        assert!(compressor.is_ok());
389
390        let (data, npts, dim) =
391            load_bin::<f32, _>(&storage_provider, TEST_PQ_DATA_PATH, 0).unwrap();
392        let mut compressed_mat = vec![0_u8; num_chunks * npts];
393        let result = compressor.unwrap().compress(
394            MatrixView::try_from(&data, npts, dim).unwrap(),
395            MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(),
396        );
397        assert!(result.is_ok());
398
399        let (compressed_gt, _, _) =
400            load_bin::<u8, _>(&storage_provider, TEST_PQ_COMPRESSED_PATH, 0).unwrap();
401        assert_eq!(compressed_gt, compressed_mat);
402    }
403
404    #[rstest]
405    #[case(129, 128, 256)] // num_chunks > dim
406    #[case(128, 0, 256)] // num_chunks == 0
407    #[case(128, 128, 0)] // num_centers == 0
408    fn test_parameter_error_cases(
409        #[case] dim: usize,
410        #[case] num_chunks: usize,
411        #[case] centers: usize,
412    ) {
413        //test the error cases for parameters: num_chunks > dim, num_chunks == 0, num_centers == 0
414        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
415        let pool = create_thread_pool_for_test();
416        let max_k_means_reps = 10;
417        let compressor = create_new_compressor(
418            CompressionStage::Start,
419            &storage_provider,
420            dim,
421            num_chunks,
422            max_k_means_reps,
423            centers,
424            1.0,
425            &pool,
426            TEST_PQ_PIVOTS_PATH.to_string(),
427            "".to_string(),
428            None,
429        );
430        assert!(compressor.is_err());
431    }
432}