Skip to main content

moq_native/
websocket.rs

1use anyhow::Context;
2use std::collections::HashSet;
3use std::sync::{Arc, LazyLock, Mutex};
4use std::time;
5use url::Url;
6use web_transport_ws::{tokio_tungstenite, tungstenite};
7
8// Track servers (hostname:port) where WebSocket won the race, so we won't give QUIC a headstart next time
9static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
10
11/// WebSocket configuration for the client.
12#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
13#[serde(default, deny_unknown_fields)]
14#[non_exhaustive]
15pub struct ClientWebSocket {
16	/// Whether to enable WebSocket support.
17	#[arg(
18		id = "websocket-enabled",
19		long = "websocket-enabled",
20		env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
21		default_value = "true"
22	)]
23	pub enabled: bool,
24
25	/// Delay in milliseconds before attempting WebSocket fallback (default: 200)
26	/// If WebSocket won the previous race for a given server, this will be 0.
27	#[arg(
28		id = "websocket-delay",
29		long = "websocket-delay",
30		env = "MOQ_CLIENT_WEBSOCKET_DELAY",
31		default_value = "200ms",
32		value_parser = humantime::parse_duration,
33	)]
34	#[serde(with = "humantime_serde")]
35	#[serde(skip_serializing_if = "Option::is_none")]
36	pub delay: Option<time::Duration>,
37}
38
39impl Default for ClientWebSocket {
40	fn default() -> Self {
41		Self {
42			enabled: true,
43			delay: Some(time::Duration::from_millis(200)),
44		}
45	}
46}
47
48pub(crate) async fn race_handle(
49	config: &ClientWebSocket,
50	tls: &rustls::ClientConfig,
51	url: Url,
52) -> Option<anyhow::Result<web_transport_ws::Session>> {
53	if !config.enabled {
54		return None;
55	}
56	let res = connect(config, tls, url).await;
57	if let Err(err) = &res {
58		tracing::warn!(%err, "WebSocket connection failed");
59	}
60	Some(res)
61}
62
63pub(crate) async fn connect(
64	config: &ClientWebSocket,
65	tls: &rustls::ClientConfig,
66	mut url: Url,
67) -> anyhow::Result<web_transport_ws::Session> {
68	anyhow::ensure!(config.enabled, "WebSocket support is disabled");
69
70	let host = url.host_str().context("missing hostname")?.to_string();
71	let port = url.port().unwrap_or_else(|| match url.scheme() {
72		"https" | "wss" | "moql" | "moqt" => 443,
73		"http" | "ws" => 80,
74		_ => 443,
75	});
76	let key = (host, port);
77
78	// Apply a small penalty to WebSocket to improve odds for QUIC to connect first,
79	// unless we've already had to fall back to WebSockets for this server.
80	// TODO if let chain
81	match config.delay {
82		Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
83			tokio::time::sleep(delay).await;
84			tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
85		}
86		_ => {}
87	}
88
89	// Convert URL scheme: http:// -> ws://, https:// -> wss://
90	// Custom protocols (moqt://, moql://) use raw QUIC and don't support WebSocket.
91	let needs_tls = match url.scheme() {
92		"http" => {
93			url.set_scheme("ws").expect("failed to set scheme");
94			false
95		}
96		"https" => {
97			url.set_scheme("wss").expect("failed to set scheme");
98			true
99		}
100		"ws" => false,
101		"wss" => true,
102		_ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
103	};
104
105	tracing::debug!(%url, "connecting via WebSocket");
106
107	// Use the existing TLS config (which respects tls-disable-verify) for secure connections
108	let connector = if needs_tls {
109		Some(tokio_tungstenite::Connector::Rustls(Arc::new(tls.clone())))
110	} else {
111		None
112	};
113
114	// Connect using tokio-tungstenite
115	let (ws_stream, _response) = tokio_tungstenite::connect_async_tls_with_config(
116		url.as_str(),
117		Some(tungstenite::protocol::WebSocketConfig {
118			max_message_size: Some(64 << 20), // 64 MB
119			max_frame_size: Some(16 << 20),   // 16 MB
120			accept_unmasked_frames: false,
121			..Default::default()
122		}),
123		false, // disable_nagle
124		connector,
125	)
126	.await
127	.context("failed to connect WebSocket")?;
128
129	// Wrap WebSocket in WebTransport compatibility layer
130	// Similar to what the relay does: web_transport_ws::Session::new(socket, true)
131	let session = web_transport_ws::Session::new(ws_stream, false);
132
133	tracing::warn!(%url, "using WebSocket fallback");
134	WEBSOCKET_WON.lock().unwrap().insert(key);
135
136	Ok(session)
137}