use ahash::AHashMap;
use parking_lot::Mutex;
use crate::page::{Page, PageId, PageManager};
use mentedb_core::error::{MenteError, MenteResult};
use tracing::{debug, trace};
type FrameId = usize;
struct Frame {
page: Box<Page>,
page_id: Option<PageId>,
pin_count: u32,
dirty: bool,
reference: bool,
}
impl Frame {
fn new() -> Self {
Self {
page: Box::new(Page::zeroed()),
page_id: None,
pin_count: 0,
dirty: false,
reference: false,
}
}
}
struct BufferPoolInner {
frames: Vec<Frame>,
page_table: AHashMap<PageId, FrameId>,
clock_hand: usize,
capacity: usize,
}
pub struct BufferPool {
inner: Mutex<BufferPoolInner>,
}
impl BufferPool {
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "buffer pool capacity must be > 0");
let frames = (0..capacity).map(|_| Frame::new()).collect();
Self {
inner: Mutex::new(BufferPoolInner {
frames,
page_table: AHashMap::with_capacity(capacity),
clock_hand: 0,
capacity,
}),
}
}
pub fn fetch_page(&self, page_id: PageId, pm: &mut PageManager) -> MenteResult<Box<Page>> {
let mut inner = self.inner.lock();
if let Some(&frame_id) = inner.page_table.get(&page_id) {
let frame = &mut inner.frames[frame_id];
frame.pin_count += 1;
frame.reference = true;
trace!(page_id = page_id.0, frame_id, "buffer pool hit");
return Ok(frame.page.clone());
}
let frame_id = Self::find_victim(&mut inner)?;
if inner.frames[frame_id].dirty
&& let Some(old_pid) = inner.frames[frame_id].page_id
{
pm.write_page(old_pid, &inner.frames[frame_id].page)?;
debug!(page_id = old_pid.0, frame_id, "flushed dirty victim");
}
if let Some(old_pid) = inner.frames[frame_id].page_id {
inner.page_table.remove(&old_pid);
}
let page = pm.read_page(page_id)?;
{
let frame = &mut inner.frames[frame_id];
*frame.page = *page;
frame.page_id = Some(page_id);
frame.pin_count = 1;
frame.dirty = false;
frame.reference = true;
}
inner.page_table.insert(page_id, frame_id);
trace!(
page_id = page_id.0,
frame_id, "loaded page into buffer pool"
);
Ok(inner.frames[frame_id].page.clone())
}
pub fn pin_page(&self, page_id: PageId) -> MenteResult<()> {
let mut inner = self.inner.lock();
match inner.page_table.get(&page_id) {
Some(&fid) => {
inner.frames[fid].pin_count += 1;
Ok(())
}
None => Err(MenteError::Storage(format!(
"page {} not in buffer pool",
page_id.0
))),
}
}
pub fn unpin_page(&self, page_id: PageId, dirty: bool) -> MenteResult<()> {
let mut inner = self.inner.lock();
match inner.page_table.get(&page_id) {
Some(&fid) => {
let frame = &mut inner.frames[fid];
if frame.pin_count > 0 {
frame.pin_count -= 1;
}
if dirty {
frame.dirty = true;
}
Ok(())
}
None => Err(MenteError::Storage(format!(
"page {} not in buffer pool",
page_id.0
))),
}
}
pub fn update_page(&self, page_id: PageId, page: &Page) -> MenteResult<()> {
let mut inner = self.inner.lock();
match inner.page_table.get(&page_id) {
Some(&fid) => {
let frame = &mut inner.frames[fid];
*frame.page = page.clone();
frame.dirty = true;
Ok(())
}
None => Err(MenteError::Storage(format!(
"page {} not in buffer pool",
page_id.0
))),
}
}
pub fn flush_page(&self, page_id: PageId, pm: &mut PageManager) -> MenteResult<()> {
let mut inner = self.inner.lock();
match inner.page_table.get(&page_id) {
Some(&fid) => {
let frame = &mut inner.frames[fid];
if frame.dirty {
pm.write_page(page_id, &frame.page)?;
frame.dirty = false;
debug!(page_id = page_id.0, "flushed page");
}
Ok(())
}
None => Err(MenteError::Storage(format!(
"page {} not in buffer pool",
page_id.0
))),
}
}
pub fn flush_all(&self, pm: &mut PageManager) -> MenteResult<()> {
let mut inner = self.inner.lock();
for frame in &mut inner.frames {
if frame.dirty
&& let Some(pid) = frame.page_id
{
pm.write_page(pid, &frame.page)?;
frame.dirty = false;
}
}
debug!("flushed all dirty pages");
Ok(())
}
fn find_victim(inner: &mut BufferPoolInner) -> MenteResult<FrameId> {
let cap = inner.capacity;
for i in 0..cap {
if inner.frames[i].page_id.is_none() {
return Ok(i);
}
}
let max_sweeps = cap * 2;
for _ in 0..max_sweeps {
let idx = inner.clock_hand;
inner.clock_hand = (inner.clock_hand + 1) % cap;
let frame = &mut inner.frames[idx];
if frame.pin_count == 0 {
if !frame.reference {
return Ok(idx);
}
frame.reference = false;
}
}
Err(MenteError::Storage(
"buffer pool full: all pages are pinned".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::page::Page;
fn setup() -> (tempfile::TempDir, PageManager) {
let dir = tempfile::tempdir().unwrap();
let pm = PageManager::open(dir.path()).unwrap();
(dir, pm)
}
#[test]
fn test_fetch_and_cache_hit() {
let (_dir, mut pm) = setup();
let pool = BufferPool::new(4);
let pid = pm.allocate_page().unwrap();
let mut page = Page::zeroed();
page.header.page_id = pid.0;
page.data[0..3].copy_from_slice(b"abc");
pm.write_page(pid, &page).unwrap();
let p1 = pool.fetch_page(pid, &mut pm).unwrap();
assert_eq!(&p1.data[0..3], b"abc");
pool.unpin_page(pid, false).unwrap();
let p2 = pool.fetch_page(pid, &mut pm).unwrap();
assert_eq!(&p2.data[0..3], b"abc");
pool.unpin_page(pid, false).unwrap();
}
#[test]
fn test_dirty_flush() {
let (_dir, mut pm) = setup();
let pool = BufferPool::new(4);
let pid = pm.allocate_page().unwrap();
let mut page = Page::zeroed();
page.header.page_id = pid.0;
page.data[0] = 42;
pm.write_page(pid, &page).unwrap();
let _ = pool.fetch_page(pid, &mut pm).unwrap();
let mut modified = Page::zeroed();
modified.header.page_id = pid.0;
modified.data[0] = 99;
pool.update_page(pid, &modified).unwrap();
pool.unpin_page(pid, true).unwrap();
pool.flush_page(pid, &mut pm).unwrap();
let on_disk = pm.read_page(pid).unwrap();
assert_eq!(on_disk.data[0], 99);
}
#[test]
fn test_eviction() {
let (_dir, mut pm) = setup();
let pool = BufferPool::new(2);
let p1 = pm.allocate_page().unwrap();
let p2 = pm.allocate_page().unwrap();
let p3 = pm.allocate_page().unwrap();
for pid in [p1, p2, p3] {
let mut page = Page::zeroed();
page.header.page_id = pid.0;
page.data[0] = pid.0 as u8;
pm.write_page(pid, &page).unwrap();
}
let _ = pool.fetch_page(p1, &mut pm).unwrap();
pool.unpin_page(p1, false).unwrap();
let _ = pool.fetch_page(p2, &mut pm).unwrap();
pool.unpin_page(p2, false).unwrap();
let page3 = pool.fetch_page(p3, &mut pm).unwrap();
assert_eq!(page3.data[0], p3.0 as u8);
pool.unpin_page(p3, false).unwrap();
}
#[test]
fn test_all_pinned_error() {
let (_dir, mut pm) = setup();
let pool = BufferPool::new(2);
let p1 = pm.allocate_page().unwrap();
let p2 = pm.allocate_page().unwrap();
let p3 = pm.allocate_page().unwrap();
for pid in [p1, p2, p3] {
let mut page = Page::zeroed();
page.header.page_id = pid.0;
pm.write_page(pid, &page).unwrap();
}
let _ = pool.fetch_page(p1, &mut pm).unwrap();
let _ = pool.fetch_page(p2, &mut pm).unwrap();
assert!(pool.fetch_page(p3, &mut pm).is_err());
}
#[test]
fn test_flush_all() {
let (_dir, mut pm) = setup();
let pool = BufferPool::new(4);
let p1 = pm.allocate_page().unwrap();
let p2 = pm.allocate_page().unwrap();
for pid in [p1, p2] {
let mut page = Page::zeroed();
page.header.page_id = pid.0;
pm.write_page(pid, &page).unwrap();
}
let _ = pool.fetch_page(p1, &mut pm).unwrap();
let _ = pool.fetch_page(p2, &mut pm).unwrap();
let mut mod1 = Page::zeroed();
mod1.data[0] = 0xAA;
pool.update_page(p1, &mod1).unwrap();
let mut mod2 = Page::zeroed();
mod2.data[0] = 0xBB;
pool.update_page(p2, &mod2).unwrap();
pool.unpin_page(p1, true).unwrap();
pool.unpin_page(p2, true).unwrap();
pool.flush_all(&mut pm).unwrap();
let d1 = pm.read_page(p1).unwrap();
let d2 = pm.read_page(p2).unwrap();
assert_eq!(d1.data[0], 0xAA);
assert_eq!(d2.data[0], 0xBB);
}
}