use super::{
H3Error,
frame::{Frame, FrameDecodeError, UniStreamType},
quic_varint::{self, QuicVarIntError},
settings::H3Settings,
};
use crate::{Buffer, Conn, HttpContext, h3::H3ErrorCode};
use futures_lite::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use std::{
future::Future,
io::{self, ErrorKind},
sync::{
Arc, OnceLock,
atomic::{AtomicBool, AtomicU64, Ordering},
},
};
use swansong::{ShutdownCompletion, Swansong};
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] pub enum H3StreamResult<Transport> {
Request(Conn<Transport>),
WebTransport {
session_id: u64,
transport: Transport,
buffer: Buffer,
},
}
#[derive(Debug)]
pub enum UniStreamResult<T> {
Handled,
WebTransport {
session_id: u64,
stream: T,
buffer: Buffer,
},
Unknown {
stream_type: u64,
stream: T,
},
}
#[derive(Debug)]
pub struct H3Connection {
context: Arc<HttpContext>,
swansong: Swansong,
peer_settings: OnceLock<H3Settings>,
max_accepted_stream_id: AtomicU64,
has_accepted_stream: AtomicBool,
}
impl H3Connection {
pub fn new(context: Arc<HttpContext>) -> Arc<Self> {
let swansong = context.swansong.child();
Arc::new(Self {
context,
swansong,
peer_settings: OnceLock::new(),
max_accepted_stream_id: AtomicU64::new(0),
has_accepted_stream: AtomicBool::new(false),
})
}
pub fn swansong(&self) -> &Swansong {
&self.swansong
}
pub fn shut_down(&self) -> ShutdownCompletion {
self.swansong.shut_down()
}
pub fn context(&self) -> Arc<HttpContext> {
self.context.clone()
}
pub fn peer_settings(&self) -> Option<&H3Settings> {
self.peer_settings.get()
}
fn record_accepted_stream(&self, stream_id: u64) {
self.max_accepted_stream_id
.fetch_max(stream_id, Ordering::Relaxed);
self.has_accepted_stream.store(true, Ordering::Relaxed);
}
fn goaway_id(&self) -> u64 {
if self.has_accepted_stream.load(Ordering::Relaxed) {
self.max_accepted_stream_id.load(Ordering::Relaxed) + 4
} else {
0
}
}
pub async fn process_inbound_bidi<Transport, Handler, Fut>(
self: Arc<Self>,
transport: Transport,
handler: Handler,
stream_id: u64,
) -> Result<H3StreamResult<Transport>, H3Error>
where
Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
Handler: FnOnce(Conn<Transport>) -> Fut,
Fut: Future<Output = Conn<Transport>>,
{
self.record_accepted_stream(stream_id);
let _guard = self.swansong.guard();
let buffer = Vec::with_capacity(self.context.config.request_buffer_initial_len).into();
match Conn::new_h3(self, transport, buffer).await? {
H3StreamResult::Request(conn) => Ok(H3StreamResult::Request(
handler(conn).await.send_h3().await?,
)),
wt @ H3StreamResult::WebTransport { .. } => Ok(wt),
}
}
pub async fn run_outbound_control<T>(&self, mut stream: T) -> Result<(), H3Error>
where
T: AsyncWrite + Unpin + Send,
{
let mut buf = vec![0; 128];
let settings = Frame::Settings(H3Settings::from(&self.context.config));
write(&mut buf, &mut stream, |buf| {
let mut written = quic_varint::encode(UniStreamType::Control, buf)?;
written += settings.encode(&mut buf[written..])?;
Some(written)
})
.await?;
self.swansong.clone().await;
write(&mut buf, &mut stream, |buf| {
Frame::Goaway(self.goaway_id()).encode(buf)
})
.await?;
Ok(())
}
pub async fn run_encoder<T>(&self, mut stream: T) -> Result<(), H3Error>
where
T: AsyncWrite + Unpin + Send,
{
let mut buf = vec![0; 8];
write(&mut buf, &mut stream, |buf| {
quic_varint::encode(UniStreamType::QpackEncoder, buf)
})
.await?;
self.swansong.clone().await;
Ok(())
}
pub async fn run_decoder<T>(&self, mut stream: T) -> Result<(), H3Error>
where
T: AsyncWrite + Unpin + Send,
{
let mut buf = vec![0; 8];
write(&mut buf, &mut stream, |buf| {
quic_varint::encode(UniStreamType::QpackDecoder, buf)
})
.await?;
self.swansong.clone().await;
Ok(())
}
pub async fn process_inbound_uni<T>(&self, mut stream: T) -> Result<UniStreamResult<T>, H3Error>
where
T: AsyncRead + Unpin + Send,
{
let mut buf = vec![0; 128];
let mut filled = 0;
let stream_type = read(
&mut buf,
&mut filled,
&mut stream,
|data| match quic_varint::decode::<u64>(data) {
Ok(ok) => Ok(Some(ok)),
Err(QuicVarIntError::UnexpectedEnd) => Ok(None),
Err(QuicVarIntError::UnknownValue { bytes, value }) => Ok(Some((value, bytes))),
},
)
.await?;
match UniStreamType::try_from(stream_type) {
Ok(UniStreamType::Control) => {
self.run_inbound_control(&mut buf, &mut filled, &mut stream)
.await?;
Ok(UniStreamResult::Handled)
}
Ok(UniStreamType::QpackEncoder | UniStreamType::QpackDecoder) => {
self.swansong.clone().await;
Ok(UniStreamResult::Handled)
}
Ok(UniStreamType::WebTransport) => {
let session_id =
read(
&mut buf,
&mut filled,
&mut stream,
|data| match quic_varint::decode::<u64>(data) {
Ok(ok) => Ok(Some(ok)),
Err(QuicVarIntError::UnexpectedEnd) => Ok(None),
Err(QuicVarIntError::UnknownValue { bytes, value }) => {
Ok(Some((value, bytes)))
}
},
)
.await?;
buf.truncate(filled);
Ok(UniStreamResult::WebTransport {
session_id,
stream,
buffer: buf.into(),
})
}
Ok(UniStreamType::Push) | Err(_) => Ok(UniStreamResult::Unknown {
stream_type,
stream,
}),
}
}
async fn run_inbound_control<T>(
&self,
buf: &mut Vec<u8>,
filled: &mut usize,
stream: &mut T,
) -> Result<(), H3Error>
where
T: AsyncRead + Unpin + Send,
{
let settings = read(buf, filled, stream, |data| match Frame::decode(data) {
Ok((Frame::Settings(s), consumed)) => Ok(Some((s, consumed))),
Ok(_) => Err(H3ErrorCode::FrameUnexpected),
Err(FrameDecodeError::Incomplete) => Ok(None),
Err(FrameDecodeError::Error(code)) => Err(code),
})
.await?;
self.peer_settings
.set(settings)
.map_err(|_| H3ErrorCode::FrameUnexpected)?;
loop {
let frame = read(buf, filled, stream, |data| match Frame::decode(data) {
Ok((frame, consumed)) => Ok(Some((frame, consumed))),
Err(FrameDecodeError::Incomplete) => Ok(None),
Err(FrameDecodeError::Error(code)) => Err(code),
})
.await?;
match frame {
Frame::Goaway(_) => {
self.swansong.shut_down();
return Ok(());
}
Frame::Settings(_) => {
return Err(H3ErrorCode::FrameUnexpected.into());
}
_ => { }
}
}
}
}
const MAX_BUFFER_SIZE: usize = 1024 * 10;
async fn write(
buf: &mut Vec<u8>,
mut stream: impl AsyncWrite + Unpin + Send,
mut f: impl FnMut(&mut [u8]) -> Option<usize>,
) -> io::Result<usize> {
let written = loop {
if let Some(w) = f(buf) {
break w;
}
if buf.len() >= MAX_BUFFER_SIZE {
return Err(io::Error::new(ErrorKind::OutOfMemory, "runaway allocation"));
}
buf.resize(buf.len() * 2, 0);
};
stream.write_all(&buf[..written]).await?;
stream.flush().await?;
Ok(written)
}
async fn read<R>(
buf: &mut Vec<u8>,
filled: &mut usize,
stream: &mut (impl AsyncRead + Unpin + Send),
f: impl Fn(&[u8]) -> Result<Option<(R, usize)>, H3ErrorCode>,
) -> Result<R, H3Error> {
loop {
if let Some((result, consumed)) = f(&buf[..*filled])? {
buf.copy_within(consumed..*filled, 0);
*filled -= consumed;
return Ok(result);
}
if *filled >= buf.len() {
if buf.len() >= MAX_BUFFER_SIZE {
return Err(io::Error::new(ErrorKind::OutOfMemory, "runaway allocation").into());
}
buf.resize(buf.len() * 2, 0);
}
let n = stream.read(&mut buf[*filled..]).await?;
if n == 0 {
return Err(io::Error::new(ErrorKind::UnexpectedEof, "stream closed").into());
}
*filled += n;
}
}