use crate::{Client, Conn, IntoUrl};
use std::{
borrow::Cow,
error::Error,
fmt::{self, Display},
ops::{Deref, DerefMut},
};
use trillium_http::{Method, Status, Version};
use trillium_server_common::h3::web_transport::WebTransportDispatcher;
use trillium_webtransport::{DEFAULT_MAX_DATAGRAM_BUFFER, Router, WebTransportConnection};
impl Client {
pub fn webtransport(&self, url: impl IntoUrl) -> Conn {
let mut conn = self.build_conn(Method::Connect, url);
conn.http_version = Version::Http3;
conn.protocol = Some(Cow::Borrowed("webtransport"));
conn
}
}
impl Conn {
pub async fn into_webtransport(
mut self,
) -> Result<WebTransportConnection, WebTransportConnectError> {
if self.status().is_some() {
return Err(WebTransportConnectError::new(
self,
ErrorKind::AlreadyExecuted,
));
}
if self.method() != Method::Connect || self.protocol.as_deref() != Some("webtransport") {
return Err(WebTransportConnectError::new(self, ErrorKind::InvalidConn));
}
if let Err(e) = (&mut self).await {
let kind = match e {
trillium_http::Error::ExtendedConnectUnsupported => {
ErrorKind::ExtendedConnectUnsupported
}
other => other.into(),
};
return Err(WebTransportConnectError::new(self, kind));
}
let status = self.status().expect("response did not include status");
if status != Status::Ok {
return Err(WebTransportConnectError::new(
self,
ErrorKind::Status(status),
));
}
let Some(entry) = self.wt_pool_entry.take() else {
return Err(WebTransportConnectError::new(self, ErrorKind::InvalidConn));
};
let Some((h3_connection, session_id)) = self.protocol_session.as_h3() else {
return Err(WebTransportConnectError::new(self, ErrorKind::InvalidConn));
};
let dispatcher = entry
.dispatcher
.get_or_init(WebTransportDispatcher::new)
.clone();
let runtime = self.config.runtime();
let max_datagram_buffer = DEFAULT_MAX_DATAGRAM_BUFFER;
let Some(router) = dispatcher.get_or_init_with(|| Router::new(max_datagram_buffer)) else {
return Err(WebTransportConnectError::new(
self,
ErrorKind::DispatcherTypeMismatch,
));
};
router
.clone()
.spawn_routing_task(entry.quic_conn.clone(), runtime.clone());
let (bidi_rx, uni_rx, datagram_rx) = router.sessions().lock().await.register(session_id);
let session_swansong = h3_connection.swansong().child();
let path = self.path.clone();
let authority = self.authority.clone();
let request_headers = std::mem::take(&mut self.request_headers);
let response_headers = std::mem::take(&mut self.response_headers);
let state = std::mem::take(&mut self.state);
Ok(WebTransportConnection::new(
session_id,
bidi_rx,
uni_rx,
datagram_rx,
session_swansong,
request_headers,
response_headers,
state,
path,
authority,
h3_connection,
entry.quic_conn,
runtime,
))
}
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum ErrorKind {
#[error(transparent)]
Http(#[from] trillium_http::Error),
#[error("Unexpected response status {0} for WebTransport upgrade")]
Status(Status),
#[error(
"Conn is not in a valid state for WebTransport upgrade — build via `Client::webtransport` \
and do not await separately"
)]
AlreadyExecuted,
#[error("Conn is not configured for a WebTransport upgrade")]
InvalidConn,
#[error("peer does not support WebTransport over HTTP/3")]
ExtendedConnectUnsupported,
#[error("dispatcher already initialized with an incompatible handler type")]
DispatcherTypeMismatch,
}
#[derive(Debug)]
pub struct WebTransportConnectError {
pub kind: ErrorKind,
conn: Box<Conn>,
}
impl WebTransportConnectError {
fn new(conn: Conn, kind: ErrorKind) -> Self {
Self {
conn: Box::new(conn),
kind,
}
}
}
impl From<WebTransportConnectError> for Conn {
fn from(value: WebTransportConnectError) -> Self {
*value.conn
}
}
impl Deref for WebTransportConnectError {
type Target = Conn;
fn deref(&self) -> &Self::Target {
&self.conn
}
}
impl DerefMut for WebTransportConnectError {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.conn
}
}
impl Error for WebTransportConnectError {}
impl Display for WebTransportConnectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}