Skip to main content

moq_native/
websocket.rs

1use qmux::tokio_tungstenite;
2use qmux::tokio_tungstenite::tungstenite::{self, client::IntoClientRequest, http};
3use std::collections::HashSet;
4use std::sync::{Arc, LazyLock, Mutex};
5use std::{net, time};
6use url::Url;
7
8/// Errors specific to the WebSocket fallback backend.
9#[derive(Debug, thiserror::Error)]
10#[non_exhaustive]
11pub enum Error {
12	#[error(transparent)]
13	Io(#[from] std::io::Error),
14
15	#[error("WebSocket support is disabled")]
16	Disabled,
17
18	#[error("missing hostname")]
19	MissingHostname,
20
21	#[error("unsupported URL scheme for WebSocket: {0}")]
22	UnsupportedScheme(String),
23
24	#[error("failed to connect WebSocket")]
25	Connect(#[source] qmux::Error),
26
27	#[error("failed to build WebSocket request")]
28	BuildRequest(#[source] tungstenite::Error),
29
30	#[error("failed to build WebSocket protocols header")]
31	ProtocolHeader(#[source] http::header::InvalidHeaderValue),
32
33	#[error("failed to connect WebSocket")]
34	WebSocketConnect(#[source] tungstenite::Error),
35
36	#[error(transparent)]
37	ConnectRejected(#[from] crate::ConnectError),
38
39	#[error("WebSocket accept failed")]
40	Accept(#[source] qmux::Error),
41}
42
43type Result<T> = std::result::Result<T, Error>;
44
45// Track servers (hostname:port) where WebSocket won the race, so we won't give QUIC a headstart next time
46static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
47
48/// WebSocket configuration for the client.
49#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
50#[serde(default, deny_unknown_fields)]
51#[group(id = "websocket-client")]
52#[non_exhaustive]
53pub struct Client {
54	/// Whether to enable WebSocket support.
55	#[arg(
56		id = "websocket-enabled",
57		long = "websocket-enabled",
58		env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
59		default_value = "true"
60	)]
61	pub enabled: bool,
62
63	/// Delay in milliseconds before attempting WebSocket fallback (default: 200)
64	/// If WebSocket won the previous race for a given server, this will be 0.
65	#[arg(
66		id = "websocket-delay",
67		long = "websocket-delay",
68		env = "MOQ_CLIENT_WEBSOCKET_DELAY",
69		default_value = "200ms",
70		value_parser = humantime::parse_duration,
71	)]
72	#[serde(with = "humantime_serde")]
73	#[serde(skip_serializing_if = "Option::is_none")]
74	pub delay: Option<time::Duration>,
75}
76
77impl Default for Client {
78	fn default() -> Self {
79		Self {
80			enabled: true,
81			delay: Some(time::Duration::from_millis(200)),
82		}
83	}
84}
85
86pub(crate) async fn race_handle(
87	config: &Client,
88	tls: &rustls::ClientConfig,
89	url: Url,
90	alpns: &[&str],
91) -> Option<Result<qmux::Session>> {
92	if !config.enabled {
93		return None;
94	}
95
96	// Only attempt WebSocket for HTTP-based schemes.
97	// Custom protocols (moqt://, moql://) use raw QUIC and don't support WebSocket.
98	match url.scheme() {
99		"http" | "https" | "ws" | "wss" => {}
100		_ => return None,
101	}
102
103	let res = connect(config, tls, url, alpns).await;
104	if let Err(err) = &res {
105		tracing::warn!(%err, "WebSocket connection failed");
106	}
107	Some(res)
108}
109
110pub(crate) async fn connect(
111	config: &Client,
112	tls: &rustls::ClientConfig,
113	mut url: Url,
114	alpns: &[&str],
115) -> Result<qmux::Session> {
116	if !config.enabled {
117		return Err(Error::Disabled);
118	}
119
120	let host = url.host_str().ok_or(Error::MissingHostname)?.to_string();
121	let port = url.port().unwrap_or_else(|| match url.scheme() {
122		"https" | "wss" | "moql" | "moqt" => 443,
123		"http" | "ws" => 80,
124		_ => 443,
125	});
126	let key = (host, port);
127
128	// Apply a small penalty to WebSocket to improve odds for QUIC to connect first,
129	// unless we've already had to fall back to WebSockets for this server.
130	// TODO if let chain
131	match config.delay {
132		Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
133			tokio::time::sleep(delay).await;
134			tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
135		}
136		_ => {}
137	}
138
139	// Convert URL scheme: http:// -> ws://, https:// -> wss://
140	// Custom protocols (moqt://, moql://) use raw QUIC and don't support WebSocket.
141	let needs_tls = match url.scheme() {
142		"http" => {
143			url.set_scheme("ws").expect("failed to set scheme");
144			false
145		}
146		"https" => {
147			url.set_scheme("wss").expect("failed to set scheme");
148			true
149		}
150		"ws" => false,
151		"wss" => true,
152		_ => return Err(Error::UnsupportedScheme(url.scheme().to_string())),
153	};
154
155	tracing::debug!(%url, "connecting via WebSocket");
156
157	// Use the existing TLS config (which respects tls-disable-verify) for secure connections.
158	let connector = if needs_tls {
159		tokio_tungstenite::Connector::Rustls(Arc::new(tls.clone()))
160	} else {
161		tokio_tungstenite::Connector::Plain
162	};
163
164	let mut request = url.as_str().into_client_request().map_err(Error::BuildRequest)?;
165	let protocols = websocket_subprotocols(alpns).join(", ");
166	request.headers_mut().insert(
167		http::header::SEC_WEBSOCKET_PROTOCOL,
168		http::HeaderValue::from_str(&protocols).map_err(Error::ProtocolHeader)?,
169	);
170
171	let (socket, response) = if needs_tls {
172		tokio_tungstenite::connect_async_tls_with_config(request, None, false, Some(connector))
173			.await
174			.map_err(map_websocket_error)?
175	} else {
176		tokio_tungstenite::connect_async_with_config(request, None, false)
177			.await
178			.map_err(map_websocket_error)?
179	};
180
181	let alpn = response
182		.headers()
183		.get(http::header::SEC_WEBSOCKET_PROTOCOL)
184		.and_then(|header| header.to_str().ok())
185		.map(str::to_owned);
186	let bare = qmux::ws::Bare::new(socket).with_keep_alive(qmux::KeepAlive::default());
187	let bare = match alpn.as_deref() {
188		Some(alpn) => bare.with_alpn(alpn),
189		None => bare,
190	};
191	let session = bare.connect();
192
193	tracing::warn!(%url, "using WebSocket fallback");
194	WEBSOCKET_WON.lock().unwrap().insert(key);
195
196	Ok(session)
197}
198
199fn websocket_subprotocols(alpns: &[&str]) -> Vec<String> {
200	let mut protocols = Vec::with_capacity(qmux::ALPNS.len() + qmux::PREFIXES.len() * alpns.len());
201	for (&bare, &prefix) in qmux::ALPNS.iter().zip(qmux::PREFIXES) {
202		protocols.push(bare.to_string());
203		protocols.extend(alpns.iter().map(|alpn| format!("{prefix}{alpn}")));
204	}
205	protocols
206}
207
208impl Error {
209	pub(crate) fn connect_error(&self) -> Option<crate::ConnectError> {
210		match self {
211			Self::ConnectRejected(err) => Some(*err),
212			_ => None,
213		}
214	}
215}
216
217fn map_websocket_error(err: tungstenite::Error) -> Error {
218	if let tungstenite::Error::Http(response) = &err
219		&& let Some(err) = crate::ConnectError::from_status_u16(response.status().as_u16())
220	{
221		return err.into();
222	}
223
224	Error::WebSocketConnect(err)
225}
226
227/// Listens for incoming WebSocket connections on a TCP port.
228///
229/// Use with [`crate::Server::with_websocket`] to accept WebSocket connections
230/// alongside QUIC connections on a separate port.
231pub struct Listener {
232	listener: tokio::net::TcpListener,
233	server: qmux::Server,
234}
235
236impl Listener {
237	pub async fn bind(addr: net::SocketAddr) -> Result<Self> {
238		Self::bind_with_alpns(addr, moq_net::ALPNS).await
239	}
240
241	pub async fn bind_with_alpns(addr: net::SocketAddr, alpns: &[&str]) -> Result<Self> {
242		let listener = tokio::net::TcpListener::bind(addr).await?;
243		let server = qmux::Server::new().with_protocols(alpns);
244		Ok(Self { listener, server })
245	}
246
247	pub fn local_addr(&self) -> Result<net::SocketAddr> {
248		Ok(self.listener.local_addr()?)
249	}
250
251	pub async fn accept(&self) -> Option<Result<qmux::Session>> {
252		match self.listener.accept().await {
253			Ok((stream, addr)) => {
254				tracing::debug!(%addr, "accepted WebSocket TCP connection");
255				let server = self.server.clone();
256				Some(server.accept(stream).await.map_err(Error::Accept))
257			}
258			Err(e) => Some(Err(e.into())),
259		}
260	}
261}