1use std::io::{Seek, SeekFrom, Write};
6
7use diskann_providers::storage::StorageWriteProvider;
8use tracing::info;
9
10pub struct CachedWriter<Storage>
12where
13 Storage: StorageWriteProvider,
14{
15 writer: Storage::Writer,
17
18 cache_size: u64,
20
21 cache_buf: Vec<u8>,
23
24 cur_off: u64,
26
27 fsize: u64,
29}
30
31impl<Storage> CachedWriter<Storage>
32where
33 Storage: StorageWriteProvider,
34{
35 pub fn new(filename: &str, cache_size: u64, writer: Storage::Writer) -> std::io::Result<Self> {
36 if cache_size == 0 {
37 return Err(std::io::Error::other("Cache size must be greater than 0"));
38 }
39
40 info!("Opened: {}, cache_size: {}", filename, cache_size);
41 Ok(Self {
42 writer,
43 cache_size,
44 cache_buf: vec![0; cache_size as usize],
45 cur_off: 0,
46 fsize: 0,
47 })
48 }
49
50 pub fn flush(&mut self) -> std::io::Result<()> {
51 if self.cur_off > 0 {
53 self.flush_cache()?;
54 }
55
56 self.writer.flush()?;
57 info!("Finished writing {}B", self.fsize);
58 Ok(())
59 }
60
61 pub fn get_file_size(&self) -> u64 {
62 self.fsize
63 }
64
65 pub fn write(&mut self, write_buf: &[u8]) -> std::io::Result<()> {
67 let n_bytes = write_buf.len() as u64;
68 if n_bytes <= (self.cache_size - self.cur_off) {
69 self.cache_buf[(self.cur_off as usize)..((self.cur_off + n_bytes) as usize)]
71 .copy_from_slice(&write_buf[..n_bytes as usize]);
72 self.cur_off += n_bytes;
73 } else {
74 self.writer
77 .write_all(&self.cache_buf[..self.cur_off as usize])?;
78 self.fsize += self.cur_off;
79 self.writer.write_all(write_buf)?;
81 self.fsize += n_bytes;
82 self.cache_buf.fill(0);
84 self.cur_off = 0;
85 }
86 Ok(())
87 }
88
89 pub fn reset(&mut self) -> std::io::Result<()> {
90 self.flush_cache()?;
91 self.writer.seek(SeekFrom::Start(0))?;
92 Ok(())
93 }
94
95 fn flush_cache(&mut self) -> std::io::Result<()> {
96 self.writer
97 .write_all(&self.cache_buf[..self.cur_off as usize])?;
98 self.fsize += self.cur_off;
99 self.cache_buf.fill(0);
100 self.cur_off = 0;
101 Ok(())
102 }
103}
104
105impl<Storage> Drop for CachedWriter<Storage>
106where
107 Storage: StorageWriteProvider,
108{
109 fn drop(&mut self) {
110 let _: std::io::Result<()> = self.flush();
112 }
113}
114
115#[cfg(test)]
116mod cached_writer_test {
117 use diskann_providers::storage::VirtualStorageProvider;
118 use vfs::OverlayFS;
119
120 use super::*;
121
122 #[test]
123 fn cached_writer_works() {
124 let file_name = "/cached_writer_works_test.bin";
125 let storage_provider = VirtualStorageProvider::new_overlay(".");
126
127 let data: [u8; 72] = [
129 2, 0, 1, 2, 8, 0, 1, 3, 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
130 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
131 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
132 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
133 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41,
134 ];
135
136 let inner_writer = storage_provider.create_for_write(file_name).unwrap();
137 let mut cached_writer =
138 CachedWriter::<VirtualStorageProvider<OverlayFS>>::new(file_name, 8, inner_writer)
139 .unwrap();
140 assert_eq!(cached_writer.get_file_size(), 0);
141 assert_eq!(cached_writer.cache_size, 8);
142 assert_eq!(cached_writer.get_file_size(), 0);
143
144 let cache_all_buf = &data[0..4];
145 cached_writer.write(cache_all_buf).unwrap();
146 assert_eq!(&cached_writer.cache_buf[..4], cache_all_buf);
147 assert_eq!(&cached_writer.cache_buf[4..], vec![0; 4]);
148 assert_eq!(cached_writer.cur_off, 4);
149 assert_eq!(cached_writer.get_file_size(), 0);
150
151 let write_all_buf = &data[4..10];
152 cached_writer.write(write_all_buf).unwrap();
153 assert_eq!(cached_writer.cache_buf, vec![0; 8]);
154 assert_eq!(cached_writer.cur_off, 0);
155 assert_eq!(cached_writer.get_file_size(), 10);
156
157 storage_provider
158 .delete(file_name)
159 .expect("Failed to delete file");
160 }
161}