1use std::{
7 io::{Seek, SeekFrom, Write},
8 marker::PhantomData,
9};
10
11use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
12use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
13use diskann_providers::{
14 forward_threadpool,
15 utils::{
16 load_metadata_from_file, write_metadata, AsThreadPool, BridgeErr, ParallelIteratorInPool,
17 Timer,
18 },
19};
20use diskann_utils::views::{self};
21use rayon::iter::IndexedParallelIterator;
22use tracing::info;
23
24use crate::{
25 build::chunking::{
26 checkpoint::Progress,
27 continuation::{process_while_resource_is_available, ChunkingConfig},
28 },
29 storage::quant::compressor::{CompressionStage, QuantCompressor},
30};
31
32#[derive(Clone, Debug)]
38pub struct GeneratorContext {
39 pub offset: usize,
41 pub compressed_data_path: String,
43}
44
45impl GeneratorContext {
46 pub fn new(offset: usize, compressed_data_path: String) -> Self {
47 Self {
48 offset,
49 compressed_data_path,
50 }
51 }
52}
53
54pub struct QuantDataGenerator<T, Q>
58where
59 T: Copy + VectorRepr,
60 Q: QuantCompressor<T>,
61{
62 pub quantizer: Q,
63 pub data_path: String, pub context: GeneratorContext, phantom: PhantomData<T>,
66}
67
68impl<T, Q> QuantDataGenerator<T, Q>
69where
70 T: Copy + VectorRepr,
71 Q: QuantCompressor<T>,
72{
73 pub fn new(
74 data_path: String,
75 context: GeneratorContext,
76 quantizer_context: &Q::CompressorContext,
77 ) -> ANNResult<Self> {
78 let stage = match context.offset {
79 0 => CompressionStage::Start,
80 _ => CompressionStage::Resume,
81 };
82 let quantizer = Q::new_at_stage(stage, quantizer_context)?;
83 Ok(Self {
84 data_path,
85 context,
86 quantizer,
87 phantom: PhantomData,
88 })
89 }
90
91 pub fn generate_data<Storage, Pool>(
106 &self,
107 storage_provider: &Storage, pool: &Pool, chunking_config: &ChunkingConfig, ) -> ANNResult<Progress>
111 where
112 Storage: StorageReadProvider + StorageWriteProvider,
113 Pool: AsThreadPool,
114 {
115 let timer = Timer::new();
116
117 let metadata = load_metadata_from_file(storage_provider, &self.data_path)?;
118 let (num_points, dim) = (metadata.npoints, metadata.ndims);
119
120 self.validate_params(num_points, storage_provider)?;
121
122 let offset = self.context.offset;
123 let compressed_path = self.context.compressed_data_path.as_str();
124
125 if offset == 0 && storage_provider.exists(compressed_path) {
126 storage_provider.delete(compressed_path)?;
127 }
128
129 info!("Generating quantized data for {}", compressed_path);
130
131 let data_reader = &mut storage_provider.open_reader(&self.data_path)?;
132
133 let mut compressed_data_writer = if offset > 0 {
135 storage_provider.open_writer(compressed_path)?
136 } else {
137 let mut sp = storage_provider.create_for_write(compressed_path)?;
138 write_metadata(&mut sp, num_points, self.quantizer.compressed_bytes())?;
140 sp
141 };
142
143 data_reader.seek(SeekFrom::Start(
145 (size_of::<i32>() * 2 + offset * dim * size_of::<T>()) as u64,
146 ))?;
147
148 let compressed_size = self.quantizer.compressed_bytes();
149 let max_block_size = chunking_config.data_compression_chunk_vector_count;
150 let num_remaining = num_points - offset;
151
152 let block_size = std::cmp::min(num_points, max_block_size);
153 let num_blocks =
154 num_remaining / block_size + !num_remaining.is_multiple_of(block_size) as usize;
155
156 info!(
157 "Compressing with block size {}, num_remaining {}, num_blocks {}, offset {}, num_points {}",
158 block_size, num_remaining, num_blocks, offset, num_points
159 );
160
161 let mut compressed_buffer = vec![0_u8; block_size * compressed_size];
162
163 forward_threadpool!(pool = pool: Pool);
164 let action = |block_index| -> ANNResult<()> {
166 let start_index: usize = offset + block_index * block_size;
167 let end_index: usize = std::cmp::min(start_index + block_size, num_points);
168 let cur_block_size: usize = end_index - start_index;
169
170 let block_compressed_base = &mut compressed_buffer[..cur_block_size * compressed_size];
171
172 let raw_block: Vec<T> =
173 diskann::utils::read_exact_into(data_reader, cur_block_size * dim)?;
174
175 let full_dim = T::full_dimension(&raw_block[..dim]).into_ann_result()?; let mut block_data: Vec<f32> = vec![f32::default(); cur_block_size * full_dim];
178 for (v, dst) in raw_block
179 .chunks_exact(dim)
180 .zip(block_data.chunks_exact_mut(full_dim))
181 {
182 T::as_f32_into(v, dst).into_ann_result()?;
183 }
184
185 const BATCH_SIZE: usize = 128;
194
195 let mut compressed_block = views::MutMatrixView::try_from(
198 block_compressed_base,
199 cur_block_size,
200 compressed_size,
201 )
202 .bridge_err()?;
203 let base_block =
204 views::MatrixView::try_from(&block_data, cur_block_size, full_dim).bridge_err()?;
205 base_block
206 .par_window_iter(BATCH_SIZE)
207 .zip_eq(compressed_block.par_window_iter_mut(BATCH_SIZE))
208 .try_for_each_in_pool(pool, |(src, dst)| self.quantizer.compress(src, dst))?;
209
210 let write_offset = start_index * compressed_size + std::mem::size_of::<i32>() * 2;
211 compressed_data_writer.seek(SeekFrom::Start(write_offset as u64))?;
212 compressed_data_writer.write_all(block_compressed_base)?;
213 compressed_data_writer.flush()?;
214 Ok(())
215 };
216
217 let progress = process_while_resource_is_available(
218 action,
219 0..num_blocks,
220 chunking_config.continuation_checker.clone_box(),
221 )?
222 .map(|processed| processed * block_size + offset);
223
224 info!(
225 "Quant data generation took {} seconds",
226 timer.elapsed().as_secs_f64()
227 );
228
229 Ok(progress)
230 }
231
232 fn validate_params<Storage: StorageReadProvider + StorageWriteProvider>(
233 &self,
234 num_points: usize,
235 storage_provider: &Storage,
236 ) -> ANNResult<()> {
237 if self.context.offset > num_points {
238 return Err(ANNError::log_pq_error(
240 "Error: offset for compression is more than number of points",
241 ));
242 }
243
244 let compressed_path = &self.context.compressed_data_path;
245
246 if self.context.offset > 0 {
247 if !storage_provider.exists(compressed_path) {
248 return Err(ANNError::log_file_not_found_error(format!(
249 "Error: Generator expected compressed file {compressed_path} but did not find it."
250 )));
251 }
252 let expected_length = self.quantizer.compressed_bytes() * self.context.offset
253 + std::mem::size_of::<i32>() * 2;
254 let existing_length =
255 storage_provider.get_length(&self.context.compressed_data_path)?;
256
257 if existing_length != expected_length as u64 {
258 return Err(ANNError::log_pq_error(format_args!(
260 "Error: compressed data file length {existing_length} does not match expected length {expected_length}."
261 )));
262 }
263 }
264
265 Ok(())
266 }
267}
268
269#[cfg(test)]
274mod generator_tests {
275 use std::{
276 io::BufReader,
277 sync::{Arc, RwLock},
278 };
279
280 use diskann::utils::read_exact_into;
281 use diskann_providers::storage::VirtualStorageProvider;
282 use diskann_providers::utils::{
283 create_thread_pool_for_test, read_metadata, save_bin_f32, save_bytes,
284 };
285 use rstest::rstest;
286 use vfs::{FileSystem, MemoryFS};
287
288 use super::*;
289 use crate::build::chunking::continuation::{
290 ContinuationGrant, ContinuationTrackerTrait, NaiveContinuationTracker,
291 };
292
293 pub struct DummyCompressor {
294 pub output_dim: u32,
295 pub code: Vec<u8>,
296 }
297 impl DummyCompressor {
298 pub fn new(output_dim: u32) -> Self {
299 Self {
300 output_dim,
301 code: (0..output_dim).map(|x| (x % 256) as u8).collect(),
302 }
303 }
304 }
305 impl QuantCompressor<f32> for DummyCompressor {
306 type CompressorContext = u32;
307
308 fn new_at_stage(
309 _stage: CompressionStage,
310 context: &Self::CompressorContext,
311 ) -> ANNResult<Self> {
312 Ok(Self::new(*context))
313 }
314
315 fn compress(
316 &self,
317 _vector: views::MatrixView<f32>,
318 mut output: views::MutMatrixView<u8>,
319 ) -> ANNResult<()> {
320 output
321 .row_iter_mut()
322 .for_each(|r| r.copy_from_slice(&self.code));
323 Ok(())
324 }
325
326 fn compressed_bytes(&self) -> usize {
327 self.output_dim as usize
328 }
329 }
330
331 fn create_test_data(num_points: usize, dim: usize) -> Vec<f32> {
332 let mut data = Vec::new();
333
334 for i in 0..num_points {
336 for j in 0..dim {
337 data.push((i * dim + j) as f32);
338 }
339 }
340
341 data
342 }
343
344 struct MockStopContinuationChecker {
346 count: Arc<RwLock<usize>>,
347 stop_count: usize,
348 }
349
350 impl Clone for MockStopContinuationChecker {
351 fn clone(&self) -> Self {
352 MockStopContinuationChecker {
353 count: self.count.clone(),
354 stop_count: self.stop_count,
355 }
356 }
357 }
358
359 impl ContinuationTrackerTrait for MockStopContinuationChecker {
360 fn get_continuation_grant(&self) -> ContinuationGrant {
361 let mut count = self.count.write().unwrap();
362 *count += 1;
363 if !(*count).is_multiple_of(self.stop_count) {
364 ContinuationGrant::Continue
365 } else {
366 ContinuationGrant::Stop
367 }
368 }
369 }
370
371 fn generate_data_and_compressed(
372 num_points: usize,
373 dim: usize,
374 offset: usize,
375 output_dim: u32,
376 ) -> ANNResult<(VirtualStorageProvider<MemoryFS>, String, String)> {
377 let storage_provider = VirtualStorageProvider::new_memory();
378 storage_provider
379 .filesystem()
380 .create_dir("/test_data")
381 .expect("Could not create test directory");
382
383 let data_path = "/test_data/test_data.bin".to_string();
384 let compressed_path = "/test_data/test_compressed.bin".to_string();
385
386 let _ = save_bin_f32(
388 &mut storage_provider.create_for_write(data_path.as_str())?,
389 &create_test_data(num_points, dim),
390 num_points,
391 dim,
392 0,
393 )?;
394
395 if offset > 0 {
396 let code = (0..output_dim).map(|x| (x % 256) as u8).collect::<Vec<_>>(); let mut buffer = vec![0_u8; offset * output_dim as usize];
400 buffer
401 .chunks_exact_mut(output_dim as usize)
402 .for_each(|bf| bf.copy_from_slice(code.as_slice()));
403 let _ = save_bytes(
404 &mut storage_provider.create_for_write(compressed_path.as_str())?,
405 buffer.as_slice(),
406 num_points,
407 output_dim as usize,
408 0,
409 )?;
410 }
411
412 Ok((storage_provider, data_path, compressed_path))
413 }
414
415 fn create_and_call_generator<F: vfs::FileSystem>(
416 offset: usize,
417 compressed_path: String,
418 storage_provider: &VirtualStorageProvider<F>,
419 data_path: String,
420 output_dim: u32,
421 chunking_config: &ChunkingConfig,
422 ) -> (
423 QuantDataGenerator<f32, DummyCompressor>,
424 Result<Progress, ANNError>,
425 ) {
426 let pool: diskann_providers::utils::RayonThreadPool = create_thread_pool_for_test();
427 let context = GeneratorContext::new(offset, compressed_path.clone());
429 let generator = QuantDataGenerator::<f32, DummyCompressor>::new(
430 data_path.clone(),
431 context,
432 &output_dim,
433 )
434 .unwrap();
435 let result = generator.generate_data(storage_provider, &&pool, chunking_config);
437 (generator, result)
438 }
439
440 #[rstest]
441 #[case(100, 8, 4, 0, 10, 100 * 4)] #[case(100, 8, 4, 50, 10, 100 * 4)] #[case(257, 4, 8, 0, 10, 257 * 8)] #[case(60_000, 384, 192, 5_000, 10, 60_000 * 192)] #[case(60_000, 384, 192, 0, 10, 60_000 * 192)] #[case(60_000, 384, 192, 0, 2, 10_000 * 192)] #[case(60_000, 384, 192, 1000, 2, 11_000 * 192)] fn test_generate_data_from_offset(
449 #[case] num_points: usize,
450 #[case] dim: usize,
451 #[case] output_dim: u32,
452 #[case] offset: usize,
453 #[case] config_stop_count: usize,
454 #[case] expected_size: usize,
455 ) -> ANNResult<()> {
456 let (storage_provider, data_path, compressed_path) =
457 generate_data_and_compressed(num_points, dim, offset, output_dim)?;
458
459 let chunking_config = ChunkingConfig {
460 continuation_checker: Box::new(MockStopContinuationChecker {
461 count: Arc::new(RwLock::new(0)),
462 stop_count: config_stop_count,
463 }),
464 data_compression_chunk_vector_count: 10_000,
465 inmemory_build_chunk_vector_count: 10_000,
466 };
467
468 let (generator, result) = create_and_call_generator(
469 offset,
470 compressed_path.clone(),
471 &storage_provider,
472 data_path,
473 output_dim,
474 &chunking_config,
475 );
476
477 assert!(result.is_ok(), "Result is not ok, got {:?}", result); assert!(storage_provider.exists(&compressed_path)); let file_len = storage_provider.get_length(&compressed_path)? as usize;
482 assert_eq!(file_len, expected_size + 2 * std::mem::size_of::<i32>());
483
484 let mut r = storage_provider.open_reader(compressed_path.as_str())?;
485 let mut reader = BufReader::new(&mut r);
486 let metadata = read_metadata(&mut reader)?;
487
488 let data: Vec<u8> = read_exact_into(&mut reader, expected_size)?;
489
490 assert_eq!(metadata.ndims as u32, output_dim);
492 assert_eq!(metadata.npoints, num_points);
493
494 data.chunks_exact(output_dim as usize)
496 .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
497
498 Ok(())
499 }
500
501 #[test]
502 fn test_stop_and_continue_chunking_config() -> ANNResult<()> {
503 let (num_points, dim, output_dim) = (256, 128, 128);
504 let chunking_config = ChunkingConfig {
505 continuation_checker: Box::<NaiveContinuationTracker>::default(),
506 data_compression_chunk_vector_count: 10,
507 inmemory_build_chunk_vector_count: 10,
508 };
509 let (storage_provider, data_path, compressed_path) =
510 generate_data_and_compressed(num_points, dim, 0, output_dim)?;
511 let (mut generator, mut result) = create_and_call_generator(
512 0,
513 compressed_path.clone(),
514 &storage_provider,
515 data_path.clone(),
516 output_dim,
517 &chunking_config,
518 );
519 loop {
520 match result.as_ref().unwrap() {
521 Progress::Completed => break,
522 Progress::Processed(num_points) => {
523 (generator, result) = create_and_call_generator(
524 *num_points,
525 compressed_path.clone(),
526 &storage_provider,
527 data_path.clone(),
528 output_dim,
529 &chunking_config,
530 );
531 }
532 }
533 }
534
535 assert!(result.is_ok(), "Result is not ok, got {:?}", result); assert!(storage_provider.exists(&compressed_path)); let file_len = storage_provider.get_length(&compressed_path)? as usize;
540 let expected_size = (num_points * output_dim as usize) + 2 * std::mem::size_of::<i32>();
541 assert_eq!(file_len, expected_size,);
542
543 let mut r = storage_provider.open_reader(compressed_path.as_str())?;
544 let mut reader = BufReader::new(&mut r);
545 let metadata = read_metadata(&mut reader)?;
546
547 let data: Vec<u8> =
548 read_exact_into(&mut reader, expected_size - 2 * std::mem::size_of::<i32>())?;
549
550 assert_eq!(metadata.ndims as u32, output_dim);
552 assert_eq!(metadata.npoints, num_points);
553
554 data.chunks_exact(output_dim as usize)
556 .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
557 Ok(())
558 }
559
560 #[rstest]
561 #[case(
562 1_024,
563 384,
564 192,
565 1_025,
566 0,
567 "offset for compression is more than number of points"
568 )]
569 #[case(
570 1_1024,
571 384,
572 192,
573 5,
574 15,
575 "compressed data file length 2888 does not match expected length 968."
576 )]
577 fn test_offset_error_case(
578 #[case] num_points: usize,
579 #[case] dim: usize,
580 #[case] output_dim: u32,
581 #[case] offset: usize,
582 #[case] error_offset: usize,
583 #[case] msg: String,
584 ) -> ANNResult<()> {
585 assert!(offset > 0);
586 let (storage_provider, data_path, compressed_path) =
587 generate_data_and_compressed(num_points, dim, error_offset, output_dim)?;
588
589 let (_, result) = create_and_call_generator(
590 offset,
591 compressed_path,
592 &storage_provider,
593 data_path,
594 output_dim,
595 &ChunkingConfig::default(),
596 );
597
598 assert!(result.is_err());
599 if let Err(e) = result {
600 let error_msg = format!("{:?}", e);
601 assert!(error_msg.contains(&msg), "{}", &error_msg);
602 }
603
604 Ok(())
605 }
606}