Skip to main content

diskann_disk/storage/
cached_writer.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use std::io::{Seek, SeekFrom, Write};
6
7use diskann_providers::storage::StorageWriteProvider;
8use tracing::info;
9
10/// Sequential cached writes with a generic storage provider with write access.
11pub struct CachedWriter<Storage>
12where
13    Storage: StorageWriteProvider,
14{
15    /// File writer
16    writer: Storage::Writer,
17
18    /// # bytes to cache for one shot write
19    cache_size: u64,
20
21    /// Underlying buf for cache
22    cache_buf: Vec<u8>,
23
24    /// Offset into cache_buf for cur_pos
25    cur_off: u64,
26
27    /// File size
28    fsize: u64,
29}
30
31impl<Storage> CachedWriter<Storage>
32where
33    Storage: StorageWriteProvider,
34{
35    pub fn new(filename: &str, cache_size: u64, writer: Storage::Writer) -> std::io::Result<Self> {
36        if cache_size == 0 {
37            return Err(std::io::Error::other("Cache size must be greater than 0"));
38        }
39
40        info!("Opened: {}, cache_size: {}", filename, cache_size);
41        Ok(Self {
42            writer,
43            cache_size,
44            cache_buf: vec![0; cache_size as usize],
45            cur_off: 0,
46            fsize: 0,
47        })
48    }
49
50    pub fn flush(&mut self) -> std::io::Result<()> {
51        // dump any remaining data in memory
52        if self.cur_off > 0 {
53            self.flush_cache()?;
54        }
55
56        self.writer.flush()?;
57        info!("Finished writing {}B", self.fsize);
58        Ok(())
59    }
60
61    pub fn get_file_size(&self) -> u64 {
62        self.fsize
63    }
64
65    /// Writes n_bytes from write_buf to the underlying cache
66    pub fn write(&mut self, write_buf: &[u8]) -> std::io::Result<()> {
67        let n_bytes = write_buf.len() as u64;
68        if n_bytes <= (self.cache_size - self.cur_off) {
69            // case 1: cache can take all data
70            self.cache_buf[(self.cur_off as usize)..((self.cur_off + n_bytes) as usize)]
71                .copy_from_slice(&write_buf[..n_bytes as usize]);
72            self.cur_off += n_bytes;
73        } else {
74            // case 2: cache cant take all data
75            // go to disk and write existing cache data
76            self.writer
77                .write_all(&self.cache_buf[..self.cur_off as usize])?;
78            self.fsize += self.cur_off;
79            // write the new data to disk
80            self.writer.write_all(write_buf)?;
81            self.fsize += n_bytes;
82            // clear cache data and reset cur_off
83            self.cache_buf.fill(0);
84            self.cur_off = 0;
85        }
86        Ok(())
87    }
88
89    pub fn reset(&mut self) -> std::io::Result<()> {
90        self.flush_cache()?;
91        self.writer.seek(SeekFrom::Start(0))?;
92        Ok(())
93    }
94
95    fn flush_cache(&mut self) -> std::io::Result<()> {
96        self.writer
97            .write_all(&self.cache_buf[..self.cur_off as usize])?;
98        self.fsize += self.cur_off;
99        self.cache_buf.fill(0);
100        self.cur_off = 0;
101        Ok(())
102    }
103}
104
105impl<Storage> Drop for CachedWriter<Storage>
106where
107    Storage: StorageWriteProvider,
108{
109    fn drop(&mut self) {
110        // Do not panic if errors are encountered in the destructor.
111        let _: std::io::Result<()> = self.flush();
112    }
113}
114
115#[cfg(test)]
116mod cached_writer_test {
117    use diskann_providers::storage::VirtualStorageProvider;
118    use vfs::OverlayFS;
119
120    use super::*;
121
122    #[test]
123    fn cached_writer_works() {
124        let file_name = "/cached_writer_works_test.bin";
125        let storage_provider = VirtualStorageProvider::new_overlay(".");
126
127        //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
128        let data: [u8; 72] = [
129            2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
130            0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
131            0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
132            0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
133            0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
134        ];
135
136        let inner_writer = storage_provider.create_for_write(file_name).unwrap();
137        let mut cached_writer =
138            CachedWriter::<VirtualStorageProvider<OverlayFS>>::new(file_name, 8, inner_writer)
139                .unwrap();
140        assert_eq!(cached_writer.get_file_size(), 0);
141        assert_eq!(cached_writer.cache_size, 8);
142        assert_eq!(cached_writer.get_file_size(), 0);
143
144        let cache_all_buf = &data[0..4];
145        cached_writer.write(cache_all_buf).unwrap();
146        assert_eq!(&cached_writer.cache_buf[..4], cache_all_buf);
147        assert_eq!(&cached_writer.cache_buf[4..], vec![0; 4]);
148        assert_eq!(cached_writer.cur_off, 4);
149        assert_eq!(cached_writer.get_file_size(), 0);
150
151        let write_all_buf = &data[4..10];
152        cached_writer.write(write_all_buf).unwrap();
153        assert_eq!(cached_writer.cache_buf, vec![0; 8]);
154        assert_eq!(cached_writer.cur_off, 0);
155        assert_eq!(cached_writer.get_file_size(), 10);
156
157        storage_provider
158            .delete(file_name)
159            .expect("Failed to delete file");
160    }
161}