Skip to main content

moq_native/
client.rs

1use crate::crypto;
2use anyhow::Context;
3use rustls::RootCertStore;
4use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{LazyLock, Mutex};
8use std::{fs, io, net, sync::Arc, time};
9use url::Url;
10#[cfg(feature = "iroh")]
11use web_transport_iroh::iroh;
12use web_transport_ws::{tokio_tungstenite, tungstenite};
13
14// Track servers (hostname:port) where WebSocket won the race, so we won't give QUIC a headstart next time
15static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
16
17/// TLS configuration for the client.
18#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
19#[serde(default, deny_unknown_fields)]
20#[non_exhaustive]
21pub struct ClientTls {
22	/// Use the TLS root at this path, encoded as PEM.
23	///
24	/// This value can be provided multiple times for multiple roots.
25	/// If this is empty, system roots will be used instead
26	#[serde(skip_serializing_if = "Vec::is_empty")]
27	#[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
28	pub root: Vec<PathBuf>,
29
30	/// Danger: Disable TLS certificate verification.
31	///
32	/// Fine for local development and between relays, but should be used in caution in production.
33	#[serde(skip_serializing_if = "Option::is_none")]
34	#[arg(
35		id = "tls-disable-verify",
36		long = "tls-disable-verify",
37		env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
38		action = clap::ArgAction::SetTrue
39	)]
40	pub disable_verify: Option<bool>,
41}
42
43/// WebSocket configuration for the client.
44#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
45#[serde(default, deny_unknown_fields)]
46#[non_exhaustive]
47pub struct ClientWebSocket {
48	/// Whether to enable WebSocket support.
49	#[arg(
50		id = "websocket-enabled",
51		long = "websocket-enabled",
52		env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
53		default_value = "true"
54	)]
55	pub enabled: bool,
56
57	/// Delay in milliseconds before attempting WebSocket fallback (default: 200)
58	/// If WebSocket won the previous race for a given server, this will be 0.
59	#[arg(
60		id = "websocket-delay",
61		long = "websocket-delay",
62		env = "MOQ_CLIENT_WEBSOCKET_DELAY",
63		default_value = "200ms",
64		value_parser = humantime::parse_duration,
65	)]
66	#[serde(with = "humantime_serde")]
67	#[serde(skip_serializing_if = "Option::is_none")]
68	pub delay: Option<time::Duration>,
69}
70
71impl Default for ClientWebSocket {
72	fn default() -> Self {
73		Self {
74			enabled: true,
75			delay: Some(time::Duration::from_millis(200)),
76		}
77	}
78}
79
80/// Configuration for the MoQ client.
81#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
82#[serde(deny_unknown_fields, default)]
83#[non_exhaustive]
84pub struct ClientConfig {
85	/// Listen for UDP packets on the given address.
86	#[arg(
87		id = "client-bind",
88		long = "client-bind",
89		default_value = "[::]:0",
90		env = "MOQ_CLIENT_BIND"
91	)]
92	pub bind: net::SocketAddr,
93
94	#[command(flatten)]
95	#[serde(default)]
96	pub tls: ClientTls,
97
98	#[command(flatten)]
99	#[serde(default)]
100	pub websocket: ClientWebSocket,
101}
102
103impl ClientConfig {
104	pub fn init(self) -> anyhow::Result<Client> {
105		Client::new(self)
106	}
107}
108
109impl Default for ClientConfig {
110	fn default() -> Self {
111		Self {
112			bind: "[::]:0".parse().unwrap(),
113			tls: ClientTls::default(),
114			websocket: ClientWebSocket::default(),
115		}
116	}
117}
118
119/// Client for establishing MoQ connections over QUIC, WebTransport, or WebSocket.
120///
121/// Create via [`ClientConfig::init`] or [`Client::new`].
122#[derive(Clone)]
123#[non_exhaustive]
124pub struct Client {
125	pub moq: moq_lite::Client,
126	pub quic: quinn::Endpoint,
127	pub tls: rustls::ClientConfig,
128	pub transport: Arc<quinn::TransportConfig>,
129	pub websocket: ClientWebSocket,
130	#[cfg(feature = "iroh")]
131	pub iroh: Option<iroh::Endpoint>,
132}
133
134impl Client {
135	pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
136		let provider = crypto::provider();
137
138		// Create a list of acceptable root certificates.
139		let mut roots = RootCertStore::empty();
140
141		if config.tls.root.is_empty() {
142			let native = rustls_native_certs::load_native_certs();
143
144			// Log any errors that occurred while loading the native root certificates.
145			for err in native.errors {
146				tracing::warn!(%err, "failed to load root cert");
147			}
148
149			// Add the platform's native root certificates.
150			for cert in native.certs {
151				roots.add(cert).context("failed to add root cert")?;
152			}
153		} else {
154			// Add the specified root certificates.
155			for root in &config.tls.root {
156				let root = fs::File::open(root).context("failed to open root cert file")?;
157				let mut root = io::BufReader::new(root);
158
159				let root = rustls_pemfile::certs(&mut root)
160					.next()
161					.context("no roots found")?
162					.context("failed to read root cert")?;
163
164				roots.add(root).context("failed to add root cert")?;
165			}
166		}
167
168		// Create the TLS configuration we'll use as a client (relay -> relay)
169		let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
170			.with_protocol_versions(&[&rustls::version::TLS13])?
171			.with_root_certificates(roots)
172			.with_no_client_auth();
173
174		// Allow disabling TLS verification altogether.
175		if config.tls.disable_verify.unwrap_or_default() {
176			tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
177
178			let noop = NoCertificateVerification(provider.clone());
179			tls.dangerous().set_certificate_verifier(Arc::new(noop));
180		}
181
182		let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
183
184		// TODO Validate the BBR implementation before enabling it
185		let mut transport = quinn::TransportConfig::default();
186		transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
187		transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
188		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
189		transport.mtu_discovery_config(None); // Disable MTU discovery
190		let transport = Arc::new(transport);
191
192		// There's a bit more boilerplate to make a generic endpoint.
193		let runtime = quinn::default_runtime().context("no async runtime")?;
194		let endpoint_config = quinn::EndpointConfig::default();
195
196		// Create the generic QUIC endpoint.
197		let quic =
198			quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
199
200		Ok(Self {
201			moq: moq_lite::Client::new(),
202			quic,
203			tls,
204			transport,
205			websocket: config.websocket,
206			#[cfg(feature = "iroh")]
207			iroh: None,
208		})
209	}
210
211	#[cfg(feature = "iroh")]
212	pub fn with_iroh(mut self, iroh: Option<iroh::Endpoint>) -> Self {
213		self.iroh = iroh;
214		self
215	}
216
217	pub fn with_publish(mut self, publish: impl Into<Option<moq_lite::OriginConsumer>>) -> Self {
218		self.moq = self.moq.with_publish(publish);
219		self
220	}
221
222	pub fn with_consume(mut self, consume: impl Into<Option<moq_lite::OriginProducer>>) -> Self {
223		self.moq = self.moq.with_consume(consume);
224		self
225	}
226
227	// TODO: Uncomment when observability feature is merged
228	// pub fn with_stats(mut self, stats: impl Into<Option<Arc<dyn moq_lite::Stats>>>) -> Self {
229	// 	self.moq = self.moq.with_stats(stats);
230	// 	self
231	// }
232
233	/// Establish a WebTransport/QUIC connection followed by a MoQ handshake.
234	pub async fn connect(&self, url: Url) -> anyhow::Result<moq_lite::Session> {
235		#[cfg(feature = "iroh")]
236		if crate::iroh::is_iroh_url(&url) {
237			let session = self.connect_iroh(url).await?;
238			let session = self.moq.connect(session).await?;
239			return Ok(session);
240		}
241
242		// Create futures for both possible protocols
243		let quic_url = url.clone();
244		let quic_handle = async {
245			let res = self.connect_quic(quic_url).await;
246			if let Err(err) = &res {
247				tracing::warn!(%err, "QUIC connection failed");
248			}
249			res
250		};
251
252		let ws_handle = async {
253			if !self.websocket.enabled {
254				return None;
255			}
256
257			let res = self.connect_websocket(url).await;
258			if let Err(err) = &res {
259				tracing::warn!(%err, "WebSocket connection failed");
260			}
261			Some(res)
262		};
263
264		// Race the connection futures
265		Ok(tokio::select! {
266			Ok(quic) = quic_handle => self.moq.connect(quic).await?,
267			Some(Ok(ws)) = ws_handle => self.moq.connect(ws).await?,
268			// If both attempts fail, return an error
269			else => anyhow::bail!("failed to connect to server"),
270		})
271	}
272
273	async fn connect_quic(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
274		let mut config = self.tls.clone();
275
276		let host = url.host().context("invalid DNS name")?.to_string();
277		let port = url.port().unwrap_or(443);
278
279		// Look up the DNS entry.
280		let ip = tokio::net::lookup_host((host.clone(), port))
281			.await
282			.context("failed DNS lookup")?
283			.next()
284			.context("no DNS entries")?;
285
286		if url.scheme() == "http" {
287			// Perform a HTTP request to fetch the certificate fingerprint.
288			let mut fingerprint = url.clone();
289			fingerprint.set_path("/certificate.sha256");
290			fingerprint.set_query(None);
291			fingerprint.set_fragment(None);
292
293			tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
294
295			let resp = reqwest::get(fingerprint.as_str())
296				.await
297				.context("failed to fetch fingerprint")?
298				.error_for_status()
299				.context("fingerprint request failed")?;
300
301			let fingerprint = resp.text().await.context("failed to read fingerprint")?;
302			let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
303
304			let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
305			config.dangerous().set_certificate_verifier(Arc::new(verifier));
306
307			url.set_scheme("https").expect("failed to set scheme");
308		}
309
310		let alpn = match url.scheme() {
311			"https" => web_transport_quinn::ALPN,
312			"moql" => moq_lite::lite::ALPN,
313			"moqt" => moq_lite::ietf::ALPN,
314			_ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
315		};
316
317		// TODO support connecting to both ALPNs at the same time
318		config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
319		config.key_log = Arc::new(rustls::KeyLogFile::new());
320
321		let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
322		let mut config = quinn::ClientConfig::new(Arc::new(config));
323		config.transport_config(self.transport.clone());
324
325		tracing::debug!(%url, %ip, %alpn, "connecting");
326
327		let connection = self.quic.connect_with(config, ip, &host)?.await?;
328		tracing::Span::current().record("id", connection.stable_id());
329
330		let session = match alpn {
331			web_transport_quinn::ALPN => web_transport_quinn::Session::connect(connection, url).await?,
332			moq_lite::lite::ALPN | moq_lite::ietf::ALPN => web_transport_quinn::Session::raw(connection, url),
333			_ => unreachable!("ALPN was checked above"),
334		};
335
336		Ok(session)
337	}
338
339	async fn connect_websocket(&self, mut url: Url) -> anyhow::Result<web_transport_ws::Session> {
340		anyhow::ensure!(self.websocket.enabled, "WebSocket support is disabled");
341
342		let host = url.host_str().context("missing hostname")?.to_string();
343		let port = url.port().unwrap_or_else(|| match url.scheme() {
344			"https" | "wss" | "moql" | "moqt" => 443,
345			"http" | "ws" => 80,
346			_ => 443,
347		});
348		let key = (host, port);
349
350		// Apply a small penalty to WebSocket to improve odds for QUIC to connect first,
351		// unless we've already had to fall back to WebSockets for this server.
352		// TODO if let chain
353		match self.websocket.delay {
354			Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
355				tokio::time::sleep(delay).await;
356				tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
357			}
358			_ => {}
359		}
360
361		// Convert URL scheme: http:// -> ws://, https:// -> wss://
362		let needs_tls = match url.scheme() {
363			"http" => {
364				url.set_scheme("ws").expect("failed to set scheme");
365				false
366			}
367			"https" | "moql" | "moqt" => {
368				url.set_scheme("wss").expect("failed to set scheme");
369				true
370			}
371			"ws" => false,
372			"wss" => true,
373			_ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
374		};
375
376		tracing::debug!(%url, "connecting via WebSocket");
377
378		// Use the existing TLS config (which respects tls-disable-verify) for secure connections
379		let connector = if needs_tls {
380			Some(tokio_tungstenite::Connector::Rustls(Arc::new(self.tls.clone())))
381		} else {
382			None
383		};
384
385		// Connect using tokio-tungstenite
386		let (ws_stream, _response) = tokio_tungstenite::connect_async_tls_with_config(
387			url.as_str(),
388			Some(tungstenite::protocol::WebSocketConfig {
389				max_message_size: Some(64 << 20), // 64 MB
390				max_frame_size: Some(16 << 20),   // 16 MB
391				accept_unmasked_frames: false,
392				..Default::default()
393			}),
394			false, // disable_nagle
395			connector,
396		)
397		.await
398		.context("failed to connect WebSocket")?;
399
400		// Wrap WebSocket in WebTransport compatibility layer
401		// Similar to what the relay does: web_transport_ws::Session::new(socket, true)
402		let session = web_transport_ws::Session::new(ws_stream, false);
403
404		tracing::warn!(%url, "using WebSocket fallback");
405		WEBSOCKET_WON.lock().unwrap().insert(key);
406
407		Ok(session)
408	}
409
410	#[cfg(feature = "iroh")]
411	async fn connect_iroh(&self, url: Url) -> anyhow::Result<web_transport_iroh::Session> {
412		let endpoint = self.iroh.as_ref().context("Iroh support is not enabled")?;
413		let alpn = match url.scheme() {
414			"moql+iroh" | "iroh" => moq_lite::lite::ALPN,
415			"moqt+iroh" => moq_lite::ietf::ALPN,
416			"h3+iroh" => web_transport_iroh::ALPN_H3,
417			_ => anyhow::bail!("Invalid URL: unknown scheme"),
418		};
419		let host = url.host().context("Invalid URL: missing host")?.to_string();
420		let endpoint_id: iroh::EndpointId = host.parse().context("Invalid URL: host is not an iroh endpoint id")?;
421		let conn = endpoint.connect(endpoint_id, alpn.as_bytes()).await?;
422		let session = match alpn {
423			web_transport_iroh::ALPN_H3 => {
424				// We need to change the scheme to `https` because currently web_transport_iroh only
425				// accepts that scheme.
426				let url = url_set_scheme(url, "https")?;
427				web_transport_iroh::Session::connect_h3(conn, url).await?
428			}
429			_ => web_transport_iroh::Session::raw(conn),
430		};
431		Ok(session)
432	}
433}
434
435#[derive(Debug)]
436struct NoCertificateVerification(crypto::Provider);
437
438impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
439	fn verify_server_cert(
440		&self,
441		_end_entity: &CertificateDer<'_>,
442		_intermediates: &[CertificateDer<'_>],
443		_server_name: &ServerName<'_>,
444		_ocsp: &[u8],
445		_now: UnixTime,
446	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
447		Ok(rustls::client::danger::ServerCertVerified::assertion())
448	}
449
450	fn verify_tls12_signature(
451		&self,
452		message: &[u8],
453		cert: &CertificateDer<'_>,
454		dss: &rustls::DigitallySignedStruct,
455	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
456		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
457	}
458
459	fn verify_tls13_signature(
460		&self,
461		message: &[u8],
462		cert: &CertificateDer<'_>,
463		dss: &rustls::DigitallySignedStruct,
464	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
465		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
466	}
467
468	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
469		self.0.signature_verification_algorithms.supported_schemes()
470	}
471}
472
473// Verify the certificate matches a provided fingerprint.
474#[derive(Debug)]
475struct FingerprintVerifier {
476	provider: crypto::Provider,
477	fingerprint: Vec<u8>,
478}
479
480impl FingerprintVerifier {
481	pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
482		Self { provider, fingerprint }
483	}
484}
485
486impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
487	fn verify_server_cert(
488		&self,
489		end_entity: &CertificateDer<'_>,
490		_intermediates: &[CertificateDer<'_>],
491		_server_name: &ServerName<'_>,
492		_ocsp: &[u8],
493		_now: UnixTime,
494	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
495		let fingerprint = crypto::sha256(&self.provider, end_entity);
496		if fingerprint.as_ref() == self.fingerprint.as_slice() {
497			Ok(rustls::client::danger::ServerCertVerified::assertion())
498		} else {
499			Err(rustls::Error::General("fingerprint mismatch".into()))
500		}
501	}
502
503	fn verify_tls12_signature(
504		&self,
505		message: &[u8],
506		cert: &CertificateDer<'_>,
507		dss: &rustls::DigitallySignedStruct,
508	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
509		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
510	}
511
512	fn verify_tls13_signature(
513		&self,
514		message: &[u8],
515		cert: &CertificateDer<'_>,
516		dss: &rustls::DigitallySignedStruct,
517	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
518		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
519	}
520
521	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
522		self.provider.signature_verification_algorithms.supported_schemes()
523	}
524}
525
526/// Returns a new URL with a changed scheme.
527///
528/// [`Url::set_scheme`] returns an error if the scheme change is not valid according to
529/// [the URL specification's section on legal scheme state overrides](https://url.spec.whatwg.org/#scheme-state).
530///
531/// This function allows all scheme changes, as long as the resulting URL is valid.
532#[cfg(feature = "iroh")]
533fn url_set_scheme(url: Url, scheme: &str) -> anyhow::Result<Url> {
534	let url = format!(
535		"{}:{}",
536		scheme,
537		url.to_string().split_once(":").context("invalid URL")?.1
538	)
539	.parse()?;
540	Ok(url)
541}