use bytes::Buf;
use futures_util::future;
use http::{HeaderMap, Response};
use quic::StreamId;
#[cfg(feature = "tracing")]
use tracing::instrument;
use crate::{
connection::{self},
error::{
connection_error_creators::{CloseStream, HandleFrameStreamErrorOnRequestStream},
internal_error::InternalConnectionError,
Code, StreamError,
},
proto::{frame::Frame, headers::Header},
qpack,
quic::{self},
shared_state::{ConnectionState, SharedState},
};
use std::{
convert::TryFrom,
task::{Context, Poll},
};
pub struct RequestStream<S, B> {
pub(super) inner: connection::RequestStream<S, B>,
}
impl<S, B> ConnectionState for RequestStream<S, B> {
fn shared_state(&self) -> &SharedState {
&self.inner.conn_state
}
}
impl<S, B> CloseStream for RequestStream<S, B> {}
impl<S, B> RequestStream<S, B>
where
S: quic::RecvStream,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn recv_response(&mut self) -> Result<Response<()>, StreamError> {
let mut frame = future::poll_fn(|cx| self.inner.stream.poll_next(cx))
.await
.map_err(|e| self.handle_frame_stream_error_on_request_stream(e))?
.ok_or_else(|| {
self.handle_connection_error_on_stream(InternalConnectionError::new(
Code::H3_FRAME_UNEXPECTED,
"Stream finished without receiving response headers".to_string(),
))
})?;
let decoded = if let Frame::Headers(ref mut encoded) = frame {
match qpack::decode_stateless(encoded, self.inner.max_field_section_size) {
Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => {
self.inner.stop_sending(Code::H3_REQUEST_CANCELLED);
return Err(StreamError::HeaderTooBig {
actual_size: cancel_size,
max_size: self.inner.max_field_section_size,
});
}
Ok(decoded) => decoded,
Err(_e) => {
return Err(
self.handle_connection_error_on_stream(InternalConnectionError {
code: Code::QPACK_DECOMPRESSION_FAILED,
message: "Failed to decode headers".to_string(),
}),
)
}
}
} else {
return Err(
self.handle_connection_error_on_stream(InternalConnectionError::new(
Code::H3_FRAME_UNEXPECTED,
"First response frame is not headers".to_string(),
)),
);
};
let qpack::Decoded { fields, .. } = decoded;
let (status, headers) = Header::try_from(fields)
.map_err(|_e| {
self.inner.stream.stop_sending(Code::H3_REQUEST_CANCELLED);
StreamError::StreamError {
code: Code::H3_MESSAGE_ERROR,
reason: "Received malformed header".to_string(),
}
})?
.into_response_parts()
.map_err(|_e| {
self.inner.stream.stop_sending(Code::H3_REQUEST_CANCELLED);
StreamError::StreamError {
code: Code::H3_MESSAGE_ERROR,
reason: "Received malformed header".to_string(),
}
})?;
let mut resp = Response::new(());
*resp.status_mut() = status;
*resp.headers_mut() = headers;
*resp.version_mut() = http::Version::HTTP_3;
Ok(resp)
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn recv_data(&mut self) -> Result<Option<impl Buf>, StreamError> {
future::poll_fn(|cx| self.poll_recv_data(cx)).await
}
pub fn poll_recv_data(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<impl Buf>, StreamError>> {
self.inner.poll_recv_data(cx)
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn recv_trailers(&mut self) -> Result<Option<HeaderMap>, StreamError> {
future::poll_fn(|cx| self.poll_recv_trailers(cx)).await
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn poll_recv_trailers(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, StreamError>> {
let res = self.inner.poll_recv_trailers(cx);
if let Poll::Ready(Err(e)) = &res {
if let StreamError::HeaderTooBig { .. } = e {
self.inner.stream.stop_sending(Code::H3_REQUEST_CANCELLED);
}
}
res
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn stop_sending(&mut self, error_code: Code) {
self.inner.stream.stop_sending(error_code)
}
pub fn id(&self) -> StreamId {
self.inner.stream.id()
}
}
impl<S, B> RequestStream<S, B>
where
S: quic::SendStream<B>,
B: Buf,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn send_data(&mut self, buf: B) -> Result<(), StreamError> {
self.inner.send_data(buf).await
}
pub fn stop_stream(&mut self, error_code: Code) {
self.inner.stop_stream(error_code);
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), StreamError> {
self.inner.send_trailers(trailers).await
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn finish(&mut self) -> Result<(), StreamError> {
self.inner.finish().await
}
}
impl<S, B> RequestStream<S, B>
where
S: quic::BidiStream<B>,
B: Buf,
{
pub fn split(
self,
) -> (
RequestStream<S::SendStream, B>,
RequestStream<S::RecvStream, B>,
) {
let (send, recv) = self.inner.split();
(RequestStream { inner: send }, RequestStream { inner: recv })
}
}