Skip to main content

diskann_disk/storage/quant/pq/
pq_generation.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::marker::PhantomData;
7
8use diskann::{utils::VectorRepr, ANNError};
9use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
10use diskann_providers::{
11    forward_threadpool,
12    model::{
13        pq::{accum_row_inplace, generate_pq_pivots},
14        GeneratePivotArguments,
15    },
16    storage::PQStorage,
17    utils::{AsThreadPool, BridgeErr, Timer},
18};
19use diskann_quantization::{product::TransposedTable, CompressInto};
20use diskann_utils::views::MatrixBase;
21use diskann_vector::distance::Metric;
22use tracing::info;
23
24use crate::storage::quant::compressor::{CompressionStage, QuantCompressor};
25
26pub struct PQGenerationContext<'a, Storage, Pool>
27where
28    Storage: StorageReadProvider + StorageWriteProvider,
29    Pool: AsThreadPool,
30{
31    pub pq_storage: PQStorage,
32    pub num_chunks: usize,
33    pub seed: Option<u64>,
34    pub p_val: f64,
35    pub storage_provider: &'a Storage,
36    pub pool: Pool,
37    pub metric: Metric,
38    pub dim: usize,
39    pub max_kmeans_reps: usize,
40    pub num_centers: usize,
41}
42
43pub struct PQGeneration<'a, T, Storage, Pool>
44where
45    T: VectorRepr,
46    Storage: StorageReadProvider + StorageWriteProvider + 'a,
47    Pool: AsThreadPool,
48{
49    table: TransposedTable,
50    num_chunks: usize,
51    phantom_data: PhantomData<T>,
52    phantom_storage: PhantomData<&'a Storage>,
53    phantom_pool: PhantomData<Pool>,
54}
55
56impl<'a, T, Storage, Pool> QuantCompressor<T> for PQGeneration<'a, T, Storage, Pool>
57where
58    T: VectorRepr,
59    Storage: StorageReadProvider + StorageWriteProvider + 'a,
60    Pool: AsThreadPool,
61{
62    type CompressorContext = PQGenerationContext<'a, Storage, Pool>;
63
64    fn new_at_stage(
65        stage: CompressionStage,
66        context: &Self::CompressorContext,
67    ) -> diskann::ANNResult<Self> {
68        // validate that the number of chunks is correct.
69        if context.num_chunks > context.dim {
70            return Err(ANNError::log_pq_error(
71                "Error: number of chunks more than dimension.",
72            ));
73        }
74
75        let pivots_exists = context
76            .pq_storage
77            .pivot_data_exist(context.storage_provider);
78
79        let pool = &context.pool;
80        forward_threadpool!(pool = pool: Pool);
81
82        if !pivots_exists {
83            if stage == CompressionStage::Resume {
84                //checks for error case when stage is Resume and pivot data doesn't exist.
85                return Err(ANNError::log_pq_error(
86                    "Error: Pivot data does not exist when start_vertex_id is not 0.",
87                ));
88            }
89
90            let timer = Timer::new();
91
92            let rng =
93                diskann_providers::utils::create_rnd_provider_from_optional_seed(context.seed);
94            let (mut train_data, train_size, train_dim) = context
95                .pq_storage
96                .get_random_train_data_slice::<T, Storage>(
97                    context.p_val,
98                    context.storage_provider,
99                    &mut rng.create_rnd(),
100                )?;
101
102            generate_pq_pivots(
103                GeneratePivotArguments::new(
104                    train_size,
105                    train_dim,
106                    context.num_centers,
107                    context.num_chunks,
108                    context.max_kmeans_reps,
109                    context.metric == Metric::L2,
110                )?,
111                &mut train_data,
112                &context.pq_storage,
113                context.storage_provider,
114                rng,
115                pool,
116            )?;
117
118            info!(
119                "PQ pivot generation took {} seconds",
120                timer.elapsed().as_secs_f64()
121            );
122        }
123
124        let (_, full_dim) = context
125            .pq_storage
126            .read_existing_pivot_metadata(context.storage_provider)?;
127
128        //Load the pivots
129        let num_chunks = context.num_chunks;
130        let (mut full_pivot_data, centroid, chunk_offsets, _) =
131            context.pq_storage.load_existing_pivot_data(
132                &num_chunks,
133                &context.num_centers,
134                &full_dim,
135                context.storage_provider,
136                false,
137            )?;
138
139        let mut full_pivot_data_mat = diskann_utils::views::MutMatrixView::try_from(
140            full_pivot_data.as_mut_slice(),
141            context.num_centers,
142            full_dim,
143        )
144        .bridge_err()?;
145
146        accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
147
148        let table = TransposedTable::from_parts(
149            full_pivot_data_mat.as_view(),
150            diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
151                .bridge_err()?
152                .to_owned(),
153        )
154        .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
155
156        Ok(Self {
157            table,
158            num_chunks,
159            phantom_data: PhantomData,
160            phantom_pool: PhantomData,
161            phantom_storage: PhantomData,
162        })
163    }
164
165    fn compress(
166        &self,
167        vector: MatrixBase<&[f32]>,
168        output: MatrixBase<&mut [u8]>,
169    ) -> Result<(), diskann::ANNError> {
170        self.table
171            .compress_into(vector, output)
172            .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
173    }
174
175    fn compressed_bytes(&self) -> usize {
176        self.num_chunks
177    }
178}
179
180//////////////////
181///// Tests /////
182/////////////////
183
184#[cfg(test)]
185mod pq_generation_tests {
186    use diskann::ANNError;
187    use diskann_providers::model::pq::generate_pq_pivots;
188    use diskann_providers::model::GeneratePivotArguments;
189    use diskann_providers::storage::{
190        PQStorage, StorageReadProvider, StorageWriteProvider, VirtualStorageProvider,
191    };
192    use diskann_providers::utils::{create_thread_pool_for_test, AsThreadPool};
193    use diskann_utils::{
194        io::{read_bin, write_bin},
195        test_data_root,
196        views::{MatrixView, MutMatrixView},
197    };
198    use diskann_vector::distance::Metric;
199    use rstest::rstest;
200    use vfs::FileSystem;
201
202    use super::{CompressionStage, PQGeneration, PQGenerationContext};
203    use crate::storage::quant::compressor::QuantCompressor;
204
205    const TEST_PQ_DATA_PATH: &str = "/sift/siftsmall_learn.bin";
206    const TEST_PQ_PIVOTS_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
207    const TEST_PQ_COMPRESSED_PATH: &str = "/sift/siftsmall_learn_pq_compressed.bin";
208    const VALIDATION_DATA: [f32; 40] = [
209        //sample validation data: npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8]
210        1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
211        2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
212        2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 100.0f32,
213        100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
214    ];
215    #[allow(clippy::too_many_arguments)]
216    fn create_new_compressor<'a, R: AsThreadPool, F: vfs::FileSystem>(
217        stage: CompressionStage,
218        provider: &'a VirtualStorageProvider<F>,
219        dim: usize,
220        num_chunks: usize,
221        max_kmeans_reps: usize,
222        num_centers: usize,
223        p_val: f64,
224        pool: R,
225        pivots_path: String,
226        compressed_path: String,
227        data_path: Option<&str>,
228    ) -> Result<PQGeneration<'a, f32, VirtualStorageProvider<F>, R>, ANNError> {
229        let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path);
230        let context = PQGenerationContext::<'_, _, _> {
231            pq_storage,
232            num_chunks,
233            num_centers,
234            seed: Some(42),
235            p_val,
236            max_kmeans_reps,
237            storage_provider: provider,
238            pool,
239            metric: Metric::L2,
240            dim,
241        };
242        PQGeneration::<_, _, _>::new_at_stage(stage, &context)
243    }
244
245    #[rstest]
246    fn test_create_and_load_pivots_file() {
247        let storage_provider = VirtualStorageProvider::new_memory();
248        storage_provider
249            .filesystem()
250            .create_dir("/pq_generation_tests")
251            .expect("Could not create test directory");
252
253        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
254        let pivot_file_name_compressor = "/pq_generation_tests/compressor_pivots_test.bin";
255        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
256        let data_path = "/pq_generation_tests/data_path.bin";
257        let pq_storage: PQStorage =
258            PQStorage::new(pivot_file_name, compressed_file_name, Some(data_path));
259
260        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
261        let mut train_data: Vec<f32> = VALIDATION_DATA.to_vec();
262
263        write_bin(
264            MatrixView::try_from(train_data.as_slice(), ndata, dim).unwrap(),
265            &mut storage_provider.create_for_write(data_path).unwrap(),
266        )
267        .unwrap();
268
269        let pool = create_thread_pool_for_test();
270        generate_pq_pivots(
271            GeneratePivotArguments::new(
272                ndata,
273                dim,
274                num_centers,
275                num_chunks,
276                max_k_means_reps,
277                true,
278            )
279            .unwrap(),
280            &mut train_data,
281            &pq_storage,
282            &storage_provider,
283            diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42),
284            &pool,
285        )
286        .unwrap();
287
288        let compressor = create_new_compressor(
289            CompressionStage::Start,
290            &storage_provider,
291            dim,
292            num_chunks,
293            max_k_means_reps,
294            num_centers,
295            1.0, //take all the data to compute codebook
296            &pool,
297            pivot_file_name_compressor.to_string(),
298            compressed_file_name.to_string(),
299            Some(data_path),
300        );
301
302        assert!(compressor.is_ok());
303
304        let compressor = compressor.unwrap();
305        assert_eq!(compressor.num_chunks, num_chunks);
306        assert_eq!(compressor.compressed_bytes(), num_chunks);
307
308        assert_eq!(compressor.table.dim(), dim);
309        assert_eq!(compressor.table.ncenters(), num_centers);
310        assert_eq!(compressor.table.nchunks(), num_chunks);
311
312        assert!(&storage_provider.exists(pivot_file_name_compressor));
313        let compressor_pivots = read_bin::<u8>(
314            &mut storage_provider
315                .open_reader(pivot_file_name_compressor)
316                .unwrap(),
317        )
318        .unwrap();
319        let true_pivots =
320            read_bin::<u8>(&mut storage_provider.open_reader(pivot_file_name).unwrap()).unwrap();
321        assert_eq!(compressor_pivots, true_pivots);
322    }
323
324    #[rstest]
325    fn throw_error_for_resume_and_no_existing_file() {
326        let storage_provider = VirtualStorageProvider::new_memory();
327        storage_provider
328            .filesystem()
329            .create_dir("/pq_generation_tests")
330            .expect("Could not create test directory");
331
332        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
333        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
334        let data_path = "/pq_generation_tests/data_path.bin";
335
336        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
337
338        write_bin(
339            MatrixView::try_from(VALIDATION_DATA.as_slice(), ndata, dim).unwrap(),
340            &mut storage_provider.create_for_write(data_path).unwrap(),
341        )
342        .unwrap();
343        let pool = create_thread_pool_for_test();
344
345        let compressor = create_new_compressor(
346            CompressionStage::Resume,
347            &storage_provider,
348            dim,
349            num_chunks,
350            max_k_means_reps,
351            num_centers,
352            1.0,
353            &pool,
354            pivot_file_name.to_string(),
355            compressed_file_name.to_string(),
356            Some(data_path),
357        );
358
359        assert!(compressor.is_err());
360    }
361
362    #[rstest]
363    fn test_pq_end_to_end_with_codebook() {
364        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
365
366        let pool = create_thread_pool_for_test();
367        let dim = 128;
368        let num_chunks = 1;
369        let max_k_means_reps = 10;
370
371        let compressor = create_new_compressor(
372            CompressionStage::Resume,
373            &storage_provider,
374            dim,
375            num_chunks,
376            max_k_means_reps,
377            256,
378            1.0,
379            &pool,
380            TEST_PQ_PIVOTS_PATH.to_string(),
381            "".to_string(),
382            None,
383        );
384
385        if let Err(x) = compressor.as_ref() {
386            println!("Error creating compressor: {x}");
387        };
388
389        assert!(compressor.is_ok());
390
391        let data_matrix =
392            read_bin::<f32>(&mut storage_provider.open_reader(TEST_PQ_DATA_PATH).unwrap()).unwrap();
393        let npts = data_matrix.nrows();
394        let mut compressed_mat = vec![0_u8; num_chunks * npts];
395        let result = compressor.unwrap().compress(
396            data_matrix.as_view(),
397            MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(),
398        );
399        assert!(result.is_ok());
400
401        let compressed_gt = read_bin::<u8>(
402            &mut storage_provider
403                .open_reader(TEST_PQ_COMPRESSED_PATH)
404                .unwrap(),
405        )
406        .unwrap();
407        assert_eq!(compressed_gt.as_slice(), &compressed_mat);
408    }
409
410    #[rstest]
411    #[case(129, 128, 256)] // num_chunks > dim
412    #[case(128, 0, 256)] // num_chunks == 0
413    #[case(128, 128, 0)] // num_centers == 0
414    fn test_parameter_error_cases(
415        #[case] dim: usize,
416        #[case] num_chunks: usize,
417        #[case] centers: usize,
418    ) {
419        //test the error cases for parameters: num_chunks > dim, num_chunks == 0, num_centers == 0
420        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
421        let pool = create_thread_pool_for_test();
422        let max_k_means_reps = 10;
423        let compressor = create_new_compressor(
424            CompressionStage::Start,
425            &storage_provider,
426            dim,
427            num_chunks,
428            max_k_means_reps,
429            centers,
430            1.0,
431            &pool,
432            TEST_PQ_PIVOTS_PATH.to_string(),
433            "".to_string(),
434            None,
435        );
436        assert!(compressor.is_err());
437    }
438}