1use std::io::{Read, Seek};
6
7use diskann::{ANNError, ANNResult};
8use diskann_providers::storage::StorageReadProvider;
9use tracing::info;
10
11pub struct CachedReader<Storage>
13where
14 Storage: StorageReadProvider,
15{
16 reader: Storage::Reader,
18
19 cache_size: u64,
21
22 cache_buf: Vec<u8>,
24
25 cur_off: u64,
27
28 size: u64,
30}
31
32impl<Storage> CachedReader<Storage>
33where
34 Storage: StorageReadProvider,
35{
36 pub fn new(
37 filename: &str,
38 cache_size: u64,
39 storage_provider: &Storage,
40 ) -> std::io::Result<Self> {
41 info!("Opening: {}", filename);
42 let mut reader = storage_provider.open_reader(filename)?;
43 let size = storage_provider.get_length(filename)?;
44
45 let cache_size = cache_size.min(size);
46 let mut cache_buf = vec![0; cache_size as usize];
47 reader.read_exact(&mut cache_buf)?;
48 info!(
49 "Opened: {}, size: {}, cache_size: {}",
50 filename, size, cache_size
51 );
52
53 Ok(Self {
54 reader,
55 cache_size,
56 cache_buf,
57 cur_off: 0,
58 size,
59 })
60 }
61
62 pub fn get_file_size(&self) -> u64 {
63 self.size
64 }
65
66 pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> {
67 let n_bytes = read_buf.len() as u64;
68 if n_bytes <= (self.cache_size - self.cur_off) {
69 read_buf.copy_from_slice(
71 &self.cache_buf
72 [(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)],
73 );
74 self.cur_off += n_bytes;
75 } else {
76 let cached_bytes = self.cache_size - self.cur_off;
78 if n_bytes - cached_bytes > self.size - self.reader.stream_position()? {
79 return Err(ANNError::log_index_error(format_args!(
80 "Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}",
81 n_bytes,
82 cached_bytes,
83 self.size,
84 self.reader.stream_position()?
85 )));
86 }
87
88 read_buf[..cached_bytes as usize]
89 .copy_from_slice(&self.cache_buf[self.cur_off as usize..]);
90 self.reader
92 .read_exact(&mut read_buf[cached_bytes as usize..])?;
93 self.cur_off = self.cache_size;
95
96 let size_left = self.size - self.reader.stream_position()?;
97 if size_left >= self.cache_size {
98 self.reader.read_exact(&mut self.cache_buf)?;
99 self.cur_off = 0;
100 }
101 }
104 Ok(())
105 }
106
107 pub fn read_u32(&mut self) -> ANNResult<u32> {
108 let mut bytes = [0u8; 4];
109 self.read(&mut bytes)?;
110 Ok(u32::from_le_bytes(bytes))
111 }
112}
113
114#[cfg(test)]
115mod cached_reader_test {
116
117 use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider};
118 use vfs::MemoryFS;
119
120 use super::*;
121
122 #[test]
123 fn cached_reader_works() {
124 let file_name = "/cached_reader_works_test2.bin";
125 let data: [u8; 72] = [
127 2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
128 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
129 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
130 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
131 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
132 ];
133 let storage_provider = VirtualStorageProvider::new_memory();
134 {
135 let mut writer = storage_provider.create_for_write(file_name).unwrap();
136 writer.write_all(&data).unwrap();
137 }
138
139 let mut reader =
140 CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
141 .unwrap();
142 assert_eq!(reader.get_file_size(), 72);
143 assert_eq!(reader.cache_size, 8);
144
145 let mut all_from_cache_buf = vec![0; 4];
146 reader.read(all_from_cache_buf.as_mut_slice()).unwrap();
147 assert_eq!(all_from_cache_buf, [2, 0, 1, 2]);
148 assert_eq!(reader.cur_off, 4);
149
150 let mut partial_from_cache_buf = vec![0; 6];
151 reader.read(partial_from_cache_buf.as_mut_slice()).unwrap();
152 assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]);
153 assert_eq!(reader.cur_off, 0);
154
155 let mut over_cache_size_buf = vec![0; 60];
156 reader.read(over_cache_size_buf.as_mut_slice()).unwrap();
157 assert_eq!(
158 over_cache_size_buf,
159 [
160 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
161 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00,
162 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
163 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00,
164 0x70, 0x41, 0x00, 0x11
165 ]
166 );
167
168 let mut remaining_less_than_cache_size_buf = vec![0; 2];
169 reader
170 .read(remaining_less_than_cache_size_buf.as_mut_slice())
171 .unwrap();
172 assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]);
173 assert_eq!(reader.cur_off, reader.cache_size);
174
175 storage_provider
176 .delete(file_name)
177 .expect("Failed to delete file");
178 }
179
180 #[test]
181 #[should_panic(expected = "Reading beyond end of file")]
182 fn failed_for_reading_beyond_end_of_file() {
183 let file_name = "/failed_for_reading_beyond_end_of_file_test_2.bin";
184 let data: [u8; 72] = [
186 2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
187 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
188 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
189 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
190 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
191 ];
192 let storage_provider = VirtualStorageProvider::new_memory();
193 {
194 let mut writer = storage_provider.create_for_write(file_name).unwrap();
195 writer.write_all(&data).unwrap();
196 }
197
198 let mut reader =
199 CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
200 .unwrap();
201 storage_provider
202 .delete(file_name)
203 .expect("Failed to delete file");
204
205 let mut over_size_buf = vec![0; 73];
206 reader.read(over_size_buf.as_mut_slice()).unwrap();
207 }
208}