moq_native/
client.rs

1use crate::crypto;
2use anyhow::Context;
3use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
4use rustls::RootCertStore;
5use std::path::PathBuf;
6use std::{fs, io, net, sync::Arc, time};
7use url::Url;
8
9#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
10#[serde(default, deny_unknown_fields)]
11pub struct ClientTls {
12	/// Use the TLS root at this path, encoded as PEM.
13	///
14	/// This value can be provided multiple times for multiple roots.
15	/// If this is empty, system roots will be used instead
16	#[serde(skip_serializing_if = "Vec::is_empty")]
17	#[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
18	pub root: Vec<PathBuf>,
19
20	/// Danger: Disable TLS certificate verification.
21	///
22	/// Fine for local development and between relays, but should be used in caution in production.
23	// This is an Option<bool> so clap skips over it when not provided, otherwise it is set to false.
24	#[serde(skip_serializing_if = "Option::is_none")]
25	#[arg(
26		id = "tls-disable-verify",
27		long = "tls-disable-verify",
28		env = "MOQ_CLIENT_TLS_DISABLE_VERIFY"
29	)]
30	pub disable_verify: Option<bool>,
31}
32
33#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
34#[serde(deny_unknown_fields, default)]
35pub struct ClientConfig {
36	/// Listen for UDP packets on the given address.
37	#[arg(
38		id = "client-bind",
39		long = "client-bind",
40		default_value = "[::]:0",
41		env = "MOQ_CLIENT_BIND"
42	)]
43	pub bind: net::SocketAddr,
44
45	#[command(flatten)]
46	#[serde(default)]
47	pub tls: ClientTls,
48}
49
50impl Default for ClientConfig {
51	fn default() -> Self {
52		Self {
53			bind: "[::]:0".parse().unwrap(),
54			tls: ClientTls::default(),
55		}
56	}
57}
58
59impl ClientConfig {
60	pub fn init(self) -> anyhow::Result<Client> {
61		Client::new(self)
62	}
63}
64
65#[derive(Clone)]
66pub struct Client {
67	pub quic: quinn::Endpoint,
68	pub tls: rustls::ClientConfig,
69	pub transport: Arc<quinn::TransportConfig>,
70}
71
72impl Client {
73	pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
74		let provider = crypto::provider();
75
76		// Create a list of acceptable root certificates.
77		let mut roots = RootCertStore::empty();
78
79		if config.tls.root.is_empty() {
80			let native = rustls_native_certs::load_native_certs();
81
82			// Log any errors that occurred while loading the native root certificates.
83			for err in native.errors {
84				tracing::warn!(%err, "failed to load root cert");
85			}
86
87			// Add the platform's native root certificates.
88			for cert in native.certs {
89				roots.add(cert).context("failed to add root cert")?;
90			}
91		} else {
92			// Add the specified root certificates.
93			for root in &config.tls.root {
94				let root = fs::File::open(root).context("failed to open root cert file")?;
95				let mut root = io::BufReader::new(root);
96
97				let root = rustls_pemfile::certs(&mut root)
98					.next()
99					.context("no roots found")?
100					.context("failed to read root cert")?;
101
102				roots.add(root).context("failed to add root cert")?;
103			}
104		}
105
106		// Create the TLS configuration we'll use as a client (relay -> relay)
107		let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
108			.with_protocol_versions(&[&rustls::version::TLS13])?
109			.with_root_certificates(roots)
110			.with_no_client_auth();
111
112		// Allow disabling TLS verification altogether.
113		if config.tls.disable_verify.unwrap_or_default() {
114			tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
115
116			let noop = NoCertificateVerification(provider.clone());
117			tls.dangerous().set_certificate_verifier(Arc::new(noop));
118		}
119
120		let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
121
122		// TODO Validate the BBR implementation before enabling it
123		let mut transport = quinn::TransportConfig::default();
124		transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
125		transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
126		//transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
127		transport.mtu_discovery_config(None); // Disable MTU discovery
128		let transport = Arc::new(transport);
129
130		// There's a bit more boilerplate to make a generic endpoint.
131		let runtime = quinn::default_runtime().context("no async runtime")?;
132		let endpoint_config = quinn::EndpointConfig::default();
133
134		// Create the generic QUIC endpoint.
135		let quic =
136			quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
137
138		Ok(Self { quic, tls, transport })
139	}
140
141	pub async fn connect(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
142		let mut config = self.tls.clone();
143
144		let host = url.host().context("invalid DNS name")?.to_string();
145		let port = url.port().unwrap_or(443);
146
147		// Look up the DNS entry.
148		let ip = tokio::net::lookup_host((host.clone(), port))
149			.await
150			.context("failed DNS lookup")?
151			.next()
152			.context("no DNS entries")?;
153
154		if url.scheme() == "http" {
155			// Perform a HTTP request to fetch the certificate fingerprint.
156			let mut fingerprint = url.clone();
157			fingerprint.set_path("/certificate.sha256");
158			fingerprint.set_query(None);
159			fingerprint.set_fragment(None);
160
161			tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
162
163			let resp = reqwest::get(fingerprint.as_str())
164				.await
165				.context("failed to fetch fingerprint")?
166				.error_for_status()
167				.context("fingerprint request failed")?;
168
169			let fingerprint = resp.text().await.context("failed to read fingerprint")?;
170			let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
171
172			let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
173			config.dangerous().set_certificate_verifier(Arc::new(verifier));
174
175			url.set_scheme("https").expect("failed to set scheme");
176		}
177
178		let alpn = match url.scheme() {
179			"https" => web_transport_quinn::ALPN,
180			"moql" => moq_lite::ALPN,
181			_ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
182		};
183
184		// TODO support connecting to both ALPNs at the same time
185		config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
186		config.key_log = Arc::new(rustls::KeyLogFile::new());
187
188		let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
189		let mut config = quinn::ClientConfig::new(Arc::new(config));
190		config.transport_config(self.transport.clone());
191
192		tracing::debug!(%url, %ip, %alpn, "connecting");
193
194		let connection = self.quic.connect_with(config, ip, &host)?.await?;
195		tracing::Span::current().record("id", connection.stable_id());
196
197		let session = match alpn {
198			web_transport_quinn::ALPN => web_transport_quinn::Session::connect(connection, url).await?,
199			moq_lite::ALPN => web_transport_quinn::Session::raw(connection, url),
200			_ => unreachable!("ALPN was checked above"),
201		};
202
203		Ok(session)
204	}
205}
206
207#[derive(Debug)]
208struct NoCertificateVerification(crypto::Provider);
209
210impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
211	fn verify_server_cert(
212		&self,
213		_end_entity: &CertificateDer<'_>,
214		_intermediates: &[CertificateDer<'_>],
215		_server_name: &ServerName<'_>,
216		_ocsp: &[u8],
217		_now: UnixTime,
218	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
219		Ok(rustls::client::danger::ServerCertVerified::assertion())
220	}
221
222	fn verify_tls12_signature(
223		&self,
224		message: &[u8],
225		cert: &CertificateDer<'_>,
226		dss: &rustls::DigitallySignedStruct,
227	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
228		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
229	}
230
231	fn verify_tls13_signature(
232		&self,
233		message: &[u8],
234		cert: &CertificateDer<'_>,
235		dss: &rustls::DigitallySignedStruct,
236	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
237		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
238	}
239
240	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
241		self.0.signature_verification_algorithms.supported_schemes()
242	}
243}
244
245// Verify the certificate matches a provided fingerprint.
246#[derive(Debug)]
247struct FingerprintVerifier {
248	provider: crypto::Provider,
249	fingerprint: Vec<u8>,
250}
251
252impl FingerprintVerifier {
253	pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
254		Self { provider, fingerprint }
255	}
256}
257
258impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
259	fn verify_server_cert(
260		&self,
261		end_entity: &CertificateDer<'_>,
262		_intermediates: &[CertificateDer<'_>],
263		_server_name: &ServerName<'_>,
264		_ocsp: &[u8],
265		_now: UnixTime,
266	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
267		let fingerprint = crypto::sha256(&self.provider, end_entity);
268		if fingerprint.as_ref() == self.fingerprint.as_slice() {
269			Ok(rustls::client::danger::ServerCertVerified::assertion())
270		} else {
271			Err(rustls::Error::General("fingerprint mismatch".into()))
272		}
273	}
274
275	fn verify_tls12_signature(
276		&self,
277		message: &[u8],
278		cert: &CertificateDer<'_>,
279		dss: &rustls::DigitallySignedStruct,
280	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
281		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
282	}
283
284	fn verify_tls13_signature(
285		&self,
286		message: &[u8],
287		cert: &CertificateDer<'_>,
288		dss: &rustls::DigitallySignedStruct,
289	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
290		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
291	}
292
293	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
294		self.provider.signature_verification_algorithms.supported_schemes()
295	}
296}