use crate::AsyncRead;
use futures_util::future::poll_fn;
use futures_util::ready;
use std::io;
use std::io::Read;
use std::pin::Pin;
use std::task::Poll;
#[derive(Debug, Clone)]
pub struct UninitBuf {
max_size: usize,
buf: Vec<u8>,
len: usize,
expand: bool,
}
impl UninitBuf {
pub fn with_capacity(capacity: usize, max_size: usize) -> Self {
UninitBuf {
max_size,
buf: Vec::with_capacity(capacity),
len: 0,
expand: false,
}
}
pub fn clear(&mut self) {
self.len = 0;
}
pub fn len(&self) -> usize {
self.len
}
}
impl UninitBuf {
fn set_unsafe_size(&mut self) {
unsafe { self.buf.set_len(self.buf.capacity()) }
}
fn set_safe_size(&mut self) {
unsafe { self.buf.set_len(self.len) }
}
fn mark_expand(&mut self) {
self.expand = self.len == self.buf.capacity();
}
pub fn read_from_sync(&mut self, r: &mut impl Read) -> io::Result<usize> {
self.reserve_if_needed();
self.set_unsafe_size();
let buf = &mut self.buf[self.len..];
let amt = r.read(buf)?;
self.len += amt;
self.set_safe_size();
self.mark_expand();
Ok(amt)
}
pub async fn read_from_async<R>(&mut self, r: &mut R) -> io::Result<usize>
where
R: AsyncRead + Unpin,
{
self.reserve_if_needed();
self.set_unsafe_size();
let buf = &mut self.buf[self.len..];
let amt = poll_fn(|cx| Pin::new(&mut *r).poll_read(cx, buf)).await?;
self.len += amt;
self.set_safe_size();
self.mark_expand();
Ok(amt)
}
pub fn poll_delegate(
&mut self,
r: impl FnOnce(&mut [u8]) -> Poll<io::Result<usize>>,
) -> Poll<io::Result<usize>> {
self.reserve_if_needed();
self.set_unsafe_size();
let buf = &mut self.buf[self.len..];
let amt = ready!(r(buf)?);
self.len += amt;
self.set_safe_size();
self.mark_expand();
Ok(amt).into()
}
fn reserve_if_needed(&mut self) {
let reserve_needed = self.len == self.buf.capacity();
let is_at_max = self.buf.capacity() >= self.max_size;
if reserve_needed && is_at_max {
panic!("No headroom and buf is at max capacity");
}
if !is_at_max && (self.expand || reserve_needed) {
self.buf.reserve(32);
self.expand = false;
}
}
}
impl Drop for UninitBuf {
fn drop(&mut self) {
self.set_safe_size();
}
}
impl std::ops::Deref for UninitBuf {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&(self.buf)[..self.len]
}
}