use crate::{buffer::tip::Buffer, Blob, Buf, BufferPool, BufferPooler, Error, IoBufs};
use commonware_utils::sync::AsyncRwLock;
use std::{num::NonZeroUsize, sync::Arc};
#[derive(Clone)]
pub struct Write<B: Blob> {
blob: B,
buffer: Arc<AsyncRwLock<Buffer>>,
}
impl<B: Blob> Write<B> {
pub fn new(blob: B, size: u64, capacity: NonZeroUsize, pool: BufferPool) -> Self {
Self {
blob,
buffer: Arc::new(AsyncRwLock::new(Buffer::new(size, capacity.get(), pool))),
}
}
pub fn from_pooler(
pooler: &impl BufferPooler,
blob: B,
size: u64,
capacity: NonZeroUsize,
) -> Self {
Self::new(blob, size, capacity, pooler.storage_buffer_pool().clone())
}
#[allow(clippy::len_without_is_empty)]
pub async fn size(&self) -> u64 {
let buffer = self.buffer.read().await;
buffer.size()
}
pub async fn read_at(&self, offset: u64, len: usize) -> Result<IoBufs, Error> {
let end_offset = offset
.checked_add(len as u64)
.ok_or(Error::OffsetOverflow)?;
let buffer = self.buffer.read().await;
if end_offset > buffer.size() {
return Err(Error::BlobInsufficientLength);
}
if len == 0 {
return Ok(IoBufs::default());
}
if offset >= buffer.offset {
let start = (offset - buffer.offset) as usize;
let end = start + len;
return Ok(buffer.slice(start..end).into());
}
if end_offset <= buffer.offset {
return Ok(self.blob.read_at(offset, len).await?.freeze());
}
let blob_len = (buffer.offset - offset) as usize;
let tip_len = len - blob_len;
let tip = buffer.slice(..tip_len);
let mut blob = self.blob.read_at(offset, blob_len).await?.freeze();
blob.append(tip);
Ok(blob)
}
pub async fn write_at(&self, offset: u64, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
let mut bufs = bufs.into();
offset
.checked_add(bufs.remaining() as u64)
.ok_or(Error::OffsetOverflow)?;
let mut buffer = self.buffer.write().await;
let mut current_offset = offset;
while bufs.has_remaining() {
let chunk = bufs.chunk();
let chunk_len = chunk.len();
if buffer.merge(chunk, current_offset) {
bufs.advance(chunk_len);
current_offset += chunk_len as u64;
continue;
}
let chunk_end = current_offset + chunk_len as u64;
if buffer.offset < chunk_end {
if let Some((old_buf, old_offset)) = buffer.take() {
self.blob.write_at(old_offset, old_buf).await?;
if buffer.merge(chunk, current_offset) {
bufs.advance(chunk_len);
current_offset += chunk_len as u64;
continue;
}
}
}
let direct = bufs.split_to(chunk_len);
self.blob.write_at(current_offset, direct).await?;
current_offset += chunk_len as u64;
buffer.offset = buffer.offset.max(current_offset);
}
Ok(())
}
pub async fn resize(&self, len: u64) -> Result<(), Error> {
let mut buffer = self.buffer.write().await;
if let Some((buf, offset)) = buffer.resize(len) {
self.blob.write_at(offset, buf).await?;
}
self.blob.resize(len).await?;
Ok(())
}
pub async fn sync(&self) -> Result<(), Error> {
let mut buffer = self.buffer.write().await;
if let Some((buf, offset)) = buffer.take() {
self.blob.write_at(offset, buf).await?;
}
self.blob.sync().await
}
}