use crate::traits::BlockStore;
use bytes::{Bytes, BytesMut};
use futures::stream::Stream;
use ipfrs_core::{Block, Cid, Error, Result};
use std::io::SeekFrom;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub buffer_size: usize,
pub prefetch: bool,
pub prefetch_queue_size: usize,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
buffer_size: 64 * 1024, prefetch: true,
prefetch_queue_size: 4,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ByteRange {
pub start: u64,
pub end: Option<u64>,
}
impl ByteRange {
#[inline]
pub fn from(start: u64) -> Self {
Self { start, end: None }
}
#[inline]
pub fn new(start: u64, end: u64) -> Self {
Self {
start,
end: Some(end),
}
}
#[inline]
pub fn with_length(start: u64, length: u64) -> Self {
Self {
start,
end: Some(start + length),
}
}
#[inline]
pub fn length(&self, total_size: u64) -> u64 {
let end = self.end.unwrap_or(total_size).min(total_size);
end.saturating_sub(self.start)
}
}
pub struct BlockReader {
data: Bytes,
position: u64,
}
impl BlockReader {
#[inline]
pub fn new(block: &Block) -> Self {
Self {
data: block.data().clone(),
position: 0,
}
}
#[inline]
pub fn from_bytes(data: Bytes) -> Self {
Self { data, position: 0 }
}
#[inline]
pub fn remaining(&self) -> u64 {
self.data.len() as u64 - self.position
}
pub fn size(&self) -> u64 {
self.data.len() as u64
}
}
impl AsyncRead for BlockReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let pos = self.position as usize;
let data_len = self.data.len();
if pos >= data_len {
return Poll::Ready(Ok(())); }
let remaining = data_len - pos;
let to_read = remaining.min(buf.remaining());
buf.put_slice(&self.data[pos..pos + to_read]);
self.position += to_read as u64;
Poll::Ready(Ok(()))
}
}
impl AsyncSeek for BlockReader {
fn start_seek(mut self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> {
let new_pos = match position {
SeekFrom::Start(pos) => pos as i64,
SeekFrom::End(offset) => self.data.len() as i64 + offset,
SeekFrom::Current(offset) => self.position as i64 + offset,
};
if new_pos < 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"seek to negative position",
));
}
self.position = new_pos as u64;
Ok(())
}
fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
Poll::Ready(Ok(self.position))
}
}
pub struct PartialBlock {
pub cid: Cid,
pub range: ByteRange,
pub data: Bytes,
pub total_size: u64,
}
impl PartialBlock {
pub fn is_complete(&self) -> bool {
self.range.start == 0 && self.data.len() as u64 == self.total_size
}
}
#[async_trait::async_trait]
pub trait StreamingBlockStore: BlockStore {
async fn get_range(&self, cid: &Cid, range: ByteRange) -> Result<Option<PartialBlock>> {
let block = self.get(cid).await?;
match block {
Some(block) => {
let data = block.data();
let total_size = data.len() as u64;
let start = (range.start as usize).min(data.len());
let end = range
.end
.map(|e| (e as usize).min(data.len()))
.unwrap_or(data.len());
let slice = if start < end {
data.slice(start..end)
} else {
Bytes::new()
};
Ok(Some(PartialBlock {
cid: *block.cid(),
range,
data: slice,
total_size,
}))
}
None => Ok(None),
}
}
async fn reader(&self, cid: &Cid) -> Result<Option<BlockReader>> {
let block = self.get(cid).await?;
Ok(block.map(|b| BlockReader::new(&b)))
}
async fn get_size(&self, cid: &Cid) -> Result<Option<u64>> {
let block = self.get(cid).await?;
Ok(block.map(|b| b.size()))
}
}
impl<T: BlockStore> StreamingBlockStore for T {}
pub struct StreamingWriter<S: BlockStore> {
store: Arc<S>,
buffer: BytesMut,
config: StreamConfig,
written_cids: Vec<Cid>,
}
impl<S: BlockStore> StreamingWriter<S> {
pub fn new(store: Arc<S>) -> Self {
Self::with_config(store, StreamConfig::default())
}
pub fn with_config(store: Arc<S>, config: StreamConfig) -> Self {
Self {
store,
buffer: BytesMut::with_capacity(config.buffer_size),
config,
written_cids: Vec::new(),
}
}
pub async fn write(&mut self, data: &[u8]) -> Result<usize> {
self.buffer.extend_from_slice(data);
while self.buffer.len() >= self.config.buffer_size {
self.flush_chunk().await?;
}
Ok(data.len())
}
pub async fn finish(mut self) -> Result<Vec<Cid>> {
if !self.buffer.is_empty() {
self.flush_chunk().await?;
}
Ok(self.written_cids)
}
async fn flush_chunk(&mut self) -> Result<()> {
let chunk_size = self.buffer.len().min(self.config.buffer_size);
let chunk_data = self.buffer.split_to(chunk_size).freeze();
let block = Block::new(chunk_data)?;
let cid = *block.cid();
self.store.put(&block).await?;
self.written_cids.push(cid);
Ok(())
}
pub fn written_cids(&self) -> &[Cid] {
&self.written_cids
}
}
pub struct BlockStream<S: BlockStore> {
store: Arc<S>,
cids: std::vec::IntoIter<Cid>,
}
impl<S: BlockStore + 'static> BlockStream<S> {
pub fn new(store: Arc<S>, cids: Vec<Cid>) -> Self {
Self {
store,
cids: cids.into_iter(),
}
}
}
impl<S: BlockStore + 'static> Stream for BlockStream<S> {
type Item = Result<Block>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.cids.next() {
Some(cid) => {
let store = Arc::clone(&self.store);
let fut = async move {
match store.get(&cid).await? {
Some(block) => Ok(block),
None => Err(Error::BlockNotFound(cid.to_string())),
}
};
let waker = cx.waker().clone();
tokio::spawn(async move {
let _ = fut.await;
waker.wake();
});
Poll::Pending
}
None => Poll::Ready(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_byte_range() {
let range = ByteRange::new(10, 50);
assert_eq!(range.length(100), 40);
assert_eq!(range.length(30), 20);
let range = ByteRange::from(80);
assert_eq!(range.length(100), 20);
let range = ByteRange::with_length(10, 30);
assert_eq!(range.length(100), 30);
}
#[tokio::test]
async fn test_block_reader() {
use tokio::io::AsyncReadExt;
let data = Bytes::from("Hello, World!");
let block = Block::new(data.clone()).unwrap();
let mut reader = BlockReader::new(&block);
let mut buf = vec![0u8; 5];
let n = reader.read(&mut buf).await.unwrap();
assert_eq!(n, 5);
assert_eq!(&buf, b"Hello");
let n = reader.read(&mut buf).await.unwrap();
assert_eq!(n, 5);
assert_eq!(&buf, b", Wor");
}
#[tokio::test]
async fn test_block_reader_seek() {
use tokio::io::{AsyncReadExt, AsyncSeekExt};
let data = Bytes::from("Hello, World!");
let block = Block::new(data).unwrap();
let mut reader = BlockReader::new(&block);
reader.seek(SeekFrom::Start(7)).await.unwrap();
let mut buf = vec![0u8; 5];
let n = reader.read(&mut buf).await.unwrap();
assert_eq!(n, 5);
assert_eq!(&buf, b"World");
}
}