use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use futures::{
FutureExt, SinkExt, Stream, StreamExt,
channel::{mpsc, oneshot},
};
use tokio::{io::AsyncRead, task::JoinHandle};
use tokio_util::codec::{Decoder, FramedRead};
use crate::{
BaseError, CodecError, Error, ResponseHeader,
body::{Body, ChunkedBodyDecoder, FixedSizeBodyDecoder, MessageBodyDecoder, SimpleBodyDecoder},
client::response::IncomingResponse,
connection::ConnectionReader,
response::{Response, ResponseHeaderDecoder, ResponseHeaderDecoderOptions},
};
#[derive(Copy, Clone)]
pub struct ResponseDecoderOptions {
header_decoder_options: ResponseHeaderDecoderOptions,
max_line_length: Option<usize>,
}
impl ResponseDecoderOptions {
#[inline]
pub const fn new() -> Self {
let max_line_length = Some(4096);
let max_lines = Some(256);
let header_decoder_options = ResponseHeaderDecoderOptions::new()
.max_line_length(max_line_length)
.max_header_field_length(max_line_length)
.max_header_fields(max_lines);
Self {
header_decoder_options,
max_line_length,
}
}
#[inline]
pub const fn max_line_length(mut self, max_length: Option<usize>) -> Self {
self.header_decoder_options = self.header_decoder_options.max_line_length(max_length);
self.max_line_length = max_length;
self
}
#[inline]
pub const fn max_header_field_length(mut self, max_length: Option<usize>) -> Self {
self.header_decoder_options = self
.header_decoder_options
.max_header_field_length(max_length);
self
}
#[inline]
pub const fn max_header_fields(mut self, max_fields: Option<usize>) -> Self {
self.header_decoder_options = self.header_decoder_options.max_header_fields(max_fields);
self
}
}
impl Default for ResponseDecoderOptions {
#[inline]
fn default() -> Self {
Self::new()
}
}
#[derive(Copy, Clone)]
pub struct ResponseDecoder {
options: ResponseDecoderOptions,
}
impl ResponseDecoder {
#[inline]
pub const fn new(options: ResponseDecoderOptions) -> Self {
Self { options }
}
pub async fn decode<IO>(
self,
connection: ConnectionReader<IO>,
) -> Result<(IncomingResponse, ConnectionReaderJoinHandle<IO>), Error>
where
IO: AsyncRead + Send + 'static,
{
let (header, connection) = self.decode_header(connection).await?;
let decoder = ResponseBodyDecoder::new(&header, self.options.max_line_length)?;
let (reader, body) = ResponseBodyReader::new(connection, decoder);
let connection = reader.spawn();
let response = IncomingResponse::new(Response::new(header, body));
Ok((response, connection))
}
async fn decode_header<IO>(
&self,
connection: ConnectionReader<IO>,
) -> Result<(ResponseHeader, ConnectionReader<IO>), Error>
where
IO: AsyncRead,
{
let header_decoder = ResponseHeaderDecoder::new(self.options.header_decoder_options);
let mut header_decoder = FramedRead::new(connection, header_decoder);
let header = header_decoder
.next()
.await
.ok_or_else(|| Error::from_static_msg("connection closed"))?
.map_err(|err| match err {
CodecError::IO(err) => Error::from(err),
CodecError::Other(err) => Error::from(err),
})?;
let buffer = header_decoder.read_buffer_mut();
let chunk = buffer.split();
let connection = header_decoder.into_inner();
Ok((header, connection.prepend(chunk.freeze())))
}
}
pub type ConnectionReaderJoinHandle<IO> = JoinHandle<Option<ConnectionReader<IO>>>;
enum ResponseBodyDecoder {
Simple(SimpleBodyDecoder),
FixedSize(FixedSizeBodyDecoder),
Chunked(ChunkedBodyDecoder),
}
impl ResponseBodyDecoder {
fn new(header: &ResponseHeader, max_line_length: Option<usize>) -> Result<Self, Error> {
let status = header.status().code();
let decoder = if (100..200).contains(&status) || status == 204 || status == 304 {
Self::FixedSize(FixedSizeBodyDecoder::new(0))
} else if let Some(tenc) = header.get_header_field("transfer-encoding") {
let tenc = tenc.value().map(|v| v.as_ref()).unwrap_or(b"");
if tenc.eq_ignore_ascii_case(b"chunked") {
Self::Chunked(ChunkedBodyDecoder::new(max_line_length))
} else {
Self::Simple(SimpleBodyDecoder::new())
}
} else if let Some(clength) = header.get_header_field("content-length") {
let clength = clength
.value()
.ok_or_else(|| Error::from_static_msg("missing Content-Length value"))?
.parse()
.map_err(|_| Error::from_static_msg("invalid Content-Length value"))?;
Self::FixedSize(FixedSizeBodyDecoder::new(clength))
} else {
Self::Simple(SimpleBodyDecoder::new())
};
Ok(decoder)
}
}
impl MessageBodyDecoder for ResponseBodyDecoder {
fn is_complete(&self) -> bool {
match self {
Self::Simple(inner) => inner.is_complete(),
Self::FixedSize(inner) => inner.is_complete(),
Self::Chunked(inner) => inner.is_complete(),
}
}
fn decode(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, BaseError> {
match self {
Self::Simple(inner) => inner.decode(data),
Self::FixedSize(inner) => inner.decode(data),
Self::Chunked(inner) => inner.decode(data),
}
}
fn decode_eof(&mut self, data: &mut BytesMut) -> Result<Option<Bytes>, BaseError> {
match self {
Self::Simple(inner) => inner.decode_eof(data),
Self::FixedSize(inner) => inner.decode_eof(data),
Self::Chunked(inner) => inner.decode_eof(data),
}
}
}
impl Decoder for ResponseBodyDecoder {
type Item = Bytes;
type Error = CodecError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
MessageBodyDecoder::decode(self, buf).map_err(CodecError::Other)
}
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
MessageBodyDecoder::decode_eof(self, buf).map_err(CodecError::Other)
}
}
struct ResponseBodyReader<IO> {
source: FramedRead<ConnectionReader<IO>, ResponseBodyDecoder>,
sink: InternalBodyStreamSender,
body_drop: BodyDropFuture,
}
impl<IO> ResponseBodyReader<IO>
where
IO: AsyncRead,
{
fn new(connection: ConnectionReader<IO>, decoder: ResponseBodyDecoder) -> (Self, Body) {
let (tx, rx) = mpsc::channel(4);
let (body, body_drop) = ResponseBodyStream::new(rx);
let res = Self {
source: FramedRead::new(connection, decoder),
sink: tx,
body_drop,
};
(res, Body::from_stream(body))
}
}
impl<IO> ResponseBodyReader<IO> {
fn take_connection(mut self) -> Option<ConnectionReader<IO>> {
let decoder = self.source.decoder();
if !decoder.is_complete() {
return None;
}
let buffer = self.source.read_buffer_mut();
let chunk = buffer.split();
let connection = self.source.into_inner();
Some(connection.prepend(chunk.freeze()))
}
}
impl<IO> ResponseBodyReader<IO>
where
IO: AsyncRead + Send + 'static,
{
fn spawn(mut self) -> JoinHandle<Option<ConnectionReader<IO>>> {
tokio::spawn(async move {
while let Some(chunk) = self.next().await {
let chunk_is_err = chunk.is_err();
let send_is_err = self.sink.send(chunk).await.is_err();
if chunk_is_err {
return None;
} else if send_is_err {
break;
}
}
self.take_connection()
})
}
}
impl<IO> Stream for ResponseBodyReader<IO>
where
IO: AsyncRead,
{
type Item = io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Poll::Ready(item) = self.source.poll_next_unpin(cx) {
let item = item.transpose().map_err(|err| match err {
CodecError::IO(err) => err,
CodecError::Other(err) => io::Error::other(err),
})?;
if let Some(item) = item {
Poll::Ready(Some(Ok(item)))
} else {
Poll::Ready(None)
}
} else if self.source.decoder().is_complete() || self.body_drop.poll_unpin(cx).is_ready() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
type BodyStreamItem = io::Result<Bytes>;
type InternalBodyStream = mpsc::Receiver<BodyStreamItem>;
type InternalBodyStreamSender = mpsc::Sender<BodyStreamItem>;
type BodyDropFuture = oneshot::Receiver<()>;
type BodyDropEventSender = oneshot::Sender<()>;
struct ResponseBodyStream {
inner: InternalBodyStream,
drop: Option<BodyDropEventSender>,
}
impl ResponseBodyStream {
fn new(receiver: InternalBodyStream) -> (Self, oneshot::Receiver<()>) {
let (drop_tx, drop_rx) = oneshot::channel();
let stream = Self {
inner: receiver,
drop: Some(drop_tx),
};
(stream, drop_rx)
}
}
impl Drop for ResponseBodyStream {
fn drop(&mut self) {
if let Some(drop) = self.drop.take() {
let _ = drop.send(());
}
}
}
impl Stream for ResponseBodyStream {
type Item = io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}