Skip to main content

diskann_disk/storage/
disk_index_reader.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::{marker::PhantomData, sync::Arc};
6
7use diskann::ANNResult;
8use diskann_providers::storage::StorageReadProvider;
9use diskann_providers::{model::pq::PQData, storage::PQStorage, utils::load_metadata_from_file};
10use tracing::info;
11
12/// This struct is used by the DiskIndexSearcher to read the index data from storage. Noted that the index data here is different from index graph,
13/// It includes the PQ data, pivot table, and the warmup query data.
14/// The Storage acts as a provider to read the data from storage system.
15/// The storage provider should be provided as a generic type and be specified by the caller when it initializes the DiskIndexSearcher.
16pub struct DiskIndexReader<VectorType> {
17    phantom: PhantomData<VectorType>,
18
19    pq_data: Arc<PQData>,
20
21    num_points: usize,
22}
23
24impl<VectorType> DiskIndexReader<VectorType> {
25    /// Create DiskIndexReader instance
26    pub fn new<Storage: StorageReadProvider>(
27        pq_pivot_path: String,
28        pq_compressed_data_path: String,
29        storage_provider: &Storage,
30    ) -> ANNResult<Self> {
31        let pq_storage = PQStorage::new(&pq_pivot_path, &pq_compressed_data_path, None);
32        let pq_pivot_table = pq_storage.load_pq_pivots_bin::<Storage>(
33            &pq_pivot_path,
34            0, // Use 0 to infer num_pq_chunks from the file
35            storage_provider,
36        )?;
37
38        // Auto-detect number of points from compressed PQ file metadata
39        let metadata = load_metadata_from_file(storage_provider, &pq_compressed_data_path)?;
40
41        let pq_compressed_data = PQStorage::load_pq_compressed_vectors_bin::<Storage>(
42            &pq_compressed_data_path,
43            metadata.npoints(),
44            pq_pivot_table.get_num_chunks(),
45            storage_provider,
46        )?;
47        info!(
48            "Loaded PQ centroids and in-memory compressed vectors. #points:{} #pq_chunks: {}",
49            metadata.npoints(),
50            pq_pivot_table.get_num_chunks()
51        );
52
53        Ok(DiskIndexReader {
54            phantom: PhantomData,
55            pq_data: Arc::<PQData>::new(PQData::new(pq_pivot_table, pq_compressed_data)?),
56            num_points: metadata.npoints(),
57        })
58    }
59
60    pub fn get_pq_data(&self) -> Arc<PQData> {
61        Arc::clone(&self.pq_data)
62    }
63
64    pub fn get_num_points(&self) -> usize {
65        self.num_points
66    }
67}
68
69#[cfg(test)]
70mod disk_index_storage_test {
71    use diskann::ANNErrorKind;
72    use diskann_providers::storage::VirtualStorageProvider;
73    use diskann_utils::test_data_root;
74    use vfs::OverlayFS;
75
76    use super::*;
77
78    #[test]
79    fn load_pivot_test() {
80        let pivot_file_prefix: &str = "/sift/siftsmall_learn";
81        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
82        let storage = DiskIndexReader::<f32>::new::<VirtualStorageProvider<OverlayFS>>(
83            pivot_file_prefix.to_string() + "_pq_pivots.bin",
84            pivot_file_prefix.to_string() + "_pq_compressed.bin",
85            &storage_provider,
86        )
87        .unwrap();
88
89        // Creating the backend storage is sufficient to verify the constraints on the
90        // PQ schema as both `FixedChunkPQTable` and the possible alternatives (such as
91        // `quantization::TransposedTable`) check for the well-formedness of the schema.
92        let _: Arc<PQData> = storage.get_pq_data();
93    }
94
95    #[test]
96    fn load_pivot_file_not_exist_test() {
97        let pivot_file_prefix: &str = "/sift/siftsmall_learn_file_not_exist";
98        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
99        let err = match DiskIndexReader::<f32>::new::<VirtualStorageProvider<OverlayFS>>(
100            pivot_file_prefix.to_string() + "_pq_pivots.bin",
101            pivot_file_prefix.to_string() + "_pq_compressed.bin",
102            &storage_provider,
103        ) {
104            Ok(_) => panic!("this function should not have succeeded"),
105            Err(err) => err,
106        };
107        assert_eq!(err.kind(), ANNErrorKind::PQError);
108        assert!(err.to_string().contains("PQ k-means pivot file not found"));
109    }
110
111    #[test]
112    fn test_get_num_points() {
113        let pivot_file_prefix: &str = "/sift/siftsmall_learn";
114        let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
115        let storage = DiskIndexReader::<f32>::new::<VirtualStorageProvider<OverlayFS>>(
116            pivot_file_prefix.to_string() + "_pq_pivots.bin",
117            pivot_file_prefix.to_string() + "_pq_compressed.bin",
118            &storage_provider,
119        )
120        .unwrap();
121
122        let num_points = storage.get_num_points();
123        assert_eq!(num_points, 25000);
124    }
125}