use std::{
marker::PhantomData,
sync::{atomic::AtomicUsize, Arc},
task::{Context, Poll},
};
use bytes::{Buf, BytesMut};
use futures_util::future;
use http::request;
#[cfg(feature = "tracing")]
use tracing::{info, instrument, trace};
use crate::{
connection::{self, ConnectionInner},
error::{
connection_error_creators::CloseStream, internal_error::InternalConnectionError, Code,
ConnectionError, StreamError,
},
frame::FrameStream,
proto::{frame::Frame, headers::Header, push::PushId},
qpack,
quic::{self, StreamId},
shared_state::{ConnectionState, SharedState},
stream::{self, BufRecvStream},
};
use super::stream::RequestStream;
pub struct SendRequest<T, B>
where
T: quic::OpenStreams<B>,
B: Buf,
{
pub(super) open: T,
pub(super) conn_state: Arc<SharedState>,
pub(super) max_field_section_size: u64, pub(super) sender_count: Arc<AtomicUsize>,
pub(super) _buf: PhantomData<fn(B)>,
pub(super) send_grease_frame: bool,
}
impl<T, B> ConnectionState for SendRequest<T, B>
where
T: quic::OpenStreams<B>,
B: Buf,
{
fn shared_state(&self) -> &SharedState {
&self.conn_state
}
}
impl<T, B> CloseStream for SendRequest<T, B>
where
T: quic::OpenStreams<B>,
B: Buf,
{
}
impl<T, B> SendRequest<T, B>
where
T: quic::OpenStreams<B>,
B: Buf,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn send_request(
&mut self,
req: http::Request<()>,
) -> Result<RequestStream<T::BidiStream, B>, StreamError> {
if let Some(error) = self.check_peer_connection_closing() {
return Err(error);
};
let (parts, _) = req.into_parts();
let request::Parts {
method,
uri,
headers,
extensions,
..
} = parts;
let headers = Header::request(method, uri, headers, extensions).map_err(|_e| {
self.handle_connection_error_on_stream(InternalConnectionError {
code: Code::H3_INTERNAL_ERROR,
message: "Failed to build request headers".to_string(),
})
})?;
let mut stream = future::poll_fn(|cx| self.open.poll_open_bidi(cx))
.await
.map_err(|e| self.handle_quic_stream_error(e))?;
let mut block = BytesMut::new();
let mem_size = qpack::encode_stateless(&mut block, headers).map_err(|_e| {
self.handle_connection_error_on_stream(InternalConnectionError {
code: Code::H3_INTERNAL_ERROR,
message: "Failed to encode headers".to_string(),
})
})?;
let peer_max_field_section_size = self.settings().max_field_section_size;
if mem_size > peer_max_field_section_size {
return Err(StreamError::HeaderTooBig {
actual_size: mem_size,
max_size: peer_max_field_section_size,
});
}
stream::write(&mut stream, Frame::Headers(block.freeze()))
.await
.map_err(|e| self.handle_quic_stream_error(e))?;
let request_stream = RequestStream {
inner: connection::RequestStream::new(
FrameStream::new(BufRecvStream::new(stream)),
self.max_field_section_size,
self.conn_state.clone(),
self.send_grease_frame,
),
};
self.send_grease_frame = false;
Ok(request_stream)
}
}
impl<T, B> Clone for SendRequest<T, B>
where
T: quic::OpenStreams<B> + Clone,
B: Buf,
{
fn clone(&self) -> Self {
self.sender_count
.fetch_add(1, std::sync::atomic::Ordering::Release);
Self {
conn_state: self.conn_state.clone(),
open: self.open.clone(),
max_field_section_size: self.max_field_section_size,
sender_count: self.sender_count.clone(),
_buf: PhantomData,
send_grease_frame: self.send_grease_frame,
}
}
}
impl<T, B> Drop for SendRequest<T, B>
where
T: quic::OpenStreams<B>,
B: Buf,
{
fn drop(&mut self) {
if self
.sender_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel)
== 1
{
self.handle_connection_error_on_stream(InternalConnectionError::new(
Code::H3_NO_ERROR,
"Connection closed by client".to_string(),
));
}
}
}
pub struct Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
pub inner: ConnectionInner<C, B>,
pub(super) sent_closing: Option<PushId>,
pub(super) recv_closing: Option<StreamId>,
}
impl<C, B> ConnectionState for Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
fn shared_state(&self) -> &SharedState {
&self.inner.shared
}
}
impl<C, B> Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn shutdown(&mut self, _max_push: usize) -> Result<(), ConnectionError> {
self.inner.shutdown(&mut self.sent_closing, PushId(0)).await
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn wait_idle(&mut self) -> ConnectionError {
future::poll_fn(|cx| self.poll_close(cx)).await
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionError> {
while let Poll::Ready(result) = self.inner.poll_control(cx) {
match result {
Ok(Frame::Settings(_)) => {
#[cfg(feature = "tracing")]
trace!("Got settings");
()
}
Ok(Frame::Goaway(id)) => {
if !StreamId::from(id).is_request() {
return Poll::Ready(self.inner.handle_connection_error(
InternalConnectionError::new(
Code::H3_ID_ERROR,
format!("non-request StreamId in a GoAway frame: {}", id),
),
));
}
if let Err(err) = self.inner.process_goaway(&mut self.recv_closing, id) {
return Poll::Ready(err);
}
#[cfg(feature = "tracing")]
info!("Server initiated graceful shutdown, last: StreamId({})", id);
}
Ok(frame) => {
return Poll::Ready(self.inner.handle_connection_error(
InternalConnectionError::new(
Code::H3_FRAME_UNEXPECTED,
format!("on client control stream: {:?}", frame),
),
));
}
Err(connection_error) => {
return Poll::Ready(connection_error);
}
}
}
if self.inner.poll_accept_bi(cx).is_ready() {
return Poll::Ready(
self.inner
.handle_connection_error(InternalConnectionError::new(
Code::H3_STREAM_CREATION_ERROR,
"client received a server-initiated bidirectional stream".to_string(),
)),
);
}
Poll::Pending
}
}