use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use futures::{
FutureExt, SinkExt, Stream, StreamExt,
channel::{
mpsc::{self, SendError},
oneshot,
},
ready,
};
use tokio::{
io::{AsyncWrite, AsyncWriteExt},
task::JoinHandle,
};
use crate::{
Error, Version,
body::{Body, ChunkedStream},
connection::ConnectionWriter,
header::HeaderFieldValue,
response::{ResponseBuilder, ResponseHeaderEncoder},
server::response::OutgoingResponse,
};
type FutureResponse = Pin<Box<dyn Future<Output = Option<OutgoingResponse>> + Send>>;
type ResponsePipelineElement = (FutureResponse, Version, CloseConnectionFuture);
pub struct ResponsePipeline<IO> {
sender: mpsc::Sender<ResponsePipelineElement>,
task: JoinHandle<Result<Option<ConnectionWriter<IO>>, Error>>,
}
impl<IO> ResponsePipeline<IO>
where
IO: AsyncWrite + Send + Unpin + 'static,
{
pub fn new(mut writer: ConnectionWriter<IO>, depth: usize) -> Self {
let (tx, mut rx) = mpsc::channel::<ResponsePipelineElement>(depth);
let task = tokio::spawn(async move {
let mut sender = OutgoingResponseSender::new();
while let Some((response, request_version, close)) = rx.next().await {
if let Some(response) = response.await {
let close = close.await;
let res = sender
.send(&mut writer, response, request_version, close)
.await
.map_err(Error::from);
if res.as_ref().map(|close| *close).unwrap_or(true) {
let _ = writer.shutdown().await;
return res.map(|_| None);
}
}
}
Ok(Some(writer))
});
Self { sender: tx, task }
}
}
impl<IO> ResponsePipeline<IO> {
pub async fn send<F>(
&mut self,
response: F,
version: Version,
close: CloseConnectionFuture,
) -> Result<(), SendError>
where
F: Future<Output = Option<OutgoingResponse>> + Send + 'static,
{
self.sender.send((Box::pin(response), version, close)).await
}
pub async fn close(mut self) -> Result<Option<ConnectionWriter<IO>>, Error> {
self.sender.close_channel();
self.task
.await
.map_err(|_| Error::from_static_msg("interrupted"))?
}
}
pub struct CloseConnectionFuture {
inner: oneshot::Receiver<bool>,
}
impl CloseConnectionFuture {
pub fn new() -> (Self, CloseConnectionResolver) {
let (tx, rx) = oneshot::channel();
let tx = CloseConnectionResolver { inner: tx };
let rx = Self { inner: rx };
(rx, tx)
}
}
impl Future for CloseConnectionFuture {
type Output = bool;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = ready!(self.inner.poll_unpin(cx));
Poll::Ready(res.unwrap_or(true))
}
}
pub struct CloseConnectionResolver {
inner: oneshot::Sender<bool>,
}
impl CloseConnectionResolver {
pub fn resolve(self, close: bool) {
let _ = self.inner.send(close);
}
}
struct OutgoingResponseSender {
buffer: BytesMut,
header_encoder: ResponseHeaderEncoder,
}
impl OutgoingResponseSender {
fn new() -> Self {
Self {
buffer: BytesMut::new(),
header_encoder: ResponseHeaderEncoder::new(),
}
}
async fn send<IO>(
&mut self,
connection: &mut ConnectionWriter<IO>,
response: OutgoingResponse,
version: Version,
mut close: bool,
) -> io::Result<bool>
where
IO: AsyncWrite + Unpin,
{
let connection_tokens = response
.get_header_field_value("connection")
.cloned()
.unwrap_or_else(|| HeaderFieldValue::from(""));
let mut connection_tokens = connection_tokens
.split(|&b| b == b',')
.map(|elem| elem.trim_ascii())
.filter(|elem| !elem.is_empty())
.map(str::from_utf8)
.filter_map(|res| res.ok())
.collect::<Vec<_>>();
close |= connection_tokens
.iter()
.any(|elem| elem.eq_ignore_ascii_case("close"));
connection_tokens.retain(|elem| !elem.eq_ignore_ascii_case("close"));
connection_tokens.retain(|elem| !elem.eq_ignore_ascii_case("keep-alive"));
let (header, b, _) = response.deconstruct();
let mut builder = ResponseBuilder::from(header)
.set_version(version)
.remove_header_fields("Content-Length")
.remove_header_fields("Transfer-Encoding")
.remove_header_fields("Connection");
let mut body;
if let Some(size) = b.size() {
builder = builder.set_header_field(("Content-Length", size));
body = ResponseBody::Plain(b);
} else if !close && version == Version::Version11 {
builder = builder.set_header_field(("Transfer-Encoding", "chunked"));
body = ResponseBody::Chunked(ChunkedStream::new(b));
} else {
body = ResponseBody::Plain(b);
close = true;
}
if close {
connection_tokens.push("close");
}
if !connection_tokens.is_empty() && version == Version::Version11 {
builder = builder.set_header_field(("Connection", connection_tokens.join(", ")));
}
let header = builder.header();
self.header_encoder.encode(&header, &mut self.buffer);
connection.write_all(&self.buffer.split()).await?;
while let Some(data) = body.next().await.transpose()? {
connection.write_all(&data).await?;
}
connection.flush().await?;
Ok(close)
}
}
enum ResponseBody {
Plain(Body),
Chunked(ChunkedStream<Body>),
}
impl Stream for ResponseBody {
type Item = io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut *self {
Self::Plain(body) => body.poll_next_unpin(cx),
Self::Chunked(body) => body.poll_next_unpin(cx),
}
}
}