use std::collections::{HashMap, VecDeque};
use std::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
use super::page::Page;
fn cache_read<'a, T>(lock: &'a RwLock<T>) -> RwLockReadGuard<'a, T> {
lock.read().unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn cache_write<'a, T>(lock: &'a RwLock<T>) -> RwLockWriteGuard<'a, T> {
lock.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn cache_lock<'a, T>(lock: &'a Mutex<T>) -> MutexGuard<'a, T> {
lock.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
}
pub const DEFAULT_CACHE_CAPACITY: usize = 100_000;
pub const MIN_CACHE_CAPACITY: usize = 2;
struct CacheEntry {
page: Page,
visited: bool,
pin_count: usize,
dirty: bool,
}
impl CacheEntry {
fn new(page: Page) -> Self {
Self {
page,
visited: false,
pin_count: 0,
dirty: false,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub writebacks: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
pub struct PageCacheShard {
capacity: usize,
index: RwLock<HashMap<u32, usize>>,
fifo: Mutex<VecDeque<u32>>,
entries: RwLock<Vec<Option<CacheEntry>>>,
free_slots: Mutex<Vec<usize>>,
hand: Mutex<usize>,
stats: Mutex<CacheStats>,
}
impl PageCacheShard {
pub fn new(capacity: usize) -> Self {
let capacity = capacity.max(MIN_CACHE_CAPACITY);
Self {
capacity,
index: RwLock::new(HashMap::with_capacity(capacity)),
fifo: Mutex::new(VecDeque::with_capacity(capacity)),
entries: RwLock::new(Vec::with_capacity(capacity)),
free_slots: Mutex::new(Vec::new()),
hand: Mutex::new(0),
stats: Mutex::new(CacheStats::default()),
}
}
pub fn with_default_capacity() -> Self {
Self::new(DEFAULT_CACHE_CAPACITY)
}
pub fn len(&self) -> usize {
cache_read(&self.index).len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn stats(&self) -> CacheStats {
cache_lock(&self.stats).clone()
}
pub fn reset_stats(&self) {
*cache_lock(&self.stats) = CacheStats::default();
}
pub fn get(&self, page_id: u32) -> Option<Page> {
let index = cache_read(&self.index);
let slot = match index.get(&page_id) {
Some(&s) => s,
None => {
drop(index);
cache_lock(&self.stats).misses += 1;
return None;
}
};
drop(index);
let entries = cache_read(&self.entries);
if let Some(entry) = entries.get(slot).and_then(|e| e.as_ref()) {
let page = entry.page.clone();
let needs_mark = !entry.visited;
drop(entries);
if needs_mark {
let mut entries = cache_write(&self.entries);
if let Some(Some(entry)) = entries.get_mut(slot) {
entry.visited = true;
}
}
cache_lock(&self.stats).hits += 1;
Some(page)
} else {
cache_lock(&self.stats).misses += 1;
None
}
}
pub fn insert(&self, page_id: u32, page: Page) -> Option<Page> {
{
let index = cache_read(&self.index);
if let Some(&slot) = index.get(&page_id) {
drop(index);
let mut entries = cache_write(&self.entries);
if let Some(Some(entry)) = entries.get_mut(slot) {
entry.page = page;
entry.visited = true;
}
return None;
}
}
let mut evicted = None;
let current_len = self.len();
if current_len >= self.capacity {
evicted = self.evict();
}
let slot = {
let mut free_slots = cache_lock(&self.free_slots);
if let Some(slot) = free_slots.pop() {
slot
} else {
drop(free_slots);
let mut entries = cache_write(&self.entries);
let slot = entries.len();
entries.push(None);
slot
}
};
{
let mut entries = cache_write(&self.entries);
while entries.len() <= slot {
entries.push(None);
}
entries[slot] = Some(CacheEntry::new(page));
}
{
let mut index = cache_write(&self.index);
index.insert(page_id, slot);
}
{
let mut fifo = cache_lock(&self.fifo);
fifo.push_back(page_id);
}
evicted
}
fn evict(&self) -> Option<Page> {
let mut fifo = cache_lock(&self.fifo);
let mut hand = cache_lock(&self.hand);
if fifo.is_empty() {
return None;
}
let fifo_len = fifo.len();
let mut attempts = 0;
loop {
if attempts >= fifo_len * 2 {
return None;
}
if *hand >= fifo_len {
*hand = 0;
}
let page_id = fifo[*hand];
attempts += 1;
let slot = {
let index = cache_read(&self.index);
match index.get(&page_id) {
Some(&s) => s,
None => {
*hand += 1;
continue;
}
}
};
let (should_evict, dirty) = {
let entries = cache_read(&self.entries);
match entries.get(slot).and_then(|e| e.as_ref()) {
Some(entry) => {
if entry.pin_count > 0 {
(false, false)
} else if entry.visited {
(false, false)
} else {
(true, entry.dirty)
}
}
None => {
*hand += 1;
continue;
}
}
};
if !should_evict {
let mut entries = cache_write(&self.entries);
if let Some(Some(entry)) = entries.get_mut(slot) {
entry.visited = false;
}
*hand += 1;
continue;
}
let evicted_page = {
let mut entries = cache_write(&self.entries);
let entry = entries[slot].take();
entry.map(|e| e.page)
};
{
let mut index = cache_write(&self.index);
index.remove(&page_id);
}
fifo.remove(*hand);
{
let mut free_slots = cache_lock(&self.free_slots);
free_slots.push(slot);
}
{
let mut stats = cache_lock(&self.stats);
stats.evictions += 1;
if dirty {
stats.writebacks += 1;
}
}
if dirty {
return evicted_page;
} else {
return None;
}
}
}
pub fn mark_dirty(&self, page_id: u32) {
let index = cache_read(&self.index);
if let Some(&slot) = index.get(&page_id) {
drop(index);
let mut entries = cache_write(&self.entries);
if let Some(Some(entry)) = entries.get_mut(slot) {
entry.dirty = true;
}
}
}
pub fn mark_clean(&self, page_id: u32) {
let index = cache_read(&self.index);
if let Some(&slot) = index.get(&page_id) {
drop(index);
let mut entries = cache_write(&self.entries);
if let Some(Some(entry)) = entries.get_mut(slot) {
entry.dirty = false;
}
}
}
pub fn pin(&self, page_id: u32) -> bool {
let index = cache_read(&self.index);
if let Some(&slot) = index.get(&page_id) {
drop(index);
let mut entries = cache_write(&self.entries);
if let Some(Some(entry)) = entries.get_mut(slot) {
entry.pin_count += 1;
return true;
}
}
false
}
pub fn unpin(&self, page_id: u32) -> bool {
let index = cache_read(&self.index);
if let Some(&slot) = index.get(&page_id) {
drop(index);
let mut entries = cache_write(&self.entries);
if let Some(Some(entry)) = entries.get_mut(slot) {
if entry.pin_count > 0 {
entry.pin_count -= 1;
return true;
}
}
}
false
}
pub fn remove(&self, page_id: u32) -> Option<Page> {
let slot = {
let mut index = cache_write(&self.index);
index.remove(&page_id)?
};
let entry = {
let mut entries = cache_write(&self.entries);
entries.get_mut(slot).and_then(|e| e.take())
};
{
let mut fifo = cache_lock(&self.fifo);
fifo.retain(|&id| id != page_id);
}
{
let mut free_slots = cache_lock(&self.free_slots);
free_slots.push(slot);
}
entry.map(|e| e.page)
}
pub fn flush_dirty(&self) -> Vec<(u32, Page)> {
let mut dirty_pages = Vec::new();
let index = cache_read(&self.index);
let entries = cache_read(&self.entries);
for (&page_id, &slot) in index.iter() {
if let Some(Some(entry)) = entries.get(slot) {
if entry.dirty {
dirty_pages.push((page_id, entry.page.clone()));
}
}
}
drop(entries);
drop(index);
for (page_id, _) in &dirty_pages {
self.mark_clean(*page_id);
}
let count = dirty_pages.len();
cache_lock(&self.stats).writebacks += count as u64;
dirty_pages
}
pub fn flush_some_dirty(&self, max: usize) -> Vec<(u32, Page)> {
if max == 0 {
return Vec::new();
}
let mut dirty_pages = Vec::with_capacity(max);
let index = cache_read(&self.index);
let entries = cache_read(&self.entries);
for (&page_id, &slot) in index.iter() {
if dirty_pages.len() >= max {
break;
}
if let Some(Some(entry)) = entries.get(slot) {
if entry.dirty {
dirty_pages.push((page_id, entry.page.clone()));
}
}
}
drop(entries);
drop(index);
for (page_id, _) in &dirty_pages {
self.mark_clean(*page_id);
}
let count = dirty_pages.len();
cache_lock(&self.stats).writebacks += count as u64;
dirty_pages
}
pub fn dirty_count(&self) -> usize {
let index = cache_read(&self.index);
let entries = cache_read(&self.entries);
let mut count = 0;
for (_, &slot) in index.iter() {
if let Some(Some(entry)) = entries.get(slot) {
if entry.dirty {
count += 1;
}
}
}
count
}
pub fn clear(&self) {
let mut index = cache_write(&self.index);
let mut entries = cache_write(&self.entries);
let mut fifo = cache_lock(&self.fifo);
let mut free_slots = cache_lock(&self.free_slots);
index.clear();
entries.clear();
fifo.clear();
free_slots.clear();
*cache_lock(&self.hand) = 0;
}
pub fn contains(&self, page_id: u32) -> bool {
cache_read(&self.index).contains_key(&page_id)
}
pub fn page_ids(&self) -> Vec<u32> {
cache_read(&self.index).keys().copied().collect()
}
}
impl Default for PageCacheShard {
fn default() -> Self {
Self::with_default_capacity()
}
}
const NUM_SHARDS: usize = 8;
pub struct PageCache {
shards: Box<[PageCacheShard]>,
capacity: usize,
}
impl PageCache {
pub fn new(capacity: usize) -> Self {
let per_shard = capacity.div_ceil(NUM_SHARDS).max(MIN_CACHE_CAPACITY);
let total = per_shard * NUM_SHARDS;
let shards: Vec<PageCacheShard> = (0..NUM_SHARDS)
.map(|_| PageCacheShard::new(per_shard))
.collect();
Self {
shards: shards.into_boxed_slice(),
capacity: total,
}
}
pub fn with_default_capacity() -> Self {
Self::new(DEFAULT_CACHE_CAPACITY)
}
#[inline]
fn shard_for(&self, page_id: u32) -> &PageCacheShard {
&self.shards[(page_id as usize) & (NUM_SHARDS - 1)]
}
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.is_empty())
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn stats(&self) -> CacheStats {
let mut agg = CacheStats::default();
for s in self.shards.iter() {
let cs = s.stats();
agg.hits += cs.hits;
agg.misses += cs.misses;
agg.evictions += cs.evictions;
agg.writebacks += cs.writebacks;
}
agg
}
pub fn reset_stats(&self) {
for s in self.shards.iter() {
s.reset_stats();
}
}
pub fn get(&self, page_id: u32) -> Option<Page> {
self.shard_for(page_id).get(page_id)
}
pub fn insert(&self, page_id: u32, page: Page) -> Option<Page> {
self.shard_for(page_id).insert(page_id, page)
}
pub fn mark_dirty(&self, page_id: u32) {
self.shard_for(page_id).mark_dirty(page_id)
}
pub fn mark_clean(&self, page_id: u32) {
self.shard_for(page_id).mark_clean(page_id)
}
pub fn pin(&self, page_id: u32) -> bool {
self.shard_for(page_id).pin(page_id)
}
pub fn unpin(&self, page_id: u32) -> bool {
self.shard_for(page_id).unpin(page_id)
}
pub fn remove(&self, page_id: u32) -> Option<Page> {
self.shard_for(page_id).remove(page_id)
}
pub fn contains(&self, page_id: u32) -> bool {
self.shard_for(page_id).contains(page_id)
}
pub fn flush_dirty(&self) -> Vec<(u32, Page)> {
let mut out = Vec::new();
for s in self.shards.iter() {
out.extend(s.flush_dirty());
}
out
}
pub fn flush_some_dirty(&self, max: usize) -> Vec<(u32, Page)> {
if max == 0 {
return Vec::new();
}
let mut out = Vec::with_capacity(max);
for s in self.shards.iter() {
if out.len() >= max {
break;
}
let budget = max - out.len();
out.extend(s.flush_some_dirty(budget));
}
out
}
pub fn dirty_count(&self) -> usize {
self.shards.iter().map(|s| s.dirty_count()).sum()
}
pub fn clear(&self) {
for s in self.shards.iter() {
s.clear();
}
}
pub fn page_ids(&self) -> Vec<u32> {
let mut out = Vec::with_capacity(self.len());
for s in self.shards.iter() {
out.extend(s.page_ids());
}
out
}
}
impl Default for PageCache {
fn default() -> Self {
Self::with_default_capacity()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::engine::page::PageType;
fn make_page(id: u32) -> Page {
Page::new(PageType::BTreeLeaf, id)
}
#[test]
fn test_cache_basic() {
let cache = PageCacheShard::new(100);
assert!(cache.is_empty());
assert_eq!(cache.capacity(), 100);
cache.insert(1, make_page(1));
assert_eq!(cache.len(), 1);
assert!(cache.contains(1));
let page = cache.get(1);
assert!(page.is_some());
let page = cache.get(999);
assert!(page.is_none());
}
#[test]
fn test_cache_eviction() {
let cache = PageCacheShard::new(4);
for i in 0..4 {
cache.insert(i, make_page(i));
}
assert_eq!(cache.len(), 4);
cache.insert(100, make_page(100));
assert_eq!(cache.len(), 4);
assert!(cache.contains(100));
}
#[test]
fn test_cache_sieve_visited() {
let cache = PageCacheShard::new(4);
for i in 0..4 {
cache.insert(i, make_page(i));
}
cache.get(0);
cache.insert(100, make_page(100));
assert!(cache.contains(0));
assert!(cache.contains(100));
}
#[test]
fn test_cache_dirty() {
let cache = PageCacheShard::new(100);
cache.insert(1, make_page(1));
cache.mark_dirty(1);
let dirty = cache.flush_dirty();
assert_eq!(dirty.len(), 1);
assert_eq!(dirty[0].0, 1);
let dirty = cache.flush_dirty();
assert_eq!(dirty.len(), 0);
}
#[test]
fn test_cache_pin() {
let cache = PageCacheShard::new(2);
cache.insert(1, make_page(1));
cache.insert(2, make_page(2));
assert!(cache.pin(1));
cache.insert(3, make_page(3));
assert!(cache.contains(1));
assert!(cache.unpin(1));
}
#[test]
fn test_cache_remove() {
let cache = PageCacheShard::new(100);
cache.insert(1, make_page(1));
assert!(cache.contains(1));
let removed = cache.remove(1);
assert!(removed.is_some());
assert!(!cache.contains(1));
}
#[test]
fn test_cache_stats() {
let cache = PageCacheShard::new(100);
cache.insert(1, make_page(1));
cache.get(1);
cache.get(1);
cache.get(999);
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_cache_clear() {
let cache = PageCacheShard::new(100);
for i in 0..50 {
cache.insert(i, make_page(i));
}
assert_eq!(cache.len(), 50);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_cache_update_existing() {
let cache = PageCacheShard::new(100);
let mut page1 = make_page(1);
page1.as_bytes_mut()[100] = 0xAA;
cache.insert(1, page1);
let mut page1_updated = make_page(1);
page1_updated.as_bytes_mut()[100] = 0xBB;
cache.insert(1, page1_updated);
assert_eq!(cache.len(), 1);
let retrieved = cache.get(1).unwrap();
assert_eq!(retrieved.as_bytes()[100], 0xBB);
}
#[test]
fn test_cache_recovers_after_index_lock_poisoning() {
let cache = std::sync::Arc::new(PageCacheShard::new(8));
let poison_target = std::sync::Arc::clone(&cache);
let _ = std::thread::spawn(move || {
let _guard = poison_target
.index
.write()
.expect("index lock should be acquired");
panic!("poison index lock");
})
.join();
cache.insert(1, make_page(1));
assert!(cache.contains(1));
assert!(cache.get(1).is_some());
}
#[test]
fn test_cache_recovers_after_stats_lock_poisoning() {
let cache = std::sync::Arc::new(PageCacheShard::new(8));
let poison_target = std::sync::Arc::clone(&cache);
let _ = std::thread::spawn(move || {
let _guard = poison_target
.stats
.lock()
.expect("stats lock should be acquired");
panic!("poison stats lock");
})
.join();
assert!(cache.get(999).is_none());
assert_eq!(cache.stats().misses, 1);
cache.reset_stats();
assert_eq!(cache.stats().misses, 0);
}
mod legacy_baseline {
use super::Page;
use std::collections::HashMap;
use std::sync::RwLock;
pub struct LegacyPageCache {
entries: RwLock<HashMap<u32, Page>>,
}
impl LegacyPageCache {
pub fn new(_capacity: usize) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
}
}
pub fn insert(&self, page_id: u32, page: Page) {
let mut entries = self.entries.write().unwrap();
entries.insert(page_id, page);
}
pub fn get(&self, page_id: u32) -> Option<Page> {
let entries = self.entries.read().unwrap();
entries.get(&page_id).cloned()
}
}
}
fn run_workload<F>(workers: usize, ops_per_worker: usize, run: F) -> std::time::Duration
where
F: Fn(u32, &Page) + Send + Sync + 'static + Clone,
{
use std::sync::Arc;
use std::time::Instant;
let run = Arc::new(run);
let start = Instant::now();
let mut handles = Vec::with_capacity(workers);
for w in 0..workers {
let run = Arc::clone(&run);
handles.push(std::thread::spawn(move || {
let base = (w as u32) * 1_000_000;
let page = make_page(0);
for i in 0..ops_per_worker {
let id = base + (i as u32);
run(id, &page);
}
}));
}
for h in handles {
h.join().unwrap();
}
start.elapsed()
}
#[test]
fn test_sharded_cache_scales_concurrently() {
use std::sync::Arc;
const WORKERS: usize = 10;
const OPS: usize = 5_000;
const CAPACITY: usize = 200_000;
let sharded = Arc::new(PageCache::new(CAPACITY));
let s1 = Arc::clone(&sharded);
let sharded_serial = run_workload(1, OPS * WORKERS, move |id, page| {
s1.insert(id, page.clone());
let _ = s1.get(id);
});
let sharded = Arc::new(PageCache::new(CAPACITY));
let s2 = Arc::clone(&sharded);
let sharded_parallel = run_workload(WORKERS, OPS, move |id, page| {
s2.insert(id, page.clone());
let _ = s2.get(id);
});
let legacy = Arc::new(legacy_baseline::LegacyPageCache::new(CAPACITY));
let l2 = Arc::clone(&legacy);
let legacy_parallel = run_workload(WORKERS, OPS, move |id, page| {
l2.insert(id, page.clone());
let _ = l2.get(id);
});
eprintln!(
"page_cache concurrency: sharded 1w={:?} sharded {}w={:?} legacy {}w={:?}",
sharded_serial, WORKERS, sharded_parallel, WORKERS, legacy_parallel
);
assert!(
sharded_parallel.as_nanos() < sharded_serial.as_nanos() * 7,
"sharded cache did not scale: 1w={:?} {}w={:?}",
sharded_serial,
WORKERS,
sharded_parallel
);
assert!(
sharded_parallel.as_nanos() * 12 < legacy_parallel.as_nanos() * 10,
"sharded cache did not beat legacy baseline: sharded {}w={:?} legacy {}w={:?}",
WORKERS,
sharded_parallel,
WORKERS,
legacy_parallel
);
}
}