diskann_disk/storage/
cached_reader.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::io::{Read, Seek};
6
7use diskann::{ANNError, ANNResult};
8use diskann_providers::storage::StorageReadProvider;
9use tracing::info;
10
11/// Sequential cached reads with a generic storage provider with read access.
12pub struct CachedReader<Storage>
13where
14    Storage: StorageReadProvider,
15{
16    /// File reader
17    reader: Storage::Reader,
18
19    /// # bytes to cache in one shot read
20    cache_size: u64,
21
22    /// Underlying buf for cache
23    cache_buf: Vec<u8>,
24
25    /// Offset into cache_buf for cur_pos
26    cur_off: u64,
27
28    /// File size
29    size: u64,
30}
31
32impl<Storage> CachedReader<Storage>
33where
34    Storage: StorageReadProvider,
35{
36    pub fn new(
37        filename: &str,
38        cache_size: u64,
39        storage_provider: &Storage,
40    ) -> std::io::Result<Self> {
41        info!("Opening: {}", filename);
42        let mut reader = storage_provider.open_reader(filename)?;
43        let size = storage_provider.get_length(filename)?;
44
45        let cache_size = cache_size.min(size);
46        let mut cache_buf = vec![0; cache_size as usize];
47        reader.read_exact(&mut cache_buf)?;
48        info!(
49            "Opened: {}, size: {}, cache_size: {}",
50            filename, size, cache_size
51        );
52
53        Ok(Self {
54            reader,
55            cache_size,
56            cache_buf,
57            cur_off: 0,
58            size,
59        })
60    }
61
62    pub fn get_file_size(&self) -> u64 {
63        self.size
64    }
65
66    pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> {
67        let n_bytes = read_buf.len() as u64;
68        if n_bytes <= (self.cache_size - self.cur_off) {
69            // case 1: cache contains all data
70            read_buf.copy_from_slice(
71                &self.cache_buf
72                    [(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)],
73            );
74            self.cur_off += n_bytes;
75        } else {
76            // case 2: cache contains some data
77            let cached_bytes = self.cache_size - self.cur_off;
78            if n_bytes - cached_bytes > self.size - self.reader.stream_position()? {
79                return Err(ANNError::log_index_error(format_args!(
80                    "Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}",
81                    n_bytes,
82                    cached_bytes,
83                    self.size,
84                    self.reader.stream_position()?
85                )));
86            }
87
88            read_buf[..cached_bytes as usize]
89                .copy_from_slice(&self.cache_buf[self.cur_off as usize..]);
90            // go to disk and fetch more data
91            self.reader
92                .read_exact(&mut read_buf[cached_bytes as usize..])?;
93            // reset cur off
94            self.cur_off = self.cache_size;
95
96            let size_left = self.size - self.reader.stream_position()?;
97            if size_left >= self.cache_size {
98                self.reader.read_exact(&mut self.cache_buf)?;
99                self.cur_off = 0;
100            }
101            // note that if size_left < cache_size, then cur_off = cache_size,
102            // so subsequent reads will all be directly from file
103        }
104        Ok(())
105    }
106
107    pub fn read_u32(&mut self) -> ANNResult<u32> {
108        let mut bytes = [0u8; 4];
109        self.read(&mut bytes)?;
110        Ok(u32::from_le_bytes(bytes))
111    }
112}
113
114#[cfg(test)]
115mod cached_reader_test {
116
117    use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider};
118    use vfs::MemoryFS;
119
120    use super::*;
121
122    #[test]
123    fn cached_reader_works() {
124        let file_name = "/cached_reader_works_test2.bin";
125        //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
126        let data: [u8; 72] = [
127            2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
128            0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
129            0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
130            0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
131            0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
132        ];
133        let filesystem = MemoryFS::new();
134        let storage_provider = VirtualStorageProvider::new(filesystem);
135        {
136            let mut writer = storage_provider.create_for_write(file_name).unwrap();
137            writer.write_all(&data).unwrap();
138        }
139
140        let mut reader =
141            CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
142                .unwrap();
143        assert_eq!(reader.get_file_size(), 72);
144        assert_eq!(reader.cache_size, 8);
145
146        let mut all_from_cache_buf = vec![0; 4];
147        reader.read(all_from_cache_buf.as_mut_slice()).unwrap();
148        assert_eq!(all_from_cache_buf, [2, 0, 1, 2]);
149        assert_eq!(reader.cur_off, 4);
150
151        let mut partial_from_cache_buf = vec![0; 6];
152        reader.read(partial_from_cache_buf.as_mut_slice()).unwrap();
153        assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]);
154        assert_eq!(reader.cur_off, 0);
155
156        let mut over_cache_size_buf = vec![0; 60];
157        reader.read(over_cache_size_buf.as_mut_slice()).unwrap();
158        assert_eq!(
159            over_cache_size_buf,
160            [
161                0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
162                0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00,
163                0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
164                0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00,
165                0x70, 0x41, 0x00, 0x11
166            ]
167        );
168
169        let mut remaining_less_than_cache_size_buf = vec![0; 2];
170        reader
171            .read(remaining_less_than_cache_size_buf.as_mut_slice())
172            .unwrap();
173        assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]);
174        assert_eq!(reader.cur_off, reader.cache_size);
175
176        storage_provider
177            .delete(file_name)
178            .expect("Failed to delete file");
179    }
180
181    #[test]
182    #[should_panic(expected = "Reading beyond end of file")]
183    fn failed_for_reading_beyond_end_of_file() {
184        let file_name = "/failed_for_reading_beyond_end_of_file_test_2.bin";
185        //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
186        let data: [u8; 72] = [
187            2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
188            0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
189            0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
190            0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
191            0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
192        ];
193        let filesystem = MemoryFS::new();
194        let storage_provider = VirtualStorageProvider::new(filesystem);
195        {
196            let mut writer = storage_provider.create_for_write(file_name).unwrap();
197            writer.write_all(&data).unwrap();
198        }
199
200        let mut reader =
201            CachedReader::<VirtualStorageProvider<MemoryFS>>::new(file_name, 8, &storage_provider)
202                .unwrap();
203        storage_provider
204            .delete(file_name)
205            .expect("Failed to delete file");
206
207        let mut over_size_buf = vec![0; 73];
208        reader.read(over_size_buf.as_mut_slice()).unwrap();
209    }
210}