use crate::{buffer::tip::Buffer, Blob, Buf, BufferPool, BufferPooler, Error, IoBufs};
use commonware_utils::sync::AsyncRwLock;
use std::{num::NonZeroUsize, sync::Arc};
struct State<B: Blob> {
blob: B,
buffer: Buffer,
needs_sync: bool,
}
impl<B: Blob> State<B> {
async fn read_at(&self, offset: u64, len: usize) -> Result<IoBufs, Error> {
Ok(self.blob.read_at(offset, len).await?.freeze())
}
async fn write_at(&mut self, offset: u64, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
self.blob.write_at(offset, bufs).await?;
self.needs_sync = true;
Ok(())
}
async fn write_at_sync(
&mut self,
offset: u64,
bufs: impl Into<IoBufs> + Send,
) -> Result<(), Error> {
if self.needs_sync {
self.write_at(offset, bufs).await?;
self.sync().await
} else {
self.needs_sync = true;
self.blob.write_at_sync(offset, bufs).await?;
self.needs_sync = false;
Ok(())
}
}
async fn resize(&mut self, len: u64) -> Result<(), Error> {
self.blob.resize(len).await?;
self.needs_sync = true;
Ok(())
}
async fn sync(&mut self) -> Result<(), Error> {
if !self.needs_sync {
return Ok(());
}
self.blob.sync().await?;
self.needs_sync = false;
Ok(())
}
}
#[derive(Clone)]
pub struct Write<B: Blob> {
state: Arc<AsyncRwLock<State<B>>>,
}
impl<B: Blob> Write<B> {
pub fn new(blob: B, size: u64, capacity: NonZeroUsize, pool: BufferPool) -> Self {
Self {
state: Arc::new(AsyncRwLock::new(State {
blob,
buffer: Buffer::new(size, capacity.get(), pool),
needs_sync: true, })),
}
}
pub fn from_pooler(
pooler: &impl BufferPooler,
blob: B,
size: u64,
capacity: NonZeroUsize,
) -> Self {
Self::new(blob, size, capacity, pooler.storage_buffer_pool().clone())
}
pub async fn size(&self) -> u64 {
let state = self.state.read().await;
state.buffer.size()
}
pub async fn read_at(&self, offset: u64, len: usize) -> Result<IoBufs, Error> {
let end_offset = offset
.checked_add(len as u64)
.ok_or(Error::OffsetOverflow)?;
let state = self.state.read().await;
let buffer = &state.buffer;
if end_offset > buffer.size() {
return Err(Error::BlobInsufficientLength);
}
if len == 0 {
return Ok(IoBufs::default());
}
if offset >= buffer.offset {
let start = (offset - buffer.offset) as usize;
let end = start + len;
return Ok(buffer.slice(start..end).into());
}
if end_offset <= buffer.offset {
return state.read_at(offset, len).await;
}
let blob_len = (buffer.offset - offset) as usize;
let tip_len = len - blob_len;
let tip = buffer.slice(..tip_len);
let mut blob = state.read_at(offset, blob_len).await?;
blob.append(tip);
Ok(blob)
}
pub async fn write_at(&self, offset: u64, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
let mut bufs = bufs.into();
offset
.checked_add(bufs.remaining() as u64)
.ok_or(Error::OffsetOverflow)?;
let mut state = self.state.write().await;
let mut current_offset = offset;
while bufs.has_remaining() {
let chunk = bufs.chunk();
let chunk_len = chunk.len();
if state.buffer.merge(chunk, current_offset) {
bufs.advance(chunk_len);
current_offset += chunk_len as u64;
continue;
}
let chunk_end = current_offset + chunk_len as u64;
if state.buffer.offset < chunk_end {
if let Some((old_buf, old_offset)) = state.buffer.take() {
state.write_at(old_offset, old_buf).await?;
if state.buffer.merge(chunk, current_offset) {
bufs.advance(chunk_len);
current_offset += chunk_len as u64;
continue;
}
}
}
let direct = bufs.split_to(chunk_len);
state.write_at(current_offset, direct).await?;
current_offset += chunk_len as u64;
state.buffer.offset = state.buffer.offset.max(current_offset);
}
Ok(())
}
pub async fn resize(&self, len: u64) -> Result<(), Error> {
let mut state = self.state.write().await;
if let Some((buf, offset)) = state.buffer.resize(len) {
state.write_at(offset, buf).await?;
}
state.resize(len).await?;
Ok(())
}
pub async fn sync(&self) -> Result<(), Error> {
let mut state = self.state.write().await;
if let Some((buf, offset)) = state.buffer.take() {
return state.write_at_sync(offset, buf).await;
}
state.sync().await
}
}