use crate::ReadAt;
use async_trait::async_trait;
use futures::{lock::Mutex, AsyncRead, AsyncReadExt};
use pin_project::pin_project;
use std::{
collections::VecDeque,
fmt, io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[async_trait(?Send)]
pub trait GetReaderAt {
type Reader: AsyncRead + Unpin;
async fn get_reader_at(self: &Arc<Self>, offset: u64) -> io::Result<Self::Reader>;
}
pub struct ReadAtWrapper<Source>
where
Source: GetReaderAt,
{
heads: Mutex<VecDeque<Head<Source::Reader>>>,
source: Arc<Source>,
len: u64,
max_heads: usize,
}
impl<Source> ReadAtWrapper<Source>
where
Source: GetReaderAt,
{
pub const DEFAULT_MAX_HEADS: usize = 3;
pub fn new(
source: Arc<Source>,
len: u64,
mut initial_head: Option<(u64, Source::Reader)>,
) -> Self {
let mut heads: VecDeque<Head<Source::Reader>> = Default::default();
if let Some((offset, reader)) = initial_head.take() {
let head = Head { offset, reader };
tracing::debug!("{:?}: initial", head);
heads.push_back(head);
}
Self {
heads: Mutex::new(heads),
source,
len,
max_heads: Self::DEFAULT_MAX_HEADS,
}
}
async fn borrow_head(&self, offset: u64) -> io::Result<Head<Source::Reader>> {
let mut heads = self.heads.lock().await;
let candidate_index = heads
.iter()
.enumerate()
.find(|(_, head)| head.offset == offset)
.map(|(i, _)| i);
let head = match candidate_index {
Some(index) => {
let head = heads
.remove(index)
.expect("internal logic error in heads pool manipulation");
tracing::trace!("{:?}: borrowing", head);
head
}
None => {
drop(heads);
let reader = self.source.get_reader_at(offset).await?;
let head = Head { offset, reader };
tracing::debug!("{:?}: new head", head);
head
}
};
Ok(head)
}
async fn return_head(&self, head: Head<Source::Reader>) {
tracing::trace!("{:?}: returning", head);
let mut heads = self.heads.lock().await;
heads.push_back(head);
if heads.len() > self.max_heads {
heads.pop_front();
}
}
}
#[async_trait(?Send)]
impl<Source> ReadAt for ReadAtWrapper<Source>
where
Source: GetReaderAt,
{
async fn read_at(&self, offset: u64, buf: &mut [u8]) -> io::Result<usize> {
let mut head = self.borrow_head(offset).await?;
let res = head.read(buf).await?;
self.return_head(head).await;
Ok(res)
}
fn len(&self) -> u64 {
self.len
}
}
#[pin_project]
struct Head<R>
where
R: AsyncRead + Unpin,
{
offset: u64,
#[pin]
reader: R,
}
impl<R> fmt::Debug for Head<R>
where
R: AsyncRead + Unpin,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Head(offset = {})", self.offset)
}
}
impl<R> AsyncRead for Head<R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let head = self.project();
let res = head.reader.poll_read(cx, buf);
if let Poll::Ready(Ok(n)) = &res {
*head.offset += *n as u64;
}
res
}
}