use std::{
io,
io::{Read, Write},
ops::{Deref, DerefMut},
};
use bytes::{Buf, BytesMut};
use crate::{fuse::Fuse, sink::IterSink, Encoder};
const DEFAULT_SEND_HIGH_WATER_MARK: usize = 131072;
#[cfg_attr(feature = "async", pin_project::pin_project)]
#[derive(Debug)]
pub struct FramedWrite<T, E> {
#[cfg_attr(feature = "async", pin)]
inner: FramedWriteImpl<Fuse<T, E>>,
}
impl<T, E> FramedWrite<T, E> {
pub fn new(inner: T, encoder: E) -> Self {
Self {
inner: FramedWriteImpl::new(Fuse::new(inner, encoder)),
}
}
pub fn send_high_water_mark(&self) -> usize {
self.inner.high_water_mark
}
pub fn set_send_high_water_mark(&mut self, hwm: usize) {
self.inner.high_water_mark = hwm;
}
pub fn release(self) -> (T, E) {
let fuse = self.inner.release();
(fuse.io, fuse.codec)
}
pub fn into_inner(self) -> T {
self.release().0
}
pub fn encoder(&self) -> &E {
&self.inner.codec
}
pub fn encoder_mut(&mut self) -> &mut E {
&mut self.inner.codec
}
}
impl<T, E> Deref for FramedWrite<T, E> {
type Target = T;
fn deref(&self) -> &T {
&self.inner
}
}
impl<T, E> DerefMut for FramedWrite<T, E> {
fn deref_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T, E, I> IterSink<I> for FramedWrite<T, E>
where
T: Write,
E: Encoder<I>,
{
type Error = E::Error;
fn start_send(&mut self, item: I) -> Result<(), Self::Error> {
self.inner.start_send(item)
}
fn ready(&mut self) -> Result<(), Self::Error> {
self.inner.ready()
}
fn flush(&mut self) -> Result<(), Self::Error> {
self.inner.flush()
}
}
#[cfg_attr(feature = "async", pin_project::pin_project)]
#[derive(Debug)]
pub(crate) struct FramedWriteImpl<T> {
#[cfg_attr(feature = "async", pin)]
pub(crate) inner: T,
pub(crate) high_water_mark: usize,
buffer: BytesMut,
}
impl<T> FramedWriteImpl<T> {
pub(crate) fn new(inner: T) -> FramedWriteImpl<T> {
FramedWriteImpl {
inner,
high_water_mark: DEFAULT_SEND_HIGH_WATER_MARK,
buffer: BytesMut::with_capacity(1028 * 8),
}
}
pub(crate) fn release(self) -> T {
self.inner
}
}
impl<T> Deref for FramedWriteImpl<T> {
type Target = T;
fn deref(&self) -> &T {
&self.inner
}
}
impl<T> DerefMut for FramedWriteImpl<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.inner
}
}
impl<T: Read> Read for FramedWriteImpl<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl<T, I> IterSink<I> for FramedWriteImpl<T>
where
T: Write + Encoder<I>,
{
type Error = T::Error;
fn start_send(&mut self, item: I) -> Result<(), Self::Error> {
self.inner.encode(item, &mut self.buffer)
}
fn flush(&mut self) -> Result<(), Self::Error> {
while !self.buffer.is_empty() {
let num_write = self.inner.write(&self.buffer)?;
if num_write == 0 {
return Err(err_eof().into());
}
self.buffer.advance(num_write);
}
self.inner.flush().map_err(Into::into)
}
fn ready(&mut self) -> Result<(), Self::Error> {
while self.buffer.len() >= self.high_water_mark {
let num_write = self.inner.write(&self.buffer)?;
if num_write == 0 {
return Err(err_eof().into());
}
self.buffer.advance(num_write);
}
Ok(())
}
}
#[cfg(feature = "async")]
mod if_async {
use std::{
marker::Unpin,
pin::Pin,
task::{Context, Poll},
};
use futures_sink::Sink;
use futures_util::{
io::{AsyncRead, AsyncWrite},
ready,
};
use super::*;
impl<T, E, I> Sink<I> for FramedWrite<T, E>
where
T: AsyncWrite + Unpin,
E: Encoder<I>,
{
type Error = E::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<T: AsyncRead + Unpin> AsyncRead for FramedWriteImpl<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.project().inner.poll_read(cx, buf)
}
}
impl<T, I> Sink<I> for FramedWriteImpl<T>
where
T: AsyncWrite + Encoder<I> + Unpin,
{
type Error = T::Error;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
let this = &mut *self;
while this.buffer.len() >= this.high_water_mark {
let num_write = ready!(Pin::new(&mut this.inner).poll_write(cx, &this.buffer))?;
if num_write == 0 {
return Poll::Ready(Err(err_eof().into()));
}
this.buffer.advance(num_write);
}
Poll::Ready(Ok(()))
}
fn start_send(mut self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> {
let this = &mut *self;
this.inner.encode(item, &mut this.buffer)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
while !this.buffer.is_empty() {
let num_write = ready!(Pin::new(&mut this.inner).poll_write(cx, &this.buffer))?;
if num_write == 0 {
return Poll::Ready(Err(err_eof().into()));
}
this.buffer.advance(num_write);
}
this.inner.poll_flush(cx).map_err(Into::into)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
self.project().inner.poll_close(cx).map_err(Into::into)
}
}
}
fn err_eof() -> io::Error {
io::Error::new(io::ErrorKind::UnexpectedEof, "End of file")
}