diskann_disk/storage/
disk_index_reader.rs1use std::{marker::PhantomData, sync::Arc};
6
7use diskann::ANNResult;
8use diskann_providers::storage::StorageReadProvider;
9use diskann_providers::{model::pq::PQData, storage::PQStorage, utils::load_metadata_from_file};
10use tracing::info;
11
12pub struct DiskIndexReader<VectorType> {
17 phantom: PhantomData<VectorType>,
18
19 pq_data: Arc<PQData>,
20
21 num_points: usize,
22}
23
24impl<VectorType> DiskIndexReader<VectorType> {
25 pub fn new<Storage: StorageReadProvider>(
27 pq_pivot_path: String,
28 pq_compressed_data_path: String,
29 storage_provider: &Storage,
30 ) -> ANNResult<Self> {
31 let pq_storage = PQStorage::new(&pq_pivot_path, &pq_compressed_data_path, None);
32 let pq_pivot_table = pq_storage.load_pq_pivots_bin::<Storage>(
33 &pq_pivot_path,
34 0, storage_provider,
36 )?;
37
38 let metadata = load_metadata_from_file(storage_provider, &pq_compressed_data_path)?;
40
41 let pq_compressed_data = PQStorage::load_pq_compressed_vectors_bin::<Storage>(
42 &pq_compressed_data_path,
43 metadata.npoints,
44 pq_pivot_table.get_num_chunks(),
45 storage_provider,
46 )?;
47 info!(
48 "Loaded PQ centroids and in-memory compressed vectors. #points:{} #pq_chunks: {}",
49 metadata.npoints,
50 pq_pivot_table.get_num_chunks()
51 );
52
53 Ok(DiskIndexReader {
54 phantom: PhantomData,
55 pq_data: Arc::<PQData>::new(PQData::new(pq_pivot_table, pq_compressed_data)?),
56 num_points: metadata.npoints,
57 })
58 }
59
60 pub fn get_pq_data(&self) -> Arc<PQData> {
61 Arc::clone(&self.pq_data)
62 }
63
64 pub fn get_num_points(&self) -> usize {
65 self.num_points
66 }
67}
68
69#[cfg(test)]
70mod disk_index_storage_test {
71 use diskann::ANNErrorKind;
72 use diskann_providers::storage::VirtualStorageProvider;
73 use diskann_utils::test_data_root;
74 use vfs::OverlayFS;
75
76 use super::*;
77
78 #[test]
79 fn load_pivot_test() {
80 let pivot_file_prefix: &str = "/sift/siftsmall_learn";
81 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
82 let storage = DiskIndexReader::<f32>::new::<VirtualStorageProvider<OverlayFS>>(
83 pivot_file_prefix.to_string() + "_pq_pivots.bin",
84 pivot_file_prefix.to_string() + "_pq_compressed.bin",
85 &storage_provider,
86 )
87 .unwrap();
88
89 let _: Arc<PQData> = storage.get_pq_data();
93 }
94
95 #[test]
96 fn load_pivot_file_not_exist_test() {
97 let pivot_file_prefix: &str = "/sift/siftsmall_learn_file_not_exist";
98 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
99 let err = match DiskIndexReader::<f32>::new::<VirtualStorageProvider<OverlayFS>>(
100 pivot_file_prefix.to_string() + "_pq_pivots.bin",
101 pivot_file_prefix.to_string() + "_pq_compressed.bin",
102 &storage_provider,
103 ) {
104 Ok(_) => panic!("this function should not have succeeded"),
105 Err(err) => err,
106 };
107 assert_eq!(err.kind(), ANNErrorKind::PQError);
108 assert!(err.to_string().contains("PQ k-means pivot file not found"));
109 }
110
111 #[test]
112 fn test_get_num_points() {
113 let pivot_file_prefix: &str = "/sift/siftsmall_learn";
114 let storage_provider = VirtualStorageProvider::new_overlay(test_data_root());
115 let storage = DiskIndexReader::<f32>::new::<VirtualStorageProvider<OverlayFS>>(
116 pivot_file_prefix.to_string() + "_pq_pivots.bin",
117 pivot_file_prefix.to_string() + "_pq_compressed.bin",
118 &storage_provider,
119 )
120 .unwrap();
121
122 let num_points = storage.get_num_points();
123 assert_eq!(num_points, 25000);
124 }
125}