Skip to main content

diskann_disk/storage/quant/
generator.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    io::{Seek, SeekFrom, Write},
8    marker::PhantomData,
9};
10
11use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
12use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
13use diskann_providers::{
14    forward_threadpool,
15    utils::{load_metadata_from_file, AsThreadPool, BridgeErr, ParallelIteratorInPool, Timer},
16};
17use diskann_utils::{io::Metadata, views};
18use rayon::iter::IndexedParallelIterator;
19use tracing::info;
20
21use crate::{
22    build::chunking::{
23        checkpoint::Progress,
24        continuation::{process_while_resource_is_available, ChunkingConfig},
25    },
26    storage::quant::compressor::{CompressionStage, QuantCompressor},
27};
28
29/// [`GeneratorContext`] defines parameters for vector quantization checkpoint state
30///
31/// This struct holds offset position that allows resuming quantization from
32/// a specific point in the dataset as well as the data path to store the
33/// compressed vectors.
34#[derive(Clone, Debug)]
35pub struct GeneratorContext {
36    /// * `offset`: The point index to start/resume quantization from (for checkpoint support)
37    pub offset: usize,
38    /// * `compressed_data_path`: The path to which to write compressed data to.
39    pub compressed_data_path: String,
40}
41
42impl GeneratorContext {
43    pub fn new(offset: usize, compressed_data_path: String) -> Self {
44        Self {
45            offset,
46            compressed_data_path,
47        }
48    }
49}
50
51/// [`QuantDataGenerator`] orchestrates the process of reading vector data, applying quantization,
52/// and writing compressed results to storage. It resumes data generation from the checkpoint manager
53/// and processes data in batches.
54pub struct QuantDataGenerator<T, Q>
55where
56    T: Copy + VectorRepr,
57    Q: QuantCompressor<T>,
58{
59    pub quantizer: Q,
60    pub data_path: String,         // Path to the source vector data
61    pub context: GeneratorContext, // Overloadable context that contains metric and offset info
62    phantom: PhantomData<T>,
63}
64
65impl<T, Q> QuantDataGenerator<T, Q>
66where
67    T: Copy + VectorRepr,
68    Q: QuantCompressor<T>,
69{
70    pub fn new(
71        data_path: String,
72        context: GeneratorContext,
73        quantizer_context: &Q::CompressorContext,
74    ) -> ANNResult<Self> {
75        let stage = match context.offset {
76            0 => CompressionStage::Start,
77            _ => CompressionStage::Resume,
78        };
79        let quantizer = Q::new_at_stage(stage, quantizer_context)?;
80        Ok(Self {
81            data_path,
82            context,
83            quantizer,
84            phantom: PhantomData,
85        })
86    }
87
88    /// This method reads the source data file, processes vectors in batches, compresses them
89    /// using the provided quantizer, and writes the results to the compressed data file.
90    /// It supports checkpointing through the chunking_config and resumes from previous
91    /// interruptions using the offset stored in the context.
92    //
93    /// The implementation is adapted from generate_quantized_data_internal in pq_construction.rs
94    //
95    /// # Processing Flow
96    /// 1. Checks if starting from beginning (offset=0) and deletes any existing output if needed
97    /// 2. Opens source data file and reads metadata (num_points and dimension)
98    /// 3. Creates or opens output compressed file and writes metadata header - [num_points as i32, compressed_vector_size as i32]
99    /// 4. Processes data in blocks of size given by chunking_config.data_compression_chunk_vector_count = 50_000
100    /// 5. Compresses each block in small batch sizes in parallel to (potentially) take advantage of batch compression with quantizer
101    /// 6. Writes compressed blocks to the output file.
102    pub fn generate_data<Storage, Pool>(
103        &self,
104        storage_provider: &Storage, // Provider for reading source data and writing compressed results
105        pool: &Pool,                // Thread pool for parallel processing
106        chunking_config: &ChunkingConfig, // Configuration for batching and checkpoint handling
107    ) -> ANNResult<Progress>
108    where
109        Storage: StorageReadProvider + StorageWriteProvider,
110        Pool: AsThreadPool,
111    {
112        let timer = Timer::new();
113
114        let metadata = load_metadata_from_file(storage_provider, &self.data_path)?;
115        let (num_points, dim) = metadata.into_dims();
116
117        self.validate_params(num_points, storage_provider)?;
118
119        let offset = self.context.offset;
120        let compressed_path = self.context.compressed_data_path.as_str();
121
122        if offset == 0 && storage_provider.exists(compressed_path) {
123            storage_provider.delete(compressed_path)?;
124        }
125
126        info!("Generating quantized data for {}", compressed_path);
127
128        let data_reader = &mut storage_provider.open_reader(&self.data_path)?;
129
130        //open the writer for the compressed dataset if starting from the middle, else create a new one.
131        let mut compressed_data_writer = if offset > 0 {
132            storage_provider.open_writer(compressed_path)?
133        } else {
134            let mut sp = storage_provider.create_for_write(compressed_path)?;
135            // write metadata to header
136            Metadata::new(num_points, self.quantizer.compressed_bytes())?.write(&mut sp)?;
137            sp
138        };
139
140        //seek to the offset after skipping metadata
141        data_reader.seek(SeekFrom::Start(
142            (size_of::<i32>() * 2 + offset * dim * size_of::<T>()) as u64,
143        ))?;
144
145        let compressed_size = self.quantizer.compressed_bytes();
146        let max_block_size = chunking_config.data_compression_chunk_vector_count;
147        let num_remaining = num_points - offset;
148
149        let block_size = std::cmp::min(num_points, max_block_size);
150        let num_blocks =
151            num_remaining / block_size + !num_remaining.is_multiple_of(block_size) as usize;
152
153        info!(
154            "Compressing with block size {}, num_remaining {}, num_blocks {}, offset {}, num_points {}",
155            block_size, num_remaining, num_blocks, offset, num_points
156        );
157
158        let mut compressed_buffer = vec![0_u8; block_size * compressed_size];
159
160        forward_threadpool!(pool = pool: Pool);
161        //Every block has size exactly block_size, except for potentially the last one
162        let action = |block_index| -> ANNResult<()> {
163            let start_index: usize = offset + block_index * block_size;
164            let end_index: usize = std::cmp::min(start_index + block_size, num_points);
165            let cur_block_size: usize = end_index - start_index;
166
167            let block_compressed_base = &mut compressed_buffer[..cur_block_size * compressed_size];
168
169            let raw_block: Vec<T> =
170                diskann::utils::read_exact_into(data_reader, cur_block_size * dim)?;
171
172            let full_dim = T::full_dimension(&raw_block[..dim]).into_ann_result()?; // read full-dimension from first vector
173
174            let mut block_data: Vec<f32> = vec![f32::default(); cur_block_size * full_dim];
175            for (v, dst) in raw_block
176                .chunks_exact(dim)
177                .zip(block_data.chunks_exact_mut(full_dim))
178            {
179                T::as_f32_into(v, dst).into_ann_result()?;
180            }
181
182            // We need some batch size of data to pass to `compress`. There is a balance
183            // to achieve here. It must be:
184            //
185            // 1. Small enough to allow for parallelism across threads/tasks.
186            // 2. Large enough to take advantage of cache locality in `compress`.
187            //
188            // A value of 128 is a somewhat arbitrary compromise, meaning each task will
189            // process `BATCH_SIZE` many dataset vectors at a time.
190            const BATCH_SIZE: usize = 128;
191
192            // Wrap the data in `MatrixViews` so we do not need to manually construct view
193            // in the compression loop.
194            let mut compressed_block = views::MutMatrixView::try_from(
195                block_compressed_base,
196                cur_block_size,
197                compressed_size,
198            )
199            .bridge_err()?;
200            let base_block =
201                views::MatrixView::try_from(&block_data, cur_block_size, full_dim).bridge_err()?;
202            base_block
203                .par_window_iter(BATCH_SIZE)
204                .zip_eq(compressed_block.par_window_iter_mut(BATCH_SIZE))
205                .try_for_each_in_pool(pool, |(src, dst)| self.quantizer.compress(src, dst))?;
206
207            let write_offset = start_index * compressed_size + std::mem::size_of::<i32>() * 2;
208            compressed_data_writer.seek(SeekFrom::Start(write_offset as u64))?;
209            compressed_data_writer.write_all(block_compressed_base)?;
210            compressed_data_writer.flush()?;
211            Ok(())
212        };
213
214        let progress = process_while_resource_is_available(
215            action,
216            0..num_blocks,
217            chunking_config.continuation_checker.clone_box(),
218        )?
219        .map(|processed| processed * block_size + offset);
220
221        info!(
222            "Quant data generation took {} seconds",
223            timer.elapsed().as_secs_f64()
224        );
225
226        Ok(progress)
227    }
228
229    fn validate_params<Storage: StorageReadProvider + StorageWriteProvider>(
230        &self,
231        num_points: usize,
232        storage_provider: &Storage,
233    ) -> ANNResult<()> {
234        if self.context.offset > num_points {
235            //check to make sure offset is within limits.
236            return Err(ANNError::log_pq_error(
237                "Error: offset for compression is more than number of points",
238            ));
239        }
240
241        let compressed_path = &self.context.compressed_data_path;
242
243        if self.context.offset > 0 {
244            if !storage_provider.exists(compressed_path) {
245                return Err(ANNError::log_file_not_found_error(format!(
246                    "Error: Generator expected compressed file {compressed_path} but did not find it."
247                )));
248            }
249            let expected_length = self.quantizer.compressed_bytes() * self.context.offset
250                + std::mem::size_of::<i32>() * 2;
251            let existing_length =
252                storage_provider.get_length(&self.context.compressed_data_path)?;
253
254            if existing_length != expected_length as u64 {
255                //check to make sure compressed data file lengths is as expected based on offset.
256                return Err(ANNError::log_pq_error(format_args!(
257                    "Error: compressed data file length {existing_length} does not match expected length {expected_length}."
258                )));
259            }
260        }
261
262        Ok(())
263    }
264}
265
266//////////////////
267///// Tests /////
268/////////////////
269
270#[cfg(test)]
271mod generator_tests {
272    use std::{
273        io::BufReader,
274        sync::{Arc, RwLock},
275    };
276
277    use diskann::utils::read_exact_into;
278    use diskann_providers::storage::VirtualStorageProvider;
279    use diskann_providers::utils::{create_thread_pool_for_test, save_bytes};
280    use diskann_utils::{
281        io::{write_bin, Metadata},
282        views::MatrixView,
283    };
284    use rstest::rstest;
285    use vfs::{FileSystem, MemoryFS};
286
287    use super::*;
288    use crate::build::chunking::continuation::{
289        ContinuationGrant, ContinuationTrackerTrait, NaiveContinuationTracker,
290    };
291
292    pub struct DummyCompressor {
293        pub output_dim: u32,
294        pub code: Vec<u8>,
295    }
296    impl DummyCompressor {
297        pub fn new(output_dim: u32) -> Self {
298            Self {
299                output_dim,
300                code: (0..output_dim).map(|x| (x % 256) as u8).collect(),
301            }
302        }
303    }
304    impl QuantCompressor<f32> for DummyCompressor {
305        type CompressorContext = u32;
306
307        fn new_at_stage(
308            _stage: CompressionStage,
309            context: &Self::CompressorContext,
310        ) -> ANNResult<Self> {
311            Ok(Self::new(*context))
312        }
313
314        fn compress(
315            &self,
316            _vector: views::MatrixView<f32>,
317            mut output: views::MutMatrixView<u8>,
318        ) -> ANNResult<()> {
319            output
320                .row_iter_mut()
321                .for_each(|r| r.copy_from_slice(&self.code));
322            Ok(())
323        }
324
325        fn compressed_bytes(&self) -> usize {
326            self.output_dim as usize
327        }
328    }
329
330    fn create_test_data(num_points: usize, dim: usize) -> Vec<f32> {
331        let mut data = Vec::new();
332
333        // Generate some test vector data
334        for i in 0..num_points {
335            for j in 0..dim {
336                data.push((i * dim + j) as f32);
337            }
338        }
339
340        data
341    }
342
343    //Mock continuation checker that stops after stop_count - 1 iterations.
344    struct MockStopContinuationChecker {
345        count: Arc<RwLock<usize>>,
346        stop_count: usize,
347    }
348
349    impl Clone for MockStopContinuationChecker {
350        fn clone(&self) -> Self {
351            MockStopContinuationChecker {
352                count: self.count.clone(),
353                stop_count: self.stop_count,
354            }
355        }
356    }
357
358    impl ContinuationTrackerTrait for MockStopContinuationChecker {
359        fn get_continuation_grant(&self) -> ContinuationGrant {
360            let mut count = self.count.write().unwrap();
361            *count += 1;
362            if !(*count).is_multiple_of(self.stop_count) {
363                ContinuationGrant::Continue
364            } else {
365                ContinuationGrant::Stop
366            }
367        }
368    }
369
370    fn generate_data_and_compressed(
371        num_points: usize,
372        dim: usize,
373        offset: usize,
374        output_dim: u32,
375    ) -> ANNResult<(VirtualStorageProvider<MemoryFS>, String, String)> {
376        let storage_provider = VirtualStorageProvider::new_memory();
377        storage_provider
378            .filesystem()
379            .create_dir("/test_data")
380            .expect("Could not create test directory");
381
382        let data_path = "/test_data/test_data.bin".to_string();
383        let compressed_path = "/test_data/test_compressed.bin".to_string();
384
385        // Setup test data
386        let data = create_test_data(num_points, dim);
387        let view = MatrixView::try_from(data.as_slice(), num_points, dim).unwrap();
388        write_bin(
389            view,
390            &mut storage_provider.create_for_write(data_path.as_str())?,
391        )?;
392
393        if offset > 0 {
394            // write head of file
395            let code = (0..output_dim).map(|x| (x % 256) as u8).collect::<Vec<_>>(); //this is the same code as in DummyQuantizer
396
397            let mut buffer = vec![0_u8; offset * output_dim as usize];
398            buffer
399                .chunks_exact_mut(output_dim as usize)
400                .for_each(|bf| bf.copy_from_slice(code.as_slice()));
401            let _ = save_bytes(
402                &mut storage_provider.create_for_write(compressed_path.as_str())?,
403                buffer.as_slice(),
404                num_points,
405                output_dim as usize,
406                0,
407            )?;
408        }
409
410        Ok((storage_provider, data_path, compressed_path))
411    }
412
413    fn create_and_call_generator<F: vfs::FileSystem>(
414        offset: usize,
415        compressed_path: String,
416        storage_provider: &VirtualStorageProvider<F>,
417        data_path: String,
418        output_dim: u32,
419        chunking_config: &ChunkingConfig,
420    ) -> (
421        QuantDataGenerator<f32, DummyCompressor>,
422        Result<Progress, ANNError>,
423    ) {
424        let pool: diskann_providers::utils::RayonThreadPool = create_thread_pool_for_test();
425        // Create generator
426        let context = GeneratorContext::new(offset, compressed_path.clone());
427        let generator = QuantDataGenerator::<f32, DummyCompressor>::new(
428            data_path.clone(),
429            context,
430            &output_dim,
431        )
432        .unwrap();
433        // Run generator
434        let result = generator.generate_data(storage_provider, &&pool, chunking_config);
435        (generator, result)
436    }
437
438    #[rstest]
439    #[case(100, 8, 4, 0, 10, 100 * 4)] //small test that fits in BATCH_SIZE
440    #[case(100, 8, 4, 50, 10, 100 * 4)] //small test that fits in BATCH_SIZE with offset > 0
441    #[case(257, 4, 8, 0, 10, 257 * 8)] //larger than BATCH_SIZE and not multiple of it
442    #[case(60_000, 384, 192, 5_000, 10, 60_000 * 192)] //larger than chunk_vector_count = 10_000 with offset > 0
443    #[case(60_000, 384, 192, 0, 10, 60_000 * 192)] //larger than chunk_vector_count = 10_000 with offset = 0
444    #[case(60_000, 384, 192, 0, 2, 10_000 * 192)] //should stop after 1 action block
445    #[case(60_000, 384, 192, 1000, 2, 11_000 * 192)] //same as above but with offset
446    fn test_generate_data_from_offset(
447        #[case] num_points: usize,
448        #[case] dim: usize,
449        #[case] output_dim: u32,
450        #[case] offset: usize,
451        #[case] config_stop_count: usize,
452        #[case] expected_size: usize,
453    ) -> ANNResult<()> {
454        let (storage_provider, data_path, compressed_path) =
455            generate_data_and_compressed(num_points, dim, offset, output_dim)?;
456
457        let chunking_config = ChunkingConfig {
458            continuation_checker: Box::new(MockStopContinuationChecker {
459                count: Arc::new(RwLock::new(0)),
460                stop_count: config_stop_count,
461            }),
462            data_compression_chunk_vector_count: 10_000,
463            inmemory_build_chunk_vector_count: 10_000,
464        };
465
466        let (generator, result) = create_and_call_generator(
467            offset,
468            compressed_path.clone(),
469            &storage_provider,
470            data_path,
471            output_dim,
472            &chunking_config,
473        );
474
475        assert!(result.is_ok(), "Result is not ok, got {:?}", result); //should have completed correctly
476        assert!(storage_provider.exists(&compressed_path)); // Verify output file
477
478        // Check compressed data size
479        let file_len = storage_provider.get_length(&compressed_path)? as usize;
480        assert_eq!(file_len, expected_size + 2 * std::mem::size_of::<i32>());
481
482        let mut r = storage_provider.open_reader(compressed_path.as_str())?;
483        let mut reader = BufReader::new(&mut r);
484        let metadata = Metadata::read(&mut reader)?;
485
486        let data: Vec<u8> = read_exact_into(&mut reader, expected_size)?;
487
488        // Check header
489        assert_eq!(metadata.ndims_u32(), output_dim);
490        assert_eq!(metadata.npoints(), num_points);
491
492        // Check compressed data content
493        data.chunks_exact(output_dim as usize)
494            .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
495
496        Ok(())
497    }
498
499    #[test]
500    fn test_stop_and_continue_chunking_config() -> ANNResult<()> {
501        let (num_points, dim, output_dim) = (256, 128, 128);
502        let chunking_config = ChunkingConfig {
503            continuation_checker: Box::<NaiveContinuationTracker>::default(),
504            data_compression_chunk_vector_count: 10,
505            inmemory_build_chunk_vector_count: 10,
506        };
507        let (storage_provider, data_path, compressed_path) =
508            generate_data_and_compressed(num_points, dim, 0, output_dim)?;
509        let (mut generator, mut result) = create_and_call_generator(
510            0,
511            compressed_path.clone(),
512            &storage_provider,
513            data_path.clone(),
514            output_dim,
515            &chunking_config,
516        );
517        loop {
518            match result.as_ref().unwrap() {
519                Progress::Completed => break,
520                Progress::Processed(num_points) => {
521                    (generator, result) = create_and_call_generator(
522                        *num_points,
523                        compressed_path.clone(),
524                        &storage_provider,
525                        data_path.clone(),
526                        output_dim,
527                        &chunking_config,
528                    );
529                }
530            }
531        }
532
533        assert!(result.is_ok(), "Result is not ok, got {:?}", result); //should have completed correctly
534        assert!(storage_provider.exists(&compressed_path)); // Verify output file
535
536        // Check compressed data size
537        let file_len = storage_provider.get_length(&compressed_path)? as usize;
538        let expected_size = (num_points * output_dim as usize) + 2 * std::mem::size_of::<i32>();
539        assert_eq!(file_len, expected_size,);
540
541        let mut r = storage_provider.open_reader(compressed_path.as_str())?;
542        let mut reader = BufReader::new(&mut r);
543        let metadata = Metadata::read(&mut reader)?;
544
545        let data: Vec<u8> =
546            read_exact_into(&mut reader, expected_size - 2 * std::mem::size_of::<i32>())?;
547
548        // Check header
549        assert_eq!(metadata.ndims_u32(), output_dim);
550        assert_eq!(metadata.npoints(), num_points);
551
552        // Check compressed data content
553        data.chunks_exact(output_dim as usize)
554            .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
555        Ok(())
556    }
557
558    #[rstest]
559    #[case(
560        1_024,
561        384,
562        192,
563        1_025,
564        0,
565        "offset for compression is more than number of points"
566    )]
567    #[case(
568        1_1024,
569        384,
570        192,
571        5,
572        15,
573        "compressed data file length 2888 does not match expected length 968."
574    )]
575    fn test_offset_error_case(
576        #[case] num_points: usize,
577        #[case] dim: usize,
578        #[case] output_dim: u32,
579        #[case] offset: usize,
580        #[case] error_offset: usize,
581        #[case] msg: String,
582    ) -> ANNResult<()> {
583        assert!(offset > 0);
584        let (storage_provider, data_path, compressed_path) =
585            generate_data_and_compressed(num_points, dim, error_offset, output_dim)?;
586
587        let (_, result) = create_and_call_generator(
588            offset,
589            compressed_path,
590            &storage_provider,
591            data_path,
592            output_dim,
593            &ChunkingConfig::default(),
594        );
595
596        assert!(result.is_err());
597        if let Err(e) = result {
598            let error_msg = format!("{:?}", e);
599            assert!(error_msg.contains(&msg), "{}", &error_msg);
600        }
601
602        Ok(())
603    }
604}