mod peer_settings_wait;
use super::{
H3Error,
frame::{Frame, FrameDecodeError, UniStreamType},
quic_varint::{self, QuicVarIntError},
settings::H3Settings,
};
use crate::{
Buffer, Conn, HttpContext,
h3::{H3ErrorCode, MAX_BUFFER_SIZE},
headers::qpack::{DecoderDynamicTable, EncoderDynamicTable, FieldSection},
};
use event_listener::Event;
use futures_lite::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use std::{
future::Future,
io::{self, ErrorKind},
pin::Pin,
sync::{
Arc, OnceLock,
atomic::{AtomicBool, AtomicU64, Ordering},
},
task::{Context, Poll},
};
use swansong::{ShutdownCompletion, Swansong};
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] pub enum H3StreamResult<Transport> {
Request(Conn<Transport>),
WebTransport {
session_id: u64,
transport: Transport,
buffer: Buffer,
},
}
#[derive(Debug)]
pub enum UniStreamResult<T> {
Handled,
WebTransport {
session_id: u64,
stream: T,
buffer: Buffer,
},
Unknown {
stream_type: u64,
stream: T,
},
}
#[derive(Debug)]
pub struct H3Connection {
context: Arc<HttpContext>,
swansong: Swansong,
pub(super) peer_settings: OnceLock<H3Settings>,
pub(super) peer_settings_event: Event,
max_accepted_stream_id: AtomicU64,
has_accepted_stream: AtomicBool,
decoder_dynamic_table: DecoderDynamicTable,
encoder_dynamic_table: EncoderDynamicTable,
}
impl H3Connection {
pub fn new(context: Arc<HttpContext>) -> Arc<Self> {
let swansong = context.swansong.child();
let max_table_capacity = context.config.dynamic_table_capacity;
let blocked_streams = context.config.h3_blocked_streams;
let encoder_dynamic_table = EncoderDynamicTable::new(&context);
Arc::new(Self {
context,
swansong,
peer_settings: OnceLock::new(),
peer_settings_event: Event::new(),
max_accepted_stream_id: AtomicU64::new(0),
has_accepted_stream: AtomicBool::new(false),
decoder_dynamic_table: DecoderDynamicTable::new(max_table_capacity, blocked_streams),
encoder_dynamic_table,
})
}
pub fn swansong(&self) -> &Swansong {
&self.swansong
}
pub fn shut_down(&self) -> ShutdownCompletion {
self.decoder_dynamic_table.fail(H3ErrorCode::NoError);
self.encoder_dynamic_table.fail(H3ErrorCode::NoError);
self.wake_peer_settings_waiters();
self.swansong.shut_down()
}
pub fn context(&self) -> Arc<HttpContext> {
self.context.clone()
}
pub fn peer_settings(&self) -> Option<&H3Settings> {
self.peer_settings.get()
}
fn record_accepted_stream(&self, stream_id: u64) {
self.max_accepted_stream_id
.fetch_max(stream_id, Ordering::Relaxed);
self.has_accepted_stream.store(true, Ordering::Relaxed);
}
fn goaway_id(&self) -> u64 {
if self.has_accepted_stream.load(Ordering::Relaxed) {
self.max_accepted_stream_id.load(Ordering::Relaxed) + 4
} else {
0
}
}
pub async fn process_inbound_bidi<Transport, Handler, Fut>(
self: Arc<Self>,
transport: Transport,
handler: Handler,
stream_id: u64,
) -> Result<H3StreamResult<Transport>, H3Error>
where
Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
Handler: FnOnce(Conn<Transport>) -> Fut,
Fut: Future<Output = Conn<Transport>>,
{
self.record_accepted_stream(stream_id);
let _guard = self.swansong.guard();
let buffer = Vec::with_capacity(self.context.config.request_buffer_initial_len).into();
match Conn::new_h3(self, transport, buffer, stream_id).await? {
H3StreamResult::Request(conn) => Ok(H3StreamResult::Request(
handler(conn).await.send_h3().await?,
)),
wt @ H3StreamResult::WebTransport { .. } => Ok(wt),
}
}
#[cfg(feature = "unstable")]
pub async fn decode_field_section(
&self,
encoded: &[u8],
stream_id: u64,
) -> Result<FieldSection<'static>, H3Error> {
self.decoder_dynamic_table.decode(encoded, stream_id).await
}
#[cfg(not(feature = "unstable"))]
pub(crate) async fn decode_field_section(
&self,
encoded: &[u8],
stream_id: u64,
) -> Result<FieldSection<'static>, H3Error> {
self.decoder_dynamic_table.decode(encoded, stream_id).await
}
#[cfg(feature = "unstable")]
#[allow(clippy::unnecessary_wraps, reason = "future-proofing api")]
pub fn encode_field_section(
&self,
field_section: &FieldSection<'_>,
buf: &mut Vec<u8>,
stream_id: u64,
) -> Result<(), H3Error> {
self.encoder_dynamic_table
.encode(field_section, buf, stream_id);
Ok(())
}
#[cfg(not(feature = "unstable"))]
#[allow(clippy::unnecessary_wraps, reason = "future-proofing api")]
pub(crate) fn encode_field_section(
&self,
field_section: &FieldSection<'_>,
buf: &mut Vec<u8>,
stream_id: u64,
) -> Result<(), H3Error> {
self.encoder_dynamic_table
.encode(field_section, buf, stream_id);
Ok(())
}
pub async fn run_outbound_control<T>(&self, mut stream: T) -> Result<(), H3Error>
where
T: AsyncWrite + Unpin + Send,
{
let mut buf = vec![0; 128];
let settings = Frame::Settings(H3Settings::from(&self.context.config));
log::trace!(
"H3 outbound control: sending SETTINGS: {:?}",
H3Settings::from(&self.context.config)
);
write(&mut buf, &mut stream, |buf| {
let mut written = quic_varint::encode(UniStreamType::Control, buf)?;
written += settings.encode(&mut buf[written..])?;
Some(written)
})
.await?;
log::trace!("H3 outbound control: SETTINGS sent");
self.swansong.clone().await;
write(&mut buf, &mut stream, |buf| {
Frame::Goaway(self.goaway_id()).encode(buf)
})
.await?;
Ok(())
}
pub async fn run_encoder<T>(&self, mut stream: T) -> Result<(), H3Error>
where
T: AsyncWrite + Unpin + Send,
{
self.encoder_dynamic_table
.run_writer(&mut stream, self.swansong.clone())
.await
}
pub async fn run_decoder<T>(&self, mut stream: T) -> Result<(), H3Error>
where
T: AsyncWrite + Unpin + Send,
{
self.decoder_dynamic_table
.run_writer(&mut stream, self.swansong.clone())
.await
}
pub async fn process_inbound_uni<T>(&self, mut stream: T) -> Result<UniStreamResult<T>, H3Error>
where
T: AsyncRead + Unpin + Send,
{
self.swansong
.interrupt(async move {
let mut buf = vec![0; 128];
let mut filled = 0;
let stream_type =
read(
&mut buf,
&mut filled,
&mut stream,
|data| match quic_varint::decode(data) {
Ok(ok) => Ok(Some(ok)),
Err(QuicVarIntError::UnexpectedEnd) => Ok(None),
Err(QuicVarIntError::UnknownValue { bytes, value }) => {
Ok(Some((value, bytes)))
}
},
)
.await?;
match UniStreamType::try_from(stream_type) {
Ok(UniStreamType::Control) => {
log::trace!("H3 inbound uni: control stream");
self.run_inbound_control(&mut buf, &mut filled, &mut stream)
.await?;
Ok(UniStreamResult::Handled)
}
Ok(UniStreamType::QpackEncoder) => {
log::trace!(
"H3 inbound uni: QPACK encoder stream ({filled} bytes pre-read)"
);
let mut reader = Prepended {
head: &buf[..filled],
tail: stream,
};
log::trace!("QPACK encoder stream: started");
self.decoder_dynamic_table.run_reader(&mut reader).await?;
Ok(UniStreamResult::Handled)
}
Ok(UniStreamType::QpackDecoder) => {
log::trace!(
"H3 inbound uni: QPACK decoder stream ({filled} bytes pre-read)"
);
let mut reader = Prepended {
head: &buf[..filled],
tail: stream,
};
self.encoder_dynamic_table.run_reader(&mut reader).await?;
Ok(UniStreamResult::Handled)
}
Ok(UniStreamType::WebTransport) => {
log::trace!("H3 inbound uni: WebTransport stream");
let session_id = read(&mut buf, &mut filled, &mut stream, |data| {
match quic_varint::decode(data) {
Ok(ok) => Ok(Some(ok)),
Err(QuicVarIntError::UnexpectedEnd) => Ok(None),
Err(QuicVarIntError::UnknownValue { bytes, value }) => {
Ok(Some((value, bytes)))
}
}
})
.await?;
buf.truncate(filled);
Ok(UniStreamResult::WebTransport {
session_id,
stream,
buffer: buf.into(),
})
}
Ok(UniStreamType::Push) => {
log::trace!("H3 inbound uni: push stream (push not supported)");
Ok(UniStreamResult::Unknown {
stream_type,
stream,
})
}
Err(_) => {
log::trace!("H3 inbound uni: unknown stream type {stream_type:#x}");
Ok(UniStreamResult::Unknown {
stream_type,
stream,
})
}
}
})
.await
.unwrap_or(Ok(UniStreamResult::Handled)) }
async fn run_inbound_control<T>(
&self,
buf: &mut Vec<u8>,
filled: &mut usize,
stream: &mut T,
) -> Result<(), H3Error>
where
T: AsyncRead + Unpin + Send,
{
let settings = read(buf, filled, stream, |data| match Frame::decode(data) {
Ok((Frame::Settings(s), consumed)) => Ok(Some((s, consumed))),
Ok(_) => Err(H3ErrorCode::FrameUnexpected),
Err(FrameDecodeError::Incomplete) => Ok(None),
Err(FrameDecodeError::Error(code)) => Err(code),
})
.await?;
log::trace!("H3 peer settings: {settings:?}");
self.peer_settings
.set(settings)
.map_err(|_| H3ErrorCode::FrameUnexpected)?;
self.wake_peer_settings_waiters();
self.encoder_dynamic_table
.initialize_from_peer_settings(settings);
loop {
let frame = self
.swansong
.interrupt(read(buf, filled, stream, |data| {
match Frame::decode(data) {
Ok((frame, consumed)) => Ok(Some((frame, consumed))),
Err(FrameDecodeError::Incomplete) => Ok(None),
Err(FrameDecodeError::Error(code)) => Err(code),
}
}))
.await
.transpose()?;
match frame {
None => {
log::trace!("H3 control stream: interrupted by shutdown");
return Ok(());
}
Some(Frame::Goaway(id)) => {
log::trace!("H3 control stream: peer sent GOAWAY(stream_id={id})");
self.swansong.shut_down();
return Ok(());
}
Some(Frame::Settings(_)) => {
return Err(H3ErrorCode::FrameUnexpected.into());
}
Some(Frame::Unknown(n)) => {
log::trace!("H3 control stream: skipping unknown frame (payload {n} bytes)");
let n = usize::try_from(n).unwrap_or(usize::MAX);
let in_buf = n.min(*filled);
buf.copy_within(in_buf..*filled, 0);
*filled -= in_buf;
let mut todo = n - in_buf;
let mut scratch = [0u8; 256];
while todo > 0 {
let to_read = todo.min(scratch.len());
let n = stream
.read(&mut scratch[..to_read])
.await
.map_err(H3Error::Io)?;
if n == 0 {
return Err(H3ErrorCode::ClosedCriticalStream.into());
}
todo -= n;
}
}
other => {
log::trace!("H3 control stream: ignoring {other:?}");
}
}
}
}
}
async fn write(
buf: &mut Vec<u8>,
mut stream: impl AsyncWrite + Unpin + Send,
mut f: impl FnMut(&mut [u8]) -> Option<usize>,
) -> io::Result<usize> {
let written = loop {
if let Some(w) = f(buf) {
break w;
}
if buf.len() >= MAX_BUFFER_SIZE {
return Err(io::Error::new(ErrorKind::OutOfMemory, "runaway allocation"));
}
buf.resize(buf.len() * 2, 0);
};
stream.write_all(&buf[..written]).await?;
stream.flush().await?;
Ok(written)
}
struct Prepended<'a, T> {
head: &'a [u8],
tail: T,
}
impl<T: AsyncRead + Unpin> AsyncRead for Prepended<'_, T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
out: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if !this.head.is_empty() {
let n = this.head.len().min(out.len());
out[..n].copy_from_slice(&this.head[..n]);
this.head = &this.head[n..];
return Poll::Ready(Ok(n));
}
Pin::new(&mut this.tail).poll_read(cx, out)
}
}
async fn read<R>(
buf: &mut Vec<u8>,
filled: &mut usize,
stream: &mut (impl AsyncRead + Unpin + Send),
f: impl Fn(&[u8]) -> Result<Option<(R, usize)>, H3ErrorCode>,
) -> Result<R, H3Error> {
loop {
if let Some((result, consumed)) = f(&buf[..*filled])? {
buf.copy_within(consumed..*filled, 0);
*filled -= consumed;
return Ok(result);
}
if *filled >= buf.len() {
if buf.len() >= MAX_BUFFER_SIZE {
return Err(io::Error::new(ErrorKind::OutOfMemory, "runaway allocation").into());
}
buf.resize(buf.len() * 2, 0);
}
let n = stream.read(&mut buf[*filled..]).await?;
if n == 0 {
return Err(io::Error::new(ErrorKind::UnexpectedEof, "stream closed").into());
}
*filled += n;
}
}