use std::{
collections::HashSet,
future::poll_fn,
option::Option,
result::Result,
task::{ready, Context, Poll},
};
use bytes::Buf;
use quic::RecvStream;
use quic::StreamId;
use tokio::sync::mpsc;
use crate::{
connection::ConnectionInner,
error::{internal_error::InternalConnectionError, Code, ConnectionError},
frame::FrameStream,
proto::{
frame::{Frame, PayloadLen},
push::PushId,
},
quic::{self, SendStream as _},
shared_state::{ConnectionState, SharedState},
stream::BufRecvStream,
};
#[cfg(feature = "tracing")]
use tracing::{instrument, trace, warn};
use super::request::RequestResolver;
pub struct Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
pub inner: ConnectionInner<C, B>,
pub(super) max_field_section_size: u64,
pub(super) ongoing_streams: HashSet<StreamId>,
pub(super) request_end_recv: mpsc::UnboundedReceiver<StreamId>,
pub(super) request_end_send: mpsc::UnboundedSender<StreamId>,
pub(super) sent_closing: Option<StreamId>,
pub(super) recv_closing: Option<PushId>,
pub(super) last_accepted_stream: Option<StreamId>,
}
impl<C, B> ConnectionState for Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
fn shared_state(&self) -> &SharedState {
&self.inner.shared
}
}
impl<C, B> Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn new(conn: C) -> Result<Self, ConnectionError> {
super::builder::builder().build(conn).await
}
}
#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")]
impl<C, B> Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")]
pub fn create_resolver(&self, stream: FrameStream<C::BidiStream, B>) -> RequestResolver<C, B> {
self.create_resolver_internal(stream)
}
#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")]
pub fn poll_accept_request_stream(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<C::BidiStream>, ConnectionError>> {
self.poll_accept_request_stream_internal(cx)
}
}
impl<C, B> Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn accept(&mut self) -> Result<Option<RequestResolver<C, B>>, ConnectionError> {
let stream = match poll_fn(|cx| self.poll_accept_request_stream_internal(cx)).await? {
Some(s) => FrameStream::new(BufRecvStream::new(s)),
None => {
self.shutdown(0).await?;
return Ok(None);
}
};
let resolver = self.create_resolver_internal(stream);
self.inner.send_grease_frame = false;
Ok(Some(resolver))
}
fn create_resolver_internal(
&self,
stream: FrameStream<C::BidiStream, B>,
) -> RequestResolver<C, B> {
RequestResolver {
frame_stream: stream,
request_end_send: self.request_end_send.clone(),
send_grease_frame: self.inner.send_grease_frame,
max_field_section_size: self.max_field_section_size,
shared: self.inner.shared.clone(),
}
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub async fn shutdown(&mut self, max_requests: usize) -> Result<(), ConnectionError> {
let max_id = self
.last_accepted_stream
.map(|id| id + max_requests)
.unwrap_or(StreamId::FIRST_REQUEST);
self.inner.shutdown(&mut self.sent_closing, max_id).await
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
fn poll_accept_request_stream_internal(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<C::BidiStream>, ConnectionError>> {
let _ = self.poll_control(cx)?;
let _ = self.poll_requests_completion(cx);
loop {
let conn = self.inner.poll_accept_bi(cx)?;
return match conn {
Poll::Pending => {
let done = if conn.is_pending() {
self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready()
} else {
self.poll_requests_completion(cx).is_ready()
};
if done {
Poll::Ready(Ok(None))
} else {
Poll::Pending
}
}
Poll::Ready(mut s) => {
if let Some(max_id) = self.sent_closing {
if s.send_id() > max_id {
s.stop_sending(Code::H3_REQUEST_REJECTED.value());
s.reset(Code::H3_REQUEST_REJECTED.value());
if self.poll_requests_completion(cx).is_ready() {
break Poll::Ready(Ok(None));
}
continue;
}
}
self.last_accepted_stream = Some(s.send_id());
self.ongoing_streams.insert(s.send_id());
Poll::Ready(Ok(Some(s)))
}
};
}
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub(crate) fn poll_control(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), ConnectionError>> {
while (self.poll_next_control(cx)?).is_ready() {}
Poll::Pending
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub(crate) fn poll_next_control(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Frame<PayloadLen>, ConnectionError>> {
let frame = ready!(self.inner.poll_control(cx))?;
match &frame {
Frame::Settings(_setting) => {
#[cfg(feature = "tracing")]
trace!("Got settings > {:?}", _setting);
()
}
&Frame::Goaway(id) => self.inner.process_goaway(&mut self.recv_closing, id)?,
_frame @ Frame::MaxPushId(_) | _frame @ Frame::CancelPush(_) => {
#[cfg(feature = "tracing")]
warn!("Control frame ignored {:?}", _frame);
}
frame => {
return Poll::Ready(Err(self.inner.handle_connection_error(
InternalConnectionError::new(
Code::H3_FRAME_UNEXPECTED,
format!("on server control stream: {:?}", frame),
),
)));
}
}
Poll::Ready(Ok(frame))
}
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
fn poll_requests_completion(&mut self, cx: &mut Context<'_>) -> Poll<()> {
loop {
match self.request_end_recv.poll_recv(cx) {
Poll::Ready(None) => return Poll::Ready(()),
Poll::Ready(Some(id)) => {
self.ongoing_streams.remove(&id);
}
Poll::Pending => {
if self.ongoing_streams.is_empty() {
return Poll::Ready(());
} else {
return Poll::Pending;
}
}
}
}
}
}
impl<C, B> Drop for Connection<C, B>
where
C: quic::Connection<B>,
B: Buf,
{
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
fn drop(&mut self) {
self.inner.close_connection(
Code::H3_NO_ERROR,
"Connection was closed by the server".to_string(),
);
}
}
pub(super) struct RequestEnd {
pub(super) request_end: mpsc::UnboundedSender<StreamId>,
pub(super) stream_id: StreamId,
}