use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::Path;
use crate::error::StoreError;
use crate::page::{Page, PageType};
use crate::types::{PAGE_SIZE, PageId};
const DEFAULT_CACHE_CAPACITY: usize = 4096;
struct LruNode {
#[allow(dead_code)]
page_id: PageId,
prev: Option<PageId>,
next: Option<PageId>,
}
pub struct PageCache {
pages: HashMap<PageId, Page>,
lru: HashMap<PageId, LruNode>,
lru_head: Option<PageId>,
lru_tail: Option<PageId>,
capacity: usize,
file: File,
next_page_id: PageId,
}
impl PageCache {
pub fn open(path: &Path, capacity: Option<usize>) -> Result<Self, StoreError> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)?;
let file_len = file.metadata()?.len();
let next_page_id = if file_len == 0 {
PageId::new(0)
} else {
PageId::new(file_len / PAGE_SIZE as u64)
};
Ok(Self {
pages: HashMap::new(),
lru: HashMap::new(),
lru_head: None,
lru_tail: None,
capacity: capacity.unwrap_or(DEFAULT_CACHE_CAPACITY),
file,
next_page_id,
})
}
pub fn next_page_id(&self) -> PageId {
self.next_page_id
}
pub fn allocate(&mut self, page_type: PageType) -> Result<PageId, StoreError> {
let page_id = self.next_page_id;
self.next_page_id = self.next_page_id.next();
let page = Page::new(page_id, page_type);
self.insert_page(page)?;
Ok(page_id)
}
pub fn get(&mut self, page_id: PageId) -> Result<Option<&Page>, StoreError> {
if self.pages.contains_key(&page_id) {
self.touch(page_id);
return Ok(self.pages.get(&page_id));
}
if page_id.as_u64() >= self.next_page_id.as_u64() {
return Ok(None);
}
let page = self.load_page(page_id)?;
self.insert_page(page)?;
Ok(self.pages.get(&page_id))
}
pub fn get_mut(&mut self, page_id: PageId) -> Result<Option<&mut Page>, StoreError> {
if !self.pages.contains_key(&page_id) {
if page_id.as_u64() >= self.next_page_id.as_u64() {
return Ok(None);
}
let page = self.load_page(page_id)?;
self.insert_page(page)?;
}
self.touch(page_id);
Ok(self.pages.get_mut(&page_id))
}
fn load_page(&mut self, page_id: PageId) -> Result<Page, StoreError> {
let mut buf = [0u8; PAGE_SIZE];
self.file.seek(SeekFrom::Start(page_id.byte_offset()))?;
self.file.read_exact(&mut buf)?;
Page::from_bytes(page_id, &buf)
}
pub fn read_raw(&mut self, page_id: PageId) -> Result<[u8; PAGE_SIZE], StoreError> {
let mut buf = [0u8; PAGE_SIZE];
self.file.seek(SeekFrom::Start(page_id.byte_offset()))?;
self.file.read_exact(&mut buf)?;
Ok(buf)
}
fn insert_page(&mut self, page: Page) -> Result<(), StoreError> {
let page_id = page.id;
if self.pages.len() >= self.capacity {
self.evict_one()?;
}
self.pages.insert(page_id, page);
self.add_to_lru(page_id);
Ok(())
}
fn touch(&mut self, page_id: PageId) {
if self.lru_head == Some(page_id) {
return; }
self.remove_from_lru(page_id);
self.add_to_lru(page_id);
}
fn add_to_lru(&mut self, page_id: PageId) {
let node = LruNode {
page_id,
prev: None,
next: self.lru_head,
};
if let Some(old_head) = self.lru_head {
if let Some(head_node) = self.lru.get_mut(&old_head) {
head_node.prev = Some(page_id);
}
}
self.lru.insert(page_id, node);
self.lru_head = Some(page_id);
if self.lru_tail.is_none() {
self.lru_tail = Some(page_id);
}
}
fn remove_from_lru(&mut self, page_id: PageId) {
let Some(node) = self.lru.remove(&page_id) else {
return;
};
if let Some(prev_id) = node.prev {
if let Some(prev_node) = self.lru.get_mut(&prev_id) {
prev_node.next = node.next;
}
} else {
self.lru_head = node.next;
}
if let Some(next_id) = node.next {
if let Some(next_node) = self.lru.get_mut(&next_id) {
next_node.prev = node.prev;
}
} else {
self.lru_tail = node.prev;
}
}
fn evict_one(&mut self) -> Result<(), StoreError> {
let Some(page_id) = self.lru_tail else {
return Ok(()); };
if let Some(page) = self.pages.get_mut(&page_id) {
if page.is_dirty() {
self.file.seek(SeekFrom::Start(page.id.byte_offset()))?;
self.file.write_all(page.as_bytes())?;
}
}
self.pages.remove(&page_id);
self.remove_from_lru(page_id);
Ok(())
}
pub fn sync(&mut self) -> Result<(), StoreError> {
for page in self.pages.values_mut() {
if page.is_dirty() {
let page_offset = page.id.byte_offset();
let bytes = page.as_bytes();
self.file.seek(SeekFrom::Start(page_offset))?;
self.file.write_all(bytes)?;
page.mark_clean();
}
}
self.file.sync_all()?;
Ok(())
}
#[allow(dead_code)]
pub fn cached_count(&self) -> usize {
self.pages.len()
}
#[allow(dead_code)]
pub fn dirty_count(&self) -> usize {
self.pages.values().filter(|p| p.is_dirty()).count()
}
#[allow(dead_code)]
pub fn flush(&mut self) -> Result<(), StoreError> {
self.sync()?;
self.pages.clear();
self.lru.clear();
self.lru_head = None;
self.lru_tail = None;
Ok(())
}
#[allow(dead_code)]
pub fn prefetch(&mut self, start: PageId, count: usize) -> Result<(), StoreError> {
for i in 0..count {
let page_id = PageId::new(start.as_u64() + i as u64);
if page_id.as_u64() < self.next_page_id.as_u64() && !self.pages.contains_key(&page_id) {
let page = self.load_page(page_id)?;
self.insert_page(page)?;
}
}
Ok(())
}
}
impl std::fmt::Debug for PageCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PageCache")
.field("cached", &self.pages.len())
.field("capacity", &self.capacity)
.field("next_page_id", &self.next_page_id)
.field("dirty", &self.dirty_count())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod cache_tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_allocate_and_get() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.db");
let mut cache = PageCache::open(&path, Some(10)).unwrap();
let page_id = cache.allocate(PageType::Leaf).unwrap();
assert_eq!(page_id, PageId::new(0));
let page = cache.get(page_id).unwrap().unwrap();
assert_eq!(page.page_type(), PageType::Leaf);
assert_eq!(page.item_count(), 0);
}
#[test]
fn test_sync_and_reload() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.db");
{
let mut cache = PageCache::open(&path, Some(10)).unwrap();
let page_id = cache.allocate(PageType::Leaf).unwrap();
let page = cache.get_mut(page_id).unwrap().unwrap();
page.insert_item(0, b"test data").unwrap();
cache.sync().unwrap();
}
{
let mut cache = PageCache::open(&path, Some(10)).unwrap();
let page = cache.get(PageId::new(0)).unwrap().unwrap();
assert_eq!(page.get_item(0), b"test data");
}
}
#[test]
fn test_lru_eviction() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.db");
let mut cache = PageCache::open(&path, Some(3)).unwrap();
let _p0 = cache.allocate(PageType::Leaf).unwrap();
let _p1 = cache.allocate(PageType::Leaf).unwrap();
let _p2 = cache.allocate(PageType::Leaf).unwrap();
assert_eq!(cache.cached_count(), 3);
let _p3 = cache.allocate(PageType::Leaf).unwrap();
assert_eq!(cache.cached_count(), 3);
}
}