Skip to main content

diskann_disk/storage/
disk_index_writer.rs

1/* Copyright (c) Microsoft Corporation.
2 * Licensed under the MIT license.
3 */
4use std::{
5    io::{Read, Seek, Write},
6    mem,
7};
8
9use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};
10use diskann::{ANNError, ANNResult};
11use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
12use diskann_providers::{
13    model::graph::traits::GraphDataType,
14    storage::{get_mem_index_file, path_utility::*},
15    utils::{save_bytes, READ_WRITE_BLOCK_SIZE},
16};
17use tracing::info;
18
19use crate::{
20    data_model::{GraphHeader, GraphMetadata},
21    storage::{CachedReader, CachedWriter},
22};
23
24// Struct DiskIndexWriterState maintains the state of the process of creating a disk
25// layout using index and associated data. By moving the state to this struct, we
26// can create it on stack in the create_disk_layout() function and then pass it
27// other functions as data processing proceeds.
28struct DiskIndexWriterState<StorageProvider>
29where
30    StorageProvider: StorageReadProvider,
31{
32    // If index data does not fit into memory, the reader will be created and used
33    // to get data from the disk.
34    muti_shard_index_reader: Option<StorageProvider::Reader>,
35
36    // Reader to get associated data from disk.
37    associated_data_reader: Option<CachedReader<StorageProvider>>,
38
39    // Reader to get data from the disk.
40    dataset_reader: Option<CachedReader<StorageProvider>>,
41
42    // Parameters required for processing data. They are set before data processing starts.
43    dims: u64,
44    num_pts: u64,
45    max_degree: u32,
46    medoid: u32,
47    vamana_frozen_num: u64,
48    node_len: u64,
49    associated_data_length: usize,
50    read_blk_size: u64,
51    write_blk_size: u64,
52}
53
54impl<StorageProvider> DiskIndexWriterState<StorageProvider>
55where
56    StorageProvider: StorageReadProvider,
57{
58    /// Create DiskIndexWriterState instance
59    fn new() -> Self {
60        DiskIndexWriterState {
61            muti_shard_index_reader: None,
62            associated_data_reader: None,
63            dataset_reader: None,
64            dims: 0,
65            num_pts: 0,
66            max_degree: 0,
67            medoid: 0,
68            vamana_frozen_num: 0,
69            node_len: 0,
70            associated_data_length: 0,
71            read_blk_size: READ_WRITE_BLOCK_SIZE,
72            write_blk_size: READ_WRITE_BLOCK_SIZE,
73        }
74    }
75}
76
77/// DiskIndexWriter is used to write disk index data to the storage system.
78/// The storage system and data types are provided as parameters to methods that need them.
79pub struct DiskIndexWriter {
80    /// Dataset file
81    dataset_file: String,
82
83    /// Index file path prefix
84    index_path_prefix: String,
85
86    /// Optional associated data file
87    associated_data_file: Option<String>,
88
89    /// Block size (bytes) used when writing the disk index.
90    block_size: usize,
91}
92
93impl DiskIndexWriter {
94    /// Create DiskIndexWriter instance
95    pub fn new(
96        dataset_file: String,
97        index_path_prefix: String,
98        associated_data_file: Option<String>,
99        block_size: usize,
100    ) -> ANNResult<Self> {
101        if block_size < GraphMetadata::get_size() {
102            return Err(ANNError::log_index_config_error(
103                "index_block_size".to_string(),
104                format!(
105                    "block_size should be greater than the size of GraphMetadata: {}",
106                    GraphMetadata::get_size()
107                ),
108            ));
109        }
110
111        Ok(DiskIndexWriter {
112            dataset_file,
113            index_path_prefix,
114            associated_data_file,
115            block_size,
116        })
117    }
118
119    pub fn dataset_file(&self) -> &String {
120        &self.dataset_file
121    }
122
123    pub fn index_path_prefix(&self) -> &String {
124        &self.index_path_prefix
125    }
126
127    pub fn block_size(&self) -> usize {
128        self.block_size
129    }
130
131    fn get_neighbors_number<StorageProvider>(
132        state: &mut DiskIndexWriterState<StorageProvider>,
133    ) -> ANNResult<u32>
134    where
135        StorageProvider: StorageReadProvider,
136    {
137        let num_nbrs: u32;
138        if let Some(vamana_reader) = state.muti_shard_index_reader.as_mut() {
139            num_nbrs = vamana_reader.read_u32::<LittleEndian>()?;
140        } else {
141            return Err(ANNError::log_index_error("invalid index reader"));
142        }
143
144        Ok(num_nbrs)
145    }
146
147    fn copy_neighbors<StorageProvider>(
148        state: &mut DiskIndexWriterState<StorageProvider>,
149        nbrs_buf: &mut [u8],
150    ) -> ANNResult<()>
151    where
152        StorageProvider: StorageReadProvider,
153    {
154        if let Some(vamana_reader) = state.muti_shard_index_reader.as_mut() {
155            vamana_reader.read_exact(nbrs_buf)?;
156        } else {
157            return Err(ANNError::log_index_error("invalid index reader"));
158        }
159
160        Ok(())
161    }
162
163    fn open_vamana_reader<StorageProvider>(
164        &self,
165        state: &mut DiskIndexWriterState<StorageProvider>,
166        storage_provider: &StorageProvider,
167    ) -> ANNResult<()>
168    where
169        StorageProvider: StorageReadProvider,
170    {
171        let mem_index_file = self.get_mem_index_file();
172
173        // Create cached reader + writer
174        let actual_file_size = storage_provider.get_length(mem_index_file.as_str())?;
175        info!("Vamana index file size={}", actual_file_size);
176
177        state.muti_shard_index_reader =
178            Some(storage_provider.open_reader(mem_index_file.as_str())?);
179
180        if let Some(vamana_reader) = state.muti_shard_index_reader.as_mut() {
181            let index_file_size = vamana_reader.read_u64::<LittleEndian>()?;
182            if index_file_size != actual_file_size {
183                info!(
184                    "Vamana Index file size does not match expected size per meta-data. file size from file: {}, actual file size: {}",
185                    index_file_size, actual_file_size
186                );
187            }
188
189            state.max_degree = vamana_reader.read_u32::<LittleEndian>()?;
190            state.medoid = vamana_reader.read_u32::<LittleEndian>()?;
191            state.vamana_frozen_num = vamana_reader.read_u64::<LittleEndian>()?;
192
193            return Ok(());
194        }
195
196        Err(ANNError::log_index_error("invalid index reader"))
197    }
198
199    fn open_associated_data_reader<StorageProvider>(
200        &self,
201        state: &mut DiskIndexWriterState<StorageProvider>,
202        storage_provider: &StorageProvider,
203    ) -> ANNResult<()>
204    where
205        StorageProvider: StorageReadProvider,
206    {
207        (state.associated_data_reader, state.associated_data_length) = match &self
208            .associated_data_file
209        {
210            Some(associated_data_stream) => {
211                let mut associated_data_reader = CachedReader::<StorageProvider>::new(
212                    associated_data_stream.as_str(),
213                    state.read_blk_size,
214                    storage_provider,
215                )?;
216
217                let associated_data_num_pts = associated_data_reader.read_u32()? as u64;
218                let length = associated_data_reader.read_u32()? as usize;
219
220                if state.num_pts != associated_data_num_pts {
221                    return Err(ANNError::log_index_error(format_args!(
222                        "Number of points in dataset file ({}) does not match number of points in associated data file ({}).",
223                        state.num_pts, associated_data_num_pts
224                    )));
225                }
226
227                (Option::Some(associated_data_reader), length)
228            }
229
230            None => (Option::None, 0),
231        };
232
233        Ok(())
234    }
235
236    fn open_dataset_reader<StorageProvider>(
237        &self,
238        state: &mut DiskIndexWriterState<StorageProvider>,
239        storage_provider: &StorageProvider,
240    ) -> ANNResult<()>
241    where
242        StorageProvider: StorageReadProvider,
243    {
244        let dataset_reader = CachedReader::<StorageProvider>::new(
245            self.dataset_file.as_str(),
246            state.read_blk_size,
247            storage_provider,
248        )?;
249        state.dataset_reader = Some(dataset_reader);
250
251        if let Some(dataset_reader) = state.dataset_reader.as_mut() {
252            state.num_pts = dataset_reader.read_u32()? as u64;
253            state.dims = dataset_reader.read_u32()? as u64;
254        }
255
256        Ok(())
257    }
258
259    fn read_neighbors<Data, StorageProvider>(
260        &self,
261        state: &mut DiskIndexWriterState<StorageProvider>,
262        block_buf: &mut [u8],
263    ) -> ANNResult<()>
264    where
265        Data: GraphDataType<VectorIdType = u32>,
266        StorageProvider: StorageReadProvider,
267    {
268        block_buf.fill(0);
269
270        // write coords of node first
271        if let Some(dataset_reader) = state.dataset_reader.as_mut() {
272            let mut cur_node_coords =
273                vec![0u8; (state.dims as usize) * mem::size_of::<Data::VectorDataType>()];
274
275            dataset_reader.read(&mut cur_node_coords)?;
276            block_buf[..cur_node_coords.len()].copy_from_slice(&cur_node_coords);
277        }
278
279        // read cur node's num_nbrs
280        let num_nbrs: u32 = Self::get_neighbors_number(state)?;
281
282        // sanity checks on num_nbrs
283        debug_assert!(num_nbrs > 0);
284        debug_assert!(num_nbrs <= state.max_degree);
285
286        let num_nbrs_start = state.dims as usize * mem::size_of::<Data::VectorDataType>();
287        let nbrs_buf_start = num_nbrs_start + mem::size_of::<u32>();
288
289        // write num_nbrs
290        LittleEndian::write_u32(
291            &mut block_buf[num_nbrs_start..(num_nbrs_start + mem::size_of::<u32>())],
292            num_nbrs,
293        );
294
295        // write neighbors
296        let nbr_buf_end = nbrs_buf_start + (num_nbrs as usize) * mem::size_of::<u32>();
297
298        let nbrs_buf = &mut block_buf[nbrs_buf_start..nbr_buf_end];
299
300        Self::copy_neighbors(state, nbrs_buf)?;
301
302        if let Some(associated_data_reader) = state.associated_data_reader.as_mut() {
303            let cur_node_associated_data = &mut block_buf[(state.node_len as usize
304                - state.associated_data_length * mem::size_of::<Data::AssociatedDataType>())
305                ..(state.node_len as usize)];
306            associated_data_reader.read(cur_node_associated_data)?;
307        }
308
309        Ok(())
310    }
311
312    fn write_header<Data, Writer, StorageProvider>(
313        &self,
314        state: &mut DiskIndexWriterState<StorageProvider>,
315        block_size: usize,
316        num_nodes_per_block: u64,
317        disk_index_file_size: u64,
318        writer: &mut Writer,
319    ) -> ANNResult<()>
320    where
321        Data: GraphDataType<VectorIdType = u32>,
322        Writer: Write + Seek,
323        StorageProvider: StorageReadProvider,
324    {
325        let mut vamana_frozen_loc: u64 = 0;
326        if state.vamana_frozen_num == 1 {
327            vamana_frozen_loc = state.medoid as u64;
328        }
329
330        let graph_metadata = GraphMetadata::new(
331            state.num_pts,
332            state.dims as usize,
333            state.medoid as u64,
334            state.node_len,
335            num_nodes_per_block,
336            state.vamana_frozen_num,
337            vamana_frozen_loc,
338            disk_index_file_size,
339            state.associated_data_length * mem::size_of::<Data::AssociatedDataType>(),
340        );
341
342        let header = GraphHeader::new(
343            graph_metadata,
344            block_size as u64,
345            GraphHeader::CURRENT_LAYOUT_VERSION,
346        );
347
348        let bytes_header = header.to_bytes()?;
349        save_bytes(
350            writer,
351            bytes_header.as_slice(),
352            bytes_header.len(), // num_points is kept to make the graph compatible with C++ version
353            1,                  // num_points is kept to make the graph compatible with C++ version
354            0,
355        )?;
356
357        Ok(())
358    }
359
360    /// Create disk layout.
361    /// Block #1: GraphMetadata.
362    /// Block #2..#n: num_nodes_per_block nodes in each block
363    ///
364    /// GraphMetadata layout:
365    /// |number_of_points (8 bytes)| dimensions (8 bytes) | medoid (8 bytes) |
366    /// ...| node_len (8 bytes) | num_nodes_per_sector (8 bytes) | vamana_frozen_point_num (8 bytes) |
367    /// ...| vamana_frozen_loc (8 bytes) | append_reorder_data (8 bytes) | disk_index_file_size (8 bytes) |
368    /// ...| associated_data_length (8 bytes) | block_size (8 bytes) | layout_version (8 bytes) |
369    ///
370    /// The metadata layout is kept compatible with the C++ diskann codebase.
371    ///
372    /// After the metadata structure, the graph is laid out as a sequence of vertices and relevant data
373    /// including the vector, #out neighbors, list of out neighbors, associated_data and appropriate padding.
374    /// `| vector (dim * size_of<VectorDataType> bytes) | neighbor_count (4 bytes) | neighbors (neighbor_count * 4 bytes) |
375    /// ... | filler | associated_data (associated_data_length * size_of<AssociatedDataType> bytes) |`
376    ///
377    /// The node_len and length of filler are calculated as follow:
378    ///
379    /// `node_length = (max_degree + 1) * size_of<u32> + dim * size_of<VectorDataType> + associated_data_length * size_of<AssociatedDataType>`
380    ///
381    /// `filler length = node_len -  (length of vector + 4 + 4 * neighbor_count + associated_data_length * size_of<AssociatedDataType>)`
382    ///
383    /// The filler is used to pad each node to node_len bytes on the disk, so it is possible to calculate the block number and offset of a node with its Id.
384    ///
385    /// When node_len < disk block size, we pack as many nodes as possible in to the disk sector without splitting a node across a sector
386    /// For example, if node_len is 600B, we can pack 6 of these on a 4KB sector, and we leave 4096-3600 = 496B unused.
387    ///
388    /// When node_len > disk block size, we align start of node to block size.
389    /// For example, if node_len is 6700B, then it would span two 4KB sectors beginning at the start of the first sector
390    /// and end on the second sector and will be followed by 1492 bytes of padding to align the next node to block boundary.
391    ///
392    ///
393    /// # Arguments
394    /// * `storage_provider` - the storage provider for I/O operations
395    pub fn create_disk_layout<Data, StorageProvider>(
396        &self,
397        storage_provider: &StorageProvider,
398    ) -> ANNResult<()>
399    where
400        Data: GraphDataType<VectorIdType = u32>,
401        StorageProvider: StorageReadProvider + StorageWriteProvider,
402    {
403        let block_size = self.block_size;
404        let mut state: DiskIndexWriterState<StorageProvider> = DiskIndexWriterState::new();
405
406        self.open_dataset_reader(&mut state, storage_provider)?;
407        self.open_associated_data_reader(&mut state, storage_provider)?;
408        self.open_vamana_reader(&mut state, storage_provider)?;
409
410        let vector_size = state.dims * (mem::size_of::<Data::VectorDataType>() as u64);
411
412        state.node_len = ((state.max_degree as u64 + 1) * (mem::size_of::<u32>() as u64))
413            + vector_size
414            + (state.associated_data_length * mem::size_of::<Data::AssociatedDataType>()) as u64;
415
416        let num_nodes_per_block = (block_size as u64) / state.node_len; // 0 if node_len > block_size
417
418        info!("block_size: {}B", block_size);
419        info!("medoid: {}B", state.medoid);
420        info!("node_len: {}B", state.node_len);
421        info!("num_nodes_per_sector: {}B", num_nodes_per_block);
422        info!(
423            "associated_data_length: {}B",
424            state.associated_data_length * mem::size_of::<Data::AssociatedDataType>()
425        );
426
427        // number of sectors (1 for meta data)
428        let num_blocks = if num_nodes_per_block > 0 {
429            state.num_pts.div_ceil(num_nodes_per_block)
430        } else {
431            let num_block_per_node = state.node_len.div_ceil(block_size as u64);
432            info!("num_sector_per_node: {}B", num_block_per_node);
433            state.num_pts * num_block_per_node
434        };
435        info!("num_blocks: {}B", num_blocks);
436
437        let disk_layout_file = self.disk_index_file();
438        {
439            let storage_writer = storage_provider.create_for_write(disk_layout_file.as_str())?;
440            let mut diskann_writer = CachedWriter::<StorageProvider>::new(
441                disk_layout_file.as_str(),
442                state.write_blk_size,
443                storage_writer,
444            )?;
445
446            // Buffer of block_size bytes for each block.
447            let mut block_buf = vec![0u8; block_size];
448            diskann_writer.write(&block_buf)?;
449
450            if num_nodes_per_block > 0 {
451                let mut cur_node_id = 0u64;
452                let mut node_buf = vec![0u8; state.node_len as usize];
453
454                // Write multiple nodes per sector
455                for sector in 0..num_blocks {
456                    if sector % 100_000 == 0 {
457                        info!("Sector #{} written", sector);
458                    }
459                    block_buf.fill(0);
460
461                    for sector_node_id in 0..num_nodes_per_block {
462                        if cur_node_id >= state.num_pts {
463                            break;
464                        }
465
466                        self.read_neighbors::<Data, _>(&mut state, &mut node_buf)?;
467
468                        // get offset into sector_buf
469                        let sector_node_buf_start = (sector_node_id * state.node_len) as usize;
470                        let sector_node_buf = &mut block_buf[sector_node_buf_start
471                            ..(sector_node_buf_start + state.node_len as usize)];
472                        sector_node_buf.copy_from_slice(&node_buf[..(state.node_len as usize)]);
473
474                        cur_node_id += 1;
475                    }
476
477                    // flush sector to disk
478                    diskann_writer.write(&block_buf)?;
479                }
480            } else {
481                // Write multi-sector nodes
482                let mut multi_block_buf =
483                    vec![0u8; state.node_len.next_multiple_of(block_size as u64) as usize];
484                let num_block_per_node = state.node_len.div_ceil(block_size as u64);
485
486                for node_idx in 0..state.num_pts {
487                    if (node_idx * num_block_per_node).is_multiple_of(100_000) {
488                        info!("Sector #{} written", node_idx * num_block_per_node);
489                    }
490
491                    self.read_neighbors::<Data, _>(&mut state, &mut multi_block_buf)?;
492
493                    // flush sector to disk
494                    diskann_writer.write(&multi_block_buf)?;
495                }
496            }
497
498            // Be sure to flush the writer before it goes out of scope so we can open a new one.
499            diskann_writer.flush()?;
500        }
501
502        // Write the header.  Must re-open the file because the cached writer cannot seek to the start of the file.
503        // CachedWriter owns the underlying writer so we must open a new writer.  A new scope ensures that the old
504        // writer is out of scope.
505        {
506            let mut storage_writer = storage_provider.open_writer(disk_layout_file.as_str())?;
507            let disk_index_file_size = (num_blocks + 1) * (block_size as u64);
508            self.write_header::<Data, _, _>(
509                &mut state,
510                block_size,
511                num_nodes_per_block,
512                disk_index_file_size,
513                &mut storage_writer,
514            )?;
515
516            storage_writer.flush()?;
517            Ok(())
518        }
519    }
520
521    pub fn index_build_cleanup<StorageProvider>(
522        &self,
523        storage_provider: &StorageProvider,
524    ) -> ANNResult<()>
525    where
526        StorageProvider: StorageReadProvider + StorageWriteProvider,
527    {
528        // Clean up the in-memory index file if it exists.
529        let inmem_index_identifier = self.get_mem_index_file();
530        if storage_provider.exists(&inmem_index_identifier) {
531            storage_provider.delete(&get_mem_index_file(&self.index_path_prefix))?;
532        }
533
534        Ok(())
535    }
536
537    pub fn disk_index_file(&self) -> String {
538        get_disk_index_file(&self.index_path_prefix)
539    }
540
541    pub fn get_pq_pivot_file(&self) -> String {
542        get_pq_pivot_file(&self.index_path_prefix)
543    }
544
545    pub fn get_compressed_pq_pivot_file(&self) -> String {
546        get_compressed_pq_file(&self.index_path_prefix)
547    }
548
549    pub fn get_disk_index_pq_pivot_file(&self) -> String {
550        get_disk_index_pq_pivot_file(&self.index_path_prefix)
551    }
552
553    pub fn get_disk_index_compressed_pq_file(&self) -> String {
554        get_disk_index_compressed_pq_file(&self.index_path_prefix)
555    }
556
557    pub fn get_index_path_prefix(&self) -> String {
558        self.index_path_prefix.clone()
559    }
560
561    pub fn get_dataset_file(&self) -> String {
562        self.dataset_file.clone()
563    }
564
565    pub fn get_associated_data_file(&self) -> Option<String> {
566        self.associated_data_file.clone()
567    }
568
569    pub fn get_mem_index_file(&self) -> String {
570        get_mem_index_file(&self.index_path_prefix)
571    }
572
573    pub fn get_merged_index_prefix(&self) -> String {
574        self.get_mem_index_file().clone() + "_tempFiles"
575    }
576
577    pub fn get_merged_index_subshard_id_map_file(prefix: &str, shard: usize) -> String {
578        format!("{}_subshard-{}_ids_uint32.bin", prefix, shard)
579    }
580
581    pub fn get_merged_index_subshard_data_file(prefix: &str, shard: usize) -> String {
582        format!("{}_subshard-{}.bin", prefix, shard)
583    }
584
585    pub fn get_merged_index_subshard_prefix(prefix: &str, shard: usize) -> String {
586        format!("{}_subshard-{}", prefix, shard)
587    }
588
589    pub fn get_merged_index_subshard_mem_index_file(prefix: &str, shard: usize) -> String {
590        format!("{}_subshard-{}_mem.index", prefix, shard)
591    }
592
593    pub fn get_merged_index_subshard_mem_dataset_file(subshard_mem_index_prefix: &str) -> String {
594        get_mem_index_data_file(subshard_mem_index_prefix)
595    }
596}
597
598#[cfg(test)]
599mod disk_index_storage_test {
600    use diskann_providers::storage::VirtualStorageProvider;
601    use diskann_providers::test_utils::graph_data_type_utils::GraphDataF32VectorU32Data;
602    use diskann_utils::test_data_root;
603    use vfs::OverlayFS;
604
605    use super::*;
606
607    const TRUTH_DISK_LAYOUT_METADATA_LENGTH: usize = 80;
608    const DEFAULT_DISK_SECTOR_LEN: usize = 4096;
609
610    #[test]
611    fn create_disk_layout_test_low_dim() {
612        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
613
614        const TEST_DATA_FILE_LOW_DIM: &str = "/sift/siftsmall_learn_256pts.fbin";
615        const DISK_INDEX_PATH_PREFIX_LOW_DIM: &str =
616            "/disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_index_writer";
617        const TRUTH_DISK_LAYOUT_LOW_DIM: &str =
618            "/disk_index_misc/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index";
619
620        let disk_layout_file = {
621            let index_writer = DiskIndexWriter::new(
622                TEST_DATA_FILE_LOW_DIM.to_string(),
623                DISK_INDEX_PATH_PREFIX_LOW_DIM.to_string(),
624                Option::None,
625                DEFAULT_DISK_SECTOR_LEN,
626            )
627            .unwrap();
628            index_writer
629                .create_disk_layout::<GraphDataF32VectorU32Data, _>(&storage_provider)
630                .unwrap();
631
632            let disk_layout_file = index_writer.disk_index_file();
633            let mut rust_disk_layout: Vec<u8> = Vec::new();
634            storage_provider
635                .open_reader(disk_layout_file.as_str())
636                .unwrap()
637                .read_to_end(&mut rust_disk_layout)
638                .unwrap();
639            let mut truth_disk_layout: Vec<u8> = Vec::new();
640            storage_provider
641                .open_reader(TRUTH_DISK_LAYOUT_LOW_DIM)
642                .unwrap()
643                .read_to_end(&mut truth_disk_layout)
644                .unwrap();
645
646            // Assert that the metadata on rust disk index is compatible with the truth disk index.
647            assert!(
648                rust_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
649                    == truth_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
650            );
651
652            // Assert that the rest of the disk index is identical with the truth disk index.
653            assert!(
654                rust_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
655                    == truth_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
656            );
657            disk_layout_file
658        };
659
660        storage_provider
661            .delete(disk_layout_file.as_str())
662            .expect("Failed to delete file");
663    }
664
665    #[test]
666    fn create_disk_layout_test_low_dim_with_associated_data() {
667        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
668
669        const TEST_DATA_FILE_LOW_DIM: &str = "/sift/siftsmall_learn_256pts.fbin";
670        const DISK_INDEX_PATH_PREFIX_LOW_DIM_WITH_ASSOCIATED_DATA: &str = "/disk_index_misc/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_index_writer_with_associated_data";
671        const TRUTH_DISK_LAYOUT_LOW_DIM: &str =
672            "/disk_index_misc/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index";
673
674        const ASSOCIATED_DATA_FILE: &str = "/sift/siftsmall_learn_256pts_u32_associated_data.fbin";
675
676        let disk_layout_file = {
677            let storage = DiskIndexWriter::new(
678                TEST_DATA_FILE_LOW_DIM.to_string(),
679                DISK_INDEX_PATH_PREFIX_LOW_DIM_WITH_ASSOCIATED_DATA.to_string(),
680                Option::Some(ASSOCIATED_DATA_FILE.to_string()),
681                DEFAULT_DISK_SECTOR_LEN,
682            )
683            .unwrap();
684
685            storage
686                .create_disk_layout::<GraphDataF32VectorU32Data, _>(&storage_provider)
687                .unwrap();
688
689            let disk_layout_file = storage.disk_index_file();
690
691            let mut rust_disk_layout: Vec<u8> = Vec::new();
692            storage_provider
693                .open_reader(disk_layout_file.as_str())
694                .unwrap()
695                .read_to_end(&mut rust_disk_layout)
696                .unwrap();
697            let mut truth_disk_layout: Vec<u8> = Vec::new();
698            storage_provider
699                .open_reader(TRUTH_DISK_LAYOUT_LOW_DIM)
700                .unwrap()
701                .read_to_end(&mut truth_disk_layout)
702                .unwrap();
703
704            let mut associated_data: Vec<u8> = Vec::new();
705            storage_provider
706                .open_reader(ASSOCIATED_DATA_FILE)
707                .unwrap()
708                .read_to_end(&mut associated_data)
709                .unwrap();
710
711            compare_disk_index_graphs(&rust_disk_layout, &truth_disk_layout);
712            compare_disk_index_graphs_associated_data::<u32>(&rust_disk_layout, &associated_data);
713            disk_layout_file
714        };
715
716        storage_provider
717            .delete(disk_layout_file.as_str())
718            .expect("Failed to delete file");
719    }
720
721    #[test]
722    fn create_disk_layout_test_high_dim() {
723        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
724
725        const TEST_DATA_FILE_HIGH_DIM: &str = "/disk_index_misc/rand_float_1024D_1K_norm1.0.bin";
726        const DISK_INDEX_PATH_PREFIX_HIGH_DIM: &str =
727            "/disk_index_misc/disk_index_rand_float_1024D_1Kpts_R4_L50_A1.2_index_writer";
728        const TRUTH_DISK_LAYOUT_HIGH_DIM: &str =
729            "/disk_index_misc/truth_disk_index_rand_float_1024D_1Kpts_R4_L50_A1.2_index_writer.index";
730
731        let disk_layout_file = {
732            let storage = DiskIndexWriter::new(
733                TEST_DATA_FILE_HIGH_DIM.to_string(),
734                DISK_INDEX_PATH_PREFIX_HIGH_DIM.to_string(),
735                Option::None,
736                DEFAULT_DISK_SECTOR_LEN,
737            )
738            .unwrap();
739
740            storage
741                .create_disk_layout::<GraphDataF32VectorU32Data, _>(&storage_provider)
742                .unwrap();
743
744            let disk_layout_file = storage.disk_index_file();
745            let mut rust_disk_layout: Vec<u8> = Vec::new();
746            storage_provider
747                .open_reader(disk_layout_file.as_str())
748                .unwrap()
749                .read_to_end(&mut rust_disk_layout)
750                .unwrap();
751            let mut truth_disk_layout: Vec<u8> = Vec::new();
752            storage_provider
753                .open_reader(TRUTH_DISK_LAYOUT_HIGH_DIM)
754                .unwrap()
755                .read_to_end(&mut truth_disk_layout)
756                .unwrap();
757
758            // Assert that the metadata on rust disk index is compatible with the truth disk index.
759            assert!(
760                rust_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
761                    == truth_disk_layout[8..TRUTH_DISK_LAYOUT_METADATA_LENGTH]
762            );
763
764            // Assert that the rest of the disk index is identical with the truth disk index.
765            assert!(
766                rust_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
767                    == truth_disk_layout[DEFAULT_DISK_SECTOR_LEN..]
768            );
769            disk_layout_file
770        };
771
772        storage_provider
773            .delete(disk_layout_file.as_str())
774            .expect("Failed to delete file");
775    }
776
777    #[test]
778    fn disk_index_writer_rejects_small_block_size() {
779        let small_block_size = GraphMetadata::get_size() - 1;
780        let result = DiskIndexWriter::new(
781            "dataset".to_string(),
782            "index".to_string(),
783            Option::None,
784            small_block_size,
785        );
786
787        assert!(result.is_err());
788    }
789
790    struct ExpectedWriter {
791        dataset_file: String,
792        index_path_prefix: String,
793        pq_pivot_file: String,
794        compressed_pq_file: String,
795        disk_index_pq_pivot_file: String,
796        disk_index_compressed_pq_file: String,
797        associated_data_file: Option<String>,
798    }
799
800    fn assert_writer_eq_expected(writer: &DiskIndexWriter, expected: &ExpectedWriter) {
801        assert_eq!(writer.dataset_file(), &expected.dataset_file);
802        assert_eq!(writer.index_path_prefix(), &expected.index_path_prefix);
803
804        assert_eq!(writer.get_pq_pivot_file(), expected.pq_pivot_file);
805        assert_eq!(
806            writer.get_compressed_pq_pivot_file(),
807            expected.compressed_pq_file
808        );
809        assert_eq!(
810            writer.get_disk_index_pq_pivot_file(),
811            expected.disk_index_pq_pivot_file
812        );
813        assert_eq!(
814            writer.get_disk_index_compressed_pq_file(),
815            expected.disk_index_compressed_pq_file
816        );
817        assert_eq!(
818            writer.get_associated_data_file(),
819            expected.associated_data_file
820        );
821    }
822
823    #[test]
824    fn test_dataset_file_and_index_path_prefix() {
825        let dataset_file_name = "dataset_file.txt";
826        let index_path_prefix = "index_path_prefix";
827        let associated_data_file = "associated_data_file.txt";
828        let writer = DiskIndexWriter::new(
829            dataset_file_name.to_string(),
830            index_path_prefix.to_string(),
831            Some(associated_data_file.to_string()),
832            DEFAULT_DISK_SECTOR_LEN,
833        )
834        .unwrap();
835        let expected = ExpectedWriter {
836            dataset_file: dataset_file_name.to_string(),
837            index_path_prefix: index_path_prefix.to_string(),
838            pq_pivot_file: get_pq_pivot_file(index_path_prefix),
839            compressed_pq_file: get_compressed_pq_file(index_path_prefix),
840            disk_index_pq_pivot_file: get_disk_index_pq_pivot_file(index_path_prefix),
841            disk_index_compressed_pq_file: get_disk_index_compressed_pq_file(index_path_prefix),
842            associated_data_file: Some(associated_data_file.to_string()),
843        };
844        assert_writer_eq_expected(&writer, &expected);
845    }
846
847    #[test]
848    fn test_get_mem_index_file() {
849        let writer = DiskIndexWriter::new(
850            "test_index".to_string(),
851            "test_dataset".to_string(),
852            Option::None,
853            DEFAULT_DISK_SECTOR_LEN,
854        )
855        .unwrap();
856
857        assert_eq!(writer.get_mem_index_file(), "test_dataset_mem.index");
858    }
859
860    #[test]
861    fn test_get_merged_index_prefix() {
862        let writer = DiskIndexWriter::new(
863            "test_index".to_string(),
864            "test_dataset".to_string(),
865            Option::None,
866            DEFAULT_DISK_SECTOR_LEN,
867        )
868        .unwrap();
869
870        assert_eq!(
871            writer.get_merged_index_prefix(),
872            "test_dataset_mem.index_tempFiles"
873        );
874    }
875
876    #[test]
877    fn test_get_merged_index_subshard_id_map_file() {
878        let prefix = "test_prefix";
879        let shard = 5;
880        assert_eq!(
881            DiskIndexWriter::get_merged_index_subshard_id_map_file(prefix, shard),
882            "test_prefix_subshard-5_ids_uint32.bin"
883        );
884    }
885
886    #[test]
887    fn test_get_merged_index_subshard_data_file() {
888        let prefix = "test_prefix";
889        let shard = 5;
890        assert_eq!(
891            DiskIndexWriter::get_merged_index_subshard_data_file(prefix, shard),
892            "test_prefix_subshard-5.bin"
893        );
894    }
895
896    #[test]
897    fn test_get_merged_index_subshard_prefix() {
898        let prefix = "test_prefix";
899        let shard = 5;
900        assert_eq!(
901            DiskIndexWriter::get_merged_index_subshard_prefix(prefix, shard),
902            "test_prefix_subshard-5"
903        );
904    }
905
906    #[test]
907    fn test_get_merged_index_subshard_mem_index_file() {
908        let prefix = "test_prefix";
909        let shard = 5;
910        assert_eq!(
911            DiskIndexWriter::get_merged_index_subshard_mem_index_file(prefix, shard),
912            "test_prefix_subshard-5_mem.index"
913        );
914    }
915
916    #[test]
917    fn test_get_merged_index_subshard_mem_dataset_file() {
918        let prefix = "test_prefix";
919        assert_eq!(
920            DiskIndexWriter::get_merged_index_subshard_mem_dataset_file(prefix),
921            "test_prefix.data"
922        );
923    }
924
925    #[test]
926    fn test_disk_index_writer_state_uninitialized() {
927        let mut state = DiskIndexWriterState::<VirtualStorageProvider<OverlayFS>>::new();
928        let mut buf = vec![0u8; 16];
929        assert!(DiskIndexWriter::get_neighbors_number(&mut state).is_err());
930        assert!(DiskIndexWriter::copy_neighbors(&mut state, &mut buf).is_err());
931    }
932
933    // Compare that the index built in test is the same as the truth index. The truth index doesn't have associated data, we are only comparing the vector and neighbor data.
934    pub fn compare_disk_index_graphs(graph_data: &[u8], truth_graph_data: &[u8]) {
935        let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
936        let truth_graph_header = GraphHeader::try_from(&truth_graph_data[8..]).unwrap();
937
938        let test_node_per_block = graph_header.metadata().num_nodes_per_block;
939        let test_max_node_length = graph_header.metadata().node_len;
940
941        let truth_node_per_block = truth_graph_header.metadata().num_nodes_per_block;
942        let truth_max_node_length = truth_graph_header.metadata().node_len;
943
944        assert_eq!(
945            graph_header.metadata().num_pts,
946            truth_graph_header.metadata().num_pts
947        );
948
949        assert_eq!(
950            graph_header.metadata().dims,
951            truth_graph_header.metadata().dims
952        );
953
954        let num_pts = graph_header.metadata().num_pts as usize;
955        let dim = graph_header.metadata().dims;
956
957        for idx in 0..num_pts {
958            let test_node_id_offset = node_data_offset(
959                idx,
960                test_max_node_length as usize,
961                test_node_per_block as usize,
962                DEFAULT_DISK_SECTOR_LEN,
963            );
964
965            let truth_node_id_offset = node_data_offset(
966                idx,
967                truth_max_node_length as usize,
968                truth_node_per_block as usize,
969                DEFAULT_DISK_SECTOR_LEN,
970            );
971
972            // Assert that the vector data is the same between the test and truth graphs for this node.
973            assert_eq!(
974                &graph_data
975                    [test_node_id_offset..test_node_id_offset + dim * std::mem::size_of::<f32>()],
976                &truth_graph_data
977                    [truth_node_id_offset..truth_node_id_offset + dim * std::mem::size_of::<f32>()]
978            );
979
980            // Assert that the neighbor count is the same between the test and truth graphs for this node.
981            let test_nbr_cnt_offset = test_node_id_offset + dim * std::mem::size_of::<f32>();
982            let truth_nbr_cnt_offset = truth_node_id_offset + dim * std::mem::size_of::<f32>();
983
984            let test_nbr_count = u32::from_le_bytes([
985                graph_data[test_nbr_cnt_offset],
986                graph_data[test_nbr_cnt_offset + 1],
987                graph_data[test_nbr_cnt_offset + 2],
988                graph_data[test_nbr_cnt_offset + 3],
989            ]);
990
991            let truth_nbr_count = u32::from_le_bytes([
992                truth_graph_data[truth_nbr_cnt_offset],
993                truth_graph_data[truth_nbr_cnt_offset + 1],
994                truth_graph_data[truth_nbr_cnt_offset + 2],
995                truth_graph_data[truth_nbr_cnt_offset + 3],
996            ]);
997
998            assert_eq!(test_nbr_count, truth_nbr_count);
999
1000            // Assert the neighbors (u32) are the same between the test and truth graphs for this node.
1001            let test_nbr_offset = test_nbr_cnt_offset + 4;
1002            let truth_nbr_offset = truth_nbr_cnt_offset + 4;
1003            assert_eq!(
1004                graph_data[test_nbr_offset..test_nbr_offset + test_nbr_count as usize * 4],
1005                truth_graph_data[truth_nbr_offset..truth_nbr_offset + truth_nbr_count as usize * 4]
1006            );
1007        }
1008    }
1009
1010    // Compare that the associated data in the index graph built in test is the same as the associated data input.
1011    pub fn compare_disk_index_graphs_associated_data<AssociatedDataType>(
1012        graph_data: &[u8],
1013        associated_data: &[u8],
1014    ) {
1015        let graph_header = GraphHeader::try_from(&graph_data[8..]).unwrap();
1016        let test_node_per_block = graph_header.metadata().num_nodes_per_block;
1017        let test_max_node_length = graph_header.metadata().node_len as usize;
1018
1019        let mut associated_data_offset = 0;
1020        let data_npts = u32::from_le_bytes([
1021            associated_data[associated_data_offset],
1022            associated_data[associated_data_offset + 1],
1023            associated_data[associated_data_offset + 2],
1024            associated_data[associated_data_offset + 3],
1025        ]) as usize;
1026
1027        associated_data_offset = 4;
1028        let _associated_data_length = u32::from_le_bytes([
1029            associated_data[associated_data_offset],
1030            associated_data[associated_data_offset + 1],
1031            associated_data[associated_data_offset + 2],
1032            associated_data[associated_data_offset + 3],
1033        ]) as usize;
1034
1035        let num_pts = graph_header.metadata().num_pts as usize;
1036        assert_eq!(num_pts, data_npts);
1037
1038        associated_data_offset = 8;
1039
1040        for idx in 0..num_pts {
1041            let test_node_id_offset = node_data_offset(
1042                idx,
1043                test_max_node_length,
1044                test_node_per_block as usize,
1045                DEFAULT_DISK_SECTOR_LEN,
1046            );
1047
1048            let node_buf_end = test_node_id_offset + test_max_node_length;
1049
1050            assert_eq!(
1051                graph_data[(node_buf_end - mem::size_of::<AssociatedDataType>())..node_buf_end],
1052                associated_data[associated_data_offset
1053                    ..associated_data_offset + mem::size_of::<AssociatedDataType>()]
1054            );
1055
1056            associated_data_offset += mem::size_of::<AssociatedDataType>();
1057        }
1058    }
1059
1060    pub fn node_data_offset(
1061        node_id: usize,
1062        node_length: usize,
1063        nodes_per_block: usize,
1064        block_size: usize,
1065    ) -> usize {
1066        let block_id = node_id / nodes_per_block;
1067        let node_id_in_block = node_id % nodes_per_block;
1068        let offset = block_id * block_size + node_id_in_block * node_length;
1069        offset + block_size
1070    }
1071}