1use std::{
5 io::{Read, Seek, Write},
6 mem,
7};
8
9use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};
10use diskann::{ANNError, ANNResult};
11use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
12use diskann_providers::{
13 model::graph::traits::GraphDataType,
14 storage::{get_mem_index_file, path_utility::*},
15 utils::{save_bytes, READ_WRITE_BLOCK_SIZE},
16};
17use tracing::info;
18
19use crate::{
20 data_model::{GraphHeader, GraphMetadata},
21 storage::{CachedReader, CachedWriter},
22};
23
24struct DiskIndexWriterState<StorageProvider>
29where
30 StorageProvider: StorageReadProvider,
31{
32 muti_shard_index_reader: Option<StorageProvider::Reader>,
35
36 associated_data_reader: Option<CachedReader<StorageProvider>>,
38
39 dataset_reader: Option<CachedReader<StorageProvider>>,
41
42 dims: u64,
44 num_pts: u64,
45 max_degree: u32,
46 medoid: u32,
47 vamana_frozen_num: u64,
48 node_len: u64,
49 associated_data_length: usize,
50 read_blk_size: u64,
51 write_blk_size: u64,
52}
53
54impl<StorageProvider> DiskIndexWriterState<StorageProvider>
55where
56 StorageProvider: StorageReadProvider,
57{
58 fn new() -> Self {
60 DiskIndexWriterState {
61 muti_shard_index_reader: None,
62 associated_data_reader: None,
63 dataset_reader: None,
64 dims: 0,
65 num_pts: 0,
66 max_degree: 0,
67 medoid: 0,
68 vamana_frozen_num: 0,
69 node_len: 0,
70 associated_data_length: 0,
71 read_blk_size: READ_WRITE_BLOCK_SIZE,
72 write_blk_size: READ_WRITE_BLOCK_SIZE,
73 }
74 }
75}
76
77pub struct DiskIndexWriter {
80 dataset_file: String,
82
83 index_path_prefix: String,
85
86 associated_data_file: Option<String>,
88
89 block_size: usize,
91}
92
93impl DiskIndexWriter {
94 pub fn new(
96 dataset_file: String,
97 index_path_prefix: String,
98 associated_data_file: Option<String>,
99 block_size: usize,
100 ) -> ANNResult<Self> {
101 if block_size < GraphMetadata::get_size() {
102 return Err(ANNError::log_index_config_error(
103 "index_block_size".to_string(),
104 format!(
105 "block_size should be greater than the size of GraphMetadata: {}",
106 GraphMetadata::get_size()
107 ),
108 ));
109 }
110
111 Ok(DiskIndexWriter {
112 dataset_file,
113 index_path_prefix,
114 associated_data_file,
115 block_size,
116 })
117 }
118
119 pub fn dataset_file(&self) -> &String {
120 &self.dataset_file
121 }
122
123 pub fn index_path_prefix(&self) -> &String {
124 &self.index_path_prefix
125 }
126
127 pub fn block_size(&self) -> usize {
128 self.block_size
129 }
130
131 fn get_neighbors_number<StorageProvider>(
132 state: &mut DiskIndexWriterState<StorageProvider>,
133 ) -> ANNResult<u32>
134 where
135 StorageProvider: StorageReadProvider,
136 {
137 let num_nbrs: u32;
138 if let Some(vamana_reader) = state.muti_shard_index_reader.as_mut() {
139 num_nbrs = vamana_reader.read_u32::<LittleEndian>()?;
140 } else {
141 return Err(ANNError::log_index_error("invalid index reader"));
142 }
143
144 Ok(num_nbrs)
145 }
146
147 fn copy_neighbors<StorageProvider>(
148 state: &mut DiskIndexWriterState<StorageProvider>,
149 nbrs_buf: &mut [u8],
150 ) -> ANNResult<()>
151 where
152 StorageProvider: StorageReadProvider,
153 {
154 if let Some(vamana_reader) = state.muti_shard_index_reader.as_mut() {
155 vamana_reader.read_exact(nbrs_buf)?;
156 } else {
157 return Err(ANNError::log_index_error("invalid index reader"));
158 }
159
160 Ok(())
161 }
162
163 fn open_vamana_reader<StorageProvider>(
164 &self,
165 state: &mut DiskIndexWriterState<StorageProvider>,
166 storage_provider: &StorageProvider,
167 ) -> ANNResult<()>
168 where
169 StorageProvider: StorageReadProvider,
170 {
171 let mem_index_file = self.get_mem_index_file();
172
173 let actual_file_size = storage_provider.get_length(mem_index_file.as_str())?;
175 info!("Vamana index file size={}", actual_file_size);
176
177 state.muti_shard_index_reader =
178 Some(storage_provider.open_reader(mem_index_file.as_str())?);
179
180 if let Some(vamana_reader) = state.muti_shard_index_reader.as_mut() {
181 let index_file_size = vamana_reader.read_u64::<LittleEndian>()?;
182 if index_file_size != actual_file_size {
183 info!(
184 "Vamana Index file size does not match expected size per meta-data. file size from file: {}, actual file size: {}",
185 index_file_size, actual_file_size
186 );
187 }
188
189 state.max_degree = vamana_reader.read_u32::<LittleEndian>()?;
190 state.medoid = vamana_reader.read_u32::<LittleEndian>()?;
191 state.vamana_frozen_num = vamana_reader.read_u64::<LittleEndian>()?;
192
193 return Ok(());
194 }
195
196 Err(ANNError::log_index_error("invalid index reader"))
197 }
198
199 fn open_associated_data_reader<StorageProvider>(
200 &self,
201 state: &mut DiskIndexWriterState<StorageProvider>,
202 storage_provider: &StorageProvider,
203 ) -> ANNResult<()>
204 where
205 StorageProvider: StorageReadProvider,
206 {
207 (state.associated_data_reader, state.associated_data_length) = match &self
208 .associated_data_file
209 {
210 Some(associated_data_stream) => {
211 let mut associated_data_reader = CachedReader::<StorageProvider>::new(
212 associated_data_stream.as_str(),
213 state.read_blk_size,
214 storage_provider,
215 )?;
216
217 let associated_data_num_pts = associated_data_reader.read_u32()? as u64;
218 let length = associated_data_reader.read_u32()? as usize;
219
220 if state.num_pts != associated_data_num_pts {
221 return Err(ANNError::log_index_error(format_args!(
222 "Number of points in dataset file ({}) does not match number of points in associated data file ({}).",
223 state.num_pts, associated_data_num_pts
224 )));
225 }
226
227 (Option::Some(associated_data_reader), length)
228 }
229
230 None => (Option::None, 0),
231 };
232
233 Ok(())
234 }
235
236 fn open_dataset_reader<StorageProvider>(
237 &self,
238 state: &mut DiskIndexWriterState<StorageProvider>,
239 storage_provider: &StorageProvider,
240 ) -> ANNResult<()>
241 where
242 StorageProvider: StorageReadProvider,
243 {
244 let dataset_reader = CachedReader::<StorageProvider>::new(
245 self.dataset_file.as_str(),
246 state.read_blk_size,
247 storage_provider,
248 )?;
249 state.dataset_reader = Some(dataset_reader);
250
251 if let Some(dataset_reader) = state.dataset_reader.as_mut() {
252 state.num_pts = dataset_reader.read_u32()? as u64;
253 state.dims = dataset_reader.read_u32()? as u64;
254 }
255
256 Ok(())
257 }
258
259 fn read_neighbors<Data, StorageProvider>(
260 &self,
261 state: &mut DiskIndexWriterState<StorageProvider>,
262 block_buf: &mut [u8],
263 ) -> ANNResult<()>
264 where
265 Data: GraphDataType<VectorIdType = u32>,
266 StorageProvider: StorageReadProvider,
267 {
268 block_buf.fill(0);
269
270 if let Some(dataset_reader) = state.dataset_reader.as_mut() {
272 let mut cur_node_coords =
273 vec![0u8; (state.dims as usize) * mem::size_of::<Data::VectorDataType>()];
274
275 dataset_reader.read(&mut cur_node_coords)?;
276 block_buf[..cur_node_coords.len()].copy_from_slice(&cur_node_coords);
277 }
278
279 let num_nbrs: u32 = Self::get_neighbors_number(state)?;
281
282 debug_assert!(num_nbrs > 0);
284 debug_assert!(num_nbrs <= state.max_degree);
285
286 let num_nbrs_start = state.dims as usize * mem::size_of::<Data::VectorDataType>();
287 let nbrs_buf_start = num_nbrs_start + mem::size_of::<u32>();
288
289 LittleEndian::write_u32(
291 &mut block_buf[num_nbrs_start..(num_nbrs_start + mem::size_of::<u32>())],
292 num_nbrs,
293 );
294
295 let nbr_buf_end = nbrs_buf_start + (num_nbrs as usize) * mem::size_of::<u32>();
297
298 let nbrs_buf = &mut block_buf[nbrs_buf_start..nbr_buf_end];
299
300 Self::copy_neighbors(state, nbrs_buf)?;
301
302 if let Some(associated_data_reader) = state.associated_data_reader.as_mut() {
303 let cur_node_associated_data = &mut block_buf[(state.node_len as usize
304 - state.associated_data_length * mem::size_of::<Data::AssociatedDataType>())
305 ..(state.node_len as usize)];
306 associated_data_reader.read(cur_node_associated_data)?;
307 }
308
309 Ok(())
310 }
311
312 fn write_header<Data, Writer, StorageProvider>(
313 &self,
314 state: &mut DiskIndexWriterState<StorageProvider>,
315 block_size: usize,
316 num_nodes_per_block: u64,
317 disk_index_file_size: u64,
318 writer: &mut Writer,
319 ) -> ANNResult<()>
320 where
321 Data: GraphDataType<VectorIdType = u32>,
322 Writer: Write + Seek,
323 StorageProvider: StorageReadProvider,
324 {
325 let mut vamana_frozen_loc: u64 = 0;
326 if state.vamana_frozen_num == 1 {
327 vamana_frozen_loc = state.medoid as u64;
328 }
329
330 let graph_metadata = GraphMetadata::new(
331 state.num_pts,
332 state.dims as usize,
333 state.medoid as u64,
334 state.node_len,
335 num_nodes_per_block,
336 state.vamana_frozen_num,
337 vamana_frozen_loc,
338 disk_index_file_size,
339 state.associated_data_length * mem::size_of::<Data::AssociatedDataType>(),
340 );
341
342 let header = GraphHeader::new(
343 graph_metadata,
344 block_size as u64,
345 GraphHeader::CURRENT_LAYOUT_VERSION,
346 );
347
348 let bytes_header = header.to_bytes()?;
349 save_bytes(
350 writer,
351 bytes_header.as_slice(),
352 bytes_header.len(), 1, 0,
355 )?;
356
357 Ok(())
358 }
359
360 pub fn create_disk_layout<Data, StorageProvider>(
396 &self,
397 storage_provider: &StorageProvider,
398 ) -> ANNResult<()>
399 where
400 Data: GraphDataType<VectorIdType = u32>,
401 StorageProvider: StorageReadProvider + StorageWriteProvider,
402 {
403 let block_size = self.block_size;
404 let mut state: DiskIndexWriterState<StorageProvider> = DiskIndexWriterState::new();
405
406 self.open_dataset_reader(&mut state, storage_provider)?;
407 self.open_associated_data_reader(&mut state, storage_provider)?;
408 self.open_vamana_reader(&mut state, storage_provider)?;
409
410 let vector_size = state.dims * (mem::size_of::<Data::VectorDataType>() as u64);
411
412 state.node_len = ((state.max_degree as u64 + 1) * (mem::size_of::<u32>() as u64))
413 + vector_size
414 + (state.associated_data_length * mem::size_of::<Data::AssociatedDataType>()) as u64;
415
416 let num_nodes_per_block = (block_size as u64) / state.node_len; info!("block_size: {}B", block_size);
419 info!("medoid: {}B", state.medoid);
420 info!("node_len: {}B", state.node_len);
421 info!("num_nodes_per_sector: {}B", num_nodes_per_block);
422 info!(
423 "associated_data_length: {}B",
424 state.associated_data_length * mem::size_of::<Data::AssociatedDataType>()
425 );
426
427 let num_blocks = if num_nodes_per_block > 0 {
429 state.num_pts.div_ceil(num_nodes_per_block)
430 } else {
431 let num_block_per_node = state.node_len.div_ceil(block_size as u64);
432 info!("num_sector_per_node: {}B", num_block_per_node);
433 state.num_pts * num_block_per_node
434 };
435 info!("num_blocks: {}B", num_blocks);
436
437 let disk_layout_file = self.disk_index_file();
438 {
439 let storage_writer = storage_provider.create_for_write(disk_layout_file.as_str())?;
440 let mut diskann_writer = CachedWriter::<StorageProvider>::new(
441 disk_layout_file.as_str(),
442 state.write_blk_size,
443 storage_writer,
444 )?;
445
446 let mut block_buf = vec![0u8; block_size];
448 diskann_writer.write(&block_buf)?;
449
450 if num_nodes_per_block > 0 {
451 let mut cur_node_id = 0u64;
452 let mut node_buf = vec![0u8; state.node_len as usize];
453
454 for sector in 0..num_blocks {
456 if sector % 100_000 == 0 {
457 info!("Sector #{} written", sector);
458 }
459 block_buf.fill(0);
460
461 for sector_node_id in 0..num_nodes_per_block {
462 if cur_node_id >= state.num_pts {
463 break;
464 }
465
466 self.read_neighbors::<Data, _>(&mut state, &mut node_buf)?;
467
468 let sector_node_buf_start = (sector_node_id * state.node_len) as usize;
470 let sector_node_buf = &mut block_buf[sector_node_buf_start
471 ..(sector_node_buf_start + state.node_len as usize)];
472 sector_node_buf.copy_from_slice(&node_buf[..(state.node_len as usize)]);
473
474 cur_node_id += 1;
475 }
476
477 diskann_writer.write(&block_buf)?;
479 }
480 } else {
481 let mut multi_block_buf =
483 vec![0u8; state.node_len.next_multiple_of(block_size as u64) as usize];
484 let num_block_per_node = state.node_len.div_ceil(block_size as u64);
485
486 for node_idx in 0..state.num_pts {
487 if (node_idx * num_block_per_node).is_multiple_of(100_000) {
488 info!("Sector #{} written", node_idx * num_block_per_node);
489 }
490
491 self.read_neighbors::<Data, _>(&mut state, &mut multi_block_buf)?;
492
493 diskann_writer.write(&multi_block_buf)?;
495 }
496 }
497
498 diskann_writer.flush()?;
500 }
501
502 {
506 let mut storage_writer = storage_provider.open_writer(disk_layout_file.as_str())?;
507 let disk_index_file_size = (num_blocks + 1) * (block_size as u64);
508 self.write_header::<Data, _, _>(
509 &mut state,
510 block_size,
511 num_nodes_per_block,
512 disk_index_file_size,
513 &mut storage_writer,
514 )?;
515
516 storage_writer.flush()?;
517 Ok(())
518 }
519 }
520
521 pub fn index_build_cleanup<StorageProvider>(
522 &self,
523 storage_provider: &StorageProvider,
524 ) -> ANNResult<()>
525 where
526 StorageProvider: StorageReadProvider + StorageWriteProvider,
527 {
528 let inmem_index_identifier = self.get_mem_index_file();
530 if storage_provider.exists(&inmem_index_identifier) {
531 storage_provider.delete(&get_mem_index_file(&self.index_path_prefix))?;
532 }
533
534 Ok(())
535 }
536
537 pub fn disk_index_file(&self) -> String {
538 get_disk_index_file(&self.index_path_prefix)
539 }
540
541 pub fn get_pq_pivot_file(&self) -> String {
542 get_pq_pivot_file(&self.index_path_prefix)
543 }
544
545 pub fn get_compressed_pq_pivot_file(&self) -> String {
546 get_compressed_pq_file(&self.index_path_prefix)
547 }
548
549 pub fn get_disk_index_pq_pivot_file(&self) -> String {
550 get_disk_index_pq_pivot_file(&self.index_path_prefix)
551 }
552
553 pub fn get_disk_index_compressed_pq_file(&self) -> String {
554 get_disk_index_compressed_pq_file(&self.index_path_prefix)
555 }
556
557 pub fn get_index_path_prefix(&self) -> String {
558 self.index_path_prefix.clone()
559 }
560
561 pub fn get_dataset_file(&self) -> String {
562 self.dataset_file.clone()
563 }
564
565 pub fn get_associated_data_file(&self) -> Option<String> {
566 self.associated_data_file.clone()
567 }
568
569 pub fn get_mem_index_file(&self) -> String {
570 get_mem_index_file(&self.index_path_prefix)
571 }
572
573 pub fn get_merged_index_prefix(&self) -> String {
574 self.get_mem_index_file().clone() + "_tempFiles"
575 }
576
577 pub fn get_merged_index_subshard_id_map_file(prefix: &str, shard: usize) -> String {
578 format!("{}_subshard-{}_ids_uint32.bin", prefix, shard)
579 }
580
581 pub fn get_merged_index_subshard_data_file(prefix: &str, shard: usize) -> String {
582 format!("{}_subshard-{}.bin", prefix, shard)
583 }
584
585 pub fn get_merged_index_subshard_prefix(prefix: &str, shard: usize) -> String {
586 format!("{}_subshard-{}", prefix, shard)
587 }
588
589 pub fn get_merged_index_subshard_mem_index_file(prefix: &str, shard: usize) -> String {
590 format!("{}_subshard-{}_mem.index", prefix, shard)
591 }
592
593 pub fn get_merged_index_subshard_mem_dataset_file(subshard_mem_index_prefix: &str) -> String {
594 get_mem_index_data_file(subshard_mem_index_prefix)
595 }
596}
597
598#[cfg(test)]
599mod disk_index_storage_test {
600 use diskann_providers::storage::VirtualStorageProvider;
601 use diskann_providers::test_utils::graph_data_type_utils::GraphDataF32VectorU32Data;
602 use diskann_utils::test_data_root;
603 use vfs::OverlayFS;
604
605 use super::*;
606
607 const TRUTH_DISK_LAYOUT_METADATA_LENGTH: usize = 80;
608 const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
609
610 #[test]
611 fn create_disk_layout_test_low_dim() {
612 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
613
614 const TEST_DATA_FILE_LOW_DIM: &str = "/sift/siftsmall_learn_256pts.fbin";
615 const DISK_INDEX_PATH_PREFIX_LOW_DIM: &str =
616 "/disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_index_writer";
617 const TRUTH_DISK_LAYOUT_LOW_DIM: &str =
618 "/disk_index_misc/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index";
619
620 let disk_layout_file = {
621 let index_writer = DiskIndexWriter::new(
622 TEST_DATA_FILE_LOW_DIM.to_string(),
623 DISK_INDEX_PATH_PREFIX_LOW_DIM.to_string(),
624 Option::None,
625 DEFAULT_DISK_SECTOR_LEN,
626 )
627 .unwrap();
628 index_writer
629 .create_disk_layout::<GraphDataF32VectorU32Data, _>(&storage_provider)
630 .unwrap();
631
632 let disk_layout_file = index_writer.disk_index_file();
633 let mut rust_disk_layout: Vec<u8> = Vec::new();
634 storage_provider
635 .open_reader(disk_layout_file.as_str())
636 .unwrap()
637 .read_to_end(&mut rust_disk_layout)
638 .unwrap();
639 let mut truth_disk_layout: Vec<u8> = Vec::new();
640 storage_provider
641 .open_reader(TRUTH_DISK_LAYOUT_LOW_DIM)
642 .unwrap()
643 .read_to_end(&mut truth_disk_layout)
644 .unwrap();
645
646 assert!(
648 rust_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
649 == truth_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
650 );
651
652 assert!(
654 rust_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
655 == truth_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
656 );
657 disk_layout_file
658 };
659
660 storage_provider
661 .delete(disk_layout_file.as_str())
662 .expect("Failed to delete file");
663 }
664
665 #[test]
666 fn create_disk_layout_test_low_dim_with_associated_data() {
667 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
668
669 const TEST_DATA_FILE_LOW_DIM: &str = "/sift/siftsmall_learn_256pts.fbin";
670 const DISK_INDEX_PATH_PREFIX_LOW_DIM_WITH_ASSOCIATED_DATA: &str = "/disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_index_writer_with_associated_data";
671 const TRUTH_DISK_LAYOUT_LOW_DIM: &str =
672 "/disk_index_misc/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index";
673
674 const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
675
676 let disk_layout_file = {
677 let storage = DiskIndexWriter::new(
678 TEST_DATA_FILE_LOW_DIM.to_string(),
679 DISK_INDEX_PATH_PREFIX_LOW_DIM_WITH_ASSOCIATED_DATA.to_string(),
680 Option::Some(ASSOCIATED_DATA_FILE.to_string()),
681 DEFAULT_DISK_SECTOR_LEN,
682 )
683 .unwrap();
684
685 storage
686 .create_disk_layout::<GraphDataF32VectorU32Data, _>(&storage_provider)
687 .unwrap();
688
689 let disk_layout_file = storage.disk_index_file();
690
691 let mut rust_disk_layout: Vec<u8> = Vec::new();
692 storage_provider
693 .open_reader(disk_layout_file.as_str())
694 .unwrap()
695 .read_to_end(&mut rust_disk_layout)
696 .unwrap();
697 let mut truth_disk_layout: Vec<u8> = Vec::new();
698 storage_provider
699 .open_reader(TRUTH_DISK_LAYOUT_LOW_DIM)
700 .unwrap()
701 .read_to_end(&mut truth_disk_layout)
702 .unwrap();
703
704 let mut associated_data: Vec<u8> = Vec::new();
705 storage_provider
706 .open_reader(ASSOCIATED_DATA_FILE)
707 .unwrap()
708 .read_to_end(&mut associated_data)
709 .unwrap();
710
711 compare_disk_index_graphs(&rust_disk_layout, &truth_disk_layout);
712 compare_disk_index_graphs_associated_data::<u32>(&rust_disk_layout, &associated_data);
713 disk_layout_file
714 };
715
716 storage_provider
717 .delete(disk_layout_file.as_str())
718 .expect("Failed to delete file");
719 }
720
721 #[test]
722 fn create_disk_layout_test_high_dim() {
723 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
724
725 const TEST_DATA_FILE_HIGH_DIM: &str = "/disk_index_misc/rand_float_1024D_1K_norm1.0.bin";
726 const DISK_INDEX_PATH_PREFIX_HIGH_DIM: &str =
727 "/disk_index_misc/disk_index_rand_float_1024D_1Kpts_R4_L50_A1.2_index_writer";
728 const TRUTH_DISK_LAYOUT_HIGH_DIM: &str =
729 "/disk_index_misc/truth_disk_index_rand_float_1024D_1Kpts_R4_L50_A1.2_index_writer.index";
730
731 let disk_layout_file = {
732 let storage = DiskIndexWriter::new(
733 TEST_DATA_FILE_HIGH_DIM.to_string(),
734 DISK_INDEX_PATH_PREFIX_HIGH_DIM.to_string(),
735 Option::None,
736 DEFAULT_DISK_SECTOR_LEN,
737 )
738 .unwrap();
739
740 storage
741 .create_disk_layout::<GraphDataF32VectorU32Data, _>(&storage_provider)
742 .unwrap();
743
744 let disk_layout_file = storage.disk_index_file();
745 let mut rust_disk_layout: Vec<u8> = Vec::new();
746 storage_provider
747 .open_reader(disk_layout_file.as_str())
748 .unwrap()
749 .read_to_end(&mut rust_disk_layout)
750 .unwrap();
751 let mut truth_disk_layout: Vec<u8> = Vec::new();
752 storage_provider
753 .open_reader(TRUTH_DISK_LAYOUT_HIGH_DIM)
754 .unwrap()
755 .read_to_end(&mut truth_disk_layout)
756 .unwrap();
757
758 assert!(
760 rust_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
761 == truth_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
762 );
763
764 assert!(
766 rust_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
767 == truth_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
768 );
769 disk_layout_file
770 };
771
772 storage_provider
773 .delete(disk_layout_file.as_str())
774 .expect("Failed to delete file");
775 }
776
777 #[test]
778 fn disk_index_writer_rejects_small_block_size() {
779 let small_block_size = GraphMetadata::get_size() - 1;
780 let result = DiskIndexWriter::new(
781 "dataset".to_string(),
782 "index".to_string(),
783 Option::None,
784 small_block_size,
785 );
786
787 assert!(result.is_err());
788 }
789
790 struct ExpectedWriter {
791 dataset_file: String,
792 index_path_prefix: String,
793 pq_pivot_file: String,
794 compressed_pq_file: String,
795 disk_index_pq_pivot_file: String,
796 disk_index_compressed_pq_file: String,
797 associated_data_file: Option<String>,
798 }
799
800 fn assert_writer_eq_expected(writer: &DiskIndexWriter, expected: &ExpectedWriter) {
801 assert_eq!(writer.dataset_file(), &expected.dataset_file);
802 assert_eq!(writer.index_path_prefix(), &expected.index_path_prefix);
803
804 assert_eq!(writer.get_pq_pivot_file(), expected.pq_pivot_file);
805 assert_eq!(
806 writer.get_compressed_pq_pivot_file(),
807 expected.compressed_pq_file
808 );
809 assert_eq!(
810 writer.get_disk_index_pq_pivot_file(),
811 expected.disk_index_pq_pivot_file
812 );
813 assert_eq!(
814 writer.get_disk_index_compressed_pq_file(),
815 expected.disk_index_compressed_pq_file
816 );
817 assert_eq!(
818 writer.get_associated_data_file(),
819 expected.associated_data_file
820 );
821 }
822
823 #[test]
824 fn test_dataset_file_and_index_path_prefix() {
825 let dataset_file_name = "dataset_file.txt";
826 let index_path_prefix = "index_path_prefix";
827 let associated_data_file = "associated_data_file.txt";
828 let writer = DiskIndexWriter::new(
829 dataset_file_name.to_string(),
830 index_path_prefix.to_string(),
831 Some(associated_data_file.to_string()),
832 DEFAULT_DISK_SECTOR_LEN,
833 )
834 .unwrap();
835 let expected = ExpectedWriter {
836 dataset_file: dataset_file_name.to_string(),
837 index_path_prefix: index_path_prefix.to_string(),
838 pq_pivot_file: get_pq_pivot_file(index_path_prefix),
839 compressed_pq_file: get_compressed_pq_file(index_path_prefix),
840 disk_index_pq_pivot_file: get_disk_index_pq_pivot_file(index_path_prefix),
841 disk_index_compressed_pq_file: get_disk_index_compressed_pq_file(index_path_prefix),
842 associated_data_file: Some(associated_data_file.to_string()),
843 };
844 assert_writer_eq_expected(&writer, &expected);
845 }
846
847 #[test]
848 fn test_get_mem_index_file() {
849 let writer = DiskIndexWriter::new(
850 "test_index".to_string(),
851 "test_dataset".to_string(),
852 Option::None,
853 DEFAULT_DISK_SECTOR_LEN,
854 )
855 .unwrap();
856
857 assert_eq!(writer.get_mem_index_file(), "test_dataset_mem.index");
858 }
859
860 #[test]
861 fn test_get_merged_index_prefix() {
862 let writer = DiskIndexWriter::new(
863 "test_index".to_string(),
864 "test_dataset".to_string(),
865 Option::None,
866 DEFAULT_DISK_SECTOR_LEN,
867 )
868 .unwrap();
869
870 assert_eq!(
871 writer.get_merged_index_prefix(),
872 "test_dataset_mem.index_tempFiles"
873 );
874 }
875
876 #[test]
877 fn test_get_merged_index_subshard_id_map_file() {
878 let prefix = "test_prefix";
879 let shard = 5;
880 assert_eq!(
881 DiskIndexWriter::get_merged_index_subshard_id_map_file(prefix, shard),
882 "test_prefix_subshard-5_ids_uint32.bin"
883 );
884 }
885
886 #[test]
887 fn test_get_merged_index_subshard_data_file() {
888 let prefix = "test_prefix";
889 let shard = 5;
890 assert_eq!(
891 DiskIndexWriter::get_merged_index_subshard_data_file(prefix, shard),
892 "test_prefix_subshard-5.bin"
893 );
894 }
895
896 #[test]
897 fn test_get_merged_index_subshard_prefix() {
898 let prefix = "test_prefix";
899 let shard = 5;
900 assert_eq!(
901 DiskIndexWriter::get_merged_index_subshard_prefix(prefix, shard),
902 "test_prefix_subshard-5"
903 );
904 }
905
906 #[test]
907 fn test_get_merged_index_subshard_mem_index_file() {
908 let prefix = "test_prefix";
909 let shard = 5;
910 assert_eq!(
911 DiskIndexWriter::get_merged_index_subshard_mem_index_file(prefix, shard),
912 "test_prefix_subshard-5_mem.index"
913 );
914 }
915
916 #[test]
917 fn test_get_merged_index_subshard_mem_dataset_file() {
918 let prefix = "test_prefix";
919 assert_eq!(
920 DiskIndexWriter::get_merged_index_subshard_mem_dataset_file(prefix),
921 "test_prefix.data"
922 );
923 }
924
925 #[test]
926 fn test_disk_index_writer_state_uninitialized() {
927 let mut state = DiskIndexWriterState::<VirtualStorageProvider<OverlayFS>>::new();
928 let mut buf = vec![0u8; 16];
929 assert!(DiskIndexWriter::get_neighbors_number(&mut state).is_err());
930 assert!(DiskIndexWriter::copy_neighbors(&mut state, &mut buf).is_err());
931 }
932
933 pub fn compare_disk_index_graphs(graph_data: &[u8], truth_graph_data: &[u8]) {
935 let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
936 let truth_graph_header = GraphHeader::try_from(&truth_graph_data[8..]).unwrap();
937
938 let test_node_per_block = graph_header.metadata().num_nodes_per_block;
939 let test_max_node_length = graph_header.metadata().node_len;
940
941 let truth_node_per_block = truth_graph_header.metadata().num_nodes_per_block;
942 let truth_max_node_length = truth_graph_header.metadata().node_len;
943
944 assert_eq!(
945 graph_header.metadata().num_pts,
946 truth_graph_header.metadata().num_pts
947 );
948
949 assert_eq!(
950 graph_header.metadata().dims,
951 truth_graph_header.metadata().dims
952 );
953
954 let num_pts = graph_header.metadata().num_pts as usize;
955 let dim = graph_header.metadata().dims;
956
957 for idx in 0..num_pts {
958 let test_node_id_offset = node_data_offset(
959 idx,
960 test_max_node_length as usize,
961 test_node_per_block as usize,
962 DEFAULT_DISK_SECTOR_LEN,
963 );
964
965 let truth_node_id_offset = node_data_offset(
966 idx,
967 truth_max_node_length as usize,
968 truth_node_per_block as usize,
969 DEFAULT_DISK_SECTOR_LEN,
970 );
971
972 assert_eq!(
974 &graph_data
975 [test_node_id_offset..test_node_id_offset + dim * std::mem::size_of::<f32>()],
976 &truth_graph_data
977 [truth_node_id_offset..truth_node_id_offset + dim * std::mem::size_of::<f32>()]
978 );
979
980 let test_nbr_cnt_offset = test_node_id_offset + dim * std::mem::size_of::<f32>();
982 let truth_nbr_cnt_offset = truth_node_id_offset + dim * std::mem::size_of::<f32>();
983
984 let test_nbr_count = u32::from_le_bytes([
985 graph_data[test_nbr_cnt_offset],
986 graph_data[test_nbr_cnt_offset + 1],
987 graph_data[test_nbr_cnt_offset + 2],
988 graph_data[test_nbr_cnt_offset + 3],
989 ]);
990
991 let truth_nbr_count = u32::from_le_bytes([
992 truth_graph_data[truth_nbr_cnt_offset],
993 truth_graph_data[truth_nbr_cnt_offset + 1],
994 truth_graph_data[truth_nbr_cnt_offset + 2],
995 truth_graph_data[truth_nbr_cnt_offset + 3],
996 ]);
997
998 assert_eq!(test_nbr_count, truth_nbr_count);
999
1000 let test_nbr_offset = test_nbr_cnt_offset + 4;
1002 let truth_nbr_offset = truth_nbr_cnt_offset + 4;
1003 assert_eq!(
1004 graph_data[test_nbr_offset..test_nbr_offset + test_nbr_count as usize * 4],
1005 truth_graph_data[truth_nbr_offset..truth_nbr_offset + truth_nbr_count as usize * 4]
1006 );
1007 }
1008 }
1009
1010 pub fn compare_disk_index_graphs_associated_data<AssociatedDataType>(
1012 graph_data: &[u8],
1013 associated_data: &[u8],
1014 ) {
1015 let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
1016 let test_node_per_block = graph_header.metadata().num_nodes_per_block;
1017 let test_max_node_length = graph_header.metadata().node_len as usize;
1018
1019 let mut associated_data_offset = 0;
1020 let data_npts = u32::from_le_bytes([
1021 associated_data[associated_data_offset],
1022 associated_data[associated_data_offset + 1],
1023 associated_data[associated_data_offset + 2],
1024 associated_data[associated_data_offset + 3],
1025 ]) as usize;
1026
1027 associated_data_offset = 4;
1028 let _associated_data_length = u32::from_le_bytes([
1029 associated_data[associated_data_offset],
1030 associated_data[associated_data_offset + 1],
1031 associated_data[associated_data_offset + 2],
1032 associated_data[associated_data_offset + 3],
1033 ]) as usize;
1034
1035 let num_pts = graph_header.metadata().num_pts as usize;
1036 assert_eq!(num_pts, data_npts);
1037
1038 associated_data_offset = 8;
1039
1040 for idx in 0..num_pts {
1041 let test_node_id_offset = node_data_offset(
1042 idx,
1043 test_max_node_length,
1044 test_node_per_block as usize,
1045 DEFAULT_DISK_SECTOR_LEN,
1046 );
1047
1048 let node_buf_end = test_node_id_offset + test_max_node_length;
1049
1050 assert_eq!(
1051 graph_data[(node_buf_end - mem::size_of::<AssociatedDataType>())..node_buf_end],
1052 associated_data[associated_data_offset
1053 ..associated_data_offset + mem::size_of::<AssociatedDataType>()]
1054 );
1055
1056 associated_data_offset += mem::size_of::<AssociatedDataType>();
1057 }
1058 }
1059
1060 pub fn node_data_offset(
1061 node_id: usize,
1062 node_length: usize,
1063 nodes_per_block: usize,
1064 block_size: usize,
1065 ) -> usize {
1066 let block_id = node_id / nodes_per_block;
1067 let node_id_in_block = node_id % nodes_per_block;
1068 let offset = block_id * block_size + node_id_in_block * node_length;
1069 offset + block_size
1070 }
1071}