redb/tree_store/page_store/
cached_file.rs

1use crate::tree_store::page_store::base::PageHint;
2use crate::tree_store::page_store::lru_cache::LRUCache;
3use crate::{CacheStats, DatabaseError, Result, StorageBackend, StorageError};
4use std::ops::{Index, IndexMut};
5use std::slice::SliceIndex;
6#[cfg(feature = "cache_metrics")]
7use std::sync::atomic::AtomicU64;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::{Arc, Mutex, RwLock};
10
11pub(super) struct WritablePage {
12    buffer: Arc<Mutex<LRUWriteCache>>,
13    offset: u64,
14    data: Arc<[u8]>,
15}
16
17impl WritablePage {
18    pub(super) fn mem(&self) -> &[u8] {
19        &self.data
20    }
21
22    pub(super) fn mem_mut(&mut self) -> &mut [u8] {
23        Arc::get_mut(&mut self.data).unwrap()
24    }
25}
26
27impl Drop for WritablePage {
28    fn drop(&mut self) {
29        self.buffer
30            .lock()
31            .unwrap()
32            .return_value(self.offset, self.data.clone());
33    }
34}
35
36impl<I: SliceIndex<[u8]>> Index<I> for WritablePage {
37    type Output = I::Output;
38
39    fn index(&self, index: I) -> &Self::Output {
40        self.mem().index(index)
41    }
42}
43
44impl<I: SliceIndex<[u8]>> IndexMut<I> for WritablePage {
45    fn index_mut(&mut self, index: I) -> &mut Self::Output {
46        self.mem_mut().index_mut(index)
47    }
48}
49
50#[derive(Default)]
51struct LRUWriteCache {
52    cache: LRUCache<Option<Arc<[u8]>>>,
53}
54
55impl LRUWriteCache {
56    fn new() -> Self {
57        Self {
58            cache: Default::default(),
59        }
60    }
61
62    fn insert(&mut self, key: u64, value: Arc<[u8]>) {
63        assert!(self.cache.insert(key, Some(value)).is_none());
64    }
65
66    fn get(&self, key: u64) -> Option<&Arc<[u8]>> {
67        self.cache.get(key).map(|x| x.as_ref().unwrap())
68    }
69
70    fn remove(&mut self, key: u64) -> Option<Arc<[u8]>> {
71        if let Some(value) = self.cache.remove(key) {
72            assert!(value.is_some());
73            return value;
74        }
75        None
76    }
77
78    fn return_value(&mut self, key: u64, value: Arc<[u8]>) {
79        assert!(self.cache.get_mut(key).unwrap().replace(value).is_none());
80    }
81
82    fn take_value(&mut self, key: u64) -> Option<Arc<[u8]>> {
83        if let Some(value) = self.cache.get_mut(key) {
84            let result = value.take().unwrap();
85            return Some(result);
86        }
87        None
88    }
89
90    fn pop_lowest_priority(&mut self) -> Option<(u64, Arc<[u8]>)> {
91        for _ in 0..self.cache.len() {
92            if let Some((k, v)) = self.cache.pop_lowest_priority() {
93                if let Some(v_inner) = v {
94                    return Some((k, v_inner));
95                }
96
97                // Value is borrowed by take_value(). We can't evict it, so put it back.
98                self.cache.insert(k, v);
99            } else {
100                break;
101            }
102        }
103        None
104    }
105
106    fn clear(&mut self) {
107        self.cache.clear();
108    }
109}
110
111#[derive(Debug)]
112struct CheckedBackend {
113    file: Box<dyn StorageBackend>,
114    io_failed: AtomicBool,
115    closed: AtomicBool,
116}
117
118impl CheckedBackend {
119    fn new(file: Box<dyn StorageBackend>) -> Self {
120        Self {
121            file,
122            io_failed: AtomicBool::new(false),
123            closed: AtomicBool::new(false),
124        }
125    }
126
127    fn check_failure(&self) -> Result<()> {
128        if self.io_failed.load(Ordering::Acquire) {
129            if self.closed.load(Ordering::Acquire) {
130                Err(StorageError::DatabaseClosed)
131            } else {
132                Err(StorageError::PreviousIo)
133            }
134        } else {
135            Ok(())
136        }
137    }
138
139    fn close(&self) -> Result {
140        self.closed.store(true, Ordering::Release);
141        self.io_failed.store(true, Ordering::Release);
142        self.file.close()?;
143
144        Ok(())
145    }
146
147    fn len(&self) -> Result<u64> {
148        self.check_failure()?;
149        let result = self.file.len();
150        if result.is_err() {
151            self.io_failed.store(true, Ordering::Release);
152        }
153        result.map_err(StorageError::from)
154    }
155
156    fn read(&self, offset: u64, out: &mut [u8]) -> Result<()> {
157        self.check_failure()?;
158        let result = self.file.read(offset, out);
159        if result.is_err() {
160            self.io_failed.store(true, Ordering::Release);
161        }
162        result.map_err(StorageError::from)
163    }
164
165    fn set_len(&self, len: u64) -> Result<()> {
166        self.check_failure()?;
167        let result = self.file.set_len(len);
168        if result.is_err() {
169            self.io_failed.store(true, Ordering::Release);
170        }
171        result.map_err(StorageError::from)
172    }
173
174    fn sync_data(&self) -> Result<()> {
175        self.check_failure()?;
176        let result = self.file.sync_data();
177        if result.is_err() {
178            self.io_failed.store(true, Ordering::Release);
179        }
180        result.map_err(StorageError::from)
181    }
182
183    fn write(&self, offset: u64, data: &[u8]) -> Result<()> {
184        self.check_failure()?;
185        let result = self.file.write(offset, data);
186        if result.is_err() {
187            self.io_failed.store(true, Ordering::Release);
188        }
189        result.map_err(StorageError::from)
190    }
191}
192
193pub(super) struct PagedCachedFile {
194    file: CheckedBackend,
195    page_size: u64,
196    max_read_cache_bytes: usize,
197    read_cache_bytes: AtomicUsize,
198    max_write_buffer_bytes: usize,
199    write_buffer_bytes: AtomicUsize,
200    #[cfg(feature = "cache_metrics")]
201    reads_total: AtomicU64,
202    #[cfg(feature = "cache_metrics")]
203    reads_hits: AtomicU64,
204    #[cfg(feature = "cache_metrics")]
205    evictions: AtomicU64,
206    read_cache: Vec<RwLock<LRUCache<Arc<[u8]>>>>,
207    // TODO: maybe move this cache to WriteTransaction?
208    write_buffer: Arc<Mutex<LRUWriteCache>>,
209}
210
211impl PagedCachedFile {
212    pub(super) fn new(
213        file: Box<dyn StorageBackend>,
214        page_size: u64,
215        max_read_cache_bytes: usize,
216        max_write_buffer_bytes: usize,
217    ) -> Result<Self, DatabaseError> {
218        let read_cache = (0..Self::lock_stripes())
219            .map(|_| RwLock::new(LRUCache::new()))
220            .collect();
221
222        Ok(Self {
223            file: CheckedBackend::new(file),
224            page_size,
225            max_read_cache_bytes,
226            read_cache_bytes: AtomicUsize::new(0),
227            max_write_buffer_bytes,
228            write_buffer_bytes: AtomicUsize::new(0),
229            #[cfg(feature = "cache_metrics")]
230            reads_total: Default::default(),
231            #[cfg(feature = "cache_metrics")]
232            reads_hits: Default::default(),
233            #[cfg(feature = "cache_metrics")]
234            evictions: Default::default(),
235            read_cache,
236            write_buffer: Arc::new(Mutex::new(LRUWriteCache::new())),
237        })
238    }
239
240    #[allow(clippy::unused_self)]
241    pub(crate) fn cache_stats(&self) -> CacheStats {
242        CacheStats {
243            #[cfg(not(feature = "cache_metrics"))]
244            evictions: 0,
245            #[cfg(feature = "cache_metrics")]
246            evictions: self.evictions.load(Ordering::Acquire),
247        }
248    }
249
250    pub(crate) fn close(&self) -> Result {
251        self.file.close()
252    }
253
254    pub(crate) fn check_io_errors(&self) -> Result {
255        self.file.check_failure()
256    }
257
258    pub(crate) fn raw_file_len(&self) -> Result<u64> {
259        self.file.len()
260    }
261
262    const fn lock_stripes() -> u64 {
263        131
264    }
265
266    fn flush_write_buffer(&self) -> Result {
267        let mut write_buffer = self.write_buffer.lock().unwrap();
268
269        for (offset, buffer) in write_buffer.cache.iter() {
270            self.file.write(*offset, buffer.as_ref().unwrap())?;
271        }
272        for (offset, buffer) in write_buffer.cache.iter_mut() {
273            let buffer = buffer.take().unwrap();
274            let cache_size = self
275                .read_cache_bytes
276                .fetch_add(buffer.len(), Ordering::AcqRel);
277
278            if cache_size + buffer.len() <= self.max_read_cache_bytes {
279                let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
280                let mut lock = self.read_cache[cache_slot].write().unwrap();
281                if let Some(replaced) = lock.insert(*offset, buffer) {
282                    // A race could cause us to replace an existing buffer
283                    self.read_cache_bytes
284                        .fetch_sub(replaced.len(), Ordering::AcqRel);
285                }
286            } else {
287                self.read_cache_bytes
288                    .fetch_sub(buffer.len(), Ordering::AcqRel);
289                break;
290            }
291        }
292        self.write_buffer_bytes.store(0, Ordering::Release);
293        write_buffer.clear();
294
295        Ok(())
296    }
297
298    // Caller should invalidate all cached pages that are no longer valid
299    pub(super) fn resize(&self, len: u64) -> Result {
300        // TODO: be more fine-grained about this invalidation
301        self.invalidate_cache_all();
302
303        self.file.set_len(len)
304    }
305
306    pub(super) fn flush(&self) -> Result {
307        self.flush_write_buffer()?;
308
309        self.file.sync_data()
310    }
311
312    // Make writes visible to readers, but does not guarantee any durability
313    pub(super) fn write_barrier(&self) -> Result {
314        // TODO: non-durable commits would be much faster, if this did not issues writes to disk,
315        // and instead just made the data visible to readers
316        self.flush_write_buffer()
317    }
318
319    // Read directly from the file, ignoring any cached data
320    pub(super) fn read_direct(&self, offset: u64, len: usize) -> Result<Vec<u8>> {
321        let mut buffer = vec![0; len];
322        self.file.read(offset, &mut buffer)?;
323        Ok(buffer)
324    }
325
326    // Read with caching. Caller must not read overlapping ranges without first calling invalidate_cache().
327    // Doing so will not cause UB, but is a logic error.
328    pub(super) fn read(&self, offset: u64, len: usize, hint: PageHint) -> Result<Arc<[u8]>> {
329        debug_assert_eq!(0, offset % self.page_size);
330        #[cfg(feature = "cache_metrics")]
331        self.reads_total.fetch_add(1, Ordering::AcqRel);
332
333        if !matches!(hint, PageHint::Clean) {
334            let lock = self.write_buffer.lock().unwrap();
335            if let Some(cached) = lock.get(offset) {
336                #[cfg(feature = "cache_metrics")]
337                self.reads_hits.fetch_add(1, Ordering::Release);
338                debug_assert_eq!(cached.len(), len);
339                return Ok(cached.clone());
340            }
341        }
342
343        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
344        {
345            let read_lock = self.read_cache[cache_slot].read().unwrap();
346            if let Some(cached) = read_lock.get(offset) {
347                #[cfg(feature = "cache_metrics")]
348                self.reads_hits.fetch_add(1, Ordering::Release);
349                debug_assert_eq!(cached.len(), len);
350                return Ok(cached.clone());
351            }
352        }
353
354        let buffer: Arc<[u8]> = self.read_direct(offset, len)?.into();
355        let cache_size = self.read_cache_bytes.fetch_add(len, Ordering::AcqRel);
356        let mut write_lock = self.read_cache[cache_slot].write().unwrap();
357        let cache_size = if let Some(replaced) = write_lock.insert(offset, buffer.clone()) {
358            // A race could cause us to replace an existing buffer
359            self.read_cache_bytes
360                .fetch_sub(replaced.len(), Ordering::AcqRel)
361        } else {
362            cache_size
363        };
364        let mut removed = 0;
365        if cache_size + len > self.max_read_cache_bytes {
366            while removed < len {
367                if let Some((_, v)) = write_lock.pop_lowest_priority() {
368                    #[cfg(feature = "cache_metrics")]
369                    {
370                        self.evictions.fetch_add(1, Ordering::Relaxed);
371                    }
372                    removed += v.len();
373                } else {
374                    break;
375                }
376            }
377        }
378        if removed > 0 {
379            self.read_cache_bytes.fetch_sub(removed, Ordering::AcqRel);
380        }
381
382        Ok(buffer)
383    }
384
385    // Discard pending writes to the given range
386    pub(super) fn cancel_pending_write(&self, offset: u64, _len: usize) {
387        assert_eq!(0, offset % self.page_size);
388        if let Some(removed) = self.write_buffer.lock().unwrap().remove(offset) {
389            self.write_buffer_bytes
390                .fetch_sub(removed.len(), Ordering::Release);
391        }
392    }
393
394    // Invalidate any caching of the given range. After this call overlapping reads of the range are allowed
395    //
396    // NOTE: Invalidating a cached region in subsections is permitted, as long as all subsections are invalidated
397    pub(super) fn invalidate_cache(&self, offset: u64, len: usize) {
398        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
399        let mut lock = self.read_cache[cache_slot].write().unwrap();
400        if let Some(removed) = lock.remove(offset) {
401            assert_eq!(len, removed.len());
402            self.read_cache_bytes
403                .fetch_sub(removed.len(), Ordering::AcqRel);
404        }
405    }
406
407    pub(super) fn invalidate_cache_all(&self) {
408        for cache_slot in 0..self.read_cache.len() {
409            let mut lock = self.read_cache[cache_slot].write().unwrap();
410            while let Some((_, removed)) = lock.pop_lowest_priority() {
411                self.read_cache_bytes
412                    .fetch_sub(removed.len(), Ordering::AcqRel);
413            }
414        }
415    }
416
417    // If overwrite is true, the page is initialized to zero
418    // cache_policy takes the existing data as an argument and returns the priority. The priority should be stable and not change after WritablePage is dropped
419    pub(super) fn write(&self, offset: u64, len: usize, overwrite: bool) -> Result<WritablePage> {
420        assert_eq!(0, offset % self.page_size);
421        let mut lock = self.write_buffer.lock().unwrap();
422
423        // TODO: allow hint that page is known to be dirty and will not be in the read cache
424        let cache_slot: usize = (offset % Self::lock_stripes()).try_into().unwrap();
425        let existing = {
426            let mut lock = self.read_cache[cache_slot].write().unwrap();
427            if let Some(removed) = lock.remove(offset) {
428                assert_eq!(
429                    len,
430                    removed.len(),
431                    "cache inconsistency {len} != {} for offset {offset}",
432                    removed.len()
433                );
434                self.read_cache_bytes
435                    .fetch_sub(removed.len(), Ordering::AcqRel);
436                Some(removed)
437            } else {
438                None
439            }
440        };
441
442        let data = if let Some(removed) = lock.take_value(offset) {
443            removed
444        } else {
445            let previous = self.write_buffer_bytes.fetch_add(len, Ordering::AcqRel);
446            if previous + len > self.max_write_buffer_bytes {
447                let mut removed_bytes = 0;
448                while removed_bytes < len {
449                    if let Some((offset, buffer)) = lock.pop_lowest_priority() {
450                        let removed_len = buffer.len();
451                        let result = self.file.write(offset, &buffer);
452                        if result.is_err() {
453                            lock.insert(offset, buffer);
454                        }
455                        result?;
456                        self.write_buffer_bytes
457                            .fetch_sub(removed_len, Ordering::Release);
458                        #[cfg(feature = "cache_metrics")]
459                        {
460                            self.evictions.fetch_add(1, Ordering::Relaxed);
461                        }
462                        removed_bytes += removed_len;
463                    } else {
464                        break;
465                    }
466                }
467            }
468            let result = if let Some(data) = existing {
469                data
470            } else if overwrite {
471                vec![0; len].into()
472            } else {
473                self.read_direct(offset, len)?.into()
474            };
475            lock.insert(offset, result);
476            lock.take_value(offset).unwrap()
477        };
478        Ok(WritablePage {
479            buffer: self.write_buffer.clone(),
480            offset,
481            data,
482        })
483    }
484}
485
486#[cfg(test)]
487mod test {
488    use crate::StorageBackend;
489    use crate::backends::InMemoryBackend;
490    use crate::tree_store::PageHint;
491    use crate::tree_store::page_store::cached_file::PagedCachedFile;
492    use std::sync::Arc;
493    use std::sync::atomic::Ordering;
494
495    #[test]
496    fn cache_leak() {
497        let backend = InMemoryBackend::new();
498        backend.set_len(1024).unwrap();
499        let cached_file = PagedCachedFile::new(Box::new(backend), 128, 1024, 128).unwrap();
500        let cached_file = Arc::new(cached_file);
501
502        let t1 = {
503            let cached_file = cached_file.clone();
504            std::thread::spawn(move || {
505                for _ in 0..1000 {
506                    cached_file.read(0, 128, PageHint::None).unwrap();
507                    cached_file.invalidate_cache(0, 128);
508                }
509            })
510        };
511        let t2 = {
512            let cached_file = cached_file.clone();
513            std::thread::spawn(move || {
514                for _ in 0..1000 {
515                    cached_file.read(0, 128, PageHint::None).unwrap();
516                    cached_file.invalidate_cache(0, 128);
517                }
518            })
519        };
520
521        t1.join().unwrap();
522        t2.join().unwrap();
523        cached_file.invalidate_cache(0, 128);
524        assert_eq!(cached_file.read_cache_bytes.load(Ordering::Acquire), 0);
525    }
526}