Skip to main content

diskann_disk/storage/quant/
generator.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    io::{Seek, SeekFrom, Write},
8    marker::PhantomData,
9};
10
11use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
12use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
13use diskann_providers::{
14    forward_threadpool,
15    utils::{
16        load_metadata_from_file, write_metadata, AsThreadPool, BridgeErr, ParallelIteratorInPool,
17        Timer,
18    },
19};
20use diskann_utils::views::{self};
21use rayon::iter::IndexedParallelIterator;
22use tracing::info;
23
24use crate::{
25    build::chunking::{
26        checkpoint::Progress,
27        continuation::{process_while_resource_is_available, ChunkingConfig},
28    },
29    storage::quant::compressor::{CompressionStage, QuantCompressor},
30};
31
32/// [`GeneratorContext`] defines parameters for vector quantization checkpoint state
33///
34/// This struct holds offset position that allows resuming quantization from
35/// a specific point in the dataset as well as the data path to store the
36/// compressed vectors.
37#[derive(Clone, Debug)]
38pub struct GeneratorContext {
39    /// * `offset`: The point index to start/resume quantization from (for checkpoint support)
40    pub offset: usize,
41    /// * `compressed_data_path`: The path to which to write compressed data to.
42    pub compressed_data_path: String,
43}
44
45impl GeneratorContext {
46    pub fn new(offset: usize, compressed_data_path: String) -> Self {
47        Self {
48            offset,
49            compressed_data_path,
50        }
51    }
52}
53
54/// [`QuantDataGenerator`] orchestrates the process of reading vector data, applying quantization,
55/// and writing compressed results to storage. It resumes data generation from the checkpoint manager
56/// and processes data in batches.
57pub struct QuantDataGenerator<T, Q>
58where
59    T: Copy + VectorRepr,
60    Q: QuantCompressor<T>,
61{
62    pub quantizer: Q,
63    pub data_path: String,         // Path to the source vector data
64    pub context: GeneratorContext, // Overloadable context that contains metric and offset info
65    phantom: PhantomData<T>,
66}
67
68impl<T, Q> QuantDataGenerator<T, Q>
69where
70    T: Copy + VectorRepr,
71    Q: QuantCompressor<T>,
72{
73    pub fn new(
74        data_path: String,
75        context: GeneratorContext,
76        quantizer_context: &Q::CompressorContext,
77    ) -> ANNResult<Self> {
78        let stage = match context.offset {
79            0 => CompressionStage::Start,
80            _ => CompressionStage::Resume,
81        };
82        let quantizer = Q::new_at_stage(stage, quantizer_context)?;
83        Ok(Self {
84            data_path,
85            context,
86            quantizer,
87            phantom: PhantomData,
88        })
89    }
90
91    /// This method reads the source data file, processes vectors in batches, compresses them
92    /// using the provided quantizer, and writes the results to the compressed data file.
93    /// It supports checkpointing through the chunking_config and resumes from previous
94    /// interruptions using the offset stored in the context.
95    //
96    /// The implementation is adapted from generate_quantized_data_internal in pq_construction.rs
97    //
98    /// # Processing Flow
99    /// 1. Checks if starting from beginning (offset=0) and deletes any existing output if needed
100    /// 2. Opens source data file and reads metadata (num_points and dimension)
101    /// 3. Creates or opens output compressed file and writes metadata header - [num_points as i32, compressed_vector_size as i32]
102    /// 4. Processes data in blocks of size given by chunking_config.data_compression_chunk_vector_count = 50_000
103    /// 5. Compresses each block in small batch sizes in parallel to (potentially) take advantage of batch compression with quantizer
104    /// 6. Writes compressed blocks to the output file.
105    pub fn generate_data<Storage, Pool>(
106        &self,
107        storage_provider: &Storage, // Provider for reading source data and writing compressed results
108        pool: &Pool,                // Thread pool for parallel processing
109        chunking_config: &ChunkingConfig, // Configuration for batching and checkpoint handling
110    ) -> ANNResult<Progress>
111    where
112        Storage: StorageReadProvider + StorageWriteProvider,
113        Pool: AsThreadPool,
114    {
115        let timer = Timer::new();
116
117        let metadata = load_metadata_from_file(storage_provider, &self.data_path)?;
118        let (num_points, dim) = (metadata.npoints, metadata.ndims);
119
120        self.validate_params(num_points, storage_provider)?;
121
122        let offset = self.context.offset;
123        let compressed_path = self.context.compressed_data_path.as_str();
124
125        if offset == 0 && storage_provider.exists(compressed_path) {
126            storage_provider.delete(compressed_path)?;
127        }
128
129        info!("Generating quantized data for {}", compressed_path);
130
131        let data_reader = &mut storage_provider.open_reader(&self.data_path)?;
132
133        //open the writer for the compressed dataset if starting from the middle, else create a new one.
134        let mut compressed_data_writer = if offset > 0 {
135            storage_provider.open_writer(compressed_path)?
136        } else {
137            let mut sp = storage_provider.create_for_write(compressed_path)?;
138            // write meatadata to header
139            write_metadata(&mut sp, num_points, self.quantizer.compressed_bytes())?;
140            sp
141        };
142
143        //seek to the offset after skipping metadata
144        data_reader.seek(SeekFrom::Start(
145            (size_of::<i32>() * 2 + offset * dim * size_of::<T>()) as u64,
146        ))?;
147
148        let compressed_size = self.quantizer.compressed_bytes();
149        let max_block_size = chunking_config.data_compression_chunk_vector_count;
150        let num_remaining = num_points - offset;
151
152        let block_size = std::cmp::min(num_points, max_block_size);
153        let num_blocks =
154            num_remaining / block_size + !num_remaining.is_multiple_of(block_size) as usize;
155
156        info!(
157            "Compressing with block size {}, num_remaining {}, num_blocks {}, offset {}, num_points {}",
158            block_size, num_remaining, num_blocks, offset, num_points
159        );
160
161        let mut compressed_buffer = vec![0_u8; block_size * compressed_size];
162
163        forward_threadpool!(pool = pool: Pool);
164        //Every block has size exactly block_size, except for potentially the last one
165        let action = |block_index| -> ANNResult<()> {
166            let start_index: usize = offset + block_index * block_size;
167            let end_index: usize = std::cmp::min(start_index + block_size, num_points);
168            let cur_block_size: usize = end_index - start_index;
169
170            let block_compressed_base = &mut compressed_buffer[..cur_block_size * compressed_size];
171
172            let raw_block: Vec<T> =
173                diskann::utils::read_exact_into(data_reader, cur_block_size * dim)?;
174
175            let full_dim = T::full_dimension(&raw_block[..dim]).into_ann_result()?; // read full-dimension from first vector
176
177            let mut block_data: Vec<f32> = vec![f32::default(); cur_block_size * full_dim];
178            for (v, dst) in raw_block
179                .chunks_exact(dim)
180                .zip(block_data.chunks_exact_mut(full_dim))
181            {
182                T::as_f32_into(v, dst).into_ann_result()?;
183            }
184
185            // We need some batch size of data to pass to `compress`. There is a balance
186            // to achieve here. It must be:
187            //
188            // 1. Small enough to allow for parallelism across threads/tasks.
189            // 2. Large enough to take advantage of cache locality in `compress`.
190            //
191            // A value of 128 is a somewhat arbitrary compromise, meaning each task will
192            // process `BATCH_SIZE` many dataset vectors at a time.
193            const BATCH_SIZE: usize = 128;
194
195            // Wrap the data in `MatrixViews` so we do not need to manually construct view
196            // in the compression loop.
197            let mut compressed_block = views::MutMatrixView::try_from(
198                block_compressed_base,
199                cur_block_size,
200                compressed_size,
201            )
202            .bridge_err()?;
203            let base_block =
204                views::MatrixView::try_from(&block_data, cur_block_size, full_dim).bridge_err()?;
205            base_block
206                .par_window_iter(BATCH_SIZE)
207                .zip_eq(compressed_block.par_window_iter_mut(BATCH_SIZE))
208                .try_for_each_in_pool(pool, |(src, dst)| self.quantizer.compress(src, dst))?;
209
210            let write_offset = start_index * compressed_size + std::mem::size_of::<i32>() * 2;
211            compressed_data_writer.seek(SeekFrom::Start(write_offset as u64))?;
212            compressed_data_writer.write_all(block_compressed_base)?;
213            compressed_data_writer.flush()?;
214            Ok(())
215        };
216
217        let progress = process_while_resource_is_available(
218            action,
219            0..num_blocks,
220            chunking_config.continuation_checker.clone_box(),
221        )?
222        .map(|processed| processed * block_size + offset);
223
224        info!(
225            "Quant data generation took {} seconds",
226            timer.elapsed().as_secs_f64()
227        );
228
229        Ok(progress)
230    }
231
232    fn validate_params<Storage: StorageReadProvider + StorageWriteProvider>(
233        &self,
234        num_points: usize,
235        storage_provider: &Storage,
236    ) -> ANNResult<()> {
237        if self.context.offset > num_points {
238            //check to make sure offset is within limits.
239            return Err(ANNError::log_pq_error(
240                "Error: offset for compression is more than number of points",
241            ));
242        }
243
244        let compressed_path = &self.context.compressed_data_path;
245
246        if self.context.offset > 0 {
247            if !storage_provider.exists(compressed_path) {
248                return Err(ANNError::log_file_not_found_error(format!(
249                    "Error: Generator expected compressed file {compressed_path} but did not find it."
250                )));
251            }
252            let expected_length = self.quantizer.compressed_bytes() * self.context.offset
253                + std::mem::size_of::<i32>() * 2;
254            let existing_length =
255                storage_provider.get_length(&self.context.compressed_data_path)?;
256
257            if existing_length != expected_length as u64 {
258                //check to make sure compressed data file lengths is as expected based on offset.
259                return Err(ANNError::log_pq_error(format_args!(
260                    "Error: compressed data file length {existing_length} does not match expected length {expected_length}."
261                )));
262            }
263        }
264
265        Ok(())
266    }
267}
268
269//////////////////
270///// Tests /////
271/////////////////
272
273#[cfg(test)]
274mod generator_tests {
275    use std::{
276        io::BufReader,
277        sync::{Arc, RwLock},
278    };
279
280    use diskann::utils::read_exact_into;
281    use diskann_providers::storage::VirtualStorageProvider;
282    use diskann_providers::utils::{
283        create_thread_pool_for_test, read_metadata, save_bin_f32, save_bytes,
284    };
285    use rstest::rstest;
286    use vfs::{FileSystem, MemoryFS, OverlayFS};
287
288    use super::*;
289    use crate::build::chunking::continuation::{
290        ContinuationGrant, ContinuationTrackerTrait, NaiveContinuationTracker,
291    };
292
293    pub struct DummyCompressor {
294        pub output_dim: u32,
295        pub code: Vec<u8>,
296    }
297    impl DummyCompressor {
298        pub fn new(output_dim: u32) -> Self {
299            Self {
300                output_dim,
301                code: (0..output_dim).map(|x| (x % 256) as u8).collect(),
302            }
303        }
304    }
305    impl QuantCompressor<f32> for DummyCompressor {
306        type CompressorContext = u32;
307
308        fn new_at_stage(
309            _stage: CompressionStage,
310            context: &Self::CompressorContext,
311        ) -> ANNResult<Self> {
312            Ok(Self::new(*context))
313        }
314
315        fn compress(
316            &self,
317            _vector: views::MatrixView<f32>,
318            mut output: views::MutMatrixView<u8>,
319        ) -> ANNResult<()> {
320            output
321                .row_iter_mut()
322                .for_each(|r| r.copy_from_slice(&self.code));
323            Ok(())
324        }
325
326        fn compressed_bytes(&self) -> usize {
327            self.output_dim as usize
328        }
329    }
330
331    fn create_test_data(num_points: usize, dim: usize) -> Vec<f32> {
332        let mut data = Vec::new();
333
334        // Generate some test vector data
335        for i in 0..num_points {
336            for j in 0..dim {
337                data.push((i * dim + j) as f32);
338            }
339        }
340
341        data
342    }
343
344    //Mock continuation checker that stops after stop_count - 1 iterations.
345    struct MockStopContinuationChecker {
346        count: Arc<RwLock<usize>>,
347        stop_count: usize,
348    }
349
350    impl Clone for MockStopContinuationChecker {
351        fn clone(&self) -> Self {
352            MockStopContinuationChecker {
353                count: self.count.clone(),
354                stop_count: self.stop_count,
355            }
356        }
357    }
358
359    impl ContinuationTrackerTrait for MockStopContinuationChecker {
360        fn get_continuation_grant(&self) -> ContinuationGrant {
361            let mut count = self.count.write().unwrap();
362            *count += 1;
363            if !(*count).is_multiple_of(self.stop_count) {
364                ContinuationGrant::Continue
365            } else {
366                ContinuationGrant::Stop
367            }
368        }
369    }
370
371    fn generate_data_and_compressed(
372        num_points: usize,
373        dim: usize,
374        offset: usize,
375        output_dim: u32,
376    ) -> ANNResult<(VirtualStorageProvider<OverlayFS>, String, String)> {
377        let fs = OverlayFS::new(&[MemoryFS::default().into()]);
378        fs.create_dir("/test_data")
379            .expect("Could not create test directory");
380        let storage_provider = VirtualStorageProvider::new(fs);
381
382        let data_path = "/test_data/test_data.bin".to_string();
383        let compressed_path = "/test_data/test_compressed.bin".to_string();
384
385        // Setup test data
386        let _ = save_bin_f32(
387            &mut storage_provider.create_for_write(data_path.as_str())?,
388            &create_test_data(num_points, dim),
389            num_points,
390            dim,
391            0,
392        )?;
393
394        if offset > 0 {
395            // write head of file
396            let code = (0..output_dim).map(|x| (x % 256) as u8).collect::<Vec<_>>(); //this is the same code as in DummyQuantizer
397
398            let mut buffer = vec![0_u8; offset * output_dim as usize];
399            buffer
400                .chunks_exact_mut(output_dim as usize)
401                .for_each(|bf| bf.copy_from_slice(code.as_slice()));
402            let _ = save_bytes(
403                &mut storage_provider.create_for_write(compressed_path.as_str())?,
404                buffer.as_slice(),
405                num_points,
406                output_dim as usize,
407                0,
408            )?;
409        }
410
411        Ok((storage_provider, data_path, compressed_path))
412    }
413
414    fn create_and_call_generator(
415        offset: usize,
416        compressed_path: String,
417        storage_provider: &VirtualStorageProvider<OverlayFS>,
418        data_path: String,
419        output_dim: u32,
420        chunking_config: &ChunkingConfig,
421    ) -> (
422        QuantDataGenerator<f32, DummyCompressor>,
423        Result<Progress, ANNError>,
424    ) {
425        let pool: diskann_providers::utils::RayonThreadPool = create_thread_pool_for_test();
426        // Create generator
427        let context = GeneratorContext::new(offset, compressed_path.clone());
428        let generator = QuantDataGenerator::<f32, DummyCompressor>::new(
429            data_path.clone(),
430            context,
431            &output_dim,
432        )
433        .unwrap();
434        // Run generator
435        let result = generator.generate_data(storage_provider, &&pool, chunking_config);
436        (generator, result)
437    }
438
439    #[rstest]
440    #[case(100, 8, 4, 0, 10, 100 * 4)] //small test that fits in BATCH_SIZE
441    #[case(100, 8, 4, 50, 10, 100 * 4)] //small test that fits in BATCH_SIZE with offset > 0
442    #[case(257, 4, 8, 0, 10, 257 * 8)] //larger than BATCH_SIZE and not multiple of it
443    #[case(60_000, 384, 192, 5_000, 10, 60_000 * 192)] //larger than chunk_vector_count = 10_000 with offset > 0
444    #[case(60_000, 384, 192, 0, 10, 60_000 * 192)] //larger than chunk_vector_count = 10_000 with offset = 0
445    #[case(60_000, 384, 192, 0, 2, 10_000 * 192)] //should stop after 1 action block
446    #[case(60_000, 384, 192, 1000, 2, 11_000 * 192)] //same as above but with offset
447    fn test_generate_data_from_offset(
448        #[case] num_points: usize,
449        #[case] dim: usize,
450        #[case] output_dim: u32,
451        #[case] offset: usize,
452        #[case] config_stop_count: usize,
453        #[case] expected_size: usize,
454    ) -> ANNResult<()> {
455        let (storage_provider, data_path, compressed_path) =
456            generate_data_and_compressed(num_points, dim, offset, output_dim)?;
457
458        let chunking_config = ChunkingConfig {
459            continuation_checker: Box::new(MockStopContinuationChecker {
460                count: Arc::new(RwLock::new(0)),
461                stop_count: config_stop_count,
462            }),
463            data_compression_chunk_vector_count: 10_000,
464            inmemory_build_chunk_vector_count: 10_000,
465        };
466
467        let (generator, result) = create_and_call_generator(
468            offset,
469            compressed_path.clone(),
470            &storage_provider,
471            data_path,
472            output_dim,
473            &chunking_config,
474        );
475
476        assert!(result.is_ok(), "Result is not ok, got {:?}", result); //should have completed correctly
477        assert!(storage_provider.exists(&compressed_path)); // Verify output file
478
479        // Check compressed data size
480        let file_len = storage_provider.get_length(&compressed_path)? as usize;
481        assert_eq!(file_len, expected_size + 2 * std::mem::size_of::<i32>());
482
483        let mut r = storage_provider.open_reader(compressed_path.as_str())?;
484        let mut reader = BufReader::new(&mut r);
485        let metadata = read_metadata(&mut reader)?;
486
487        let data: Vec<u8> = read_exact_into(&mut reader, expected_size)?;
488
489        // Check header
490        assert_eq!(metadata.ndims as u32, output_dim);
491        assert_eq!(metadata.npoints, num_points);
492
493        // Check compressed data content
494        data.chunks_exact(output_dim as usize)
495            .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
496
497        Ok(())
498    }
499
500    #[test]
501    fn test_stop_and_continue_chunking_config() -> ANNResult<()> {
502        let (num_points, dim, output_dim) = (256, 128, 128);
503        let chunking_config = ChunkingConfig {
504            continuation_checker: Box::<NaiveContinuationTracker>::default(),
505            data_compression_chunk_vector_count: 10,
506            inmemory_build_chunk_vector_count: 10,
507        };
508        let (storage_provider, data_path, compressed_path) =
509            generate_data_and_compressed(num_points, dim, 0, output_dim)?;
510        let (mut generator, mut result) = create_and_call_generator(
511            0,
512            compressed_path.clone(),
513            &storage_provider,
514            data_path.clone(),
515            output_dim,
516            &chunking_config,
517        );
518        loop {
519            match result.as_ref().unwrap() {
520                Progress::Completed => break,
521                Progress::Processed(num_points) => {
522                    (generator, result) = create_and_call_generator(
523                        *num_points,
524                        compressed_path.clone(),
525                        &storage_provider,
526                        data_path.clone(),
527                        output_dim,
528                        &chunking_config,
529                    );
530                }
531            }
532        }
533
534        assert!(result.is_ok(), "Result is not ok, got {:?}", result); //should have completed correctly
535        assert!(storage_provider.exists(&compressed_path)); // Verify output file
536
537        // Check compressed data size
538        let file_len = storage_provider.get_length(&compressed_path)? as usize;
539        let expected_size = (num_points * output_dim as usize) + 2 * std::mem::size_of::<i32>();
540        assert_eq!(file_len, expected_size,);
541
542        let mut r = storage_provider.open_reader(compressed_path.as_str())?;
543        let mut reader = BufReader::new(&mut r);
544        let metadata = read_metadata(&mut reader)?;
545
546        let data: Vec<u8> =
547            read_exact_into(&mut reader, expected_size - 2 * std::mem::size_of::<i32>())?;
548
549        // Check header
550        assert_eq!(metadata.ndims as u32, output_dim);
551        assert_eq!(metadata.npoints, num_points);
552
553        // Check compressed data content
554        data.chunks_exact(output_dim as usize)
555            .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
556        Ok(())
557    }
558
559    #[rstest]
560    #[case(
561        1_024,
562        384,
563        192,
564        1_025,
565        0,
566        "offset for compression is more than number of points"
567    )]
568    #[case(
569        1_1024,
570        384,
571        192,
572        5,
573        15,
574        "compressed data file length 2888 does not match expected length 968."
575    )]
576    fn test_offset_error_case(
577        #[case] num_points: usize,
578        #[case] dim: usize,
579        #[case] output_dim: u32,
580        #[case] offset: usize,
581        #[case] error_offset: usize,
582        #[case] msg: String,
583    ) -> ANNResult<()> {
584        assert!(offset > 0);
585        let (storage_provider, data_path, compressed_path) =
586            generate_data_and_compressed(num_points, dim, error_offset, output_dim)?;
587
588        let (_, result) = create_and_call_generator(
589            offset,
590            compressed_path,
591            &storage_provider,
592            data_path,
593            output_dim,
594            &ChunkingConfig::default(),
595        );
596
597        assert!(result.is_err());
598        if let Err(e) = result {
599            let error_msg = format!("{:?}", e);
600            assert!(error_msg.contains(&msg), "{}", &error_msg);
601        }
602
603        Ok(())
604    }
605}