use crate::{Pool, h3::alt_svc::DEFAULT_BROKEN_DURATION, pool::PoolEntry};
use alt_svc::AltSvcCache;
use std::{
io::{self, ErrorKind},
net::SocketAddr,
sync::{Arc, OnceLock},
time::Duration,
};
use trillium_http::{
HttpContext,
h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
};
#[cfg(feature = "webtransport")]
use trillium_server_common::h3::web_transport::{WebTransportDispatcher, WebTransportStream};
use trillium_server_common::{
ArcedConnector, ArcedQuicClientConfig, ArcedQuicEndpoint, Connector, QuicConnection, Runtime,
url::{Origin, Url},
};
mod alt_svc;
#[derive(Clone, Debug)]
pub(crate) struct H3PoolEntry {
pub(crate) quic_conn: QuicConnection,
pub(crate) h3: Arc<H3Connection>,
#[cfg(feature = "webtransport")]
pub(crate) dispatcher: Arc<OnceLock<WebTransportDispatcher>>,
}
#[derive(Clone, Debug)]
pub(crate) struct H3ClientState {
pub(crate) config: ArcedQuicClientConfig,
pub(crate) pool: Pool<Origin, H3PoolEntry>,
pub(crate) alt_svc: AltSvcCache,
pub(crate) broken_duration: Duration,
endpoint_v4: OnceLockEndpoint,
endpoint_v6: OnceLockEndpoint,
}
#[derive(Clone, Debug)]
struct OnceLockEndpoint(Arc<OnceLock<ArcedQuicEndpoint>>);
impl Default for OnceLockEndpoint {
fn default() -> Self {
Self(Arc::new(OnceLock::new()))
}
}
impl H3ClientState {
pub(crate) fn update_alt_svc(&self, alt_svc: &str, url: &Url) {
self.alt_svc.update(alt_svc, url);
}
pub(crate) fn mark_broken(&self, origin: &Origin) {
if let Some(mut entry) = self.alt_svc.get_mut(origin) {
entry.mark_broken(self.broken_duration);
}
}
pub(crate) async fn get_or_create_quic_conn(
&self,
origin: &Origin,
host: &str,
port: u16,
connector: &ArcedConnector,
context: &Arc<HttpContext>,
) -> io::Result<H3PoolEntry> {
if let Some(entry) = self.pool.peek_candidate(origin) {
return Ok(entry);
}
let addr = *connector
.resolve(host, port)
.await?
.first()
.ok_or_else(|| io::Error::new(ErrorKind::NotFound, "no addresses resolved for host"))?;
let endpoint = self.endpoint_for(addr)?;
let conn = endpoint.connect(addr, host).await?;
let entry = setup_h3_connection(conn, context, &connector.runtime());
self.pool
.insert(origin.clone(), PoolEntry::new(entry.clone(), None));
Ok(entry)
}
fn endpoint_for(&self, addr: SocketAddr) -> io::Result<ArcedQuicEndpoint> {
let (holder, bind_addr) = if addr.is_ipv6() {
(&self.endpoint_v6, "[::]:0")
} else {
(&self.endpoint_v4, "0.0.0.0:0")
};
if let Some(ep) = holder.0.get() {
return Ok(ep.clone());
}
let bind_addr: SocketAddr = bind_addr
.parse()
.map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
let ep = self.config.bind(bind_addr)?;
let _ = holder.0.set(ep);
Ok(holder.0.get().unwrap().clone())
}
pub(crate) fn new(config: ArcedQuicClientConfig) -> Self {
Self {
config,
pool: Default::default(),
alt_svc: AltSvcCache::default(),
broken_duration: DEFAULT_BROKEN_DURATION,
endpoint_v4: OnceLockEndpoint::default(),
endpoint_v6: OnceLockEndpoint::default(),
}
}
}
fn setup_h3_connection(
quic_conn: QuicConnection,
context: &Arc<HttpContext>,
runtime: &Runtime,
) -> H3PoolEntry {
let h3 = H3Connection::new(context.clone());
#[cfg(feature = "webtransport")]
let dispatcher = Arc::new(OnceLock::new());
spawn_outbound_control_stream(&quic_conn, &h3, runtime);
spawn_qpack_encoder_stream(&quic_conn, &h3, runtime);
spawn_qpack_decoder_stream(&quic_conn, &h3, runtime);
spawn_inbound_uni_streams(
&quic_conn,
&h3,
runtime,
#[cfg(feature = "webtransport")]
&dispatcher,
);
spawn_inbound_bidi_streams(
&quic_conn,
&h3,
runtime,
#[cfg(feature = "webtransport")]
&dispatcher,
);
H3PoolEntry {
quic_conn,
h3,
#[cfg(feature = "webtransport")]
dispatcher,
}
}
fn spawn_outbound_control_stream(conn: &QuicConnection, h3: &Arc<H3Connection>, runtime: &Runtime) {
let (conn, h3) = (conn.clone(), h3.clone());
runtime.spawn(async move {
let _guard = h3.swansong().guard();
let result: Result<(), H3Error> =
async { h3.run_outbound_control(conn.open_uni().await?.1).await }.await;
if let Err(error) = result {
log::debug!("client H3 control stream error: {error}");
}
});
}
fn spawn_qpack_encoder_stream(conn: &QuicConnection, h3: &Arc<H3Connection>, runtime: &Runtime) {
let (conn, h3) = (conn.clone(), h3.clone());
runtime.spawn(async move {
let result: Result<(), H3Error> =
async { h3.run_encoder(conn.open_uni().await?.1).await }.await;
if let Err(error) = result {
log::debug!("client H3 qpack encoder error: {error}");
}
});
}
fn spawn_qpack_decoder_stream(conn: &QuicConnection, h3: &Arc<H3Connection>, runtime: &Runtime) {
let (conn, h3) = (conn.clone(), h3.clone());
runtime.spawn(async move {
let result: Result<(), H3Error> =
async { h3.run_decoder(conn.open_uni().await?.1).await }.await;
if let Err(error) = result {
log::debug!("client H3 qpack decoder error: {error}");
}
});
}
fn spawn_inbound_bidi_streams(
conn: &QuicConnection,
h3: &Arc<H3Connection>,
runtime: &Runtime,
#[cfg(feature = "webtransport")] dispatcher: &Arc<OnceLock<WebTransportDispatcher>>,
) {
let (conn, h3, runtime) = (conn.clone(), h3.clone(), runtime.clone());
#[cfg(feature = "webtransport")]
let dispatcher = dispatcher.clone();
runtime.clone().spawn(async move {
while let Ok((stream_id, transport)) = conn.accept_bidi().await {
#[cfg(feature = "webtransport")]
let dispatcher = dispatcher.clone();
let h3 = h3.clone();
runtime.spawn(async move {
let result = h3
.clone()
.process_inbound_bidi(transport, |conn| async move { conn }, stream_id)
.await;
match result {
Ok(H3StreamResult::WebTransport {
session_id,
mut transport,
buffer,
}) => {
#[cfg(feature = "webtransport")]
if let Some(dispatcher) = dispatcher.get() {
dispatcher.dispatch(WebTransportStream::Bidi {
session_id,
stream: transport,
buffer: buffer.into(),
});
return;
}
let _ = (session_id, &buffer);
log::debug!(
"inbound WT bidi stream before any WT session opened on this \
connection, rejecting"
);
transport.stop(H3ErrorCode::StreamCreationError.into());
transport.reset(H3ErrorCode::StreamCreationError.into());
}
Ok(H3StreamResult::Request(_)) => {
log::warn!(
"server opened a request bidi stream to client (RFC 9114 §6.1 \
violation)"
);
}
Err(error) => {
log::debug!("client H3 inbound bidi stream error: {error}");
}
}
});
}
});
}
fn spawn_inbound_uni_streams(
conn: &QuicConnection,
h3: &Arc<H3Connection>,
runtime: &Runtime,
#[cfg(feature = "webtransport")] dispatcher: &Arc<OnceLock<WebTransportDispatcher>>,
) {
let (conn, h3, runtime) = (conn.clone(), h3.clone(), runtime.clone());
#[cfg(feature = "webtransport")]
let dispatcher = dispatcher.clone();
runtime.clone().spawn(async move {
while let Ok((_stream_id, recv)) = conn.accept_uni().await {
let (conn_for_error, h3) = (conn.clone(), h3.clone());
#[cfg(feature = "webtransport")]
let dispatcher = dispatcher.clone();
runtime.spawn(async move {
match h3.process_inbound_uni(recv).await {
Ok(UniStreamResult::Handled) => {}
Ok(UniStreamResult::WebTransport {
session_id,
mut stream,
buffer,
}) => {
#[cfg(feature = "webtransport")]
if let Some(dispatcher) = dispatcher.get() {
dispatcher.dispatch(WebTransportStream::Uni {
session_id,
stream,
buffer: buffer.into(),
});
return;
}
let _ = (session_id, &buffer);
log::debug!(
"inbound WT uni stream before any WT session opened on this \
connection, rejecting"
);
stream.stop(H3ErrorCode::StreamCreationError.into());
}
Ok(UniStreamResult::Unknown { mut stream, .. }) => {
log::debug!("ignoring unknown inbound uni stream");
stream.stop(H3ErrorCode::StreamCreationError.into());
}
Err(error) => {
log::debug!("client H3 inbound uni stream error: {error}");
if let H3Error::Protocol(code) = error {
conn_for_error.close(code.into(), code.reason().as_bytes());
}
h3.shut_down().await;
}
}
});
}
});
}