moq_native/
client.rs

1use crate::crypto;
2use anyhow::Context;
3use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
4use rustls::RootCertStore;
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{LazyLock, Mutex};
8use std::{fs, io, net, sync::Arc, time};
9use url::Url;
10#[cfg(feature = "iroh")]
11use web_transport_iroh::iroh;
12use web_transport_ws::{tokio_tungstenite, tungstenite};
13
14// Track servers (hostname:port) where WebSocket won the race, so we won't give QUIC a headstart next time
15static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
16
17/// TLS configuration for the client.
18#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
19#[serde(default, deny_unknown_fields)]
20pub struct ClientTls {
21	/// Use the TLS root at this path, encoded as PEM.
22	///
23	/// This value can be provided multiple times for multiple roots.
24	/// If this is empty, system roots will be used instead
25	#[serde(skip_serializing_if = "Vec::is_empty")]
26	#[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
27	pub root: Vec<PathBuf>,
28
29	/// Danger: Disable TLS certificate verification.
30	///
31	/// Fine for local development and between relays, but should be used in caution in production.
32	#[serde(skip_serializing_if = "Option::is_none")]
33	#[arg(
34		id = "tls-disable-verify",
35		long = "tls-disable-verify",
36		env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
37		action = clap::ArgAction::SetTrue
38	)]
39	pub disable_verify: Option<bool>,
40}
41
42/// WebSocket configuration for the client.
43#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
44#[serde(default, deny_unknown_fields)]
45pub struct ClientWebSocket {
46	/// Delay in milliseconds before attempting WebSocket fallback (default: 200)
47	/// If WebSocket won the previous race for a given server, this will be 0.
48	#[arg(
49		id = "websocket-delay",
50		long = "websocket-delay",
51		env = "MOQ_CLIENT_WEBSOCKET_DELAY",
52		default_value = "200ms",
53		value_parser = humantime::parse_duration,
54	)]
55	#[serde(with = "humantime_serde")]
56	#[serde(skip_serializing_if = "Option::is_none")]
57	pub delay: Option<time::Duration>,
58}
59
60/// Configuration for the MoQ client.
61#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
62#[serde(deny_unknown_fields, default)]
63pub struct ClientConfig {
64	/// Listen for UDP packets on the given address.
65	#[arg(
66		id = "client-bind",
67		long = "client-bind",
68		default_value = "[::]:0",
69		env = "MOQ_CLIENT_BIND"
70	)]
71	pub bind: net::SocketAddr,
72
73	#[command(flatten)]
74	#[serde(default)]
75	pub tls: ClientTls,
76
77	#[command(flatten)]
78	#[serde(default)]
79	pub websocket: ClientWebSocket,
80}
81
82impl ClientConfig {
83	pub fn init(self) -> anyhow::Result<Client> {
84		Client::new(self)
85	}
86}
87
88impl Default for ClientConfig {
89	fn default() -> Self {
90		Self {
91			bind: "[::]:0".parse().unwrap(),
92			tls: ClientTls::default(),
93			websocket: ClientWebSocket::default(),
94		}
95	}
96}
97
98/// Client for establishing MoQ connections over QUIC, WebTransport, or WebSocket.
99///
100/// Create via [`ClientConfig::init`] or [`Client::new`].
101#[derive(Clone)]
102pub struct Client {
103	pub quic: quinn::Endpoint,
104	pub tls: rustls::ClientConfig,
105	pub transport: Arc<quinn::TransportConfig>,
106	pub websocket_delay: Option<time::Duration>,
107	#[cfg(feature = "iroh")]
108	pub iroh: Option<iroh::Endpoint>,
109}
110
111impl Client {
112	pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
113		let provider = crypto::provider();
114
115		// Create a list of acceptable root certificates.
116		let mut roots = RootCertStore::empty();
117
118		if config.tls.root.is_empty() {
119			let native = rustls_native_certs::load_native_certs();
120
121			// Log any errors that occurred while loading the native root certificates.
122			for err in native.errors {
123				tracing::warn!(%err, "failed to load root cert");
124			}
125
126			// Add the platform's native root certificates.
127			for cert in native.certs {
128				roots.add(cert).context("failed to add root cert")?;
129			}
130		} else {
131			// Add the specified root certificates.
132			for root in &config.tls.root {
133				let root = fs::File::open(root).context("failed to open root cert file")?;
134				let mut root = io::BufReader::new(root);
135
136				let root = rustls_pemfile::certs(&mut root)
137					.next()
138					.context("no roots found")?
139					.context("failed to read root cert")?;
140
141				roots.add(root).context("failed to add root cert")?;
142			}
143		}
144
145		// Create the TLS configuration we'll use as a client (relay -> relay)
146		let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
147			.with_protocol_versions(&[&rustls::version::TLS13])?
148			.with_root_certificates(roots)
149			.with_no_client_auth();
150
151		// Allow disabling TLS verification altogether.
152		if config.tls.disable_verify.unwrap_or_default() {
153			tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
154
155			let noop = NoCertificateVerification(provider.clone());
156			tls.dangerous().set_certificate_verifier(Arc::new(noop));
157		}
158
159		let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
160
161		// TODO Validate the BBR implementation before enabling it
162		let mut transport = quinn::TransportConfig::default();
163		transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
164		transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
165		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
166		transport.mtu_discovery_config(None); // Disable MTU discovery
167		let transport = Arc::new(transport);
168
169		// There's a bit more boilerplate to make a generic endpoint.
170		let runtime = quinn::default_runtime().context("no async runtime")?;
171		let endpoint_config = quinn::EndpointConfig::default();
172
173		// Create the generic QUIC endpoint.
174		let quic =
175			quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
176
177		Ok(Self {
178			quic,
179			tls,
180			transport,
181			websocket_delay: config.websocket.delay,
182			#[cfg(feature = "iroh")]
183			iroh: None,
184		})
185	}
186
187	#[cfg(feature = "iroh")]
188	pub fn with_iroh(&mut self, iroh: Option<iroh::Endpoint>) -> &mut Self {
189		self.iroh = iroh;
190		self
191	}
192
193	/// Establish a WebTransport/QUIC connection followed by a MoQ handshake.
194	pub async fn connect(
195		&self,
196		url: Url,
197		publish: impl Into<Option<moq_lite::OriginConsumer>>,
198		subscribe: impl Into<Option<moq_lite::OriginProducer>>,
199	) -> anyhow::Result<moq_lite::Session> {
200		#[cfg(feature = "iroh")]
201		if crate::iroh::is_iroh_url(&url) {
202			let session = self.connect_iroh(url).await?;
203			let session = moq_lite::Session::connect(session, publish, subscribe).await?;
204			return Ok(session);
205		}
206
207		let session = self.connect_quic(url).await?;
208		let session = moq_lite::Session::connect(session, publish, subscribe).await?;
209		Ok(session)
210	}
211
212	/// Establish a WebTransport/QUIC connection or a WebSocket connection, whichever is available first.
213	///
214	/// Establishes a MoQ handshake on the winning transport.
215	pub async fn connect_with_fallback(
216		&self,
217		url: Url,
218		publish: impl Into<Option<moq_lite::OriginConsumer>>,
219		subscribe: impl Into<Option<moq_lite::OriginProducer>>,
220	) -> anyhow::Result<moq_lite::Session> {
221		#[cfg(feature = "iroh")]
222		if crate::iroh::is_iroh_url(&url) {
223			let session = self.connect_iroh(url).await?;
224			let session = moq_lite::Session::connect(session, publish, subscribe).await?;
225			return Ok(session);
226		}
227
228		// Create futures for both possible protocols
229		let quic_url = url.clone();
230		let quic_handle = async {
231			match self.connect_quic(quic_url).await {
232				Ok(session) => Some(session),
233				Err(err) => {
234					tracing::warn!(%err, "QUIC connection failed");
235					None
236				}
237			}
238		};
239
240		let ws_handle = async {
241			match self.connect_websocket(url).await {
242				Ok(session) => Some(session),
243				Err(err) => {
244					tracing::warn!(%err, "WebSocket connection failed");
245					None
246				}
247			}
248		};
249
250		// Race the connection futures
251		Ok(tokio::select! {
252			Some(quic) = quic_handle => moq_lite::Session::connect(quic, publish, subscribe).await?,
253			Some(ws) = ws_handle => moq_lite::Session::connect(ws, publish, subscribe).await?,
254			// If both attempts fail, return an error
255			else => anyhow::bail!("failed to connect to server"),
256		})
257	}
258
259	async fn connect_quic(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
260		let mut config = self.tls.clone();
261
262		let host = url.host().context("invalid DNS name")?.to_string();
263		let port = url.port().unwrap_or(443);
264
265		// Look up the DNS entry.
266		let ip = tokio::net::lookup_host((host.clone(), port))
267			.await
268			.context("failed DNS lookup")?
269			.next()
270			.context("no DNS entries")?;
271
272		if url.scheme() == "http" {
273			// Perform a HTTP request to fetch the certificate fingerprint.
274			let mut fingerprint = url.clone();
275			fingerprint.set_path("/certificate.sha256");
276			fingerprint.set_query(None);
277			fingerprint.set_fragment(None);
278
279			tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
280
281			let resp = reqwest::get(fingerprint.as_str())
282				.await
283				.context("failed to fetch fingerprint")?
284				.error_for_status()
285				.context("fingerprint request failed")?;
286
287			let fingerprint = resp.text().await.context("failed to read fingerprint")?;
288			let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
289
290			let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
291			config.dangerous().set_certificate_verifier(Arc::new(verifier));
292
293			url.set_scheme("https").expect("failed to set scheme");
294		}
295
296		let alpn = match url.scheme() {
297			"https" => web_transport_quinn::ALPN,
298			"moql" => moq_lite::lite::ALPN,
299			"moqt" => moq_lite::ietf::ALPN,
300			_ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
301		};
302
303		// TODO support connecting to both ALPNs at the same time
304		config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
305		config.key_log = Arc::new(rustls::KeyLogFile::new());
306
307		let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
308		let mut config = quinn::ClientConfig::new(Arc::new(config));
309		config.transport_config(self.transport.clone());
310
311		tracing::debug!(%url, %ip, %alpn, "connecting");
312
313		let connection = self.quic.connect_with(config, ip, &host)?.await?;
314		tracing::Span::current().record("id", connection.stable_id());
315
316		let session = match alpn {
317			web_transport_quinn::ALPN => web_transport_quinn::Session::connect(connection, url).await?,
318			moq_lite::lite::ALPN | moq_lite::ietf::ALPN => web_transport_quinn::Session::raw(connection, url),
319			_ => unreachable!("ALPN was checked above"),
320		};
321
322		Ok(session)
323	}
324
325	async fn connect_websocket(&self, mut url: Url) -> anyhow::Result<web_transport_ws::Session> {
326		let host = url.host_str().context("missing hostname")?.to_string();
327		let port = url.port().unwrap_or_else(|| match url.scheme() {
328			"https" | "wss" | "moql" | "moqt" => 443,
329			"http" | "ws" => 80,
330			_ => 443,
331		});
332		let key = (host, port);
333
334		// Apply a small penalty to WebSocket to improve odds for QUIC to connect first,
335		// unless we've already had to fall back to WebSockets for this server.
336		// TODO if let chain
337		match self.websocket_delay {
338			Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
339				tokio::time::sleep(delay).await;
340				tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
341			}
342			_ => {}
343		}
344
345		// Convert URL scheme: http:// -> ws://, https:// -> wss://
346		match url.scheme() {
347			"http" => {
348				url.set_scheme("ws").expect("failed to set scheme");
349			}
350			"https" | "moql" | "moqt" => {
351				url.set_scheme("wss").expect("failed to set scheme");
352			}
353			"ws" | "wss" => {}
354			_ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
355		};
356
357		tracing::debug!(%url, "connecting via WebSocket");
358
359		// Connect using tokio-tungstenite
360		let (ws_stream, _response) = tokio_tungstenite::connect_async_with_config(
361			url.as_str(),
362			Some(tungstenite::protocol::WebSocketConfig {
363				max_message_size: Some(64 << 20), // 64 MB
364				max_frame_size: Some(16 << 20),   // 16 MB
365				accept_unmasked_frames: false,
366				..Default::default()
367			}),
368			false, // disable_nagle
369		)
370		.await
371		.context("failed to connect WebSocket")?;
372
373		// Wrap WebSocket in WebTransport compatibility layer
374		// Similar to what the relay does: web_transport_ws::Session::new(socket, true)
375		let session = web_transport_ws::Session::new(ws_stream, false);
376
377		tracing::warn!(%url, "using WebSocket fallback");
378		WEBSOCKET_WON.lock().unwrap().insert(key);
379
380		Ok(session)
381	}
382
383	#[cfg(feature = "iroh")]
384	async fn connect_iroh(&self, url: Url) -> anyhow::Result<web_transport_iroh::Session> {
385		let endpoint = self.iroh.as_ref().context("Iroh support is not enabled")?;
386		let alpn = match url.scheme() {
387			"moql+iroh" | "iroh" => moq_lite::lite::ALPN,
388			"moqt+iroh" => moq_lite::ietf::ALPN,
389			"h3+iroh" => web_transport_iroh::ALPN_H3,
390			_ => anyhow::bail!("Invalid URL: unknown scheme"),
391		};
392		let host = url.host().context("Invalid URL: missing host")?.to_string();
393		let endpoint_id: iroh::EndpointId = host.parse().context("Invalid URL: host is not an iroh endpoint id")?;
394		let conn = endpoint.connect(endpoint_id, alpn.as_bytes()).await?;
395		let session = match alpn {
396			web_transport_iroh::ALPN_H3 => {
397				// We need to change the scheme to `https` because currently web_transport_iroh only
398				// accepts that scheme.
399				let url = url_set_scheme(url, "https")?;
400				web_transport_iroh::Session::connect_h3(conn, url).await?
401			}
402			_ => web_transport_iroh::Session::raw(conn),
403		};
404		Ok(session)
405	}
406}
407
408#[derive(Debug)]
409struct NoCertificateVerification(crypto::Provider);
410
411impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
412	fn verify_server_cert(
413		&self,
414		_end_entity: &CertificateDer<'_>,
415		_intermediates: &[CertificateDer<'_>],
416		_server_name: &ServerName<'_>,
417		_ocsp: &[u8],
418		_now: UnixTime,
419	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
420		Ok(rustls::client::danger::ServerCertVerified::assertion())
421	}
422
423	fn verify_tls12_signature(
424		&self,
425		message: &[u8],
426		cert: &CertificateDer<'_>,
427		dss: &rustls::DigitallySignedStruct,
428	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
429		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
430	}
431
432	fn verify_tls13_signature(
433		&self,
434		message: &[u8],
435		cert: &CertificateDer<'_>,
436		dss: &rustls::DigitallySignedStruct,
437	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
438		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
439	}
440
441	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
442		self.0.signature_verification_algorithms.supported_schemes()
443	}
444}
445
446// Verify the certificate matches a provided fingerprint.
447#[derive(Debug)]
448struct FingerprintVerifier {
449	provider: crypto::Provider,
450	fingerprint: Vec<u8>,
451}
452
453impl FingerprintVerifier {
454	pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
455		Self { provider, fingerprint }
456	}
457}
458
459impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
460	fn verify_server_cert(
461		&self,
462		end_entity: &CertificateDer<'_>,
463		_intermediates: &[CertificateDer<'_>],
464		_server_name: &ServerName<'_>,
465		_ocsp: &[u8],
466		_now: UnixTime,
467	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
468		let fingerprint = crypto::sha256(&self.provider, end_entity);
469		if fingerprint.as_ref() == self.fingerprint.as_slice() {
470			Ok(rustls::client::danger::ServerCertVerified::assertion())
471		} else {
472			Err(rustls::Error::General("fingerprint mismatch".into()))
473		}
474	}
475
476	fn verify_tls12_signature(
477		&self,
478		message: &[u8],
479		cert: &CertificateDer<'_>,
480		dss: &rustls::DigitallySignedStruct,
481	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
482		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
483	}
484
485	fn verify_tls13_signature(
486		&self,
487		message: &[u8],
488		cert: &CertificateDer<'_>,
489		dss: &rustls::DigitallySignedStruct,
490	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
491		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
492	}
493
494	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
495		self.provider.signature_verification_algorithms.supported_schemes()
496	}
497}
498
499/// Returns a new URL with a changed scheme.
500///
501/// [`Url::set_scheme`] returns an error if the scheme change is not valid according to
502/// [the URL specification's section on legal scheme state overrides](https://url.spec.whatwg.org/#scheme-state).
503///
504/// This function allows all scheme changes, as long as the resulting URL is valid.
505#[cfg(feature = "iroh")]
506fn url_set_scheme(url: Url, scheme: &str) -> anyhow::Result<Url> {
507	let url = format!(
508		"{}:{}",
509		scheme,
510		url.to_string().split_once(":").context("invalid URL")?.1
511	)
512	.parse()?;
513	Ok(url)
514}