1use std::num::NonZeroUsize;
24use std::sync::Mutex;
25
26use lru::LruCache;
27
28use crate::block_store::BlockStore;
29use crate::error::{FsError, FsResult};
30
31struct CacheEntry {
33 data: Vec<u8>,
34 dirty: bool,
35}
36
37pub struct CachedBlockStore<S: BlockStore> {
44 inner: S,
45 cache: Mutex<LruCache<u64, CacheEntry>>,
46}
47
48impl<S: BlockStore> CachedBlockStore<S> {
49 pub fn new(inner: S, capacity: usize) -> Self {
53 let cap = NonZeroUsize::new(capacity.max(16)).unwrap();
54 Self {
55 inner,
56 cache: Mutex::new(LruCache::new(cap)),
57 }
58 }
59
60 fn writeback(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
62 self.inner.write_block(block_id, data)
63 }
64
65 fn insert(
67 &self,
68 cache: &mut LruCache<u64, CacheEntry>,
69 block_id: u64,
70 entry: CacheEntry,
71 ) -> Option<(u64, Vec<u8>)> {
72 match cache.push(block_id, entry) {
73 Some((evicted_id, evicted)) if evicted_id != block_id && evicted.dirty => {
74 Some((evicted_id, evicted.data))
75 }
76 _ => None,
77 }
78 }
79}
80
81impl<S: BlockStore> BlockStore for CachedBlockStore<S> {
82 fn block_size(&self) -> usize {
83 self.inner.block_size()
84 }
85
86 fn total_blocks(&self) -> u64 {
87 self.inner.total_blocks()
88 }
89
90 fn read_block(&self, block_id: u64) -> FsResult<Vec<u8>> {
91 {
93 let mut cache = self
94 .cache
95 .lock()
96 .map_err(|e| FsError::Internal(e.to_string()))?;
97 if let Some(entry) = cache.get(&block_id) {
98 return Ok(entry.data.clone());
99 }
100 }
101
102 let data = self.inner.read_block(block_id)?;
104
105 let wb = {
107 let mut cache = self
108 .cache
109 .lock()
110 .map_err(|e| FsError::Internal(e.to_string()))?;
111 if cache.contains(&block_id) {
113 return Ok(cache.get(&block_id).unwrap().data.clone());
114 }
115 self.insert(
116 &mut cache,
117 block_id,
118 CacheEntry {
119 data: data.clone(),
120 dirty: false,
121 },
122 )
123 };
124
125 if let Some((id, wb_data)) = wb {
126 self.writeback(id, &wb_data)?;
127 }
128
129 Ok(data)
130 }
131
132 fn write_block(&self, block_id: u64, data: &[u8]) -> FsResult<()> {
133 let wb = {
134 let mut cache = self
135 .cache
136 .lock()
137 .map_err(|e| FsError::Internal(e.to_string()))?;
138 self.insert(
139 &mut cache,
140 block_id,
141 CacheEntry {
142 data: data.to_vec(),
143 dirty: true,
144 },
145 )
146 };
147
148 if let Some((id, wb_data)) = wb {
149 self.writeback(id, &wb_data)?;
150 }
151
152 Ok(())
153 }
154
155 fn sync(&self) -> FsResult<()> {
156 let dirty: Vec<(u64, Vec<u8>)> = {
158 let cache = self
159 .cache
160 .lock()
161 .map_err(|e| FsError::Internal(e.to_string()))?;
162 cache
163 .iter()
164 .filter(|(_, e)| e.dirty)
165 .map(|(&id, e)| (id, e.data.clone()))
166 .collect()
167 };
168
169 if !dirty.is_empty() {
171 let refs: Vec<(u64, &[u8])> = dirty.iter().map(|(id, d)| (*id, d.as_slice())).collect();
172 self.inner.write_blocks(&refs)?;
173 }
174
175 {
177 let mut cache = self
178 .cache
179 .lock()
180 .map_err(|e| FsError::Internal(e.to_string()))?;
181 for (id, _) in &dirty {
182 if let Some(entry) = cache.peek_mut(id) {
183 entry.dirty = false;
184 }
185 }
186 }
187
188 self.inner.sync()
189 }
190
191 fn read_blocks(&self, block_ids: &[u64]) -> FsResult<Vec<Vec<u8>>> {
192 let mut results: Vec<Option<Vec<u8>>> = vec![None; block_ids.len()];
194 let mut miss_indices = Vec::new();
195 let mut miss_ids = Vec::new();
196
197 {
198 let mut cache = self
199 .cache
200 .lock()
201 .map_err(|e| FsError::Internal(e.to_string()))?;
202 for (i, &id) in block_ids.iter().enumerate() {
203 if let Some(entry) = cache.get(&id) {
204 results[i] = Some(entry.data.clone());
205 } else {
206 miss_indices.push(i);
207 miss_ids.push(id);
208 }
209 }
210 }
211
212 if !miss_ids.is_empty() {
214 let fetched = self.inner.read_blocks(&miss_ids)?;
215 let mut writebacks = Vec::new();
216
217 {
218 let mut cache = self
219 .cache
220 .lock()
221 .map_err(|e| FsError::Internal(e.to_string()))?;
222 for (&idx, data) in miss_indices.iter().zip(fetched) {
223 results[idx] = Some(data.clone());
224 let block_id = block_ids[idx];
225 if !cache.contains(&block_id) {
226 if let Some((eid, entry)) =
227 cache.push(block_id, CacheEntry { data, dirty: false })
228 {
229 if eid != block_id && entry.dirty {
230 writebacks.push((eid, entry.data));
231 }
232 }
233 }
234 }
235 }
236
237 for (id, data) in writebacks {
238 self.writeback(id, &data)?;
239 }
240 }
241
242 results
243 .into_iter()
244 .map(|r| r.ok_or_else(|| FsError::Internal("missing read result".into())))
245 .collect()
246 }
247
248 fn write_blocks(&self, blocks: &[(u64, &[u8])]) -> FsResult<()> {
249 let mut writebacks = Vec::new();
250
251 {
252 let mut cache = self
253 .cache
254 .lock()
255 .map_err(|e| FsError::Internal(e.to_string()))?;
256 for &(block_id, data) in blocks {
257 if let Some((eid, entry)) = cache.push(
258 block_id,
259 CacheEntry {
260 data: data.to_vec(),
261 dirty: true,
262 },
263 ) {
264 if eid != block_id && entry.dirty {
265 writebacks.push((eid, entry.data));
266 }
267 }
268 }
269 }
270
271 if !writebacks.is_empty() {
273 let refs: Vec<(u64, &[u8])> = writebacks
274 .iter()
275 .map(|(id, d)| (*id, d.as_slice()))
276 .collect();
277 self.inner.write_blocks(&refs)?;
278 }
279
280 Ok(())
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::block_store::MemoryBlockStore;
288
289 #[test]
290 fn read_populates_cache() {
291 let inner = MemoryBlockStore::new(64, 10);
292 inner.write_block(0, &vec![0xAA; 64]).unwrap();
293
294 let cached = CachedBlockStore::new(inner, 16);
295 let data = cached.read_block(0).unwrap();
296 assert_eq!(data, vec![0xAA; 64]);
297
298 let data2 = cached.read_block(0).unwrap();
301 assert_eq!(data2, vec![0xAA; 64]);
302 }
303
304 #[test]
305 fn write_is_not_visible_to_inner_until_sync() {
306 let inner = MemoryBlockStore::new(64, 10);
307 let cached = CachedBlockStore::new(inner, 16);
308
309 cached.write_block(0, &vec![0xBB; 64]).unwrap();
310
311 assert_eq!(cached.read_block(0).unwrap(), vec![0xBB; 64]);
313
314 cached.sync().unwrap();
316 }
317
318 #[test]
319 fn dirty_eviction_writes_back() {
320 let inner = MemoryBlockStore::new(64, 100);
321 let cached = CachedBlockStore::new(inner, 16);
323
324 for i in 0..20u64 {
326 cached.write_block(i, &vec![i as u8; 64]).unwrap();
327 }
328
329 cached.sync().unwrap();
331
332 for i in 0..20u64 {
334 let data = cached.read_block(i).unwrap();
336 assert_eq!(data, vec![i as u8; 64]);
337 }
338 }
339
340 #[test]
341 fn batch_read_write() {
342 let inner = MemoryBlockStore::new(64, 10);
343 let cached = CachedBlockStore::new(inner, 16);
344
345 let blocks: Vec<(u64, &[u8])> = vec![(0, &[0x11; 64]), (1, &[0x22; 64]), (2, &[0x33; 64])];
346 cached.write_blocks(&blocks).unwrap();
347
348 let results = cached.read_blocks(&[0, 1, 2]).unwrap();
349 assert_eq!(results[0], vec![0x11; 64]);
350 assert_eq!(results[1], vec![0x22; 64]);
351 assert_eq!(results[2], vec![0x33; 64]);
352 }
353}