1use std::{
7 io::{Seek, SeekFrom, Write},
8 marker::PhantomData,
9};
10
11use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
12use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
13use diskann_providers::{
14 forward_threadpool,
15 utils::{load_metadata_from_file, AsThreadPool, BridgeErr, ParallelIteratorInPool, Timer},
16};
17use diskann_utils::{io::Metadata, views};
18use rayon::iter::IndexedParallelIterator;
19use tracing::info;
20
21use crate::{
22 build::chunking::{
23 checkpoint::Progress,
24 continuation::{process_while_resource_is_available, ChunkingConfig},
25 },
26 storage::quant::compressor::{CompressionStage, QuantCompressor},
27};
28
29#[derive(Clone, Debug)]
35pub struct GeneratorContext {
36 pub offset: usize,
38 pub compressed_data_path: String,
40}
41
42impl GeneratorContext {
43 pub fn new(offset: usize, compressed_data_path: String) -> Self {
44 Self {
45 offset,
46 compressed_data_path,
47 }
48 }
49}
50
51pub struct QuantDataGenerator<T, Q>
55where
56 T: Copy + VectorRepr,
57 Q: QuantCompressor<T>,
58{
59 pub quantizer: Q,
60 pub data_path: String, pub context: GeneratorContext, phantom: PhantomData<T>,
63}
64
65impl<T, Q> QuantDataGenerator<T, Q>
66where
67 T: Copy + VectorRepr,
68 Q: QuantCompressor<T>,
69{
70 pub fn new(
71 data_path: String,
72 context: GeneratorContext,
73 quantizer_context: &Q::CompressorContext,
74 ) -> ANNResult<Self> {
75 let stage = match context.offset {
76 0 => CompressionStage::Start,
77 _ => CompressionStage::Resume,
78 };
79 let quantizer = Q::new_at_stage(stage, quantizer_context)?;
80 Ok(Self {
81 data_path,
82 context,
83 quantizer,
84 phantom: PhantomData,
85 })
86 }
87
88 pub fn generate_data<Storage, Pool>(
103 &self,
104 storage_provider: &Storage, pool: &Pool, chunking_config: &ChunkingConfig, ) -> ANNResult<Progress>
108 where
109 Storage: StorageReadProvider + StorageWriteProvider,
110 Pool: AsThreadPool,
111 {
112 let timer = Timer::new();
113
114 let metadata = load_metadata_from_file(storage_provider, &self.data_path)?;
115 let (num_points, dim) = metadata.into_dims();
116
117 self.validate_params(num_points, storage_provider)?;
118
119 let offset = self.context.offset;
120 let compressed_path = self.context.compressed_data_path.as_str();
121
122 if offset == 0 && storage_provider.exists(compressed_path) {
123 storage_provider.delete(compressed_path)?;
124 }
125
126 info!("Generating quantized data for {}", compressed_path);
127
128 let data_reader = &mut storage_provider.open_reader(&self.data_path)?;
129
130 let mut compressed_data_writer = if offset > 0 {
132 storage_provider.open_writer(compressed_path)?
133 } else {
134 let mut sp = storage_provider.create_for_write(compressed_path)?;
135 Metadata::new(num_points, self.quantizer.compressed_bytes())?.write(&mut sp)?;
137 sp
138 };
139
140 data_reader.seek(SeekFrom::Start(
142 (size_of::<i32>() * 2 + offset * dim * size_of::<T>()) as u64,
143 ))?;
144
145 let compressed_size = self.quantizer.compressed_bytes();
146 let max_block_size = chunking_config.data_compression_chunk_vector_count;
147 let num_remaining = num_points - offset;
148
149 let block_size = std::cmp::min(num_points, max_block_size);
150 let num_blocks =
151 num_remaining / block_size + !num_remaining.is_multiple_of(block_size) as usize;
152
153 info!(
154 "Compressing with block size {}, num_remaining {}, num_blocks {}, offset {}, num_points {}",
155 block_size, num_remaining, num_blocks, offset, num_points
156 );
157
158 let mut compressed_buffer = vec![0_u8; block_size * compressed_size];
159
160 forward_threadpool!(pool = pool: Pool);
161 let action = |block_index| -> ANNResult<()> {
163 let start_index: usize = offset + block_index * block_size;
164 let end_index: usize = std::cmp::min(start_index + block_size, num_points);
165 let cur_block_size: usize = end_index - start_index;
166
167 let block_compressed_base = &mut compressed_buffer[..cur_block_size * compressed_size];
168
169 let raw_block: Vec<T> =
170 diskann::utils::read_exact_into(data_reader, cur_block_size * dim)?;
171
172 let full_dim = T::full_dimension(&raw_block[..dim]).into_ann_result()?; let mut block_data: Vec<f32> = vec![f32::default(); cur_block_size * full_dim];
175 for (v, dst) in raw_block
176 .chunks_exact(dim)
177 .zip(block_data.chunks_exact_mut(full_dim))
178 {
179 T::as_f32_into(v, dst).into_ann_result()?;
180 }
181
182 const BATCH_SIZE: usize = 128;
191
192 let mut compressed_block = views::MutMatrixView::try_from(
195 block_compressed_base,
196 cur_block_size,
197 compressed_size,
198 )
199 .bridge_err()?;
200 let base_block =
201 views::MatrixView::try_from(&block_data, cur_block_size, full_dim).bridge_err()?;
202 base_block
203 .par_window_iter(BATCH_SIZE)
204 .zip_eq(compressed_block.par_window_iter_mut(BATCH_SIZE))
205 .try_for_each_in_pool(pool, |(src, dst)| self.quantizer.compress(src, dst))?;
206
207 let write_offset = start_index * compressed_size + std::mem::size_of::<i32>() * 2;
208 compressed_data_writer.seek(SeekFrom::Start(write_offset as u64))?;
209 compressed_data_writer.write_all(block_compressed_base)?;
210 compressed_data_writer.flush()?;
211 Ok(())
212 };
213
214 let progress = process_while_resource_is_available(
215 action,
216 0..num_blocks,
217 chunking_config.continuation_checker.clone_box(),
218 )?
219 .map(|processed| processed * block_size + offset);
220
221 info!(
222 "Quant data generation took {} seconds",
223 timer.elapsed().as_secs_f64()
224 );
225
226 Ok(progress)
227 }
228
229 fn validate_params<Storage: StorageReadProvider + StorageWriteProvider>(
230 &self,
231 num_points: usize,
232 storage_provider: &Storage,
233 ) -> ANNResult<()> {
234 if self.context.offset > num_points {
235 return Err(ANNError::log_pq_error(
237 "Error: offset for compression is more than number of points",
238 ));
239 }
240
241 let compressed_path = &self.context.compressed_data_path;
242
243 if self.context.offset > 0 {
244 if !storage_provider.exists(compressed_path) {
245 return Err(ANNError::log_file_not_found_error(format!(
246 "Error: Generator expected compressed file {compressed_path} but did not find it."
247 )));
248 }
249 let expected_length = self.quantizer.compressed_bytes() * self.context.offset
250 + std::mem::size_of::<i32>() * 2;
251 let existing_length =
252 storage_provider.get_length(&self.context.compressed_data_path)?;
253
254 if existing_length != expected_length as u64 {
255 return Err(ANNError::log_pq_error(format_args!(
257 "Error: compressed data file length {existing_length} does not match expected length {expected_length}."
258 )));
259 }
260 }
261
262 Ok(())
263 }
264}
265
266#[cfg(test)]
271mod generator_tests {
272 use std::{
273 io::BufReader,
274 sync::{Arc, RwLock},
275 };
276
277 use diskann::utils::read_exact_into;
278 use diskann_providers::storage::VirtualStorageProvider;
279 use diskann_providers::utils::{create_thread_pool_for_test, save_bytes};
280 use diskann_utils::{
281 io::{write_bin, Metadata},
282 views::MatrixView,
283 };
284 use rstest::rstest;
285 use vfs::{FileSystem, MemoryFS};
286
287 use super::*;
288 use crate::build::chunking::continuation::{
289 ContinuationGrant, ContinuationTrackerTrait, NaiveContinuationTracker,
290 };
291
292 pub struct DummyCompressor {
293 pub output_dim: u32,
294 pub code: Vec<u8>,
295 }
296 impl DummyCompressor {
297 pub fn new(output_dim: u32) -> Self {
298 Self {
299 output_dim,
300 code: (0..output_dim).map(|x| (x % 256) as u8).collect(),
301 }
302 }
303 }
304 impl QuantCompressor<f32> for DummyCompressor {
305 type CompressorContext = u32;
306
307 fn new_at_stage(
308 _stage: CompressionStage,
309 context: &Self::CompressorContext,
310 ) -> ANNResult<Self> {
311 Ok(Self::new(*context))
312 }
313
314 fn compress(
315 &self,
316 _vector: views::MatrixView<f32>,
317 mut output: views::MutMatrixView<u8>,
318 ) -> ANNResult<()> {
319 output
320 .row_iter_mut()
321 .for_each(|r| r.copy_from_slice(&self.code));
322 Ok(())
323 }
324
325 fn compressed_bytes(&self) -> usize {
326 self.output_dim as usize
327 }
328 }
329
330 fn create_test_data(num_points: usize, dim: usize) -> Vec<f32> {
331 let mut data = Vec::new();
332
333 for i in 0..num_points {
335 for j in 0..dim {
336 data.push((i * dim + j) as f32);
337 }
338 }
339
340 data
341 }
342
343 struct MockStopContinuationChecker {
345 count: Arc<RwLock<usize>>,
346 stop_count: usize,
347 }
348
349 impl Clone for MockStopContinuationChecker {
350 fn clone(&self) -> Self {
351 MockStopContinuationChecker {
352 count: self.count.clone(),
353 stop_count: self.stop_count,
354 }
355 }
356 }
357
358 impl ContinuationTrackerTrait for MockStopContinuationChecker {
359 fn get_continuation_grant(&self) -> ContinuationGrant {
360 let mut count = self.count.write().unwrap();
361 *count += 1;
362 if !(*count).is_multiple_of(self.stop_count) {
363 ContinuationGrant::Continue
364 } else {
365 ContinuationGrant::Stop
366 }
367 }
368 }
369
370 fn generate_data_and_compressed(
371 num_points: usize,
372 dim: usize,
373 offset: usize,
374 output_dim: u32,
375 ) -> ANNResult<(VirtualStorageProvider<MemoryFS>, String, String)> {
376 let storage_provider = VirtualStorageProvider::new_memory();
377 storage_provider
378 .filesystem()
379 .create_dir("/test_data")
380 .expect("Could not create test directory");
381
382 let data_path = "/test_data/test_data.bin".to_string();
383 let compressed_path = "/test_data/test_compressed.bin".to_string();
384
385 let data = create_test_data(num_points, dim);
387 let view = MatrixView::try_from(data.as_slice(), num_points, dim).unwrap();
388 write_bin(
389 view,
390 &mut storage_provider.create_for_write(data_path.as_str())?,
391 )?;
392
393 if offset > 0 {
394 let code = (0..output_dim).map(|x| (x % 256) as u8).collect::<Vec<_>>(); let mut buffer = vec![0_u8; offset * output_dim as usize];
398 buffer
399 .chunks_exact_mut(output_dim as usize)
400 .for_each(|bf| bf.copy_from_slice(code.as_slice()));
401 let _ = save_bytes(
402 &mut storage_provider.create_for_write(compressed_path.as_str())?,
403 buffer.as_slice(),
404 num_points,
405 output_dim as usize,
406 0,
407 )?;
408 }
409
410 Ok((storage_provider, data_path, compressed_path))
411 }
412
413 fn create_and_call_generator<F: vfs::FileSystem>(
414 offset: usize,
415 compressed_path: String,
416 storage_provider: &VirtualStorageProvider<F>,
417 data_path: String,
418 output_dim: u32,
419 chunking_config: &ChunkingConfig,
420 ) -> (
421 QuantDataGenerator<f32, DummyCompressor>,
422 Result<Progress, ANNError>,
423 ) {
424 let pool: diskann_providers::utils::RayonThreadPool = create_thread_pool_for_test();
425 let context = GeneratorContext::new(offset, compressed_path.clone());
427 let generator = QuantDataGenerator::<f32, DummyCompressor>::new(
428 data_path.clone(),
429 context,
430 &output_dim,
431 )
432 .unwrap();
433 let result = generator.generate_data(storage_provider, &&pool, chunking_config);
435 (generator, result)
436 }
437
438 #[rstest]
439 #[case(100, 8, 4, 0, 10, 100 * 4)] #[case(100, 8, 4, 50, 10, 100 * 4)] #[case(257, 4, 8, 0, 10, 257 * 8)] #[case(60_000, 384, 192, 5_000, 10, 60_000 * 192)] #[case(60_000, 384, 192, 0, 10, 60_000 * 192)] #[case(60_000, 384, 192, 0, 2, 10_000 * 192)] #[case(60_000, 384, 192, 1000, 2, 11_000 * 192)] fn test_generate_data_from_offset(
447 #[case] num_points: usize,
448 #[case] dim: usize,
449 #[case] output_dim: u32,
450 #[case] offset: usize,
451 #[case] config_stop_count: usize,
452 #[case] expected_size: usize,
453 ) -> ANNResult<()> {
454 let (storage_provider, data_path, compressed_path) =
455 generate_data_and_compressed(num_points, dim, offset, output_dim)?;
456
457 let chunking_config = ChunkingConfig {
458 continuation_checker: Box::new(MockStopContinuationChecker {
459 count: Arc::new(RwLock::new(0)),
460 stop_count: config_stop_count,
461 }),
462 data_compression_chunk_vector_count: 10_000,
463 inmemory_build_chunk_vector_count: 10_000,
464 };
465
466 let (generator, result) = create_and_call_generator(
467 offset,
468 compressed_path.clone(),
469 &storage_provider,
470 data_path,
471 output_dim,
472 &chunking_config,
473 );
474
475 assert!(result.is_ok(), "Result is not ok, got {:?}", result); assert!(storage_provider.exists(&compressed_path)); let file_len = storage_provider.get_length(&compressed_path)? as usize;
480 assert_eq!(file_len, expected_size + 2 * std::mem::size_of::<i32>());
481
482 let mut r = storage_provider.open_reader(compressed_path.as_str())?;
483 let mut reader = BufReader::new(&mut r);
484 let metadata = Metadata::read(&mut reader)?;
485
486 let data: Vec<u8> = read_exact_into(&mut reader, expected_size)?;
487
488 assert_eq!(metadata.ndims_u32(), output_dim);
490 assert_eq!(metadata.npoints(), num_points);
491
492 data.chunks_exact(output_dim as usize)
494 .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
495
496 Ok(())
497 }
498
499 #[test]
500 fn test_stop_and_continue_chunking_config() -> ANNResult<()> {
501 let (num_points, dim, output_dim) = (256, 128, 128);
502 let chunking_config = ChunkingConfig {
503 continuation_checker: Box::<NaiveContinuationTracker>::default(),
504 data_compression_chunk_vector_count: 10,
505 inmemory_build_chunk_vector_count: 10,
506 };
507 let (storage_provider, data_path, compressed_path) =
508 generate_data_and_compressed(num_points, dim, 0, output_dim)?;
509 let (mut generator, mut result) = create_and_call_generator(
510 0,
511 compressed_path.clone(),
512 &storage_provider,
513 data_path.clone(),
514 output_dim,
515 &chunking_config,
516 );
517 loop {
518 match result.as_ref().unwrap() {
519 Progress::Completed => break,
520 Progress::Processed(num_points) => {
521 (generator, result) = create_and_call_generator(
522 *num_points,
523 compressed_path.clone(),
524 &storage_provider,
525 data_path.clone(),
526 output_dim,
527 &chunking_config,
528 );
529 }
530 }
531 }
532
533 assert!(result.is_ok(), "Result is not ok, got {:?}", result); assert!(storage_provider.exists(&compressed_path)); let file_len = storage_provider.get_length(&compressed_path)? as usize;
538 let expected_size = (num_points * output_dim as usize) + 2 * std::mem::size_of::<i32>();
539 assert_eq!(file_len, expected_size,);
540
541 let mut r = storage_provider.open_reader(compressed_path.as_str())?;
542 let mut reader = BufReader::new(&mut r);
543 let metadata = Metadata::read(&mut reader)?;
544
545 let data: Vec<u8> =
546 read_exact_into(&mut reader, expected_size - 2 * std::mem::size_of::<i32>())?;
547
548 assert_eq!(metadata.ndims_u32(), output_dim);
550 assert_eq!(metadata.npoints(), num_points);
551
552 data.chunks_exact(output_dim as usize)
554 .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
555 Ok(())
556 }
557
558 #[rstest]
559 #[case(
560 1_024,
561 384,
562 192,
563 1_025,
564 0,
565 "offset for compression is more than number of points"
566 )]
567 #[case(
568 1_1024,
569 384,
570 192,
571 5,
572 15,
573 "compressed data file length 2888 does not match expected length 968."
574 )]
575 fn test_offset_error_case(
576 #[case] num_points: usize,
577 #[case] dim: usize,
578 #[case] output_dim: u32,
579 #[case] offset: usize,
580 #[case] error_offset: usize,
581 #[case] msg: String,
582 ) -> ANNResult<()> {
583 assert!(offset > 0);
584 let (storage_provider, data_path, compressed_path) =
585 generate_data_and_compressed(num_points, dim, error_offset, output_dim)?;
586
587 let (_, result) = create_and_call_generator(
588 offset,
589 compressed_path,
590 &storage_provider,
591 data_path,
592 output_dim,
593 &ChunkingConfig::default(),
594 );
595
596 assert!(result.is_err());
597 if let Err(e) = result {
598 let error_msg = format!("{:?}", e);
599 assert!(error_msg.contains(&msg), "{}", &error_msg);
600 }
601
602 Ok(())
603 }
604}