moq_native/
client.rs

1use anyhow::Context;
2use ring::digest::{digest, SHA256};
3use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
4use rustls::RootCertStore;
5use std::path::PathBuf;
6use std::{fs, io, net, sync::Arc, time};
7use url::Url;
8
9use web_transport::quinn as web_transport_quinn;
10
11#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
12#[serde(default, deny_unknown_fields)]
13pub struct ClientTls {
14	/// Use the TLS root at this path, encoded as PEM.
15	///
16	/// This value can be provided multiple times for multiple roots.
17	/// If this is empty, system roots will be used instead
18	#[serde(skip_serializing_if = "Vec::is_empty")]
19	#[arg(long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
20	pub root: Vec<PathBuf>,
21
22	/// Danger: Disable TLS certificate verification.
23	///
24	/// Fine for local development and between relays, but should be used in caution in production.
25	// This is an Option<bool> so clap skips over it when not provided, otherwise it is set to false.
26	#[serde(skip_serializing_if = "Option::is_none")]
27	#[arg(long = "tls-disable-verify", env = "MOQ_CLIENT_TLS_DISABLE_VERIFY")]
28	pub disable_verify: Option<bool>,
29}
30
31#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
32#[serde(deny_unknown_fields, default)]
33pub struct ClientConfig {
34	/// Listen for UDP packets on the given address.
35	#[arg(long, id = "client-bind", default_value = "[::]:0", env = "MOQ_CLIENT_BIND")]
36	pub bind: net::SocketAddr,
37
38	#[command(flatten)]
39	#[serde(default)]
40	pub tls: ClientTls,
41}
42
43impl Default for ClientConfig {
44	fn default() -> Self {
45		Self {
46			bind: "[::]:0".parse().unwrap(),
47			tls: ClientTls::default(),
48		}
49	}
50}
51
52impl ClientConfig {
53	pub fn init(self) -> anyhow::Result<Client> {
54		Client::new(self)
55	}
56}
57
58#[derive(Clone)]
59pub struct Client {
60	pub quic: quinn::Endpoint,
61	pub tls: rustls::ClientConfig,
62	pub transport: Arc<quinn::TransportConfig>,
63}
64
65impl Client {
66	pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
67		let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
68
69		// Create a list of acceptable root certificates.
70		let mut roots = RootCertStore::empty();
71
72		if config.tls.root.is_empty() {
73			let native = rustls_native_certs::load_native_certs();
74
75			// Log any errors that occurred while loading the native root certificates.
76			for err in native.errors {
77				tracing::warn!(?err, "failed to load root cert");
78			}
79
80			// Add the platform's native root certificates.
81			for cert in native.certs {
82				roots.add(cert).context("failed to add root cert")?;
83			}
84		} else {
85			// Add the specified root certificates.
86			for root in &config.tls.root {
87				let root = fs::File::open(root).context("failed to open root cert file")?;
88				let mut root = io::BufReader::new(root);
89
90				let root = rustls_pemfile::certs(&mut root)
91					.next()
92					.context("no roots found")?
93					.context("failed to read root cert")?;
94
95				roots.add(root).context("failed to add root cert")?;
96			}
97		}
98
99		// Create the TLS configuration we'll use as a client (relay -> relay)
100		let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
101			.with_protocol_versions(&[&rustls::version::TLS13])?
102			.with_root_certificates(roots)
103			.with_no_client_auth();
104
105		// Allow disabling TLS verification altogether.
106		if config.tls.disable_verify.unwrap_or_default() {
107			tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
108
109			let noop = NoCertificateVerification(provider.clone());
110			tls.dangerous().set_certificate_verifier(Arc::new(noop));
111		}
112
113		let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
114
115		// Enable BBR congestion control
116		// TODO validate the implementation
117		let mut transport = quinn::TransportConfig::default();
118		transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
119		transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
120		transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
121		transport.mtu_discovery_config(None); // Disable MTU discovery
122		let transport = Arc::new(transport);
123
124		// There's a bit more boilerplate to make a generic endpoint.
125		let runtime = quinn::default_runtime().context("no async runtime")?;
126		let endpoint_config = quinn::EndpointConfig::default();
127
128		// Create the generic QUIC endpoint.
129		let quic =
130			quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
131
132		Ok(Self { quic, tls, transport })
133	}
134
135	pub async fn connect(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
136		let mut config = self.tls.clone();
137
138		let host = url.host().context("invalid DNS name")?.to_string();
139		let port = url.port().unwrap_or(443);
140
141		// Look up the DNS entry.
142		let ip = tokio::net::lookup_host((host.clone(), port))
143			.await
144			.context("failed DNS lookup")?
145			.next()
146			.context("no DNS entries")?;
147
148		if url.scheme() == "http" {
149			// Perform a HTTP request to fetch the certificate fingerprint.
150			let mut fingerprint = url.clone();
151			fingerprint.set_path("/certificate.sha256");
152			fingerprint.set_query(None);
153			fingerprint.set_fragment(None);
154
155			tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
156
157			let resp = reqwest::get(fingerprint.as_str())
158				.await
159				.context("failed to fetch fingerprint")?
160				.error_for_status()
161				.context("fingerprint request failed")?;
162
163			let fingerprint = resp.text().await.context("failed to read fingerprint")?;
164			let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
165
166			let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
167			config.dangerous().set_certificate_verifier(Arc::new(verifier));
168
169			url.set_scheme("https").expect("failed to set scheme");
170		}
171
172		let alpn = match url.scheme() {
173			"https" => web_transport::quinn::ALPN,
174			"moql" => moq_lite::ALPN,
175			_ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
176		};
177
178		// TODO support connecting to both ALPNs at the same time
179		config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
180		config.key_log = Arc::new(rustls::KeyLogFile::new());
181
182		let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
183		let mut config = quinn::ClientConfig::new(Arc::new(config));
184		config.transport_config(self.transport.clone());
185
186		tracing::debug!(%url, %ip, %alpn, "connecting");
187
188		let connection = self.quic.connect_with(config, ip, &host)?.await?;
189		tracing::Span::current().record("id", connection.stable_id());
190
191		let session = match url.scheme() {
192			"https" => web_transport::quinn::Session::connect(connection, url).await?,
193			moq_lite::ALPN => web_transport::quinn::Session::raw(connection, url),
194			_ => unreachable!(),
195		};
196
197		Ok(session)
198	}
199}
200
201#[derive(Debug)]
202struct NoCertificateVerification(Arc<rustls::crypto::CryptoProvider>);
203
204impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
205	fn verify_server_cert(
206		&self,
207		_end_entity: &CertificateDer<'_>,
208		_intermediates: &[CertificateDer<'_>],
209		_server_name: &ServerName<'_>,
210		_ocsp: &[u8],
211		_now: UnixTime,
212	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
213		Ok(rustls::client::danger::ServerCertVerified::assertion())
214	}
215
216	fn verify_tls12_signature(
217		&self,
218		message: &[u8],
219		cert: &CertificateDer<'_>,
220		dss: &rustls::DigitallySignedStruct,
221	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
222		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
223	}
224
225	fn verify_tls13_signature(
226		&self,
227		message: &[u8],
228		cert: &CertificateDer<'_>,
229		dss: &rustls::DigitallySignedStruct,
230	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
231		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
232	}
233
234	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
235		self.0.signature_verification_algorithms.supported_schemes()
236	}
237}
238
239// Verify the certificate matches a provided fingerprint.
240#[derive(Debug)]
241struct FingerprintVerifier {
242	provider: Arc<rustls::crypto::CryptoProvider>,
243	fingerprint: Vec<u8>,
244}
245
246impl FingerprintVerifier {
247	pub fn new(provider: Arc<rustls::crypto::CryptoProvider>, fingerprint: Vec<u8>) -> Self {
248		Self { provider, fingerprint }
249	}
250}
251
252impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
253	fn verify_server_cert(
254		&self,
255		end_entity: &CertificateDer<'_>,
256		_intermediates: &[CertificateDer<'_>],
257		_server_name: &ServerName<'_>,
258		_ocsp: &[u8],
259		_now: UnixTime,
260	) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
261		let fingerprint = digest(&SHA256, end_entity);
262		if fingerprint.as_ref() == self.fingerprint.as_slice() {
263			Ok(rustls::client::danger::ServerCertVerified::assertion())
264		} else {
265			Err(rustls::Error::General("fingerprint mismatch".into()))
266		}
267	}
268
269	fn verify_tls12_signature(
270		&self,
271		message: &[u8],
272		cert: &CertificateDer<'_>,
273		dss: &rustls::DigitallySignedStruct,
274	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
275		rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
276	}
277
278	fn verify_tls13_signature(
279		&self,
280		message: &[u8],
281		cert: &CertificateDer<'_>,
282		dss: &rustls::DigitallySignedStruct,
283	) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
284		rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
285	}
286
287	fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
288		self.provider.signature_verification_algorithms.supported_schemes()
289	}
290}