Skip to main content

moq_native/
client.rs

1use crate::crypto;
2use anyhow::Context;
3use rustls::RootCertStore;
4use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{LazyLock, Mutex};
8use std::{fs, io, net, sync::Arc, time};
9use url::Url;
10#[cfg(feature = "iroh")]
11use web_transport_iroh::iroh;
12use web_transport_ws::{tokio_tungstenite, tungstenite};
13
14// Track servers (hostname:port) where WebSocket won the race, so we won't give QUIC a headstart next time
15static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
16
17/// TLS configuration for the client.
18#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
19#[serde(default, deny_unknown_fields)]
20#[non_exhaustive]
21pub struct ClientTls {
22	/// Use the TLS root at this path, encoded as PEM.
23	///
24	/// This value can be provided multiple times for multiple roots.
25	/// If this is empty, system roots will be used instead
26	#[serde(skip_serializing_if = "Vec::is_empty")]
27	#[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
28	pub root: Vec<PathBuf>,
29
30	/// Danger: Disable TLS certificate verification.
31	///
32	/// Fine for local development and between relays, but should be used in caution in production.
33	#[serde(skip_serializing_if = "Option::is_none")]
34	#[arg(
35		id = "tls-disable-verify",
36		long = "tls-disable-verify",
37		env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
38		default_missing_value = "true",
39		num_args = 0..=1,
40		value_parser = clap::value_parser!(bool),
41	)]
42	pub disable_verify: Option<bool>,
43}
44
45/// WebSocket configuration for the client.
46#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
47#[serde(default, deny_unknown_fields)]
48#[non_exhaustive]
49pub struct ClientWebSocket {
50	/// Whether to enable WebSocket support.
51	#[arg(
52		id = "websocket-enabled",
53		long = "websocket-enabled",
54		env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
55		default_value = "true"
56	)]
57	pub enabled: bool,
58
59	/// Delay in milliseconds before attempting WebSocket fallback (default: 200)
60	/// If WebSocket won the previous race for a given server, this will be 0.
61	#[arg(
62		id = "websocket-delay",
63		long = "websocket-delay",
64		env = "MOQ_CLIENT_WEBSOCKET_DELAY",
65		default_value = "200ms",
66		value_parser = humantime::parse_duration,
67	)]
68	#[serde(with = "humantime_serde")]
69	#[serde(skip_serializing_if = "Option::is_none")]
70	pub delay: Option<time::Duration>,
71}
72
73impl Default for ClientWebSocket {
74	fn default() -> Self {
75		Self {
76			enabled: true,
77			delay: Some(time::Duration::from_millis(200)),
78		}
79	}
80}
81
82/// Configuration for the MoQ client.
83#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
84#[serde(deny_unknown_fields, default)]
85#[non_exhaustive]
86pub struct ClientConfig {
87	/// Listen for UDP packets on the given address.
88	#[arg(
89		id = "client-bind",
90		long = "client-bind",
91		default_value = "[::]:0",
92		env = "MOQ_CLIENT_BIND"
93	)]
94	pub bind: net::SocketAddr,
95
96	#[command(flatten)]
97	#[serde(default)]
98	pub tls: ClientTls,
99
100	#[command(flatten)]
101	#[serde(default)]
102	pub websocket: ClientWebSocket,
103}
104
105impl ClientConfig {
106	pub fn init(self) -> anyhow::Result<Client> {
107		Client::new(self)
108	}
109}
110
111impl Default for ClientConfig {
112	fn default() -> Self {
113		Self {
114			bind: "[::]:0".parse().unwrap(),
115			tls: ClientTls::default(),
116			websocket: ClientWebSocket::default(),
117		}
118	}
119}
120
121/// Client for establishing MoQ connections over QUIC, WebTransport, or WebSocket.
122///
123/// Create via [`ClientConfig::init`] or [`Client::new`].
124#[derive(Clone)]
125#[non_exhaustive]
126pub struct Client {
127	pub moq: moq_lite::Client,
128	pub quic: quinn::Endpoint,
129	pub tls: rustls::ClientConfig,
130	pub transport: Arc<quinn::TransportConfig>,
131	pub websocket: ClientWebSocket,
132	#[cfg(feature = "iroh")]
133	pub iroh: Option<iroh::Endpoint>,
134}
135
136impl Client {
137	pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
138		let provider = crypto::provider();
139
140		// Create a list of acceptable root certificates.
141		let mut roots = RootCertStore::empty();
142
143		if config.tls.root.is_empty() {
144			let native = rustls_native_certs::load_native_certs();
145
146			// Log any errors that occurred while loading the native root certificates.
147			for err in native.errors {
148				tracing::warn!(%err, "failed to load root cert");
149			}
150
151			// Add the platform's native root certificates.
152			for cert in native.certs {
153				roots.add(cert).context("failed to add root cert")?;
154			}
155		} else {
156			// Add the specified root certificates.
157			for root in &config.tls.root {
158				let root = fs::File::open(root).context("failed to open root cert file")?;
159				let mut root = io::BufReader::new(root);
160
161				let root = rustls_pemfile::certs(&mut root)
162					.next()
163					.context("no roots found")?
164					.context("failed to read root cert")?;
165
166				roots.add(root).context("failed to add root cert")?;
167			}
168		}
169
170		// Create the TLS configuration we'll use as a client (relay -> relay)
171		let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
172			.with_protocol_versions(&[&rustls::version::TLS13])?
173			.with_root_certificates(roots)
174			.with_no_client_auth();
175
176		// Allow disabling TLS verification altogether.
177		if config.tls.disable_verify.unwrap_or_default() {
178			tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
179
180			let noop = NoCertificateVerification(provider.clone());
181			tls.dangerous().set_certificate_verifier(Arc::new(noop));
182		}
183
184		let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
185
186		// TODO Validate the BBR implementation before enabling it
187		let mut transport = quinn::TransportConfig::default();
188		transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
189		transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
190		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
191		transport.mtu_discovery_config(None); // Disable MTU discovery
192		let transport = Arc::new(transport);
193
194		// There's a bit more boilerplate to make a generic endpoint.
195		let runtime = quinn::default_runtime().context("no async runtime")?;
196		let endpoint_config = quinn::EndpointConfig::default();
197
198		// Create the generic QUIC endpoint.
199		let quic =
200			quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
201
202		Ok(Self {
203			moq: moq_lite::Client::new(),
204			quic,
205			tls,
206			transport,
207			websocket: config.websocket,
208			#[cfg(feature = "iroh")]
209			iroh: None,
210		})
211	}
212
213	#[cfg(feature = "iroh")]
214	pub fn with_iroh(mut self, iroh: Option<iroh::Endpoint>) -> Self {
215		self.iroh = iroh;
216		self
217	}
218
219	pub fn with_publish(mut self, publish: impl Into<Option<moq_lite::OriginConsumer>>) -> Self {
220		self.moq = self.moq.with_publish(publish);
221		self
222	}
223
224	pub fn with_consume(mut self, consume: impl Into<Option<moq_lite::OriginProducer>>) -> Self {
225		self.moq = self.moq.with_consume(consume);
226		self
227	}
228
229	// TODO: Uncomment when observability feature is merged
230	// pub fn with_stats(mut self, stats: impl Into<Option<Arc<dyn moq_lite::Stats>>>) -> Self {
231	// 	self.moq = self.moq.with_stats(stats);
232	// 	self
233	// }
234
235	/// Establish a WebTransport/QUIC connection followed by a MoQ handshake.
236	pub async fn connect(&self, url: Url) -> anyhow::Result<moq_lite::Session> {
237		#[cfg(feature = "iroh")]
238		if crate::iroh::is_iroh_url(&url) {
239			let session = self.connect_iroh(url).await?;
240			let session = self.moq.connect(session).await?;
241			return Ok(session);
242		}
243
244		// Create futures for both possible protocols
245		let quic_url = url.clone();
246		let quic_handle = async {
247			let res = self.connect_quic(quic_url).await;
248			if let Err(err) = &res {
249				tracing::warn!(%err, "QUIC connection failed");
250			}
251			res
252		};
253
254		let ws_handle = async {
255			if !self.websocket.enabled {
256				return None;
257			}
258
259			let res = self.connect_websocket(url).await;
260			if let Err(err) = &res {
261				tracing::warn!(%err, "WebSocket connection failed");
262			}
263			Some(res)
264		};
265
266		// Race the connection futures
267		Ok(tokio::select! {
268			Ok(quic) = quic_handle => self.moq.connect(quic).await?,
269			Some(Ok(ws)) = ws_handle => self.moq.connect(ws).await?,
270			// If both attempts fail, return an error
271			else => anyhow::bail!("failed to connect to server"),
272		})
273	}
274
275	async fn connect_quic(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
276		let mut config = self.tls.clone();
277
278		let host = url.host().context("invalid DNS name")?.to_string();
279		let port = url.port().unwrap_or(443);
280
281		// Look up the DNS entry.
282		let ip = tokio::net::lookup_host((host.clone(), port))
283			.await
284			.context("failed DNS lookup")?
285			.next()
286			.context("no DNS entries")?;
287
288		if url.scheme() == "http" {
289			// Perform a HTTP request to fetch the certificate fingerprint.
290			let mut fingerprint = url.clone();
291			fingerprint.set_path("/certificate.sha256");
292			fingerprint.set_query(None);
293			fingerprint.set_fragment(None);
294
295			tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
296
297			let resp = reqwest::get(fingerprint.as_str())
298				.await
299				.context("failed to fetch fingerprint")?
300				.error_for_status()
301				.context("fingerprint request failed")?;
302
303			let fingerprint = resp.text().await.context("failed to read fingerprint")?;
304			let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
305
306			let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
307			config.dangerous().set_certificate_verifier(Arc::new(verifier));
308
309			url.set_scheme("https").expect("failed to set scheme");
310		}
311
312		let alpn = match url.scheme() {
313			"https" => web_transport_quinn::ALPN,
314			"moql" => moq_lite::lite::ALPN,
315			"moqt" => moq_lite::ietf::ALPN,
316			_ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
317		};
318
319		// TODO support connecting to both ALPNs at the same time
320		config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
321		config.key_log = Arc::new(rustls::KeyLogFile::new());
322
323		let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
324		let mut config = quinn::ClientConfig::new(Arc::new(config));
325		config.transport_config(self.transport.clone());
326
327		tracing::debug!(%url, %ip, %alpn, "connecting");
328
329		let connection = self.quic.connect_with(config, ip, &host)?.await?;
330		tracing::Span::current().record("id", connection.stable_id());
331
332		let session = match alpn {
333			web_transport_quinn::ALPN => web_transport_quinn::Session::connect(connection, url).await?,
334			moq_lite::lite::ALPN | moq_lite::ietf::ALPN => web_transport_quinn::Session::raw(connection, url),
335			_ => unreachable!("ALPN was checked above"),
336		};
337
338		Ok(session)
339	}
340
341	async fn connect_websocket(&self, mut url: Url) -> anyhow::Result<web_transport_ws::Session> {
342		anyhow::ensure!(self.websocket.enabled, "WebSocket support is disabled");
343
344		let host = url.host_str().context("missing hostname")?.to_string();
345		let port = url.port().unwrap_or_else(|| match url.scheme() {
346			"https" | "wss" | "moql" | "moqt" => 443,
347			"http" | "ws" => 80,
348			_ => 443,
349		});
350		let key = (host, port);
351
352		// Apply a small penalty to WebSocket to improve odds for QUIC to connect first,
353		// unless we've already had to fall back to WebSockets for this server.
354		// TODO if let chain
355		match self.websocket.delay {
356			Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
357				tokio::time::sleep(delay).await;
358				tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
359			}
360			_ => {}
361		}
362
363		// Convert URL scheme: http:// -> ws://, https:// -> wss://
364		let needs_tls = match url.scheme() {
365			"http" => {
366				url.set_scheme("ws").expect("failed to set scheme");
367				false
368			}
369			"https" | "moql" | "moqt" => {
370				url.set_scheme("wss").expect("failed to set scheme");
371				true
372			}
373			"ws" => false,
374			"wss" => true,
375			_ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
376		};
377
378		tracing::debug!(%url, "connecting via WebSocket");
379
380		// Use the existing TLS config (which respects tls-disable-verify) for secure connections
381		let connector = if needs_tls {
382			Some(tokio_tungstenite::Connector::Rustls(Arc::new(self.tls.clone())))
383		} else {
384			None
385		};
386
387		// Connect using tokio-tungstenite
388		let (ws_stream, _response) = tokio_tungstenite::connect_async_tls_with_config(
389			url.as_str(),
390			Some(tungstenite::protocol::WebSocketConfig {
391				max_message_size: Some(64 << 20), // 64 MB
392				max_frame_size: Some(16 << 20),   // 16 MB
393				accept_unmasked_frames: false,
394				..Default::default()
395			}),
396			false, // disable_nagle
397			connector,
398		)
399		.await
400		.context("failed to connect WebSocket")?;
401
402		// Wrap WebSocket in WebTransport compatibility layer
403		// Similar to what the relay does: web_transport_ws::Session::new(socket, true)
404		let session = web_transport_ws::Session::new(ws_stream, false);
405
406		tracing::warn!(%url, "using WebSocket fallback");
407		WEBSOCKET_WON.lock().unwrap().insert(key);
408
409		Ok(session)
410	}
411
412	#[cfg(feature = "iroh")]
413	async fn connect_iroh(&self, url: Url) -> anyhow::Result<web_transport_iroh::Session> {
414		let endpoint = self.iroh.as_ref().context("Iroh support is not enabled")?;
415		let alpn = match url.scheme() {
416			"moql+iroh" | "iroh" => moq_lite::lite::ALPN,
417			"moqt+iroh" => moq_lite::ietf::ALPN,
418			"h3+iroh" => web_transport_iroh::ALPN_H3,
419			_ => anyhow::bail!("Invalid URL: unknown scheme"),
420		};
421		let host = url.host().context("Invalid URL: missing host")?.to_string();
422		let endpoint_id: iroh::EndpointId = host.parse().context("Invalid URL: host is not an iroh endpoint id")?;
423		let conn = endpoint.connect(endpoint_id, alpn.as_bytes()).await?;
424		let session = match alpn {
425			web_transport_iroh::ALPN_H3 => {
426				// We need to change the scheme to `https` because currently web_transport_iroh only
427				// accepts that scheme.
428				let url = url_set_scheme(url, "https")?;
429				web_transport_iroh::Session::connect_h3(conn, url).await?
430			}
431			_ => web_transport_iroh::Session::raw(conn),
432		};
433		Ok(session)
434	}
435}
436
437#[derive(Debug)]
438struct NoCertificateVerification(crypto::Provider);
439
440impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
441	fn verify_server_cert(
442		&self,
443		_end_entity: &CertificateDer<'_>,
444		_intermediates: &[CertificateDer<'_>],
445		_server_name: &ServerName<'_>,
446		_ocsp: &[u8],
447		_now: UnixTime,
448	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
449		Ok(rustls::client::danger::ServerCertVerified::assertion())
450	}
451
452	fn verify_tls12_signature(
453		&self,
454		message: &[u8],
455		cert: &CertificateDer<'_>,
456		dss: &rustls::DigitallySignedStruct,
457	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
458		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
459	}
460
461	fn verify_tls13_signature(
462		&self,
463		message: &[u8],
464		cert: &CertificateDer<'_>,
465		dss: &rustls::DigitallySignedStruct,
466	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
467		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
468	}
469
470	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
471		self.0.signature_verification_algorithms.supported_schemes()
472	}
473}
474
475// Verify the certificate matches a provided fingerprint.
476#[derive(Debug)]
477struct FingerprintVerifier {
478	provider: crypto::Provider,
479	fingerprint: Vec<u8>,
480}
481
482impl FingerprintVerifier {
483	pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
484		Self { provider, fingerprint }
485	}
486}
487
488impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
489	fn verify_server_cert(
490		&self,
491		end_entity: &CertificateDer<'_>,
492		_intermediates: &[CertificateDer<'_>],
493		_server_name: &ServerName<'_>,
494		_ocsp: &[u8],
495		_now: UnixTime,
496	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
497		let fingerprint = crypto::sha256(&self.provider, end_entity);
498		if fingerprint.as_ref() == self.fingerprint.as_slice() {
499			Ok(rustls::client::danger::ServerCertVerified::assertion())
500		} else {
501			Err(rustls::Error::General("fingerprint mismatch".into()))
502		}
503	}
504
505	fn verify_tls12_signature(
506		&self,
507		message: &[u8],
508		cert: &CertificateDer<'_>,
509		dss: &rustls::DigitallySignedStruct,
510	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
511		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
512	}
513
514	fn verify_tls13_signature(
515		&self,
516		message: &[u8],
517		cert: &CertificateDer<'_>,
518		dss: &rustls::DigitallySignedStruct,
519	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
520		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
521	}
522
523	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
524		self.provider.signature_verification_algorithms.supported_schemes()
525	}
526}
527
528/// Returns a new URL with a changed scheme.
529///
530/// [`Url::set_scheme`] returns an error if the scheme change is not valid according to
531/// [the URL specification's section on legal scheme state overrides](https://url.spec.whatwg.org/#scheme-state).
532///
533/// This function allows all scheme changes, as long as the resulting URL is valid.
534#[cfg(feature = "iroh")]
535fn url_set_scheme(url: Url, scheme: &str) -> anyhow::Result<Url> {
536	let url = format!(
537		"{}:{}",
538		scheme,
539		url.to_string().split_once(":").context("invalid URL")?.1
540	)
541	.parse()?;
542	Ok(url)
543}
544
545#[cfg(test)]
546mod tests {
547	use super::*;
548	use clap::Parser;
549
550	#[test]
551	fn test_toml_disable_verify_survives_update_from() {
552		let toml = r#"
553			tls.disable_verify = true
554		"#;
555
556		let mut config: ClientConfig = toml::from_str(toml).unwrap();
557		assert_eq!(config.tls.disable_verify, Some(true));
558
559		// Simulate: TOML loaded, then CLI args re-applied (no --tls-disable-verify flag).
560		config.update_from(["test"]);
561		assert_eq!(config.tls.disable_verify, Some(true));
562	}
563
564	#[test]
565	fn test_cli_disable_verify_flag() {
566		let config = ClientConfig::parse_from(["test", "--tls-disable-verify"]);
567		assert_eq!(config.tls.disable_verify, Some(true));
568	}
569
570	#[test]
571	fn test_cli_disable_verify_explicit_false() {
572		let config = ClientConfig::parse_from(["test", "--tls-disable-verify", "false"]);
573		assert_eq!(config.tls.disable_verify, Some(false));
574	}
575
576	#[test]
577	fn test_cli_no_disable_verify() {
578		let config = ClientConfig::parse_from(["test"]);
579		assert_eq!(config.tls.disable_verify, None);
580	}
581}