Skip to main content

moq_native/
websocket.rs

1use qmux::tokio_tungstenite;
2use std::collections::HashSet;
3use std::sync::{Arc, LazyLock, Mutex};
4use std::{net, time};
5use url::Url;
6
7/// Errors specific to the WebSocket fallback backend.
8#[derive(Debug, thiserror::Error)]
9#[non_exhaustive]
10pub enum Error {
11	#[error(transparent)]
12	Io(#[from] std::io::Error),
13
14	#[error("WebSocket support is disabled")]
15	Disabled,
16
17	#[error("missing hostname")]
18	MissingHostname,
19
20	#[error("unsupported URL scheme for WebSocket: {0}")]
21	UnsupportedScheme(String),
22
23	#[error("failed to connect WebSocket")]
24	Connect(#[source] qmux::Error),
25
26	#[error("WebSocket accept failed")]
27	Accept(#[source] qmux::Error),
28}
29
30type Result<T> = std::result::Result<T, Error>;
31
32// Track servers (hostname:port) where WebSocket won the race, so we won't give QUIC a headstart next time
33static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
34
35/// WebSocket configuration for the client.
36#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
37#[serde(default, deny_unknown_fields)]
38#[group(id = "websocket-client")]
39#[non_exhaustive]
40pub struct Client {
41	/// Whether to enable WebSocket support.
42	#[arg(
43		id = "websocket-enabled",
44		long = "websocket-enabled",
45		env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
46		default_value = "true"
47	)]
48	pub enabled: bool,
49
50	/// Delay in milliseconds before attempting WebSocket fallback (default: 200)
51	/// If WebSocket won the previous race for a given server, this will be 0.
52	#[arg(
53		id = "websocket-delay",
54		long = "websocket-delay",
55		env = "MOQ_CLIENT_WEBSOCKET_DELAY",
56		default_value = "200ms",
57		value_parser = humantime::parse_duration,
58	)]
59	#[serde(with = "humantime_serde")]
60	#[serde(skip_serializing_if = "Option::is_none")]
61	pub delay: Option<time::Duration>,
62}
63
64impl Default for Client {
65	fn default() -> Self {
66		Self {
67			enabled: true,
68			delay: Some(time::Duration::from_millis(200)),
69		}
70	}
71}
72
73pub(crate) async fn race_handle(
74	config: &Client,
75	tls: &rustls::ClientConfig,
76	url: Url,
77	alpns: &[&str],
78) -> Option<Result<qmux::Session>> {
79	if !config.enabled {
80		return None;
81	}
82
83	// Only attempt WebSocket for HTTP-based schemes.
84	// Custom protocols (moqt://, moql://) use raw QUIC and don't support WebSocket.
85	match url.scheme() {
86		"http" | "https" | "ws" | "wss" => {}
87		_ => return None,
88	}
89
90	let res = connect(config, tls, url, alpns).await;
91	if let Err(err) = &res {
92		tracing::warn!(%err, "WebSocket connection failed");
93	}
94	Some(res)
95}
96
97pub(crate) async fn connect(
98	config: &Client,
99	tls: &rustls::ClientConfig,
100	mut url: Url,
101	alpns: &[&str],
102) -> Result<qmux::Session> {
103	if !config.enabled {
104		return Err(Error::Disabled);
105	}
106
107	let host = url.host_str().ok_or(Error::MissingHostname)?.to_string();
108	let port = url.port().unwrap_or_else(|| match url.scheme() {
109		"https" | "wss" | "moql" | "moqt" => 443,
110		"http" | "ws" => 80,
111		_ => 443,
112	});
113	let key = (host, port);
114
115	// Apply a small penalty to WebSocket to improve odds for QUIC to connect first,
116	// unless we've already had to fall back to WebSockets for this server.
117	// TODO if let chain
118	match config.delay {
119		Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
120			tokio::time::sleep(delay).await;
121			tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
122		}
123		_ => {}
124	}
125
126	// Convert URL scheme: http:// -> ws://, https:// -> wss://
127	// Custom protocols (moqt://, moql://) use raw QUIC and don't support WebSocket.
128	let needs_tls = match url.scheme() {
129		"http" => {
130			url.set_scheme("ws").expect("failed to set scheme");
131			false
132		}
133		"https" => {
134			url.set_scheme("wss").expect("failed to set scheme");
135			true
136		}
137		"ws" => false,
138		"wss" => true,
139		_ => return Err(Error::UnsupportedScheme(url.scheme().to_string())),
140	};
141
142	tracing::debug!(%url, "connecting via WebSocket");
143
144	// Use the existing TLS config (which respects tls-disable-verify) for secure connections
145	let connector = if needs_tls {
146		tokio_tungstenite::Connector::Rustls(Arc::new(tls.clone()))
147	} else {
148		tokio_tungstenite::Connector::Plain
149	};
150
151	let session = qmux::Client::new()
152		.with_protocols(alpns)
153		.with_connector(connector)
154		.with_keep_alive(qmux::KeepAlive::default()) // 5s ping / 30s deadline — parity with QUIC
155		.connect(url.as_str())
156		.await
157		.map_err(Error::Connect)?;
158
159	tracing::warn!(%url, "using WebSocket fallback");
160	WEBSOCKET_WON.lock().unwrap().insert(key);
161
162	Ok(session)
163}
164
165/// Listens for incoming WebSocket connections on a TCP port.
166///
167/// Use with [`crate::Server::with_websocket`] to accept WebSocket connections
168/// alongside QUIC connections on a separate port.
169pub struct Listener {
170	listener: tokio::net::TcpListener,
171	server: qmux::Server,
172}
173
174impl Listener {
175	pub async fn bind(addr: net::SocketAddr) -> Result<Self> {
176		Self::bind_with_alpns(addr, moq_net::ALPNS).await
177	}
178
179	pub async fn bind_with_alpns(addr: net::SocketAddr, alpns: &[&str]) -> Result<Self> {
180		let listener = tokio::net::TcpListener::bind(addr).await?;
181		let server = qmux::Server::new().with_protocols(alpns);
182		Ok(Self { listener, server })
183	}
184
185	pub fn local_addr(&self) -> Result<net::SocketAddr> {
186		Ok(self.listener.local_addr()?)
187	}
188
189	pub async fn accept(&self) -> Option<Result<qmux::Session>> {
190		match self.listener.accept().await {
191			Ok((stream, addr)) => {
192				tracing::debug!(%addr, "accepted WebSocket TCP connection");
193				let server = self.server.clone();
194				Some(server.accept(stream).await.map_err(Error::Accept))
195			}
196			Err(e) => Some(Err(e.into())),
197		}
198	}
199}