use crate::PutHandle;
use futures_lite::AsyncRead;
use std::{
io,
pin::Pin,
task::{Context, Poll, ready},
};
use trillium_http::{Body, BodySource, Headers};
pub(crate) struct TeeingReader<W: PutHandle> {
upstream: Body,
state: TeeState<W>,
cap: u64,
bytes_written: u64,
pending: Vec<u8>,
trailers: Option<Headers>,
}
enum TeeState<W: PutHandle> {
Active { writer: W },
Aborted,
Finalizing(Pin<Box<dyn Future<Output = io::Result<()>> + Send>>),
Done,
}
impl<W: PutHandle> TeeingReader<W> {
pub(crate) fn new(upstream: Body, writer: W, cap: u64) -> Self {
Self {
upstream,
state: TeeState::Active { writer },
cap,
bytes_written: 0,
pending: Vec::new(),
trailers: None,
}
}
fn abort(&mut self) {
self.state = TeeState::Aborted;
self.pending.clear();
}
}
impl<W: PutHandle> AsyncRead for TeeingReader<W> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if let TeeState::Finalizing(fut) = &mut this.state {
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
if let Err(e) = result {
log::warn!("cache: finalize failed: {e}");
}
this.state = TeeState::Done;
}
}
}
if matches!(this.state, TeeState::Done) {
return Poll::Ready(Ok(0));
}
while !this.pending.is_empty() {
let TeeState::Active { writer } = &mut this.state else {
this.pending.clear();
break;
};
match Pin::new(writer).poll_write(cx, &this.pending) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(0)) | Poll::Ready(Err(_)) => this.abort(),
Poll::Ready(Ok(n)) => {
this.pending.drain(..n);
}
}
}
let n = ready!(Pin::new(&mut this.upstream).poll_read(cx, buf))?;
if n == 0 {
if let TeeState::Active { .. } = this.state {
this.trailers = this.upstream.trailers();
let TeeState::Active { writer } =
std::mem::replace(&mut this.state, TeeState::Done)
else {
unreachable!("just matched Active above");
};
let fut = Box::pin(writer.finalize(this.trailers.clone()))
as Pin<Box<dyn Future<Output = io::Result<()>> + Send>>;
this.state = TeeState::Finalizing(fut);
let TeeState::Finalizing(fut) = &mut this.state else {
unreachable!("just set Finalizing");
};
match fut.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
if let Err(e) = result {
log::warn!("cache: finalize failed: {e}");
}
this.state = TeeState::Done;
}
}
}
return Poll::Ready(Ok(0));
}
if let TeeState::Active { writer } = &mut this.state {
if this.bytes_written.saturating_add(n as u64) > this.cap {
this.abort();
} else {
match Pin::new(writer).poll_write(cx, &buf[..n]) {
Poll::Ready(Ok(0)) | Poll::Ready(Err(_)) => this.abort(),
Poll::Ready(Ok(written)) => {
this.bytes_written += n as u64;
if written < n {
this.pending.extend_from_slice(&buf[written..n]);
}
}
Poll::Pending => {
this.bytes_written += n as u64;
this.pending.extend_from_slice(&buf[..n]);
}
}
}
}
Poll::Ready(Ok(n))
}
}
impl<W: PutHandle> BodySource for TeeingReader<W> {
fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
self.get_mut().trailers.take()
}
}