use bytes::{Bytes, BytesMut};
use futures_sink::Sink;
use futures_util::ready;
use pin_project::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};
#[pin_project(project = BufBytesSinkProj)]
pub struct BufBytesSink<S> {
buf: BytesMut,
pending_write: Option<Bytes>,
needs_flush: bool,
limit: usize,
#[pin]
inner: S,
}
impl<S> BufBytesSink<S> {
pub fn new(inner: S) -> Self {
Self::with_capacity(8 * 1024, inner)
}
fn with_capacity(limit: usize, inner: S) -> Self {
BufBytesSink {
buf: BytesMut::new(),
pending_write: None,
needs_flush: false,
limit,
inner,
}
}
}
impl<S> BufBytesSinkProj<'_, S>
where
S: Sink<Bytes>,
{
fn inner_poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
self.inner.as_mut().poll_ready(cx)
}
fn inner_poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
let result = ready!(self.inner.as_mut().poll_flush(cx));
*self.needs_flush = false;
Poll::Ready(result)
}
fn inner_poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
self.inner.as_mut().poll_close(cx)
}
fn inner_start_send(&mut self, item: Bytes) -> Result<(), S::Error> {
let result = self.inner.as_mut().start_send(item);
if result.is_ok() {
*self.needs_flush = true
}
result
}
fn poll_write_pending(&mut self, cx: &mut Context) -> Poll<Result<(), S::Error>> {
let pending_write = match self.pending_write.take() {
Some(buf) => buf,
None => return Poll::Ready(Ok(())),
};
if pending_write.len() <= *self.limit - self.buf.len() {
self.buf.extend_from_slice(&pending_write);
return Poll::Ready(Ok(()));
}
let poll = self.poll_write_buf(cx);
match poll {
Poll::Ready(Ok(())) => {}
poll => {
*self.pending_write = Some(pending_write);
return poll;
}
}
debug_assert!(self.buf.is_empty());
if pending_write.len() < *self.limit {
self.buf.extend_from_slice(&pending_write);
return Poll::Ready(Ok(()));
}
match self.inner.as_mut().poll_ready(cx) {
Poll::Ready(Ok(())) => {}
poll => {
*self.pending_write = Some(pending_write);
return poll;
}
}
self.inner_start_send(pending_write)?;
Poll::Ready(Ok(()))
}
fn poll_write_buf(&mut self, cx: &mut Context) -> Poll<Result<(), S::Error>> {
if self.buf.is_empty() {
return Poll::Ready(Ok(()));
}
ready!(self.inner_poll_ready(cx))?;
let buf = self.buf.split().freeze();
self.inner_start_send(buf)?;
Poll::Ready(Ok(()))
}
fn poll_flush_shallow(&mut self, cx: &mut Context) -> Poll<Result<(), S::Error>> {
ready!(self.poll_write_pending(cx))?;
ready!(self.poll_write_buf(cx))?;
Poll::Ready(Ok(()))
}
}
impl<S> Sink<Bytes> for BufBytesSink<S>
where
S: Sink<Bytes>,
{
type Error = S::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
ready!(this.poll_write_pending(cx))?;
if *this.needs_flush {
ready!(this.inner_poll_flush(cx))?;
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let this = self.project();
debug_assert!(this.pending_write.is_none());
debug_assert!(!*this.needs_flush);
*this.pending_write = Some(item);
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
ready!(this.poll_flush_shallow(cx))?;
this.inner_poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
ready!(this.poll_flush_shallow(cx))?;
this.inner_poll_close(cx)
}
}
#[cfg(test)]
mod test {
use super::*;
use futures_util::SinkExt;
#[tokio::test]
async fn simple_writes() {
let mut sink = BufBytesSink::with_capacity(5, Vec::<Bytes>::new());
sink.feed(Bytes::from(&b"aaaa"[..])).await.unwrap();
sink.feed(Bytes::from(&b"b"[..])).await.unwrap();
sink.feed(Bytes::from(&b"ccc"[..])).await.unwrap();
sink.feed(Bytes::from(&b"d"[..])).await.unwrap();
assert_eq!(sink.inner, vec![Bytes::from(&b"aaaab"[..])]);
sink.flush().await.unwrap();
assert_eq!(
sink.inner,
vec![Bytes::from(&b"aaaab"[..]), Bytes::from(&b"cccd"[..])],
);
}
#[tokio::test]
async fn oversized_writes() {
let mut sink = BufBytesSink::with_capacity(5, Vec::<Bytes>::new());
sink.feed(Bytes::from(&b"aaaa"[..])).await.unwrap();
sink.feed(Bytes::from(&b"bbbbbb"[..])).await.unwrap();
sink.feed(Bytes::from(&b"c"[..])).await.unwrap();
sink.flush().await.unwrap();
assert_eq!(
sink.inner,
vec![
Bytes::from(&b"aaaa"[..]),
Bytes::from(&b"bbbbbb"[..]),
Bytes::from(&b"c"[..])
],
);
}
}