Skip to main content

diskann_disk/storage/
cached_reader.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::io::{Read, Seek};
6
7use diskann::{ANNError, ANNResult};
8use diskann_providers::storage::StorageReadProvider;
9use tracing::info;
10
11/// Sequential cached reads with a generic storage provider with read access.
12pub struct CachedReader<Storage>
13where
14    Storage: StorageReadProvider,
15{
16    /// File reader
17    reader: Storage::Reader,
18
19    /// # bytes to cache in one shot read
20    cache_size: u64,
21
22    /// Underlying buf for cache
23    cache_buf: Vec<u8>,
24
25    /// Offset into cache_buf for cur_pos
26    cur_off: u64,
27
28    /// File size
29    size: u64,
30}
31
32impl<Storage> CachedReader<Storage>
33where
34    Storage: StorageReadProvider,
35{
36    pub fn new(
37        filename: &str,
38        cache_size: u64,
39        storage_provider: &Storage,
40    ) -> std::io::Result<Self> {
41        info!("Opening: {}", filename);
42        let mut reader = storage_provider.open_reader(filename)?;
43        let size = storage_provider.get_length(filename)?;
44
45        let cache_size = cache_size.min(size);
46        let mut cache_buf = vec![0; cache_size as usize];
47        reader.read_exact(&mut cache_buf)?;
48        info!(
49            "Opened: {}, size: {}, cache_size: {}",
50            filename, size, cache_size
51        );
52
53        Ok(Self {
54            reader,
55            cache_size,
56            cache_buf,
57            cur_off: 0,
58            size,
59        })
60    }
61
62    pub fn get_file_size(&self) -> u64 {
63        self.size
64    }
65
66    pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> {
67        let n_bytes = read_buf.len() as u64;
68        if n_bytes <= (self.cache_size - self.cur_off) {
69            // case 1: cache contains all data
70            read_buf.copy_from_slice(
71                &self.cache_buf
72                    [(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)],
73            );
74            self.cur_off += n_bytes;
75        } else {
76            // case 2: cache contains some data
77            let cached_bytes = self.cache_size - self.cur_off;
78            if n_bytes - cached_bytes > self.size - self.reader.stream_position()? {
79                return Err(ANNError::log_index_error(format_args!(
80                    "Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}",
81                    n_bytes,
82                    cached_bytes,
83                    self.size,
84                    self.reader.stream_position()?
85                )));
86            }
87
88            read_buf[..cached_bytes as usize]
89                .copy_from_slice(&self.cache_buf[self.cur_off as usize..]);
90            // go to disk and fetch more data
91            self.reader
92                .read_exact(&mut read_buf[cached_bytes as usize..])?;
93            // reset cur off
94            self.cur_off = self.cache_size;
95
96            let size_left = self.size - self.reader.stream_position()?;
97            if size_left >= self.cache_size {
98                self.reader.read_exact(&mut self.cache_buf)?;
99                self.cur_off = 0;
100            }
101            // note that if size_left < cache_size, then cur_off = cache_size,
102            // so subsequent reads will all be directly from file
103        }
104        Ok(())
105    }
106
107    pub fn read_u32(&mut self) -> ANNResult<u32> {
108        let mut bytes = [0u8; 4];
109        self.read(&mut bytes)?;
110        Ok(u32::from_le_bytes(bytes))
111    }
112}
113
114#[cfg(test)]
115mod cached_reader_test {
116
117    use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider};
118    use vfs::MemoryFS;
119
120    use super::*;
121
122    #[test]
123    fn cached_reader_works() {
124        let file_name = "/cached_reader_works_test2.bin";
125        //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
126        let data: [u8; 72] = [
127            2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
128            0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
129            0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
130            0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
131            0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
132        ];
133        let storage_provider = VirtualStorageProvider::new_memory();
134        {
135            let mut writer = storage_provider.create_for_write(file_name).unwrap();
136            writer.write_all(&data).unwrap();
137        }
138
139        let mut reader =
140            CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
141                .unwrap();
142        assert_eq!(reader.get_file_size(), 72);
143        assert_eq!(reader.cache_size, 8);
144
145        let mut all_from_cache_buf = vec![0; 4];
146        reader.read(all_from_cache_buf.as_mut_slice()).unwrap();
147        assert_eq!(all_from_cache_buf, [2, 0, 1, 2]);
148        assert_eq!(reader.cur_off, 4);
149
150        let mut partial_from_cache_buf = vec![0; 6];
151        reader.read(partial_from_cache_buf.as_mut_slice()).unwrap();
152        assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]);
153        assert_eq!(reader.cur_off, 0);
154
155        let mut over_cache_size_buf = vec![0; 60];
156        reader.read(over_cache_size_buf.as_mut_slice()).unwrap();
157        assert_eq!(
158            over_cache_size_buf,
159            [
160                0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
161                0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00,
162                0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
163                0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00,
164                0x70, 0x41, 0x00, 0x11
165            ]
166        );
167
168        let mut remaining_less_than_cache_size_buf = vec![0; 2];
169        reader
170            .read(remaining_less_than_cache_size_buf.as_mut_slice())
171            .unwrap();
172        assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]);
173        assert_eq!(reader.cur_off, reader.cache_size);
174
175        storage_provider
176            .delete(file_name)
177            .expect("Failed to delete file");
178    }
179
180    #[test]
181    #[should_panic(expected = "Reading beyond end of file")]
182    fn failed_for_reading_beyond_end_of_file() {
183        let file_name = "/failed_for_reading_beyond_end_of_file_test_2.bin";
184        //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
185        let data: [u8; 72] = [
186            2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
187            0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
188            0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
189            0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
190            0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
191        ];
192        let storage_provider = VirtualStorageProvider::new_memory();
193        {
194            let mut writer = storage_provider.create_for_write(file_name).unwrap();
195            writer.write_all(&data).unwrap();
196        }
197
198        let mut reader =
199            CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
200                .unwrap();
201        storage_provider
202            .delete(file_name)
203            .expect("Failed to delete file");
204
205        let mut over_size_buf = vec![0; 73];
206        reader.read(over_size_buf.as_mut_slice()).unwrap();
207    }
208}