use crate::storage::StorageBackend;
use anyhow::Result;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock};
struct CacheEntry {
data: Arc<Vec<u8>>,
dirty: bool,
}
struct CacheInner {
entries: HashMap<u64, CacheEntry>,
order: VecDeque<u64>,
capacity: usize,
}
impl CacheInner {
fn touch(&mut self, page_id: u64) {
if let Some(pos) = self.order.iter().position(|&id| id == page_id) {
if pos == self.order.len() - 1 {
return; }
self.order.remove(pos);
}
self.order.push_back(page_id);
}
}
pub struct PageCache {
inner: RwLock<CacheInner>,
}
impl PageCache {
pub fn new(capacity: usize) -> Self {
PageCache {
inner: RwLock::new(CacheInner {
entries: HashMap::new(),
order: VecDeque::new(),
capacity,
}),
}
}
pub fn get_or_load(&self, page_id: u64, backend: &dyn StorageBackend) -> Result<Arc<Vec<u8>>> {
{
let inner = self.inner.read().expect("lock poisoned");
if let Some(entry) = inner.entries.get(&page_id) {
return Ok(entry.data.clone());
}
}
let data = Arc::new(backend.read_page(page_id)?);
let mut inner = self.inner.write().expect("lock poisoned");
if let Some(entry) = inner.entries.get(&page_id) {
return Ok(entry.data.clone());
}
while inner.entries.len() >= inner.capacity && inner.capacity > 0 {
if let Some(id) = inner.order.pop_front() {
inner.entries.remove(&id);
} else {
break; }
}
inner.entries.insert(
page_id,
CacheEntry {
data: data.clone(),
dirty: false,
},
);
inner.order.push_back(page_id);
Ok(data)
}
pub fn put_dirty(&self, page_id: u64, data: Vec<u8>) {
let mut inner = self.inner.write().expect("lock poisoned");
let data = Arc::new(data);
if inner.entries.contains_key(&page_id) {
inner.entries.get_mut(&page_id).unwrap().data = data;
inner.entries.get_mut(&page_id).unwrap().dirty = true;
inner.touch(page_id);
} else {
while inner.entries.len() >= inner.capacity && inner.capacity > 0 {
if let Some(id) = inner.order.pop_front() {
inner.entries.remove(&id);
} else {
break; }
}
inner
.entries
.insert(page_id, CacheEntry { data, dirty: true });
inner.order.push_back(page_id);
}
}
#[allow(dead_code)]
pub fn flush(&self, backend: &mut dyn StorageBackend) -> Result<()> {
let mut inner = self.inner.write().expect("lock poisoned");
for (&page_id, entry) in inner.entries.iter_mut() {
if entry.dirty {
backend.write_page(page_id, &entry.data[..])?;
entry.dirty = false;
}
}
Ok(())
}
#[allow(dead_code)]
pub fn invalidate(&self, page_id: u64) {
let mut inner = self.inner.write().expect("lock poisoned");
inner.entries.remove(&page_id);
inner.order.retain(|&id| id != page_id);
}
pub fn invalidate_from(&self, from_page: u64) {
let mut inner = self.inner.write().expect("lock poisoned");
inner.entries.retain(|&id, _| id < from_page);
inner.order.retain(|&id| id < from_page);
}
#[allow(dead_code)]
pub fn cached_page_count(&self) -> usize {
self.inner.read().expect("lock poisoned").entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::PAGE_SIZE;
use crate::storage::backend::MemoryBackend;
fn make_page(byte: u8) -> Vec<u8> {
vec![byte; PAGE_SIZE]
}
#[test]
fn test_cache_miss_loads_from_backend() {
let mut backend = MemoryBackend::new();
backend.write_page(1, &make_page(0xAB)).unwrap();
let cache = PageCache::new(4);
let page = cache.get_or_load(1, &backend).unwrap();
assert_eq!(page[0], 0xAB);
}
#[test]
fn test_cache_hit_returns_same_bytes() {
let mut backend = MemoryBackend::new();
backend.write_page(1, &make_page(0x11)).unwrap();
let cache = PageCache::new(4);
let p1 = cache.get_or_load(1, &backend).unwrap();
let p2 = cache.get_or_load(1, &backend).unwrap();
assert_eq!(p1[0], p2[0]);
}
#[test]
fn test_lru_eviction_respects_capacity() {
let mut backend = MemoryBackend::new();
for i in 1u64..=5 {
backend.write_page(i, &make_page(i as u8)).unwrap();
}
let cache = PageCache::new(3); cache.get_or_load(1, &backend).unwrap();
cache.get_or_load(2, &backend).unwrap();
cache.get_or_load(3, &backend).unwrap();
cache.get_or_load(4, &backend).unwrap();
assert!(cache.cached_page_count() <= 3);
}
#[test]
fn test_dirty_page_written_back_on_flush() {
let mut backend = MemoryBackend::new();
backend.write_page(1, &make_page(0x00)).unwrap();
let cache = PageCache::new(4);
cache.put_dirty(1, make_page(0xFF));
cache.flush(&mut backend).unwrap();
let page = backend.read_page(1).unwrap();
assert_eq!(page[0], 0xFF);
}
#[test]
#[cfg(not(target_os = "wasi"))]
fn test_concurrent_reads() {
use std::sync::Arc;
use std::thread;
let mut backend = MemoryBackend::new();
backend.write_page(1, &make_page(0x42)).unwrap();
let cache = Arc::new(PageCache::new(8));
let handles: Vec<_> = (0..4)
.map(|_| {
let c = cache.clone();
let b = backend.clone();
thread::spawn(move || {
let page = c.get_or_load(1, &b).unwrap();
assert_eq!(page[0], 0x42);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_lru_eviction_evicts_correct_page() {
let mut backend = MemoryBackend::new();
for i in 1u64..=4 {
backend.write_page(i, &make_page(i as u8)).unwrap();
}
let cache = PageCache::new(3);
cache.get_or_load(1, &backend).unwrap();
cache.get_or_load(2, &backend).unwrap();
cache.get_or_load(3, &backend).unwrap();
cache.get_or_load(4, &backend).unwrap();
assert!(cache.cached_page_count() <= 3);
assert!(cache.cached_page_count() == 3);
}
#[test]
fn test_put_dirty_after_eviction_does_not_panic() {
let mut backend = MemoryBackend::new();
for i in 1u64..=3 {
backend.write_page(i, &make_page(i as u8)).unwrap();
}
let cache = PageCache::new(2);
cache.get_or_load(1, &backend).unwrap();
cache.get_or_load(2, &backend).unwrap();
cache.get_or_load(3, &backend).unwrap();
cache.put_dirty(2, make_page(0xBB)); assert_eq!(cache.cached_page_count(), 2);
}
}