diskann_disk/storage/quant/pq/
pq_generation.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::marker::PhantomData;
7
8use diskann::{utils::VectorRepr, ANNError};
9use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
10use diskann_providers::{
11    forward_threadpool,
12    model::{
13        pq::{accum_row_inplace, generate_pq_pivots},
14        GeneratePivotArguments,
15    },
16    storage::PQStorage,
17    utils::{AsThreadPool, BridgeErr, Timer},
18};
19use diskann_quantization::{product::TransposedTable, CompressInto};
20use diskann_utils::views::MatrixBase;
21use diskann_vector::distance::Metric;
22use tracing::info;
23
24use crate::storage::quant::compressor::{CompressionStage, QuantCompressor};
25
26pub struct PQGenerationContext<'a, Storage, Pool>
27where
28    Storage: StorageReadProvider + StorageWriteProvider,
29    Pool: AsThreadPool,
30{
31    pub pq_storage: PQStorage,
32    pub num_chunks: usize,
33    pub seed: Option<u64>,
34    pub p_val: f64,
35    pub storage_provider: &'a Storage,
36    pub pool: Pool,
37    pub metric: Metric,
38    pub dim: usize,
39    pub max_kmeans_reps: usize,
40    pub num_centers: usize,
41}
42
43pub struct PQGeneration<'a, T, Storage, Pool>
44where
45    T: VectorRepr,
46    Storage: StorageReadProvider + StorageWriteProvider + 'a,
47    Pool: AsThreadPool,
48{
49    table: TransposedTable,
50    num_chunks: usize,
51    phantom_data: PhantomData<T>,
52    phantom_storage: PhantomData<&'a Storage>,
53    phantom_pool: PhantomData<Pool>,
54}
55
56impl<'a, T, Storage, Pool> QuantCompressor<T> for PQGeneration<'a, T, Storage, Pool>
57where
58    T: VectorRepr,
59    Storage: StorageReadProvider + StorageWriteProvider + 'a,
60    Pool: AsThreadPool,
61{
62    type CompressorContext = PQGenerationContext<'a, Storage, Pool>;
63
64    fn new_at_stage(
65        stage: CompressionStage,
66        context: &Self::CompressorContext,
67    ) -> diskann::ANNResult<Self> {
68        // validate that the number of chunks is correct.
69        if context.num_chunks > context.dim {
70            return Err(ANNError::log_pq_error(
71                "Error: number of chunks more than dimension.",
72            ));
73        }
74
75        let pivots_exists = context
76            .pq_storage
77            .pivot_data_exist(context.storage_provider);
78
79        let pool = &context.pool;
80        forward_threadpool!(pool = pool: Pool);
81
82        if !pivots_exists {
83            if stage == CompressionStage::Resume {
84                //checks for error case when stage is Resume and pivot data doesn't exist.
85                return Err(ANNError::log_pq_error(
86                    "Error: Pivot data does not exist when start_vertex_id is not 0.",
87                ));
88            }
89
90            let timer = Timer::new();
91
92            let rng =
93                diskann_providers::utils::create_rnd_provider_from_optional_seed(context.seed);
94            let (mut train_data, train_size, train_dim) = context
95                .pq_storage
96                .get_random_train_data_slice::<T, Storage>(
97                    context.p_val,
98                    context.storage_provider,
99                    &mut rng.create_rnd(),
100                )?;
101
102            generate_pq_pivots(
103                GeneratePivotArguments::new(
104                    train_size,
105                    train_dim,
106                    context.num_centers,
107                    context.num_chunks,
108                    context.max_kmeans_reps,
109                    context.metric == Metric::L2,
110                )?,
111                &mut train_data,
112                &context.pq_storage,
113                context.storage_provider,
114                rng,
115                pool,
116            )?;
117
118            info!(
119                "PQ pivot generation took {} seconds",
120                timer.elapsed().as_secs_f64()
121            );
122        }
123
124        let (_, full_dim) = context
125            .pq_storage
126            .read_existing_pivot_metadata(context.storage_provider)?;
127
128        //Load the pivots
129        let num_chunks = context.num_chunks;
130        let (mut full_pivot_data, centroid, chunk_offsets, _) =
131            context.pq_storage.load_existing_pivot_data(
132                &num_chunks,
133                &context.num_centers,
134                &full_dim,
135                context.storage_provider,
136                false,
137            )?;
138
139        let mut full_pivot_data_mat = diskann_utils::views::MutMatrixView::try_from(
140            full_pivot_data.as_mut_slice(),
141            context.num_centers,
142            full_dim,
143        )
144        .bridge_err()?;
145
146        accum_row_inplace(full_pivot_data_mat.as_mut_view(), centroid.as_slice());
147
148        let table = TransposedTable::from_parts(
149            full_pivot_data_mat.as_view(),
150            diskann_quantization::views::ChunkOffsetsView::new(&chunk_offsets)
151                .bridge_err()?
152                .to_owned(),
153        )
154        .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))?;
155
156        Ok(Self {
157            table,
158            num_chunks,
159            phantom_data: PhantomData,
160            phantom_pool: PhantomData,
161            phantom_storage: PhantomData,
162        })
163    }
164
165    fn compress(
166        &self,
167        vector: MatrixBase<&[f32]>,
168        output: MatrixBase<&mut [u8]>,
169    ) -> Result<(), diskann::ANNError> {
170        self.table
171            .compress_into(vector, output)
172            .map_err(|err| ANNError::log_pq_error(diskann_quantization::error::format(&err)))
173    }
174
175    fn compressed_bytes(&self) -> usize {
176        self.num_chunks
177    }
178}
179
180//////////////////
181///// Tests /////
182/////////////////
183
184#[cfg(test)]
185mod pq_generation_tests {
186    use diskann::ANNError;
187    use diskann_providers::model::pq::generate_pq_pivots;
188    use diskann_providers::model::GeneratePivotArguments;
189    use diskann_providers::storage::{PQStorage, StorageWriteProvider, VirtualStorageProvider};
190    use diskann_providers::utils::{
191        create_thread_pool_for_test, file_util::load_bin, save_bin_f32, AsThreadPool,
192    };
193    use diskann_utils::test_data_root;
194    use diskann_utils::views::{MatrixView, MutMatrixView};
195    use diskann_vector::distance::Metric;
196    use rstest::rstest;
197    use vfs::{FileSystem, MemoryFS, OverlayFS, PhysicalFS};
198
199    use super::{CompressionStage, PQGeneration, PQGenerationContext};
200    use crate::storage::quant::compressor::QuantCompressor;
201
202    const TEST_PQ_DATA_PATH: &str = "/sift/siftsmall_learn.bin";
203    const TEST_PQ_PIVOTS_PATH: &str = "/sift/siftsmall_learn_pq_pivots.bin";
204    const TEST_PQ_COMPRESSED_PATH: &str = "/sift/siftsmall_learn_pq_compressed.bin";
205    const VALIDATION_DATA: [f32; 40] = [
206        //sample validation data: npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8]
207        1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
208        2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
209        2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 100.0f32,
210        100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
211    ];
212    #[allow(clippy::too_many_arguments)]
213    fn create_new_compressor<'a, R: AsThreadPool>(
214        stage: CompressionStage,
215        provider: &'a VirtualStorageProvider<OverlayFS>,
216        dim: usize,
217        num_chunks: usize,
218        max_kmeans_reps: usize,
219        num_centers: usize,
220        p_val: f64,
221        pool: R,
222        pivots_path: String,
223        compressed_path: String,
224        data_path: Option<&str>,
225    ) -> Result<PQGeneration<'a, f32, VirtualStorageProvider<OverlayFS>, R>, ANNError> {
226        let pq_storage = PQStorage::new(&pivots_path, &compressed_path, data_path);
227        let context = PQGenerationContext::<'_, _, _> {
228            pq_storage,
229            num_chunks,
230            num_centers,
231            seed: Some(42),
232            p_val,
233            max_kmeans_reps,
234            storage_provider: provider,
235            pool,
236            metric: Metric::L2,
237            dim,
238        };
239        PQGeneration::<_, _, _>::new_at_stage(stage, &context)
240    }
241
242    #[rstest]
243    fn test_create_and_load_pivots_file() {
244        let fs = OverlayFS::new(&[MemoryFS::default().into()]);
245        fs.create_dir("/pq_generation_tests")
246            .expect("Could not create test directory");
247        let storage_provider = VirtualStorageProvider::new(fs);
248
249        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
250        let pivot_file_name_compressor = "/pq_generation_tests/compressor_pivots_test.bin";
251        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
252        let data_path = "/pq_generation_tests/data_path.bin";
253        let pq_storage: PQStorage =
254            PQStorage::new(pivot_file_name, compressed_file_name, Some(data_path));
255
256        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
257        let mut train_data: Vec<f32> = VALIDATION_DATA.to_vec();
258
259        let _ = save_bin_f32(
260            &mut storage_provider.create_for_write(data_path).unwrap(),
261            &train_data,
262            ndata,
263            dim,
264            0,
265        );
266
267        let pool = create_thread_pool_for_test();
268        generate_pq_pivots(
269            GeneratePivotArguments::new(
270                ndata,
271                dim,
272                num_centers,
273                num_chunks,
274                max_k_means_reps,
275                true,
276            )
277            .unwrap(),
278            &mut train_data,
279            &pq_storage,
280            &storage_provider,
281            diskann_providers::utils::create_rnd_provider_from_seed_in_tests(42),
282            &pool,
283        )
284        .unwrap();
285
286        let compressor = create_new_compressor(
287            CompressionStage::Start,
288            &storage_provider,
289            dim,
290            num_chunks,
291            max_k_means_reps,
292            num_centers,
293            1.0, //take all the data to compute codebook
294            &pool,
295            pivot_file_name_compressor.to_string(),
296            compressed_file_name.to_string(),
297            Some(data_path),
298        );
299
300        assert!(compressor.is_ok());
301
302        let compressor = compressor.unwrap();
303        assert_eq!(compressor.num_chunks, num_chunks);
304        assert_eq!(compressor.compressed_bytes(), num_chunks);
305
306        assert_eq!(compressor.table.dim(), dim);
307        assert_eq!(compressor.table.ncenters(), num_centers);
308        assert_eq!(compressor.table.nchunks(), num_chunks);
309
310        assert!(&storage_provider.exists(pivot_file_name_compressor));
311        let (compressor_pivots, cn, cd) =
312            load_bin::<u8, _>(&storage_provider, pivot_file_name_compressor, 0).unwrap();
313        let (true_pivots, n, d) = load_bin::<u8, _>(&storage_provider, pivot_file_name, 0).unwrap();
314
315        assert_eq!(cn, n);
316        assert_eq!(cd, d);
317        assert_eq!(compressor_pivots, true_pivots);
318    }
319
320    #[rstest]
321    fn throw_error_for_resume_and_no_existing_file() {
322        let fs = OverlayFS::new(&[
323            MemoryFS::default().into(),
324            // PhysicalFS::new("tests/data/").into(),
325        ]);
326        fs.create_dir("/pq_generation_tests")
327            .expect("Could not create test directory");
328        let storage_provider = VirtualStorageProvider::new(fs);
329
330        let pivot_file_name = "/pq_generation_tests/generate_pq_pivots_test.bin";
331        let compressed_file_name = "/pq_generation_tests/compressed_not_used.bin";
332        let data_path = "/pq_generation_tests/data_path.bin";
333
334        let (ndata, dim, num_centers, num_chunks, max_k_means_reps) = (5, 8, 2, 2, 5);
335
336        let _ = save_bin_f32(
337            &mut storage_provider.create_for_write(data_path).unwrap(),
338            &VALIDATION_DATA,
339            ndata,
340            dim,
341            0,
342        );
343        let pool = create_thread_pool_for_test();
344
345        let compressor = create_new_compressor(
346            CompressionStage::Resume,
347            &storage_provider,
348            dim,
349            num_chunks,
350            max_k_means_reps,
351            num_centers,
352            1.0,
353            &pool,
354            pivot_file_name.to_string(),
355            compressed_file_name.to_string(),
356            Some(data_path),
357        );
358
359        assert!(compressor.is_err());
360    }
361
362    #[rstest]
363    fn test_pq_end_to_end_with_codebook() {
364        let storage_provider = VirtualStorageProvider::new(OverlayFS::new(&[
365            MemoryFS::default().into(),
366            PhysicalFS::new(test_data_root()).into(),
367        ]));
368
369        let pool = create_thread_pool_for_test();
370        let dim = 128;
371        let num_chunks = 1;
372        let max_k_means_reps = 10;
373
374        let compressor = create_new_compressor(
375            CompressionStage::Resume,
376            &storage_provider,
377            dim,
378            num_chunks,
379            max_k_means_reps,
380            256,
381            1.0,
382            &pool,
383            TEST_PQ_PIVOTS_PATH.to_string(),
384            "".to_string(),
385            None,
386        );
387
388        if let Err(x) = compressor.as_ref() {
389            println!("Error creating compressor: {x}");
390        };
391
392        assert!(compressor.is_ok());
393
394        let (data, npts, dim) =
395            load_bin::<f32, _>(&storage_provider, TEST_PQ_DATA_PATH, 0).unwrap();
396        let mut compressed_mat = vec![0_u8; num_chunks * npts];
397        let result = compressor.unwrap().compress(
398            MatrixView::try_from(&data, npts, dim).unwrap(),
399            MutMatrixView::try_from(&mut compressed_mat, npts, num_chunks).unwrap(),
400        );
401        assert!(result.is_ok());
402
403        let (compressed_gt, _, _) =
404            load_bin::<u8, _>(&storage_provider, TEST_PQ_COMPRESSED_PATH, 0).unwrap();
405        assert_eq!(compressed_gt, compressed_mat);
406    }
407
408    #[rstest]
409    #[case(129, 128, 256)] // num_chunks > dim
410    #[case(128, 0, 256)] // num_chunks == 0
411    #[case(128, 128, 0)] // num_centers == 0
412    fn test_parameter_error_cases(
413        #[case] dim: usize,
414        #[case] num_chunks: usize,
415        #[case] centers: usize,
416    ) {
417        //test the error cases for parameters: num_chunks > dim, num_chunks == 0, num_centers == 0
418        let storage_provider = VirtualStorageProvider::new(OverlayFS::new(&[
419            MemoryFS::default().into(),
420            PhysicalFS::new("tests/data/").into(),
421        ]));
422        let pool = create_thread_pool_for_test();
423        let max_k_means_reps = 10;
424        let compressor = create_new_compressor(
425            CompressionStage::Start,
426            &storage_provider,
427            dim,
428            num_chunks,
429            max_k_means_reps,
430            centers,
431            1.0,
432            &pool,
433            TEST_PQ_PIVOTS_PATH.to_string(),
434            "".to_string(),
435            None,
436        );
437        assert!(compressor.is_err());
438    }
439}