use crate::{Pool, h3::alt_svc::DEFAULT_BROKEN_DURATION, pool::PoolEntry};
use alt_svc::AltSvcCache;
use std::{
io::{self, ErrorKind},
net::SocketAddr,
sync::{Arc, OnceLock},
time::Duration,
};
use trillium_http::{
HttpContext,
h3::{H3Connection, H3Error, H3ErrorCode, UniStreamResult},
};
use trillium_server_common::{
ArcedConnector, ArcedQuicClientConfig, ArcedQuicEndpoint, Connector, QuicConnection, Runtime,
url::{Origin, Url},
};
mod alt_svc;
#[derive(Clone, Debug)]
pub(crate) struct H3ClientState {
pub(crate) config: ArcedQuicClientConfig,
pub(crate) pool: Pool<Origin, QuicConnection>,
pub(crate) alt_svc: AltSvcCache,
pub(crate) broken_duration: Duration,
endpoint_v4: OnceLockEndpoint,
endpoint_v6: OnceLockEndpoint,
}
#[derive(Clone, Debug)]
struct OnceLockEndpoint(Arc<OnceLock<ArcedQuicEndpoint>>);
impl Default for OnceLockEndpoint {
fn default() -> Self {
Self(Arc::new(OnceLock::new()))
}
}
impl H3ClientState {
pub(crate) fn update_alt_svc(&self, alt_svc: &str, url: &Url) {
self.alt_svc.update(alt_svc, url);
}
pub(crate) fn mark_broken(&self, origin: &Origin) {
if let Some(mut entry) = self.alt_svc.get_mut(origin) {
entry.mark_broken(self.broken_duration);
}
}
pub(crate) async fn get_or_create_quic_conn(
&self,
origin: &Origin,
host: &str,
port: u16,
connector: &ArcedConnector,
context: &Arc<HttpContext>,
) -> io::Result<QuicConnection> {
if let Some(conn) = self.pool.peek_candidate(origin) {
return Ok(conn);
}
let addr = *connector
.resolve(host, port)
.await?
.first()
.ok_or_else(|| io::Error::new(ErrorKind::NotFound, "no addresses resolved for host"))?;
let endpoint = self.endpoint_for(addr)?;
let conn = endpoint.connect(addr, host).await?;
setup_h3_connection(&conn, context, &connector.runtime());
self.pool
.insert(origin.clone(), PoolEntry::new(conn.clone(), None));
Ok(conn)
}
fn endpoint_for(&self, addr: SocketAddr) -> io::Result<ArcedQuicEndpoint> {
let (holder, bind_addr) = if addr.is_ipv6() {
(&self.endpoint_v6, "[::]:0")
} else {
(&self.endpoint_v4, "0.0.0.0:0")
};
if let Some(ep) = holder.0.get() {
return Ok(ep.clone());
}
let bind_addr: SocketAddr = bind_addr
.parse()
.map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
let ep = self.config.bind(bind_addr)?;
let _ = holder.0.set(ep);
Ok(holder.0.get().unwrap().clone())
}
pub(crate) fn new(config: ArcedQuicClientConfig) -> Self {
Self {
config,
pool: Default::default(),
alt_svc: AltSvcCache::default(),
broken_duration: DEFAULT_BROKEN_DURATION,
endpoint_v4: OnceLockEndpoint::default(),
endpoint_v6: OnceLockEndpoint::default(),
}
}
}
fn setup_h3_connection(quic_conn: &QuicConnection, context: &Arc<HttpContext>, runtime: &Runtime) {
let h3 = H3Connection::new(context.clone());
spawn_outbound_control_stream(quic_conn, &h3, runtime);
spawn_qpack_encoder_stream(quic_conn, &h3, runtime);
spawn_qpack_decoder_stream(quic_conn, &h3, runtime);
spawn_inbound_uni_streams(quic_conn, &h3, runtime);
spawn_reject_inbound_bidi_streams(quic_conn, runtime);
}
fn spawn_outbound_control_stream(conn: &QuicConnection, h3: &Arc<H3Connection>, runtime: &Runtime) {
let (conn, h3) = (conn.clone(), h3.clone());
runtime.spawn(async move {
let _guard = h3.swansong().guard();
let result: Result<(), H3Error> =
async { h3.run_outbound_control(conn.open_uni().await?.1).await }.await;
if let Err(error) = result {
log::debug!("client H3 control stream error: {error}");
}
});
}
fn spawn_qpack_encoder_stream(conn: &QuicConnection, h3: &Arc<H3Connection>, runtime: &Runtime) {
let (conn, h3) = (conn.clone(), h3.clone());
runtime.spawn(async move {
let result: Result<(), H3Error> =
async { h3.run_encoder(conn.open_uni().await?.1).await }.await;
if let Err(error) = result {
log::debug!("client H3 qpack encoder error: {error}");
}
});
}
fn spawn_qpack_decoder_stream(conn: &QuicConnection, h3: &Arc<H3Connection>, runtime: &Runtime) {
let (conn, h3) = (conn.clone(), h3.clone());
runtime.spawn(async move {
let result: Result<(), H3Error> =
async { h3.run_decoder(conn.open_uni().await?.1).await }.await;
if let Err(error) = result {
log::debug!("client H3 qpack decoder error: {error}");
}
});
}
fn spawn_reject_inbound_bidi_streams(conn: &QuicConnection, runtime: &Runtime) {
let conn = conn.clone();
runtime.spawn(async move {
while let Ok((_stream_id, mut transport)) = conn.accept_bidi().await {
log::debug!("rejecting unexpected inbound bidi stream from server");
let code = H3ErrorCode::StreamCreationError;
transport.stop(code.into());
transport.reset(code.into());
}
});
}
fn spawn_inbound_uni_streams(conn: &QuicConnection, h3: &Arc<H3Connection>, runtime: &Runtime) {
let (conn, h3, runtime) = (conn.clone(), h3.clone(), runtime.clone());
runtime.clone().spawn(async move {
while let Ok((_stream_id, recv)) = conn.accept_uni().await {
let (conn_for_error, h3) = (conn.clone(), h3.clone());
runtime.spawn(async move {
match h3.process_inbound_uni(recv).await {
Ok(UniStreamResult::Handled) => {}
Ok(
UniStreamResult::WebTransport { mut stream, .. }
| UniStreamResult::Unknown { mut stream, .. },
) => {
log::debug!("ignoring unexpected inbound uni stream");
stream.stop(H3ErrorCode::StreamCreationError.into());
}
Err(error) => {
log::debug!("client H3 inbound uni stream error: {error}");
if let H3Error::Protocol(code) = error {
conn_for_error.close(code.into(), code.reason().as_bytes());
}
h3.shut_down().await;
}
}
});
}
});
}