use crate::{Blob, Error, RwLock};
use commonware_utils::StableBuf;
use futures::{future::Shared, FutureExt};
use std::{
collections::{hash_map::Entry, HashMap},
future::Future,
num::NonZeroUsize,
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
};
use tracing::{debug, trace};
type PageFetchFut = Shared<Pin<Box<dyn Future<Output = Result<StableBuf, Arc<Error>>> + Send>>>;
pub struct Pool {
index: HashMap<(u64, u64), usize>,
cache: Vec<CacheEntry>,
clock: usize,
capacity: usize,
page_fetches: HashMap<(u64, u64), PageFetchFut>,
}
struct CacheEntry {
key: (u64, u64),
referenced: AtomicBool,
data: Vec<u8>,
}
#[derive(Clone)]
pub struct PoolRef {
pub(super) page_size: usize,
next_id: Arc<AtomicU64>,
pool: Arc<RwLock<Pool>>,
}
impl PoolRef {
pub fn new(page_size: NonZeroUsize, capacity: NonZeroUsize) -> Self {
Self {
page_size: page_size.get(),
next_id: Arc::new(AtomicU64::new(0)),
pool: Arc::new(RwLock::new(Pool::new(capacity.get()))),
}
}
pub async fn next_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::Relaxed)
}
pub fn offset_to_page(&self, offset: u64) -> (u64, usize) {
Pool::offset_to_page(self.page_size, offset)
}
pub(super) async fn read<B: Blob>(
&self,
blob: &B,
blob_id: u64,
mut buf: &mut [u8],
mut offset: u64,
) -> Result<(), Error> {
while !buf.is_empty() {
{
let buffer_pool = self.pool.read().await;
let count = buffer_pool.read_at(self.page_size, blob_id, buf, offset);
if count != 0 {
offset += count as u64;
buf = &mut buf[count..];
continue;
}
}
let count = self
.read_after_page_fault(blob, blob_id, buf, offset)
.await?;
offset += count as u64;
buf = &mut buf[count..];
}
Ok(())
}
async fn read_after_page_fault<B: Blob>(
&self,
blob: &B,
blob_id: u64,
buf: &mut [u8],
offset: u64,
) -> Result<usize, Error> {
assert!(!buf.is_empty());
let (page_num, offset_in_page) = Pool::offset_to_page(self.page_size, offset);
let page_size = self.page_size;
trace!(page_num, blob_id, "page fault");
let (fetch_future, is_first_fetcher) = {
let mut pool = self.pool.write().await;
let count = pool.read_at(page_size, blob_id, buf, offset);
if count != 0 {
return Ok(count);
}
let entry = pool.page_fetches.entry((blob_id, page_num));
match entry {
Entry::Occupied(o) => {
(o.get().clone(), false)
}
Entry::Vacant(v) => {
let blob = blob.clone();
let future = async move {
blob.read_at(vec![0; page_size], page_num * page_size as u64)
.await
.map_err(Arc::new)
};
let shareable = future.boxed().shared();
v.insert(shareable.clone());
(shareable, true)
}
}
};
let fetch_result = fetch_future.await;
if !is_first_fetcher {
let page_buf: Vec<u8> = fetch_result.map_err(|_| Error::ReadFailed)?.into();
let bytes_to_copy = std::cmp::min(buf.len(), page_size - offset_in_page);
buf[..bytes_to_copy]
.copy_from_slice(&page_buf[offset_in_page..offset_in_page + bytes_to_copy]);
return Ok(bytes_to_copy);
}
let mut pool = self.pool.write().await;
let _ = pool.page_fetches.remove(&(blob_id, page_num));
let Ok(page_buf) = fetch_result else {
return Err(Error::ReadFailed);
};
pool.cache(page_size, blob_id, page_buf.as_ref(), page_num);
let page_buf: Vec<u8> = page_buf.into();
let bytes_to_copy = std::cmp::min(buf.len(), page_size - offset_in_page);
buf[..bytes_to_copy]
.copy_from_slice(&page_buf[offset_in_page..offset_in_page + bytes_to_copy]);
Ok(bytes_to_copy)
}
pub async fn cache(&self, blob_id: u64, mut buf: &[u8], offset: u64) -> usize {
let (mut page_num, offset_in_page) = self.offset_to_page(offset);
assert_eq!(offset_in_page, 0);
{
let mut buffer_pool = self.pool.write().await;
while buf.len() >= self.page_size {
buffer_pool.cache(self.page_size, blob_id, &buf[..self.page_size], page_num);
buf = &buf[self.page_size..];
page_num += 1;
}
}
buf.len()
}
}
impl Pool {
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0);
Self {
index: HashMap::new(),
cache: Vec::new(),
clock: 0,
capacity,
page_fetches: HashMap::new(),
}
}
fn offset_to_page(page_size: usize, offset: u64) -> (u64, usize) {
(
offset / page_size as u64,
(offset % page_size as u64) as usize,
)
}
fn read_at(&self, page_size: usize, blob_id: u64, buf: &mut [u8], offset: u64) -> usize {
let (page_num, offset_in_page) = Self::offset_to_page(page_size, offset);
let page_index = self.index.get(&(blob_id, page_num));
let Some(&page_index) = page_index else {
return 0;
};
let page = &self.cache[page_index];
assert_eq!(page.key, (blob_id, page_num));
page.referenced.store(true, Ordering::Relaxed);
let page = &page.data;
let bytes_to_copy = std::cmp::min(buf.len(), page_size - offset_in_page);
buf[..bytes_to_copy].copy_from_slice(&page[offset_in_page..offset_in_page + bytes_to_copy]);
bytes_to_copy
}
fn cache(&mut self, page_size: usize, blob_id: u64, page: &[u8], page_num: u64) {
assert_eq!(page.len(), page_size);
let key = (blob_id, page_num);
let index_entry = self.index.entry(key);
if let Entry::Occupied(index_entry) = index_entry {
debug!(blob_id, page_num, "updating duplicate page");
let entry = &mut self.cache[*index_entry.get()];
assert_eq!(entry.key, key);
entry.referenced.store(true, Ordering::Relaxed);
entry.data.copy_from_slice(page);
return;
}
if self.cache.len() < self.capacity {
self.index.insert(key, self.cache.len());
self.cache.push(CacheEntry {
key,
referenced: AtomicBool::new(true),
data: page.into(),
});
return;
}
while self.cache[self.clock].referenced.load(Ordering::Relaxed) {
self.cache[self.clock]
.referenced
.store(false, Ordering::Relaxed);
self.clock = (self.clock + 1) % self.cache.len();
}
let entry = &mut self.cache[self.clock];
entry.referenced.store(true, Ordering::Relaxed);
assert!(self.index.remove(&entry.key).is_some());
self.index.insert(key, self.clock);
entry.key = key;
entry.data.copy_from_slice(page);
self.clock = (self.clock + 1) % self.cache.len();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{deterministic, Runner as _, Storage as _};
use commonware_macros::test_traced;
use commonware_utils::NZUsize;
const PAGE_SIZE: usize = 1024;
#[test_traced]
fn test_pool_basic() {
let mut pool: Pool = Pool::new(10);
let mut buf = vec![0; PAGE_SIZE];
let bytes_read = pool.read_at(PAGE_SIZE, 0, &mut buf, 0);
assert_eq!(bytes_read, 0);
pool.cache(PAGE_SIZE, 0, &[1; PAGE_SIZE], 0);
let bytes_read = pool.read_at(PAGE_SIZE, 0, &mut buf, 0);
assert_eq!(bytes_read, PAGE_SIZE);
assert_eq!(buf, [1; PAGE_SIZE]);
pool.cache(PAGE_SIZE, 0, &[2; PAGE_SIZE], 0);
let bytes_read = pool.read_at(PAGE_SIZE, 0, &mut buf, 0);
assert_eq!(bytes_read, PAGE_SIZE);
assert_eq!(buf, [2; PAGE_SIZE]);
for i in 0u64..11 {
pool.cache(PAGE_SIZE, 0, &[i as u8; PAGE_SIZE], i);
}
let bytes_read = pool.read_at(PAGE_SIZE, 0, &mut buf, 0);
assert_eq!(bytes_read, 0);
for i in 1u64..11 {
let bytes_read = pool.read_at(PAGE_SIZE, 0, &mut buf, i * PAGE_SIZE as u64);
assert_eq!(bytes_read, PAGE_SIZE);
assert_eq!(buf, [i as u8; PAGE_SIZE]);
}
let mut buf = vec![0; PAGE_SIZE];
let bytes_read = pool.read_at(PAGE_SIZE, 0, &mut buf, PAGE_SIZE as u64 + 2);
assert_eq!(bytes_read, PAGE_SIZE - 2);
assert_eq!(&buf[..PAGE_SIZE - 2], [1; PAGE_SIZE - 2]);
}
#[test_traced]
fn test_pool_read_with_blob() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let (blob, size) = context
.open("test", "blob".as_bytes())
.await
.expect("Failed to open blob");
assert_eq!(size, 0);
for i in 0..11 {
let buf = vec![i as u8; PAGE_SIZE];
blob.write_at(buf, i * PAGE_SIZE as u64).await.unwrap();
}
let pool_ref = PoolRef::new(NZUsize!(PAGE_SIZE), NZUsize!(10));
assert_eq!(pool_ref.next_id().await, 0);
assert_eq!(pool_ref.next_id().await, 1);
for i in 0..11 {
let mut buf = vec![0; PAGE_SIZE];
pool_ref
.read(&blob, 0, &mut buf, i * PAGE_SIZE as u64)
.await
.unwrap();
assert_eq!(buf, [i as u8; PAGE_SIZE]);
}
for i in 1..11 {
let mut buf = vec![0; PAGE_SIZE];
pool_ref
.read(&blob, 0, &mut buf, i * PAGE_SIZE as u64)
.await
.unwrap();
assert_eq!(buf, [i as u8; PAGE_SIZE]);
}
blob.sync().await.unwrap();
});
}
}