1use std::{
7 io::{Seek, SeekFrom, Write},
8 marker::PhantomData,
9};
10
11use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};
12use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
13use diskann_providers::{
14 forward_threadpool,
15 utils::{
16 load_metadata_from_file, write_metadata, AsThreadPool, BridgeErr, ParallelIteratorInPool,
17 Timer,
18 },
19};
20use diskann_utils::views::{self};
21use rayon::iter::IndexedParallelIterator;
22use tracing::info;
23
24use crate::{
25 build::chunking::{
26 checkpoint::Progress,
27 continuation::{process_while_resource_is_available, ChunkingConfig},
28 },
29 storage::quant::compressor::{CompressionStage, QuantCompressor},
30};
31
32#[derive(Clone, Debug)]
38pub struct GeneratorContext {
39 pub offset: usize,
41 pub compressed_data_path: String,
43}
44
45impl GeneratorContext {
46 pub fn new(offset: usize, compressed_data_path: String) -> Self {
47 Self {
48 offset,
49 compressed_data_path,
50 }
51 }
52}
53
54pub struct QuantDataGenerator<T, Q>
58where
59 T: Copy + VectorRepr,
60 Q: QuantCompressor<T>,
61{
62 pub quantizer: Q,
63 pub data_path: String, pub context: GeneratorContext, phantom: PhantomData<T>,
66}
67
68impl<T, Q> QuantDataGenerator<T, Q>
69where
70 T: Copy + VectorRepr,
71 Q: QuantCompressor<T>,
72{
73 pub fn new(
74 data_path: String,
75 context: GeneratorContext,
76 quantizer_context: &Q::CompressorContext,
77 ) -> ANNResult<Self> {
78 let stage = match context.offset {
79 0 => CompressionStage::Start,
80 _ => CompressionStage::Resume,
81 };
82 let quantizer = Q::new_at_stage(stage, quantizer_context)?;
83 Ok(Self {
84 data_path,
85 context,
86 quantizer,
87 phantom: PhantomData,
88 })
89 }
90
91 pub fn generate_data<Storage, Pool>(
106 &self,
107 storage_provider: &Storage, pool: &Pool, chunking_config: &ChunkingConfig, ) -> ANNResult<Progress>
111 where
112 Storage: StorageReadProvider + StorageWriteProvider,
113 Pool: AsThreadPool,
114 {
115 let timer = Timer::new();
116
117 let metadata = load_metadata_from_file(storage_provider, &self.data_path)?;
118 let (num_points, dim) = (metadata.npoints, metadata.ndims);
119
120 self.validate_params(num_points, storage_provider)?;
121
122 let offset = self.context.offset;
123 let compressed_path = self.context.compressed_data_path.as_str();
124
125 if offset == 0 && storage_provider.exists(compressed_path) {
126 storage_provider.delete(compressed_path)?;
127 }
128
129 info!("Generating quantized data for {}", compressed_path);
130
131 let data_reader = &mut storage_provider.open_reader(&self.data_path)?;
132
133 let mut compressed_data_writer = if offset > 0 {
135 storage_provider.open_writer(compressed_path)?
136 } else {
137 let mut sp = storage_provider.create_for_write(compressed_path)?;
138 write_metadata(&mut sp, num_points, self.quantizer.compressed_bytes())?;
140 sp
141 };
142
143 data_reader.seek(SeekFrom::Start(
145 (size_of::<i32>() * 2 + offset * dim * size_of::<T>()) as u64,
146 ))?;
147
148 let compressed_size = self.quantizer.compressed_bytes();
149 let max_block_size = chunking_config.data_compression_chunk_vector_count;
150 let num_remaining = num_points - offset;
151
152 let block_size = std::cmp::min(num_points, max_block_size);
153 let num_blocks =
154 num_remaining / block_size + !num_remaining.is_multiple_of(block_size) as usize;
155
156 info!(
157 "Compressing with block size {}, num_remaining {}, num_blocks {}, offset {}, num_points {}",
158 block_size, num_remaining, num_blocks, offset, num_points
159 );
160
161 let mut compressed_buffer = vec![0_u8; block_size * compressed_size];
162
163 forward_threadpool!(pool = pool: Pool);
164 let action = |block_index| -> ANNResult<()> {
166 let start_index: usize = offset + block_index * block_size;
167 let end_index: usize = std::cmp::min(start_index + block_size, num_points);
168 let cur_block_size: usize = end_index - start_index;
169
170 let block_compressed_base = &mut compressed_buffer[..cur_block_size * compressed_size];
171
172 let raw_block: Vec<T> =
173 diskann::utils::read_exact_into(data_reader, cur_block_size * dim)?;
174
175 let full_dim = T::full_dimension(&raw_block[..dim]).into_ann_result()?; let mut block_data: Vec<f32> = vec![f32::default(); cur_block_size * full_dim];
178 for (v, dst) in raw_block
179 .chunks_exact(dim)
180 .zip(block_data.chunks_exact_mut(full_dim))
181 {
182 T::as_f32_into(v, dst).into_ann_result()?;
183 }
184
185 const BATCH_SIZE: usize = 128;
194
195 let mut compressed_block = views::MutMatrixView::try_from(
198 block_compressed_base,
199 cur_block_size,
200 compressed_size,
201 )
202 .bridge_err()?;
203 let base_block =
204 views::MatrixView::try_from(&block_data, cur_block_size, full_dim).bridge_err()?;
205 base_block
206 .par_window_iter(BATCH_SIZE)
207 .zip_eq(compressed_block.par_window_iter_mut(BATCH_SIZE))
208 .try_for_each_in_pool(pool, |(src, dst)| self.quantizer.compress(src, dst))?;
209
210 let write_offset = start_index * compressed_size + std::mem::size_of::<i32>() * 2;
211 compressed_data_writer.seek(SeekFrom::Start(write_offset as u64))?;
212 compressed_data_writer.write_all(block_compressed_base)?;
213 compressed_data_writer.flush()?;
214 Ok(())
215 };
216
217 let progress = process_while_resource_is_available(
218 action,
219 0..num_blocks,
220 chunking_config.continuation_checker.clone_box(),
221 )?
222 .map(|processed| processed * block_size + offset);
223
224 info!(
225 "Quant data generation took {} seconds",
226 timer.elapsed().as_secs_f64()
227 );
228
229 Ok(progress)
230 }
231
232 fn validate_params<Storage: StorageReadProvider + StorageWriteProvider>(
233 &self,
234 num_points: usize,
235 storage_provider: &Storage,
236 ) -> ANNResult<()> {
237 if self.context.offset > num_points {
238 return Err(ANNError::log_pq_error(
240 "Error: offset for compression is more than number of points",
241 ));
242 }
243
244 let compressed_path = &self.context.compressed_data_path;
245
246 if self.context.offset > 0 {
247 if !storage_provider.exists(compressed_path) {
248 return Err(ANNError::log_file_not_found_error(format!(
249 "Error: Generator expected compressed file {compressed_path} but did not find it."
250 )));
251 }
252 let expected_length = self.quantizer.compressed_bytes() * self.context.offset
253 + std::mem::size_of::<i32>() * 2;
254 let existing_length =
255 storage_provider.get_length(&self.context.compressed_data_path)?;
256
257 if existing_length != expected_length as u64 {
258 return Err(ANNError::log_pq_error(format_args!(
260 "Error: compressed data file length {existing_length} does not match expected length {expected_length}."
261 )));
262 }
263 }
264
265 Ok(())
266 }
267}
268
269#[cfg(test)]
274mod generator_tests {
275 use std::{
276 io::BufReader,
277 sync::{Arc, RwLock},
278 };
279
280 use diskann::utils::read_exact_into;
281 use diskann_providers::storage::VirtualStorageProvider;
282 use diskann_providers::utils::{
283 create_thread_pool_for_test, read_metadata, save_bin_f32, save_bytes,
284 };
285 use rstest::rstest;
286 use vfs::{FileSystem, MemoryFS, OverlayFS};
287
288 use super::*;
289 use crate::build::chunking::continuation::{
290 ContinuationGrant, ContinuationTrackerTrait, NaiveContinuationTracker,
291 };
292
293 pub struct DummyCompressor {
294 pub output_dim: u32,
295 pub code: Vec<u8>,
296 }
297 impl DummyCompressor {
298 pub fn new(output_dim: u32) -> Self {
299 Self {
300 output_dim,
301 code: (0..output_dim).map(|x| (x % 256) as u8).collect(),
302 }
303 }
304 }
305 impl QuantCompressor<f32> for DummyCompressor {
306 type CompressorContext = u32;
307
308 fn new_at_stage(
309 _stage: CompressionStage,
310 context: &Self::CompressorContext,
311 ) -> ANNResult<Self> {
312 Ok(Self::new(*context))
313 }
314
315 fn compress(
316 &self,
317 _vector: views::MatrixView<f32>,
318 mut output: views::MutMatrixView<u8>,
319 ) -> ANNResult<()> {
320 output
321 .row_iter_mut()
322 .for_each(|r| r.copy_from_slice(&self.code));
323 Ok(())
324 }
325
326 fn compressed_bytes(&self) -> usize {
327 self.output_dim as usize
328 }
329 }
330
331 fn create_test_data(num_points: usize, dim: usize) -> Vec<f32> {
332 let mut data = Vec::new();
333
334 for i in 0..num_points {
336 for j in 0..dim {
337 data.push((i * dim + j) as f32);
338 }
339 }
340
341 data
342 }
343
344 struct MockStopContinuationChecker {
346 count: Arc<RwLock<usize>>,
347 stop_count: usize,
348 }
349
350 impl Clone for MockStopContinuationChecker {
351 fn clone(&self) -> Self {
352 MockStopContinuationChecker {
353 count: self.count.clone(),
354 stop_count: self.stop_count,
355 }
356 }
357 }
358
359 impl ContinuationTrackerTrait for MockStopContinuationChecker {
360 fn get_continuation_grant(&self) -> ContinuationGrant {
361 let mut count = self.count.write().unwrap();
362 *count += 1;
363 if !(*count).is_multiple_of(self.stop_count) {
364 ContinuationGrant::Continue
365 } else {
366 ContinuationGrant::Stop
367 }
368 }
369 }
370
371 fn generate_data_and_compressed(
372 num_points: usize,
373 dim: usize,
374 offset: usize,
375 output_dim: u32,
376 ) -> ANNResult<(VirtualStorageProvider<OverlayFS>, String, String)> {
377 let fs = OverlayFS::new(&[MemoryFS::default().into()]);
378 fs.create_dir("/test_data")
379 .expect("Could not create test directory");
380 let storage_provider = VirtualStorageProvider::new(fs);
381
382 let data_path = "/test_data/test_data.bin".to_string();
383 let compressed_path = "/test_data/test_compressed.bin".to_string();
384
385 let _ = save_bin_f32(
387 &mut storage_provider.create_for_write(data_path.as_str())?,
388 &create_test_data(num_points, dim),
389 num_points,
390 dim,
391 0,
392 )?;
393
394 if offset > 0 {
395 let code = (0..output_dim).map(|x| (x % 256) as u8).collect::<Vec<_>>(); let mut buffer = vec![0_u8; offset * output_dim as usize];
399 buffer
400 .chunks_exact_mut(output_dim as usize)
401 .for_each(|bf| bf.copy_from_slice(code.as_slice()));
402 let _ = save_bytes(
403 &mut storage_provider.create_for_write(compressed_path.as_str())?,
404 buffer.as_slice(),
405 num_points,
406 output_dim as usize,
407 0,
408 )?;
409 }
410
411 Ok((storage_provider, data_path, compressed_path))
412 }
413
414 fn create_and_call_generator(
415 offset: usize,
416 compressed_path: String,
417 storage_provider: &VirtualStorageProvider<OverlayFS>,
418 data_path: String,
419 output_dim: u32,
420 chunking_config: &ChunkingConfig,
421 ) -> (
422 QuantDataGenerator<f32, DummyCompressor>,
423 Result<Progress, ANNError>,
424 ) {
425 let pool: diskann_providers::utils::RayonThreadPool = create_thread_pool_for_test();
426 let context = GeneratorContext::new(offset, compressed_path.clone());
428 let generator = QuantDataGenerator::<f32, DummyCompressor>::new(
429 data_path.clone(),
430 context,
431 &output_dim,
432 )
433 .unwrap();
434 let result = generator.generate_data(storage_provider, &&pool, chunking_config);
436 (generator, result)
437 }
438
439 #[rstest]
440 #[case(100, 8, 4, 0, 10, 100 * 4)] #[case(100, 8, 4, 50, 10, 100 * 4)] #[case(257, 4, 8, 0, 10, 257 * 8)] #[case(60_000, 384, 192, 5_000, 10, 60_000 * 192)] #[case(60_000, 384, 192, 0, 10, 60_000 * 192)] #[case(60_000, 384, 192, 0, 2, 10_000 * 192)] #[case(60_000, 384, 192, 1000, 2, 11_000 * 192)] fn test_generate_data_from_offset(
448 #[case] num_points: usize,
449 #[case] dim: usize,
450 #[case] output_dim: u32,
451 #[case] offset: usize,
452 #[case] config_stop_count: usize,
453 #[case] expected_size: usize,
454 ) -> ANNResult<()> {
455 let (storage_provider, data_path, compressed_path) =
456 generate_data_and_compressed(num_points, dim, offset, output_dim)?;
457
458 let chunking_config = ChunkingConfig {
459 continuation_checker: Box::new(MockStopContinuationChecker {
460 count: Arc::new(RwLock::new(0)),
461 stop_count: config_stop_count,
462 }),
463 data_compression_chunk_vector_count: 10_000,
464 inmemory_build_chunk_vector_count: 10_000,
465 };
466
467 let (generator, result) = create_and_call_generator(
468 offset,
469 compressed_path.clone(),
470 &storage_provider,
471 data_path,
472 output_dim,
473 &chunking_config,
474 );
475
476 assert!(result.is_ok(), "Result is not ok, got {:?}", result); assert!(storage_provider.exists(&compressed_path)); let file_len = storage_provider.get_length(&compressed_path)? as usize;
481 assert_eq!(file_len, expected_size + 2 * std::mem::size_of::<i32>());
482
483 let mut r = storage_provider.open_reader(compressed_path.as_str())?;
484 let mut reader = BufReader::new(&mut r);
485 let metadata = read_metadata(&mut reader)?;
486
487 let data: Vec<u8> = read_exact_into(&mut reader, expected_size)?;
488
489 assert_eq!(metadata.ndims as u32, output_dim);
491 assert_eq!(metadata.npoints, num_points);
492
493 data.chunks_exact(output_dim as usize)
495 .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
496
497 Ok(())
498 }
499
500 #[test]
501 fn test_stop_and_continue_chunking_config() -> ANNResult<()> {
502 let (num_points, dim, output_dim) = (256, 128, 128);
503 let chunking_config = ChunkingConfig {
504 continuation_checker: Box::<NaiveContinuationTracker>::default(),
505 data_compression_chunk_vector_count: 10,
506 inmemory_build_chunk_vector_count: 10,
507 };
508 let (storage_provider, data_path, compressed_path) =
509 generate_data_and_compressed(num_points, dim, 0, output_dim)?;
510 let (mut generator, mut result) = create_and_call_generator(
511 0,
512 compressed_path.clone(),
513 &storage_provider,
514 data_path.clone(),
515 output_dim,
516 &chunking_config,
517 );
518 loop {
519 match result.as_ref().unwrap() {
520 Progress::Completed => break,
521 Progress::Processed(num_points) => {
522 (generator, result) = create_and_call_generator(
523 *num_points,
524 compressed_path.clone(),
525 &storage_provider,
526 data_path.clone(),
527 output_dim,
528 &chunking_config,
529 );
530 }
531 }
532 }
533
534 assert!(result.is_ok(), "Result is not ok, got {:?}", result); assert!(storage_provider.exists(&compressed_path)); let file_len = storage_provider.get_length(&compressed_path)? as usize;
539 let expected_size = (num_points * output_dim as usize) + 2 * std::mem::size_of::<i32>();
540 assert_eq!(file_len, expected_size,);
541
542 let mut r = storage_provider.open_reader(compressed_path.as_str())?;
543 let mut reader = BufReader::new(&mut r);
544 let metadata = read_metadata(&mut reader)?;
545
546 let data: Vec<u8> =
547 read_exact_into(&mut reader, expected_size - 2 * std::mem::size_of::<i32>())?;
548
549 assert_eq!(metadata.ndims as u32, output_dim);
551 assert_eq!(metadata.npoints, num_points);
552
553 data.chunks_exact(output_dim as usize)
555 .for_each(|chunk| assert_eq!(chunk, generator.quantizer.code.as_slice()));
556 Ok(())
557 }
558
559 #[rstest]
560 #[case(
561 1_024,
562 384,
563 192,
564 1_025,
565 0,
566 "offset for compression is more than number of points"
567 )]
568 #[case(
569 1_1024,
570 384,
571 192,
572 5,
573 15,
574 "compressed data file length 2888 does not match expected length 968."
575 )]
576 fn test_offset_error_case(
577 #[case] num_points: usize,
578 #[case] dim: usize,
579 #[case] output_dim: u32,
580 #[case] offset: usize,
581 #[case] error_offset: usize,
582 #[case] msg: String,
583 ) -> ANNResult<()> {
584 assert!(offset > 0);
585 let (storage_provider, data_path, compressed_path) =
586 generate_data_and_compressed(num_points, dim, error_offset, output_dim)?;
587
588 let (_, result) = create_and_call_generator(
589 offset,
590 compressed_path,
591 &storage_provider,
592 data_path,
593 output_dim,
594 &ChunkingConfig::default(),
595 );
596
597 assert!(result.is_err());
598 if let Err(e) = result {
599 let error_msg = format!("{:?}", e);
600 assert!(error_msg.contains(&msg), "{}", &error_msg);
601 }
602
603 Ok(())
604 }
605}