use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use crate::error::{PageError, PageResult};
use crate::file::PageFile;
use crate::page::{Page, PageId, PageSize};
use crate::store::PageStore;
use crate::sync::{self, Arc, AtomicBool, AtomicU64, AtomicUsize, Mutex, Ordering, RwLock};
use crate::sync::{RwLockReadGuard, RwLockWriteGuard};
const NO_PAGE: u64 = u64::MAX;
struct FrameInner {
page: RwLock<Page>,
id: AtomicU64,
pin: AtomicUsize,
dirty: AtomicBool,
referenced: AtomicBool,
}
impl FrameInner {
fn new(page: Page) -> Self {
Self {
page: RwLock::new(page),
id: AtomicU64::new(NO_PAGE),
pin: AtomicUsize::new(0),
dirty: AtomicBool::new(false),
referenced: AtomicBool::new(false),
}
}
#[inline]
fn resident_id(&self) -> PageId {
PageId::new(self.id.load(Ordering::Acquire))
}
}
struct Core {
map: HashMap<PageId, usize>,
free: Vec<usize>,
hand: usize,
}
pub struct BufferPool<S = PageFile> {
store: S,
frames: Vec<Arc<FrameInner>>,
core: Mutex<Core>,
capacity: usize,
}
impl BufferPool<PageFile> {
pub fn open<P: AsRef<Path>>(path: P, page_size: PageSize, capacity: usize) -> PageResult<Self> {
let file = PageFile::open(path, page_size)?;
Ok(Self::new(file, capacity))
}
}
impl<S: PageStore> BufferPool<S> {
#[must_use]
pub fn new(store: S, capacity: usize) -> Self {
let capacity = capacity.max(1);
let mut frames = Vec::with_capacity(capacity);
for _ in 0..capacity {
frames.push(Arc::new(FrameInner::new(store.allocate_page())));
}
let free = (0..capacity).collect();
Self {
store,
frames,
core: Mutex::new(Core {
map: HashMap::with_capacity(capacity),
free,
hand: 0,
}),
capacity,
}
}
#[inline]
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn resident_len(&self) -> usize {
sync::lock(&self.core).map.len()
}
#[must_use]
pub fn is_resident(&self, id: PageId) -> bool {
sync::lock(&self.core).map.contains_key(&id)
}
pub fn fetch(&self, id: PageId) -> PageResult<PageGuard> {
let mut core = sync::lock(&self.core);
if let Some(&slot) = core.map.get(&id) {
let frame = self.frames[slot].clone();
let _ = frame.pin.fetch_add(1, Ordering::AcqRel);
frame.referenced.store(true, Ordering::Release);
return Ok(PageGuard { frame });
}
let slot = self.take_slot(&mut core)?;
{
let frame = &self.frames[slot];
let mut page = sync::write(&frame.page);
if let Err(err) = self.store.read_into(id, &mut page) {
drop(page);
core.free.push(slot);
return Err(err);
}
}
Ok(self.install(&mut core, slot, id, false))
}
pub fn new_page(&self, id: PageId) -> PageResult<PageGuard> {
let mut core = sync::lock(&self.core);
if let Some(&slot) = core.map.get(&id) {
let frame = self.frames[slot].clone();
sync::write(&frame.page).reset();
let _ = frame.pin.fetch_add(1, Ordering::AcqRel);
frame.dirty.store(true, Ordering::Release);
frame.referenced.store(true, Ordering::Release);
return Ok(PageGuard { frame });
}
let slot = self.take_slot(&mut core)?;
sync::write(&self.frames[slot].page).reset();
Ok(self.install(&mut core, slot, id, true))
}
pub fn flush(&self, id: PageId) -> PageResult<()> {
let core = sync::lock(&self.core);
if let Some(&slot) = core.map.get(&id) {
self.flush_slot(slot, id)?;
}
Ok(())
}
pub fn flush_all(&self) -> PageResult<()> {
let core = sync::lock(&self.core);
for (&id, &slot) in core.map.iter() {
self.flush_slot(slot, id)?;
}
Ok(())
}
pub fn checkpoint(&self) -> PageResult<()> {
self.flush_all()?;
self.sync()
}
pub fn sync(&self) -> PageResult<()> {
self.store.sync()
}
fn install(&self, core: &mut Core, slot: usize, id: PageId, dirty: bool) -> PageGuard {
let frame = &self.frames[slot];
frame.id.store(id.get(), Ordering::Release);
frame.dirty.store(dirty, Ordering::Release);
frame.referenced.store(true, Ordering::Release);
frame.pin.store(1, Ordering::Release);
let _ = core.map.insert(id, slot);
PageGuard {
frame: self.frames[slot].clone(),
}
}
fn flush_slot(&self, slot: usize, id: PageId) -> PageResult<()> {
let frame = &self.frames[slot];
if frame.dirty.load(Ordering::Acquire) {
let mut page = sync::write(&frame.page);
self.store.write_page(id, &mut page)?;
frame.dirty.store(false, Ordering::Release);
}
Ok(())
}
fn take_slot(&self, core: &mut Core) -> PageResult<usize> {
if let Some(slot) = core.free.pop() {
return Ok(slot);
}
let slot = match self.find_victim(core) {
Some(slot) => slot,
None => {
return Err(PageError::BufferPoolExhausted {
capacity: self.capacity,
});
}
};
let victim_id = self.frames[slot].resident_id();
self.flush_slot(slot, victim_id)?;
let _ = core.map.remove(&victim_id);
Ok(slot)
}
fn find_victim(&self, core: &mut Core) -> Option<usize> {
let n = self.capacity;
let mut steps = 0;
while steps < 2 * n {
let slot = core.hand;
core.hand = (core.hand + 1) % n;
steps += 1;
let frame = &self.frames[slot];
if frame.pin.load(Ordering::Acquire) > 0 {
continue;
}
if frame.referenced.swap(false, Ordering::AcqRel) {
continue;
}
return Some(slot);
}
None
}
}
pub struct PageGuard {
frame: Arc<FrameInner>,
}
impl PageGuard {
#[inline]
#[must_use]
pub fn id(&self) -> PageId {
self.frame.resident_id()
}
#[inline]
#[must_use]
pub fn is_dirty(&self) -> bool {
self.frame.dirty.load(Ordering::Acquire)
}
#[inline]
#[must_use]
pub fn read(&self) -> PageRef<'_> {
PageRef {
guard: sync::read(&self.frame.page),
}
}
#[inline]
#[must_use]
pub fn write(&self) -> PageMut<'_> {
self.frame.dirty.store(true, Ordering::Release);
PageMut {
guard: sync::write(&self.frame.page),
}
}
}
impl Drop for PageGuard {
fn drop(&mut self) {
let _ = self.frame.pin.fetch_sub(1, Ordering::AcqRel);
}
}
impl std::fmt::Debug for PageGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PageGuard")
.field("id", &self.id())
.field("dirty", &self.is_dirty())
.finish()
}
}
pub struct PageRef<'a> {
guard: RwLockReadGuard<'a, Page>,
}
impl Deref for PageRef<'_> {
type Target = Page;
#[inline]
fn deref(&self) -> &Page {
&self.guard
}
}
pub struct PageMut<'a> {
guard: RwLockWriteGuard<'a, Page>,
}
impl Deref for PageMut<'_> {
type Target = Page;
#[inline]
fn deref(&self) -> &Page {
&self.guard
}
}
impl DerefMut for PageMut<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Page {
&mut self.guard
}
}
#[cfg(all(test, not(loom)))]
mod tests {
#![allow(clippy::unwrap_used, clippy::expect_used)]
use std::collections::HashMap;
use proptest::prelude::*;
use super::*;
use crate::page::Lsn;
use crate::test_store::MemStore;
const PS: usize = 4096;
fn pool(capacity: usize) -> BufferPool<MemStore> {
BufferPool::new(MemStore::new(PS), capacity)
}
#[test]
fn test_new_page_then_fetch_serves_from_cache() {
let pool = pool(8);
{
let guard = pool.new_page(PageId::new(0)).unwrap();
guard.write().payload_mut()[0] = 0x7A;
}
assert!(pool.is_resident(PageId::new(0)));
let guard = pool.fetch(PageId::new(0)).unwrap();
assert_eq!(guard.read().payload()[0], 0x7A);
}
#[test]
fn test_capacity_is_clamped_up_to_one() {
assert_eq!(pool(0).capacity(), 1);
}
#[test]
fn test_pinned_page_is_never_evicted() {
let pool = pool(1);
let _held = pool.new_page(PageId::new(0)).unwrap();
assert!(matches!(
pool.new_page(PageId::new(1)),
Err(PageError::BufferPoolExhausted { capacity: 1 })
));
assert!(pool.is_resident(PageId::new(0)));
}
#[test]
fn test_dirty_page_is_flushed_before_eviction() {
let pool = pool(1);
{
let guard = pool.new_page(PageId::new(0)).unwrap();
guard.write().set_lsn(Lsn::new(9));
}
{
let _ = pool.new_page(PageId::new(1)).unwrap();
}
assert!(pool.store_contains(0));
let guard = pool.fetch(PageId::new(0)).unwrap();
assert_eq!(guard.read().lsn(), Lsn::new(9));
}
#[test]
fn test_clock_keeps_the_recently_used_page() {
let pool = pool(2);
let _ = pool.new_page(PageId::new(0)).unwrap();
let _ = pool.new_page(PageId::new(1)).unwrap();
pool.flush_all().unwrap();
let _ = pool.fetch(PageId::new(0)).unwrap();
let _ = pool.new_page(PageId::new(2)).unwrap();
assert!(pool.is_resident(PageId::new(0)));
assert!(!pool.is_resident(PageId::new(1)));
assert!(pool.is_resident(PageId::new(2)));
}
#[test]
fn test_flush_clears_dirty() {
let pool = pool(4);
{
let guard = pool.new_page(PageId::new(0)).unwrap();
assert!(guard.is_dirty());
}
pool.flush(PageId::new(0)).unwrap();
let guard = pool.fetch(PageId::new(0)).unwrap();
assert!(!guard.is_dirty());
}
#[test]
fn test_fetch_missing_unwritten_page_errors() {
let pool = pool(4);
assert!(matches!(
pool.fetch(PageId::new(99)),
Err(PageError::ShortRead { .. })
));
assert_eq!(pool.resident_len(), 0);
assert_eq!(pool.capacity(), 4);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(48))]
#[test]
fn pool_never_loses_data(
ops in proptest::collection::vec((0u8..6, any::<u8>(), any::<bool>()), 1..200),
) {
const N: u64 = 6;
let pool = pool(2); let mut expected: HashMap<u64, u8> = HashMap::new();
for id in 0..N {
let guard = pool.new_page(PageId::new(id)).unwrap();
guard.write().payload_mut()[0] = 0;
let _ = expected.insert(id, 0);
drop(guard);
}
pool.flush_all().unwrap();
for (id, marker, write) in ops {
let id = id as u64 % N;
let guard = pool.fetch(PageId::new(id)).unwrap();
prop_assert_eq!(guard.read().payload()[0], expected[&id]);
if write {
guard.write().payload_mut()[0] = marker;
let _ = expected.insert(id, marker);
}
}
pool.flush_all().unwrap();
for id in 0..N {
let guard = pool.fetch(PageId::new(id)).unwrap();
prop_assert_eq!(guard.read().payload()[0], expected[&id]);
}
}
}
impl BufferPool<MemStore> {
fn store_contains(&self, id: u64) -> bool {
self.store.contains(id)
}
}
}
#[cfg(all(test, loom))]
mod loom_tests {
use super::*;
use crate::sync::Arc;
use crate::test_store::MemStore;
#[test]
fn loom_pinned_page_never_evicted() {
loom::model(|| {
let pool = Arc::new(BufferPool::new(MemStore::new(4096), 1));
let held = pool.new_page(PageId::new(0)).unwrap();
let p = Arc::clone(&pool);
let other = loom::thread::spawn(move || p.new_page(PageId::new(1)).is_err());
assert!(pool.is_resident(PageId::new(0)));
let admit_failed = other.join().unwrap();
assert!(admit_failed);
assert_eq!(held.id(), PageId::new(0));
drop(held);
});
}
#[test]
fn loom_dirty_page_flushed_on_eviction() {
loom::model(|| {
let store_pages = {
let pool = Arc::new(BufferPool::new(MemStore::new(4096), 1));
{
let guard = pool.new_page(PageId::new(0)).unwrap();
guard.write().payload_mut()[0] = 0x5A;
}
let p = Arc::clone(&pool);
let t = loom::thread::spawn(move || {
let _ = p.new_page(PageId::new(1)).unwrap();
});
t.join().unwrap();
pool.store_contains_loom(0)
};
assert!(store_pages, "evicted dirty page 0 was not flushed");
});
}
impl BufferPool<MemStore> {
fn store_contains_loom(&self, id: u64) -> bool {
self.store.contains(id)
}
}
}