use std::{
collections::HashSet,
marker::PhantomData,
option::Option,
result::Result,
sync::Arc,
task::{Context, Poll},
};
use bytes::{Buf, BytesMut};
use futures_util::{
future::{self, Future},
ready,
};
use http::{response, HeaderMap, Request, Response};
use pin_project_lite::pin_project;
use quic::RecvStream;
use quic::StreamId;
use tokio::sync::mpsc;
use crate::{
config::Config,
connection::{self, ConnectionInner, ConnectionState, SharedStateRef},
error::{Code, Error, ErrorLevel},
ext::Datagram,
frame::{FrameStream, FrameStreamError},
proto::{
frame::{Frame, PayloadLen},
headers::Header,
push::PushId,
},
qpack,
quic::{self, RecvDatagramExt, SendDatagramExt, SendStream as _},
request::ResolveRequest,
stream::{self, BufRecvStream},
};
use tracing::{error, trace, warn};
pub fn builder() -> Builder {
Builder::new()
}
pub struct Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
pub inner: ConnectionInner<C, B>,
max_field_section_size: u64,
ongoing_streams: HashSet<StreamId>,
request_end_recv: mpsc::UnboundedReceiver<StreamId>,
request_end_send: mpsc::UnboundedSender<StreamId>,
sent_closing: Option<StreamId>,
recv_closing: Option<PushId>,
last_accepted_stream: Option<StreamId>,
}
impl<C, B> ConnectionState for Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
fn shared_state(&self) -> &SharedStateRef {
&self.inner.shared
}
}
impl<C, B> Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
pub async fn new(conn: C) -> Result<Self, Error> {
builder().build(conn).await
}
pub fn close<T: AsRef<str>>(&mut self, code: Code, reason: T) -> Error {
self.inner.close(code, reason)
}
}
impl<C, B> Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
pub async fn accept(
&mut self,
) -> Result<Option<(Request<()>, RequestStream<C::BidiStream, B>)>, Error> {
let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await {
Ok(Some(s)) => FrameStream::new(BufRecvStream::new(s)),
Ok(None) => {
self.shutdown(0).await?;
return Ok(None);
}
Err(err) => {
match err.inner.kind {
crate::error::Kind::Closed => return Ok(None),
crate::error::Kind::Application {
code,
reason,
level: ErrorLevel::ConnectionError,
} => {
return Err(self.inner.close(
code,
reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))),
))
}
_ => return Err(err),
};
}
};
let frame = future::poll_fn(|cx| stream.poll_next(cx)).await;
let req = self.accept_with_frame(stream, frame)?;
if let Some(req) = req {
Ok(Some(req.resolve().await?))
} else {
Ok(None)
}
}
pub fn accept_with_frame(
&mut self,
mut stream: FrameStream<C::BidiStream, B>,
frame: Result<Option<Frame<PayloadLen>>, FrameStreamError>,
) -> Result<Option<ResolveRequest<C, B>>, Error> {
let mut encoded = match frame {
Ok(Some(Frame::Headers(h))) => h,
Ok(None) => {
return Err(self.inner.close(
Code::H3_REQUEST_INCOMPLETE,
"request stream closed before headers",
));
}
Ok(Some(_)) => {
return Err(self.inner.close(
Code::H3_FRAME_UNEXPECTED,
"first request frame is not headers",
));
}
Err(e) => {
let err: Error = e.into();
if err.is_closed() {
return Ok(None);
}
match err.inner.kind {
crate::error::Kind::Closed => return Ok(None),
crate::error::Kind::Application {
code,
reason,
level: ErrorLevel::ConnectionError,
} => {
return Err(self.inner.close(
code,
reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))),
))
}
crate::error::Kind::Application {
code,
reason: _,
level: ErrorLevel::StreamError,
} => {
stream.reset(code.into());
return Err(err);
}
_ => return Err(err),
};
}
};
let mut request_stream = RequestStream {
request_end: Arc::new(RequestEnd {
request_end: self.request_end_send.clone(),
stream_id: stream.send_id(),
}),
inner: connection::RequestStream::new(
stream,
self.max_field_section_size,
self.inner.shared.clone(),
self.inner.send_grease_frame,
),
};
let decoded = match qpack::decode_stateless(&mut encoded, self.max_field_section_size) {
Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => Err(cancel_size),
Ok(decoded) => {
self.inner.send_grease_frame = false;
Ok(decoded)
}
Err(e) => {
let err: Error = e.into();
if err.is_closed() {
return Ok(None);
}
match err.inner.kind {
crate::error::Kind::Closed => return Ok(None),
crate::error::Kind::Application {
code,
reason,
level: ErrorLevel::ConnectionError,
} => {
return Err(self.inner.close(
code,
reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))),
))
}
crate::error::Kind::Application {
code,
reason: _,
level: ErrorLevel::StreamError,
} => {
request_stream.stop_stream(code);
return Err(err);
}
_ => return Err(err),
};
}
};
Ok(Some(ResolveRequest::new(
request_stream,
decoded,
self.max_field_section_size,
)))
}
pub async fn shutdown(&mut self, max_requests: usize) -> Result<(), Error> {
let max_id = self
.last_accepted_stream
.map(|id| id + max_requests)
.unwrap_or(StreamId::FIRST_REQUEST);
self.inner.shutdown(&mut self.sent_closing, max_id).await
}
pub fn poll_accept_request(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<C::BidiStream>, Error>> {
let _ = self.poll_control(cx)?;
let _ = self.poll_requests_completion(cx);
loop {
match self.inner.poll_accept_request(cx) {
Poll::Ready(Err(x)) => break Poll::Ready(Err(x)),
Poll::Ready(Ok(None)) => {
if self.poll_requests_completion(cx).is_ready() {
break Poll::Ready(Ok(None));
} else {
break Poll::Pending;
}
}
Poll::Pending => {
if self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready() {
break Poll::Ready(Ok(None));
} else {
return Poll::Pending;
}
}
Poll::Ready(Ok(Some(mut s))) => {
if let Some(max_id) = self.sent_closing {
if s.send_id() > max_id {
s.stop_sending(Code::H3_REQUEST_REJECTED.value());
s.reset(Code::H3_REQUEST_REJECTED.value());
if self.poll_requests_completion(cx).is_ready() {
break Poll::Ready(Ok(None));
}
continue;
}
}
self.last_accepted_stream = Some(s.send_id());
self.ongoing_streams.insert(s.send_id());
break Poll::Ready(Ok(Some(s)));
}
};
}
}
pub(crate) fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while (self.poll_next_control(cx)?).is_ready() {}
Poll::Pending
}
pub(crate) fn poll_next_control(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Frame<PayloadLen>, Error>> {
let frame = ready!(self.inner.poll_control(cx))?;
match &frame {
Frame::Settings(w) => trace!("Got settings > {:?}", w),
&Frame::Goaway(id) => self.inner.process_goaway(&mut self.recv_closing, id)?,
f @ Frame::MaxPushId(_) | f @ Frame::CancelPush(_) => {
warn!("Control frame ignored {:?}", f);
}
frame => {
return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.with_reason(
format!("on server control stream: {:?}", frame),
ErrorLevel::ConnectionError,
)))
}
}
Poll::Ready(Ok(frame))
}
fn poll_requests_completion(&mut self, cx: &mut Context<'_>) -> Poll<()> {
loop {
match self.request_end_recv.poll_recv(cx) {
Poll::Ready(None) => return Poll::Ready(()),
Poll::Ready(Some(id)) => {
self.ongoing_streams.remove(&id);
}
Poll::Pending => {
if self.ongoing_streams.is_empty() {
return Poll::Ready(());
} else {
return Poll::Pending;
}
}
}
}
}
}
impl<C, B> Connection<C, B>
where
C: quic::Connection<B> + SendDatagramExt<B>,
B: Buf,
{
pub fn send_datagram(&mut self, stream_id: StreamId, data: B) -> Result<(), Error> {
self.inner
.conn
.send_datagram(Datagram::new(stream_id, data))?;
Ok(())
}
}
impl<C, B> Connection<C, B>
where
C: quic::Connection<B> + RecvDatagramExt,
B: Buf,
{
pub fn read_datagram(&mut self) -> ReadDatagram<C, B> {
ReadDatagram {
conn: self,
_marker: PhantomData,
}
}
}
impl<C, B> Drop for Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
fn drop(&mut self) {
self.inner.close(Code::H3_NO_ERROR, "");
}
}
pub struct Builder {
pub(crate) config: Config,
}
impl Builder {
pub(super) fn new() -> Self {
Builder {
config: Default::default(),
}
}
pub fn max_field_section_size(&mut self, value: u64) -> &mut Self {
self.config.max_field_section_size = value;
self
}
#[inline]
pub fn send_grease(&mut self, value: bool) -> &mut Self {
self.config.send_grease = value;
self
}
#[inline]
pub fn enable_webtransport(&mut self, value: bool) -> &mut Self {
self.config.enable_webtransport = value;
self
}
pub fn enable_connect(&mut self, value: bool) -> &mut Self {
self.config.enable_extended_connect = value;
self
}
pub fn max_webtransport_sessions(&mut self, value: u64) -> &mut Self {
self.config.max_webtransport_sessions = value;
self
}
pub fn enable_datagram(&mut self, value: bool) -> &mut Self {
self.config.enable_datagram = value;
self
}
}
impl Builder {
pub async fn build<C, B>(&self, conn: C) -> Result<Connection<C, B>, Error>
where
C: quic::Connection<B>,
B: Buf,
{
let (sender, receiver) = mpsc::unbounded_channel();
Ok(Connection {
inner: ConnectionInner::new(conn, SharedStateRef::default(), self.config).await?,
max_field_section_size: self.config.max_field_section_size,
request_end_send: sender,
request_end_recv: receiver,
ongoing_streams: HashSet::new(),
sent_closing: None,
recv_closing: None,
last_accepted_stream: None,
})
}
}
struct RequestEnd {
request_end: mpsc::UnboundedSender<StreamId>,
stream_id: StreamId,
}
pub struct RequestStream<S, B> {
inner: connection::RequestStream<S, B>,
request_end: Arc<RequestEnd>,
}
impl<S, B> AsMut<connection::RequestStream<S, B>> for RequestStream<S, B> {
fn as_mut(&mut self) -> &mut connection::RequestStream<S, B> {
&mut self.inner
}
}
impl<S, B> ConnectionState for RequestStream<S, B> {
fn shared_state(&self) -> &SharedStateRef {
&self.inner.conn_state
}
}
impl<S, B> RequestStream<S, B>
where
S: quic::RecvStream,
B: Buf,
{
pub async fn recv_data(&mut self) -> Result<Option<impl Buf>, Error> {
self.inner.recv_data().await
}
pub async fn recv_trailers(&mut self) -> Result<Option<HeaderMap>, Error> {
self.inner.recv_trailers().await
}
pub fn stop_sending(&mut self, error_code: crate::error::Code) {
self.inner.stream.stop_sending(error_code)
}
pub fn id(&self) -> StreamId {
self.inner.stream.id()
}
}
impl<S, B> RequestStream<S, B>
where
S: quic::SendStream<B>,
B: Buf,
{
pub async fn send_response(&mut self, resp: Response<()>) -> Result<(), Error> {
let (parts, _) = resp.into_parts();
let response::Parts {
status, headers, ..
} = parts;
let headers = Header::response(status, headers);
let mut block = BytesMut::new();
let mem_size = qpack::encode_stateless(&mut block, headers)?;
let max_mem_size = self
.inner
.conn_state
.read("send_response")
.peer_config
.max_field_section_size;
if mem_size > max_mem_size {
return Err(Error::header_too_big(mem_size, max_mem_size));
}
stream::write(&mut self.inner.stream, Frame::Headers(block.freeze()))
.await
.map_err(|e| self.maybe_conn_err(e))?;
Ok(())
}
pub async fn send_data(&mut self, buf: B) -> Result<(), Error> {
self.inner.send_data(buf).await
}
pub fn stop_stream(&mut self, error_code: Code) {
self.inner.stop_stream(error_code);
}
pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> {
self.inner.send_trailers(trailers).await
}
pub async fn finish(&mut self) -> Result<(), Error> {
self.inner.finish().await
}
pub fn send_id(&self) -> StreamId {
self.inner.stream.send_id()
}
}
impl<S, B> RequestStream<S, B>
where
S: quic::BidiStream<B>,
B: Buf,
{
pub fn split(
self,
) -> (
RequestStream<S::SendStream, B>,
RequestStream<S::RecvStream, B>,
) {
let (send, recv) = self.inner.split();
(
RequestStream {
inner: send,
request_end: self.request_end.clone(),
},
RequestStream {
inner: recv,
request_end: self.request_end,
},
)
}
}
impl Drop for RequestEnd {
fn drop(&mut self) {
if let Err(e) = self.request_end.send(self.stream_id) {
error!(
"failed to notify connection of request end: {} {}",
self.stream_id, e
);
}
}
}
pin_project! {
pub struct ReadDatagram<'a, C, B>
where
C: quic::Connection<B>,
B: Buf,
{
conn: &'a mut Connection<C, B>,
_marker: PhantomData<B>,
}
}
impl<'a, C, B> Future for ReadDatagram<'a, C, B>
where
C: quic::Connection<B> + RecvDatagramExt,
B: Buf,
{
type Output = Result<Option<Datagram<C::Buf>>, Error>;
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
tracing::trace!("poll: read_datagram");
match ready!(self.conn.inner.conn.poll_accept_datagram(cx))? {
Some(v) => Poll::Ready(Ok(Some(Datagram::decode(v)?))),
None => Poll::Ready(Ok(None)),
}
}
}