use rcgen::{CertifiedKey, generate_simple_self_signed};
use rustls::pki_types::CertificateDer;
use std::{future::Future, io, sync::Arc};
use trillium_http::{
Conn, HttpContext, Upgrade,
h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
};
use trillium_quinn::{QuicConfig, QuinnTransport};
use trillium_server_common::{
Info, QuicConfig as QuicConfigTrait, QuicConnectionTrait, QuicEndpoint, QuicTransportReceive,
QuicTransportSend, Server,
};
use trillium_smol::{SmolRuntime, SmolTransport, SmolUdpSocket, async_net::TcpStream};
use trillium_testing::{Runtime, runtime};
enum FixtureServer {}
impl Server for FixtureServer {
type Runtime = SmolRuntime;
type Transport = SmolTransport<TcpStream>;
type UdpTransport = SmolUdpSocket;
fn runtime() -> Self::Runtime {
unreachable!("FixtureServer is uninhabited")
}
async fn accept(&mut self) -> io::Result<Self::Transport> {
match *self {}
}
fn from_host_and_port(_host: &str, _port: u16) -> Self {
unreachable!("FixtureServer is uninhabited")
}
}
#[derive(Debug)]
pub struct H3Server {
base_url: String,
cert_der: CertificateDer<'static>,
context: Arc<HttpContext>,
}
impl H3Server {
pub async fn new<H, HFut>(handler: H) -> Self
where
H: Fn(Conn<QuinnTransport>) -> HFut + Send + Sync + 'static,
HFut: Future<Output = Conn<QuinnTransport>> + Send + 'static,
{
Self::with_upgrade(handler, noop_upgrade).await
}
pub async fn with_upgrade<H, HFut, U, UFut>(handler: H, upgrade_handler: U) -> Self
where
H: Fn(Conn<QuinnTransport>) -> HFut + Send + Sync + 'static,
HFut: Future<Output = Conn<QuinnTransport>> + Send + 'static,
U: Fn(Upgrade<QuinnTransport>) -> UFut + Send + Sync + 'static,
UFut: Future<Output = ()> + Send + 'static,
{
let CertifiedKey { cert, signing_key } =
generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
let cert_pem = cert.pem().into_bytes();
let key_pem = signing_key.serialize_pem().into_bytes();
let cert_der = cert.der().clone();
let quic = QuicConfig::from_single_cert(&cert_pem, &key_pem);
let rt = SmolRuntime::default();
let mut info = Info::default();
let resolved: std::net::SocketAddr = {
use std::net::ToSocketAddrs;
("localhost", 0)
.to_socket_addrs()
.expect("resolve localhost")
.next()
.expect("at least one address for localhost")
};
let endpoint =
<QuicConfig as QuicConfigTrait<FixtureServer>>::bind(quic, resolved, rt, &mut info)
.expect("QuicConfig::bind returned None")
.expect("QuicConfig::bind failed");
let addr = endpoint.local_addr().expect("local_addr");
let base_url = format!("https://localhost:{}/", addr.port());
let context = Arc::new(HttpContext::new());
let handler = Arc::new(handler);
let upgrade_handler = Arc::new(upgrade_handler);
let trillium_rt: Runtime = runtime().into();
let endpoint = Arc::new(endpoint);
let swansong = context.swansong().clone();
let context_for_loop = context.clone();
let trillium_rt_for_loop = trillium_rt.clone();
trillium_rt.spawn(async move {
while let Some(connection) = swansong.interrupt(endpoint.accept()).await.flatten() {
let h3 = H3Connection::new(context_for_loop.clone());
let connection = Arc::new(connection);
let handler = handler.clone();
let upgrade_handler = upgrade_handler.clone();
let trillium_rt = trillium_rt_for_loop.clone();
trillium_rt.clone().spawn(run_connection(
connection,
h3,
handler,
upgrade_handler,
trillium_rt,
));
}
});
Self {
base_url,
cert_der,
context,
}
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn cert_der(&self) -> &CertificateDer<'static> {
&self.cert_der
}
pub async fn shut_down(self) {
self.context.shut_down().await;
}
}
async fn noop_upgrade<T>(_: Upgrade<T>) {}
async fn run_connection<C, H, HFut, U, UFut>(
connection: Arc<C>,
h3: Arc<H3Connection>,
handler: Arc<H>,
upgrade_handler: Arc<U>,
rt: Runtime,
) where
C: QuicConnectionTrait + 'static,
H: Fn(Conn<C::BidiStream>) -> HFut + Send + Sync + 'static,
HFut: Future<Output = Conn<C::BidiStream>> + Send + 'static,
U: Fn(Upgrade<C::BidiStream>) -> UFut + Send + Sync + 'static,
UFut: Future<Output = ()> + Send + 'static,
{
spawn_outbound_control(&connection, &h3, &rt);
spawn_qpack_encoder(&connection, &h3, &rt);
spawn_qpack_decoder(&connection, &h3, &rt);
spawn_inbound_uni(&connection, &h3, &rt);
accept_bidi_loop(connection, h3, handler, upgrade_handler, rt).await;
}
async fn accept_bidi_loop<C, H, HFut, U, UFut>(
connection: Arc<C>,
h3: Arc<H3Connection>,
handler: Arc<H>,
upgrade_handler: Arc<U>,
rt: Runtime,
) where
C: QuicConnectionTrait + 'static,
H: Fn(Conn<C::BidiStream>) -> HFut + Send + Sync + 'static,
HFut: Future<Output = Conn<C::BidiStream>> + Send + 'static,
U: Fn(Upgrade<C::BidiStream>) -> UFut + Send + Sync + 'static,
UFut: Future<Output = ()> + Send + 'static,
{
loop {
match h3.swansong().interrupt(connection.accept_bidi()).await {
None => break,
Some(Err(_)) => break,
Some(Ok((stream_id, transport))) => {
let connection = connection.clone();
let h3 = h3.clone();
let handler = handler.clone();
let upgrade_handler = upgrade_handler.clone();
rt.spawn(async move {
let handler_fn = {
let handler = handler.clone();
move |conn: Conn<_>| {
let handler = handler.clone();
async move { handler(conn).await }
}
};
let result = h3
.clone()
.process_inbound_bidi_with_reset(
transport,
handler_fn,
stream_id,
|t, code| {
let raw = u64::from(code);
t.stop(raw);
t.reset(raw);
},
)
.await;
match result {
Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
upgrade_handler(Upgrade::from(conn)).await;
}
Ok(H3StreamResult::Request(_)) => {}
Ok(H3StreamResult::WebTransport { mut transport, .. }) => {
transport.stop(H3ErrorCode::StreamCreationError.into());
transport.reset(H3ErrorCode::StreamCreationError.into());
}
Err(error) => handle_error(error, &*connection, &h3),
}
});
}
}
}
h3.shut_down();
}
fn spawn_inbound_uni<C>(connection: &Arc<C>, h3: &Arc<H3Connection>, rt: &Runtime)
where
C: QuicConnectionTrait + 'static,
{
let (connection, h3, rt) = (connection.clone(), h3.clone(), rt.clone());
rt.clone().spawn(async move {
while let Some(Ok((_id, recv))) = h3.swansong().interrupt(connection.accept_uni()).await {
let (connection, h3) = (connection.clone(), h3.clone());
rt.spawn(async move {
let close_connection = {
let connection = connection.clone();
let h3 = h3.clone();
move |code: H3ErrorCode| {
connection.close(code.into(), code.reason().as_bytes());
h3.shut_down();
}
};
let result = h3
.process_inbound_uni_with_close(recv, close_connection)
.await;
match result {
Ok(UniStreamResult::Handled) => {}
Ok(UniStreamResult::WebTransport { mut stream, .. }) => {
stream.stop(H3ErrorCode::StreamCreationError.into());
}
Ok(UniStreamResult::Unknown { mut stream, .. }) => {
stream.stop(H3ErrorCode::StreamCreationError.into());
}
Err(error) => handle_error(error, &*connection, &h3),
}
});
}
h3.shut_down();
});
}
fn spawn_outbound_control<C>(connection: &Arc<C>, h3: &Arc<H3Connection>, rt: &Runtime)
where
C: QuicConnectionTrait + 'static,
{
let (connection, h3) = (connection.clone(), h3.clone());
rt.spawn(async move {
if let Ok((_id, stream)) = connection.open_uni().await
&& let Err(error) = h3.run_outbound_control(stream).await
{
handle_error(error, &*connection, &h3);
}
h3.shut_down();
});
}
fn spawn_qpack_encoder<C>(connection: &Arc<C>, h3: &Arc<H3Connection>, rt: &Runtime)
where
C: QuicConnectionTrait + 'static,
{
let (connection, h3) = (connection.clone(), h3.clone());
rt.spawn(async move {
if let Ok((_id, stream)) = connection.open_uni().await
&& let Err(error) = h3.run_encoder(stream).await
{
handle_error(error, &*connection, &h3);
}
h3.shut_down();
});
}
fn spawn_qpack_decoder<C>(connection: &Arc<C>, h3: &Arc<H3Connection>, rt: &Runtime)
where
C: QuicConnectionTrait + 'static,
{
let (connection, h3) = (connection.clone(), h3.clone());
rt.spawn(async move {
if let Ok((_id, stream)) = connection.open_uni().await
&& let Err(error) = h3.run_decoder(stream).await
{
handle_error(error, &*connection, &h3);
}
h3.shut_down();
});
}
fn handle_error<C: QuicConnectionTrait>(error: H3Error, connection: &C, h3: &H3Connection) {
if let H3Error::Protocol(code) = error
&& code.is_connection_error()
{
connection.close(code.into(), code.reason().as_bytes());
h3.shut_down();
}
}