use qmux::tokio_tungstenite;
use qmux::tokio_tungstenite::tungstenite::{self, client::IntoClientRequest, http};
use std::collections::HashSet;
use std::sync::{Arc, LazyLock, Mutex};
use std::{net, time};
use url::Url;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("WebSocket support is disabled")]
Disabled,
#[error("missing hostname")]
MissingHostname,
#[error("unsupported URL scheme for WebSocket: {0}")]
UnsupportedScheme(String),
#[error("failed to connect WebSocket")]
Connect(#[source] qmux::Error),
#[error("failed to build WebSocket request")]
BuildRequest(#[source] tungstenite::Error),
#[error("failed to build WebSocket protocols header")]
ProtocolHeader(#[source] http::header::InvalidHeaderValue),
#[error("failed to connect WebSocket")]
WebSocketConnect(#[source] tungstenite::Error),
#[error(transparent)]
ConnectRejected(#[from] crate::ConnectError),
#[error("WebSocket accept failed")]
Accept(#[source] qmux::Error),
}
type Result<T> = std::result::Result<T, Error>;
static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
#[serde(default, deny_unknown_fields)]
#[group(id = "websocket-client")]
#[non_exhaustive]
pub struct Client {
#[arg(
id = "websocket-enabled",
long = "websocket-enabled",
env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
default_value = "true"
)]
pub enabled: bool,
#[arg(
id = "websocket-delay",
long = "websocket-delay",
env = "MOQ_CLIENT_WEBSOCKET_DELAY",
default_value = "200ms",
value_parser = humantime::parse_duration,
)]
#[serde(with = "humantime_serde")]
#[serde(skip_serializing_if = "Option::is_none")]
pub delay: Option<time::Duration>,
}
impl Default for Client {
fn default() -> Self {
Self {
enabled: true,
delay: Some(time::Duration::from_millis(200)),
}
}
}
pub(crate) async fn race_handle(
config: &Client,
tls: &rustls::ClientConfig,
url: Url,
alpns: &[&str],
) -> Option<Result<qmux::Session>> {
if !config.enabled {
return None;
}
match url.scheme() {
"http" | "https" | "ws" | "wss" => {}
_ => return None,
}
let res = connect(config, tls, url, alpns).await;
if let Err(err) = &res {
tracing::warn!(%err, "WebSocket connection failed");
}
Some(res)
}
pub(crate) async fn connect(
config: &Client,
tls: &rustls::ClientConfig,
mut url: Url,
alpns: &[&str],
) -> Result<qmux::Session> {
if !config.enabled {
return Err(Error::Disabled);
}
let host = url.host_str().ok_or(Error::MissingHostname)?.to_string();
let port = url.port().unwrap_or_else(|| match url.scheme() {
"https" | "wss" | "moql" | "moqt" => 443,
"http" | "ws" => 80,
_ => 443,
});
let key = (host, port);
match config.delay {
Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
tokio::time::sleep(delay).await;
tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
}
_ => {}
}
let needs_tls = match url.scheme() {
"http" => {
url.set_scheme("ws").expect("failed to set scheme");
false
}
"https" => {
url.set_scheme("wss").expect("failed to set scheme");
true
}
"ws" => false,
"wss" => true,
_ => return Err(Error::UnsupportedScheme(url.scheme().to_string())),
};
tracing::debug!(%url, "connecting via WebSocket");
let connector = if needs_tls {
tokio_tungstenite::Connector::Rustls(Arc::new(tls.clone()))
} else {
tokio_tungstenite::Connector::Plain
};
let mut request = url.as_str().into_client_request().map_err(Error::BuildRequest)?;
let protocols = websocket_subprotocols(alpns).join(", ");
request.headers_mut().insert(
http::header::SEC_WEBSOCKET_PROTOCOL,
http::HeaderValue::from_str(&protocols).map_err(Error::ProtocolHeader)?,
);
let (socket, response) = if needs_tls {
tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
.await
.map_err(map_websocket_error)?
} else {
tokio_tungstenite::connect_async_with_config(request, None, false)
.await
.map_err(map_websocket_error)?
};
let alpn = response
.headers()
.get(http::header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|header| header.to_str().ok())
.map(str::to_owned);
let bare = qmux::ws::Bare::new(socket).with_keep_alive(qmux::KeepAlive::default());
let bare = match alpn.as_deref() {
Some(alpn) => bare.with_alpn(alpn),
None => bare,
};
let session = bare.connect();
tracing::warn!(%url, "using WebSocket fallback");
WEBSOCKET_WON.lock().unwrap().insert(key);
Ok(session)
}
fn websocket_subprotocols(alpns: &[&str]) -> Vec<String> {
let mut protocols = Vec::with_capacity(qmux::ALPNS.len() + qmux::PREFIXES.len() * alpns.len());
for (&bare, &prefix) in qmux::ALPNS.iter().zip(qmux::PREFIXES) {
protocols.push(bare.to_string());
protocols.extend(alpns.iter().map(|alpn| format!("{prefix}{alpn}")));
}
protocols
}
impl Error {
pub(crate) fn connect_error(&self) -> Option<crate::ConnectError> {
match self {
Self::ConnectRejected(err) => Some(*err),
_ => None,
}
}
}
fn map_websocket_error(err: tungstenite::Error) -> Error {
if let tungstenite::Error::Http(response) = &err
&& let Some(err) = crate::ConnectError::from_status_u16(response.status().as_u16())
{
return err.into();
}
Error::WebSocketConnect(err)
}
pub struct Listener {
listener: tokio::net::TcpListener,
server: qmux::Server,
}
impl Listener {
pub async fn bind(addr: net::SocketAddr) -> Result<Self> {
Self::bind_with_alpns(addr, moq_net::ALPNS).await
}
pub async fn bind_with_alpns(addr: net::SocketAddr, alpns: &[&str]) -> Result<Self> {
let listener = tokio::net::TcpListener::bind(addr).await?;
let server = qmux::Server::new().with_protocols(alpns);
Ok(Self { listener, server })
}
pub fn local_addr(&self) -> Result<net::SocketAddr> {
Ok(self.listener.local_addr()?)
}
pub async fn accept(&self) -> Option<Result<qmux::Session>> {
match self.listener.accept().await {
Ok((stream, addr)) => {
tracing::debug!(%addr, "accepted WebSocket TCP connection");
let server = self.server.clone();
Some(server.accept(stream).await.map_err(Error::Accept))
}
Err(e) => Some(Err(e.into())),
}
}
}