use std::cell::UnsafeCell;
use std::collections::{HashMap, VecDeque};
use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex};
use crate::api::errors::Result;
use crate::concurrency::{Guard as LatchGuard, HybridLatch};
use crate::layout::BlobGuid;
use super::backend::{AlignedBlobBuf, Backend};
pub struct BufferManager {
backend: Arc<dyn Backend>,
capacity: usize,
state: Mutex<BufferManagerState>,
}
struct BufferManagerState {
cache: HashMap<BlobGuid, Arc<CachedBlob>>,
lru: VecDeque<BlobGuid>,
}
pub struct CachedBlob {
latch: HybridLatch,
buf: UnsafeCell<AlignedBlobBuf>,
}
unsafe impl Sync for CachedBlob {}
impl CachedBlob {
fn new(buf: AlignedBlobBuf) -> Self {
Self {
latch: HybridLatch::new(),
buf: UnsafeCell::new(buf),
}
}
pub fn read_optimistic(&self) -> OptimisticGuard<'_> {
OptimisticGuard {
latch: LatchGuard::optimistic(&self.latch),
buf: &self.buf,
}
}
pub fn read(&self) -> BlobReadGuard<'_> {
BlobReadGuard {
_latch: LatchGuard::shared(&self.latch),
buf: &self.buf,
}
}
pub fn write(&self) -> BlobWriteGuard<'_> {
BlobWriteGuard {
_latch: LatchGuard::exclusive(&self.latch),
buf: &self.buf,
}
}
}
pub struct OptimisticGuard<'a> {
latch: LatchGuard<'a>,
buf: &'a UnsafeCell<AlignedBlobBuf>,
}
impl<'a> OptimisticGuard<'a> {
#[must_use]
pub fn as_slice(&self) -> &'a [u8] {
unsafe { (&*self.buf.get()).as_slice() }
}
#[must_use]
pub fn validate(&self) -> bool {
self.latch.validate()
}
}
pub struct BlobReadGuard<'a> {
_latch: LatchGuard<'a>,
buf: &'a UnsafeCell<AlignedBlobBuf>,
}
impl Deref for BlobReadGuard<'_> {
type Target = AlignedBlobBuf;
fn deref(&self) -> &AlignedBlobBuf {
unsafe { &*self.buf.get() }
}
}
pub struct BlobWriteGuard<'a> {
_latch: LatchGuard<'a>,
buf: &'a UnsafeCell<AlignedBlobBuf>,
}
impl Deref for BlobWriteGuard<'_> {
type Target = AlignedBlobBuf;
fn deref(&self) -> &AlignedBlobBuf {
unsafe { &*self.buf.get() }
}
}
impl DerefMut for BlobWriteGuard<'_> {
fn deref_mut(&mut self) -> &mut AlignedBlobBuf {
unsafe { &mut *self.buf.get() }
}
}
impl BufferManager {
#[must_use]
pub fn new(backend: Arc<dyn Backend>, capacity: usize) -> Self {
Self {
backend,
capacity: capacity.max(1),
state: Mutex::new(BufferManagerState {
cache: HashMap::new(),
lru: VecDeque::new(),
}),
}
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
#[must_use]
pub fn cached_count(&self) -> usize {
self.state.lock().unwrap().cache.len()
}
pub fn clear(&self) {
let mut state = self.state.lock().unwrap();
state.cache.clear();
state.lru.clear();
}
fn get_cached(&self, guid: BlobGuid) -> Option<Arc<CachedBlob>> {
let mut state = self.state.lock().unwrap();
if let Some(entry) = state.cache.get(&guid).cloned() {
if let Some(pos) = state.lru.iter().position(|g| *g == guid) {
state.lru.remove(pos);
}
state.lru.push_back(guid);
Some(entry)
} else {
None
}
}
fn insert_into_cache(&self, guid: BlobGuid, contents: &AlignedBlobBuf) {
let mut state = self.state.lock().unwrap();
if state.cache.contains_key(&guid) {
if let Some(pos) = state.lru.iter().position(|g| *g == guid) {
state.lru.remove(pos);
}
state.lru.push_back(guid);
return;
}
let entry = Arc::new(CachedBlob::new(contents.clone()));
state.cache.insert(guid, entry);
state.lru.push_back(guid);
while state.cache.len() > self.capacity {
if !Self::try_evict_lru(&mut state) {
break;
}
}
}
fn try_evict_lru(state: &mut BufferManagerState) -> bool {
let mut victim_idx = None;
for (i, guid) in state.lru.iter().enumerate() {
if let Some(entry) = state.cache.get(guid) {
if Arc::strong_count(entry) <= 1 {
victim_idx = Some((i, *guid));
break;
}
}
}
if let Some((idx, guid)) = victim_idx {
state.lru.remove(idx);
state.cache.remove(&guid);
true
} else {
false
}
}
fn evict_from_cache(&self, guid: BlobGuid) {
let mut state = self.state.lock().unwrap();
state.cache.remove(&guid);
if let Some(pos) = state.lru.iter().position(|g| *g == guid) {
state.lru.remove(pos);
}
}
pub fn pin(&self, guid: BlobGuid) -> Result<Arc<CachedBlob>> {
if let Some(entry) = self.get_cached(guid) {
return Ok(entry);
}
let mut scratch = AlignedBlobBuf::zeroed();
self.backend.read_blob(guid, &mut scratch)?;
self.insert_into_cache(guid, &scratch);
if let Some(entry) = self.get_cached(guid) {
return Ok(entry);
}
let entry = Arc::new(CachedBlob::new(scratch));
let mut state = self.state.lock().unwrap();
state.cache.insert(guid, entry.clone());
if let Some(pos) = state.lru.iter().position(|g| *g == guid) {
state.lru.remove(pos);
}
state.lru.push_back(guid);
Ok(entry)
}
pub fn commit(&self, guid: BlobGuid) -> Result<()> {
if let Some(entry) = self.get_cached(guid) {
let buf = entry.read();
self.backend.write_blob(guid, &buf)?;
}
Ok(())
}
}
impl Backend for BufferManager {
fn read_blob(&self, guid: BlobGuid, dst: &mut AlignedBlobBuf) -> Result<()> {
if let Some(entry) = self.get_cached(guid) {
let buf = entry.read();
dst.as_mut_slice().copy_from_slice(buf.as_slice());
return Ok(());
}
self.backend.read_blob(guid, dst)?;
self.insert_into_cache(guid, dst);
Ok(())
}
fn write_blob(&self, guid: BlobGuid, src: &AlignedBlobBuf) -> Result<()> {
if let Some(entry) = self.get_cached(guid) {
let mut buf = entry.write();
buf.as_mut_slice().copy_from_slice(src.as_slice());
}
self.backend.write_blob(guid, src)
}
fn delete_blob(&self, guid: BlobGuid) -> Result<()> {
self.evict_from_cache(guid);
self.backend.delete_blob(guid)
}
fn list_blobs(&self) -> Result<Vec<BlobGuid>> {
self.backend.list_blobs()
}
fn flush(&self) -> Result<()> {
self.backend.flush()
}
fn has_blob(&self, guid: BlobGuid) -> Result<bool> {
{
let state = self.state.lock().unwrap();
if state.cache.contains_key(&guid) {
return Ok(true);
}
}
self.backend.has_blob(guid)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::backend::MemoryBackend;
fn make_buf(byte_at_100: u8) -> AlignedBlobBuf {
let mut b = AlignedBlobBuf::zeroed();
b.as_mut_slice()[100] = byte_at_100;
b
}
#[test]
fn read_caches_after_first_load() {
let inner: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
inner.write_blob([0xAB; 16], &make_buf(7)).unwrap();
let bm = BufferManager::new(inner.clone(), 4);
assert_eq!(bm.cached_count(), 0);
let mut dst = AlignedBlobBuf::zeroed();
bm.read_blob([0xAB; 16], &mut dst).unwrap();
assert_eq!(dst.as_slice()[100], 7);
assert_eq!(bm.cached_count(), 1);
bm.read_blob([0xAB; 16], &mut dst).unwrap();
assert_eq!(bm.cached_count(), 1);
}
#[test]
fn lru_eviction_at_capacity() {
let inner: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
for i in 0..10u8 {
let mut g = [0u8; 16];
g[0] = i;
inner.write_blob(g, &make_buf(i)).unwrap();
}
let bm = BufferManager::new(inner, 4);
for i in 0..10u8 {
let mut g = [0u8; 16];
g[0] = i;
let mut dst = AlignedBlobBuf::zeroed();
bm.read_blob(g, &mut dst).unwrap();
}
assert_eq!(
bm.cached_count(),
4,
"cache must shrink to capacity after over-fill",
);
let state = bm.state.lock().unwrap();
let mut g_last = [0u8; 16];
g_last[0] = 9;
let mut g_first = [0u8; 16];
g_first[0] = 0;
assert!(state.cache.contains_key(&g_last));
assert!(!state.cache.contains_key(&g_first));
}
#[test]
fn write_through_propagates_to_inner_backend() {
let inner: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
let bm = BufferManager::new(inner.clone(), 4);
bm.write_blob([0xCD; 16], &make_buf(0x42)).unwrap();
assert!(inner.has_blob([0xCD; 16]).unwrap());
let mut dst = AlignedBlobBuf::zeroed();
inner.read_blob([0xCD; 16], &mut dst).unwrap();
assert_eq!(dst.as_slice()[100], 0x42);
}
#[test]
fn write_through_updates_cache_if_present() {
let inner: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
inner.write_blob([0xEF; 16], &make_buf(1)).unwrap();
let bm = BufferManager::new(inner.clone(), 4);
let mut dst = AlignedBlobBuf::zeroed();
bm.read_blob([0xEF; 16], &mut dst).unwrap();
assert_eq!(dst.as_slice()[100], 1);
bm.write_blob([0xEF; 16], &make_buf(99)).unwrap();
bm.read_blob([0xEF; 16], &mut dst).unwrap();
assert_eq!(dst.as_slice()[100], 99);
}
#[test]
fn delete_evicts_from_cache_and_inner() {
let inner: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
inner.write_blob([0x33; 16], &make_buf(5)).unwrap();
let bm = BufferManager::new(inner.clone(), 4);
let mut dst = AlignedBlobBuf::zeroed();
bm.read_blob([0x33; 16], &mut dst).unwrap();
assert_eq!(bm.cached_count(), 1);
bm.delete_blob([0x33; 16]).unwrap();
assert_eq!(bm.cached_count(), 0);
assert!(!inner.has_blob([0x33; 16]).unwrap());
assert!(!bm.has_blob([0x33; 16]).unwrap());
}
#[test]
fn has_blob_fast_path_avoids_inner_when_cached() {
let inner: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
inner.write_blob([0x77; 16], &make_buf(11)).unwrap();
let bm = BufferManager::new(inner.clone(), 4);
let mut dst = AlignedBlobBuf::zeroed();
bm.read_blob([0x77; 16], &mut dst).unwrap();
assert!(bm.has_blob([0x77; 16]).unwrap());
assert!(!bm.has_blob([0x88; 16]).unwrap());
}
#[test]
fn concurrent_reads_on_different_blobs_progress() {
use std::thread;
let inner: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
for i in 0..16u8 {
let mut g = [0u8; 16];
g[0] = i;
inner.write_blob(g, &make_buf(i)).unwrap();
}
let bm = Arc::new(BufferManager::new(inner, 16));
let handles: Vec<_> = (0..8u8)
.map(|t| {
let bm = bm.clone();
thread::spawn(move || {
for _ in 0..50 {
let mut g = [0u8; 16];
g[0] = t * 2; let mut dst = AlignedBlobBuf::zeroed();
bm.read_blob(g, &mut dst).unwrap();
assert_eq!(dst.as_slice()[100], t * 2);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(bm.cached_count(), 8);
}
}