#[cfg(test)]
#[doc = include_str!("../README.md")]
mod readme {}
mod session_router;
mod stream;
use crate::session_router::Router;
pub use crate::stream::{
Datagram, InboundBidiStream, InboundStream, InboundUniStream, OutboundBidiStream,
OutboundUniStream,
};
use async_channel::Receiver;
use futures_lite::AsyncWriteExt;
use std::{
io,
sync::{Arc, OnceLock},
};
use swansong::Swansong;
use trillium::{Conn, Handler, Info, Method, Status, Transport, Upgrade};
use trillium_http::h3::{H3Connection, quic_varint};
use trillium_server_common::{
QuicConnection, Runtime,
h3::{
StreamId,
web_transport::{WebTransportDispatcher, WebTransportStream},
},
};
pub struct WebTransportConnection {
session_id: u64,
bidi_rx: Receiver<InboundBidiStream>,
uni_rx: Receiver<InboundUniStream>,
datagram_rx: Receiver<Datagram>,
swansong: Swansong,
upgrade: Upgrade,
h3_connection: Arc<H3Connection>,
quic_connection: QuicConnection,
runtime: Runtime,
}
impl WebTransportConnection {
pub async fn accept_bidi(&self) -> Option<InboundBidiStream> {
self.swansong.interrupt(self.bidi_rx.recv()).await?.ok()
}
pub fn runtime(&self) -> &Runtime {
&self.runtime
}
pub fn h3_connection(&self) -> &H3Connection {
&self.h3_connection
}
pub fn upgrade(&self) -> &Upgrade {
&self.upgrade
}
pub fn upgrade_mut(&mut self) -> &mut Upgrade {
&mut self.upgrade
}
pub async fn accept_uni(&self) -> Option<InboundUniStream> {
self.swansong.interrupt(self.uni_rx.recv()).await?.ok()
}
pub async fn recv_datagram(&self) -> Option<Datagram> {
self.swansong.interrupt(self.datagram_rx.recv()).await?.ok()
}
pub async fn accept_next_stream(&self) -> Option<InboundStream> {
futures_lite::future::race(
async { self.accept_bidi().await.map(InboundStream::Bidi) },
async { self.accept_uni().await.map(InboundStream::Uni) },
)
.await
}
pub fn send_datagram(&self, payload: &[u8]) -> io::Result<()> {
let quarter_id = self.session_id / 4;
let header_len = quic_varint::encoded_len(quarter_id);
let mut buf = vec![0u8; header_len + payload.len()];
quic_varint::encode(quarter_id, &mut buf).unwrap();
buf[header_len..].copy_from_slice(payload);
self.quic_connection.send_datagram(&buf)
}
pub async fn open_bidi(&self) -> io::Result<OutboundBidiStream> {
let (_stream_id, mut transport) = self.quic_connection.open_bidi().await?;
transport
.write_all(&wt_bidi_header(self.session_id))
.await?;
Ok(OutboundBidiStream::new(transport))
}
pub async fn open_uni(&self) -> io::Result<OutboundUniStream> {
let (_stream_id, mut stream) = self.quic_connection.open_uni().await?;
stream.write_all(&wt_uni_header(self.session_id)).await?;
Ok(OutboundUniStream::new(stream))
}
}
enum RoutingAction {
Stream(WebTransportStream),
Datagram(Vec<u8>),
}
fn wt_bidi_header(session_id: u64) -> Vec<u8> {
let mut buf =
vec![0u8; quic_varint::encoded_len(0x41u64) + quic_varint::encoded_len(session_id)];
let mut offset = quic_varint::encode(0x41u64, &mut buf).unwrap();
offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
buf.truncate(offset);
buf
}
fn wt_uni_header(session_id: u64) -> Vec<u8> {
let mut buf =
vec![0u8; quic_varint::encoded_len(0x54u64) + quic_varint::encoded_len(session_id)];
let mut offset = quic_varint::encode(0x54u64, &mut buf).unwrap();
offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
buf.truncate(offset);
buf
}
const DEFAULT_MAX_DATAGRAM_BUFFER: usize = 16;
pub struct WebTransport<H> {
runtime: OnceLock<Runtime>,
max_datagram_buffer: usize,
handler: H,
}
pub trait WebTransportHandler: Send + Sync + 'static {
fn run(
&self,
web_transport_connection: WebTransportConnection,
) -> impl Future<Output = ()> + Send;
}
impl<Fun, Fut> WebTransportHandler for Fun
where
Fun: Fn(WebTransportConnection) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send,
{
async fn run(&self, web_transport_connection: WebTransportConnection) {
self(web_transport_connection).await
}
}
impl<H> WebTransport<H>
where
H: WebTransportHandler,
{
pub fn new(handler: H) -> Self {
Self {
handler,
runtime: Default::default(),
max_datagram_buffer: DEFAULT_MAX_DATAGRAM_BUFFER,
}
}
pub fn with_max_datagram_buffer(mut self, max: usize) -> Self {
self.max_datagram_buffer = max;
self
}
fn runtime(&self) -> &Runtime {
self.runtime.get().unwrap()
}
}
struct WTUpgrade;
impl<H> Handler for WebTransport<H>
where
H: WebTransportHandler,
{
async fn run(&self, conn: Conn) -> Conn {
let inner: &trillium_http::Conn<Box<dyn Transport>> = conn.as_ref();
if inner.state().contains::<QuicConnection>() && conn.method() == Method::Connect
{
conn.with_state(WTUpgrade).with_status(Status::Ok).halt()
} else {
conn
}
}
async fn init(&mut self, info: &mut Info) {
self.runtime.get_or_init(|| {
info.shared_state::<Runtime>()
.cloned()
.expect("webtransport requires a Runtime")
});
info.config_mut()
.set_h3_datagrams_enabled(true)
.set_webtransport_enabled(true);
}
fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
upgrade.state().get::<WTUpgrade>().is_some()
}
async fn upgrade(&self, mut upgrade: Upgrade) {
let Some(h3_connection) = upgrade.h3_connection() else {
log::error!("missing H3Connection in upgrade state");
return;
};
let Some(quic_connection) = upgrade.state_mut().take::<QuicConnection>() else {
log::error!("missing QuicConnection in upgrade state");
return;
};
let Some(stream_id) = upgrade.state_mut().take::<StreamId>() else {
log::error!("missing StreamId in upgrade state");
return;
};
let Some(dispatcher) = upgrade.state().get::<WebTransportDispatcher>().cloned() else {
log::error!("missing WebTransportDispatcher in upgrade state");
return;
};
let max_datagram_buffer = self.max_datagram_buffer;
let Some(router) = dispatcher.get_or_init_with(|| Router::new(max_datagram_buffer)) else {
log::error!("WebTransportDispatcher has a handler of an unexpected type");
return;
};
if let Some(routing_rx) = router.take_routing_rx() {
let router = router.clone();
let quic = quic_connection.clone();
self.runtime().clone().spawn(async move {
loop {
let action = futures_lite::future::race(
async { routing_rx.recv().await.ok().map(RoutingAction::Stream) },
async {
let mut data = Vec::new();
quic.recv_datagram(|d| data.extend_from_slice(d))
.await
.ok()
.map(|()| RoutingAction::Datagram(data))
},
)
.await;
match action {
Some(RoutingAction::Stream(stream)) => {
router.sessions.lock().await.route(stream);
}
Some(RoutingAction::Datagram(data)) => {
router.sessions.lock().await.route_datagram(&data);
}
None => break,
}
}
});
}
let session_id = stream_id.into();
log::trace!("starting webtransport session {session_id}");
let session_swansong = h3_connection.swansong().child();
let (bidi_rx, uni_rx, datagram_rx) = router.sessions.lock().await.register(session_id);
let runtime = self.runtime().clone();
self.handler
.run(WebTransportConnection {
session_id,
bidi_rx,
uni_rx,
datagram_rx,
swansong: session_swansong.clone(),
upgrade,
h3_connection,
quic_connection,
runtime,
})
.await;
log::trace!("finished handler, cleaning up");
session_swansong.shut_down().await;
router.sessions.lock().await.unregister(session_id);
}
}