1use std::io::{Read, Seek};
6
7use diskann::{ANNError, ANNResult};
8use diskann_providers::storage::StorageReadProvider;
9use tracing::info;
10
11pub struct CachedReader<Storage>
13where
14 Storage: StorageReadProvider,
15{
16 reader: Storage::Reader,
18
19 cache_size: u64,
21
22 cache_buf: Vec<u8>,
24
25 cur_off: u64,
27
28 size: u64,
30}
31
32impl<Storage> CachedReader<Storage>
33where
34 Storage: StorageReadProvider,
35{
36 pub fn new(
37 filename: &str,
38 cache_size: u64,
39 storage_provider: &Storage,
40 ) -> std::io::Result<Self> {
41 info!("Opening: {}", filename);
42 let mut reader = storage_provider.open_reader(filename)?;
43 let size = storage_provider.get_length(filename)?;
44
45 let cache_size = cache_size.min(size);
46 let mut cache_buf = vec![0; cache_size as usize];
47 reader.read_exact(&mut cache_buf)?;
48 info!(
49 "Opened: {}, size: {}, cache_size: {}",
50 filename, size, cache_size
51 );
52
53 Ok(Self {
54 reader,
55 cache_size,
56 cache_buf,
57 cur_off: 0,
58 size,
59 })
60 }
61
62 pub fn get_file_size(&self) -> u64 {
63 self.size
64 }
65
66 pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> {
67 let n_bytes = read_buf.len() as u64;
68 if n_bytes <= (self.cache_size - self.cur_off) {
69 read_buf.copy_from_slice(
71 &self.cache_buf
72 [(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)],
73 );
74 self.cur_off += n_bytes;
75 } else {
76 let cached_bytes = self.cache_size - self.cur_off;
78 if n_bytes - cached_bytes > self.size - self.reader.stream_position()? {
79 return Err(ANNError::log_index_error(format_args!(
80 "Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}",
81 n_bytes,
82 cached_bytes,
83 self.size,
84 self.reader.stream_position()?
85 )));
86 }
87
88 read_buf[..cached_bytes as usize]
89 .copy_from_slice(&self.cache_buf[self.cur_off as usize..]);
90 self.reader
92 .read_exact(&mut read_buf[cached_bytes as usize..])?;
93 self.cur_off = self.cache_size;
95
96 let size_left = self.size - self.reader.stream_position()?;
97 if size_left >= self.cache_size {
98 self.reader.read_exact(&mut self.cache_buf)?;
99 self.cur_off = 0;
100 }
101 }
104 Ok(())
105 }
106
107 pub fn read_u32(&mut self) -> ANNResult<u32> {
108 let mut bytes = [0u8; 4];
109 self.read(&mut bytes)?;
110 Ok(u32::from_le_bytes(bytes))
111 }
112}
113
114#[cfg(test)]
115mod cached_reader_test {
116
117 use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider};
118 use vfs::MemoryFS;
119
120 use super::*;
121
122 #[test]
123 fn cached_reader_works() {
124 let file_name = "/cached_reader_works_test2.bin";
125 let data: [u8; 72] = [
127 2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
128 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
129 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
130 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
131 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
132 ];
133 let filesystem = MemoryFS::new();
134 let storage_provider = VirtualStorageProvider::new(filesystem);
135 {
136 let mut writer = storage_provider.create_for_write(file_name).unwrap();
137 writer.write_all(&data).unwrap();
138 }
139
140 let mut reader =
141 CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
142 .unwrap();
143 assert_eq!(reader.get_file_size(), 72);
144 assert_eq!(reader.cache_size, 8);
145
146 let mut all_from_cache_buf = vec![0; 4];
147 reader.read(all_from_cache_buf.as_mut_slice()).unwrap();
148 assert_eq!(all_from_cache_buf, [2, 0, 1, 2]);
149 assert_eq!(reader.cur_off, 4);
150
151 let mut partial_from_cache_buf = vec![0; 6];
152 reader.read(partial_from_cache_buf.as_mut_slice()).unwrap();
153 assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]);
154 assert_eq!(reader.cur_off, 0);
155
156 let mut over_cache_size_buf = vec![0; 60];
157 reader.read(over_cache_size_buf.as_mut_slice()).unwrap();
158 assert_eq!(
159 over_cache_size_buf,
160 [
161 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
162 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00,
163 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
164 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00,
165 0x70, 0x41, 0x00, 0x11
166 ]
167 );
168
169 let mut remaining_less_than_cache_size_buf = vec![0; 2];
170 reader
171 .read(remaining_less_than_cache_size_buf.as_mut_slice())
172 .unwrap();
173 assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]);
174 assert_eq!(reader.cur_off, reader.cache_size);
175
176 storage_provider
177 .delete(file_name)
178 .expect("Failed to delete file");
179 }
180
181 #[test]
182 #[should_panic(expected = "Reading beyond end of file")]
183 fn failed_for_reading_beyond_end_of_file() {
184 let file_name = "/failed_for_reading_beyond_end_of_file_test_2.bin";
185 let data: [u8; 72] = [
187 2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
188 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
189 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
190 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
191 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
192 ];
193 let filesystem = MemoryFS::new();
194 let storage_provider = VirtualStorageProvider::new(filesystem);
195 {
196 let mut writer = storage_provider.create_for_write(file_name).unwrap();
197 writer.write_all(&data).unwrap();
198 }
199
200 let mut reader =
201 CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
202 .unwrap();
203 storage_provider
204 .delete(file_name)
205 .expect("Failed to delete file");
206
207 let mut over_size_buf = vec![0; 73];
208 reader.read(over_size_buf.as_mut_slice()).unwrap();
209 }
210}