1use anyhow::Context;
2use ring::digest::{digest, SHA256};
3use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
4use rustls::RootCertStore;
5use std::path::PathBuf;
6use std::{fs, io, net, sync::Arc, time};
7use url::Url;
8
9use web_transport::quinn as web_transport_quinn;
10
11#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
12#[serde(default, deny_unknown_fields)]
13pub struct ClientTls {
14 #[serde(skip_serializing_if = "Vec::is_empty")]
19 #[arg(long = "tls-root")]
20 pub root: Vec<PathBuf>,
21
22 #[serde(skip_serializing_if = "Option::is_none")]
27 #[arg(long = "tls-disable-verify")]
28 pub disable_verify: Option<bool>,
29}
30
31#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
32#[serde(deny_unknown_fields, default)]
33pub struct ClientConfig {
34 #[arg(long, id = "client-bind", default_value = "[::]:0")]
36 pub bind: net::SocketAddr,
37
38 #[command(flatten)]
39 #[serde(default)]
40 pub tls: ClientTls,
41}
42
43impl Default for ClientConfig {
44 fn default() -> Self {
45 Self {
46 bind: "[::]:0".parse().unwrap(),
47 tls: ClientTls::default(),
48 }
49 }
50}
51
52impl ClientConfig {
53 pub fn init(self) -> anyhow::Result<Client> {
54 Client::new(self)
55 }
56}
57
58#[derive(Clone)]
59pub struct Client {
60 pub quic: quinn::Endpoint,
61 pub tls: rustls::ClientConfig,
62 pub transport: Arc<quinn::TransportConfig>,
63}
64
65impl Client {
66 pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
67 let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
68
69 let mut roots = RootCertStore::empty();
71
72 if config.tls.root.is_empty() {
73 let native = rustls_native_certs::load_native_certs();
74
75 for err in native.errors {
77 tracing::warn!(?err, "failed to load root cert");
78 }
79
80 for cert in native.certs {
82 roots.add(cert).context("failed to add root cert")?;
83 }
84 } else {
85 for root in &config.tls.root {
87 let root = fs::File::open(root).context("failed to open root cert file")?;
88 let mut root = io::BufReader::new(root);
89
90 let root = rustls_pemfile::certs(&mut root)
91 .next()
92 .context("no roots found")?
93 .context("failed to read root cert")?;
94
95 roots.add(root).context("failed to add root cert")?;
96 }
97 }
98
99 let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
101 .with_protocol_versions(&[&rustls::version::TLS13])?
102 .with_root_certificates(roots)
103 .with_no_client_auth();
104
105 if config.tls.disable_verify.unwrap_or_default() {
107 tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
108
109 let noop = NoCertificateVerification(provider.clone());
110 tls.dangerous().set_certificate_verifier(Arc::new(noop));
111 }
112
113 let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
114
115 let mut transport = quinn::TransportConfig::default();
118 transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
119 transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
120 transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
121 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
123
124 let runtime = quinn::default_runtime().context("no async runtime")?;
126 let endpoint_config = quinn::EndpointConfig::default();
127
128 let quic =
130 quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
131
132 Ok(Self { quic, tls, transport })
133 }
134
135 pub async fn connect(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
136 let mut config = self.tls.clone();
137
138 let host = url.host().context("invalid DNS name")?.to_string();
139 let port = url.port().unwrap_or(443);
140
141 let ip = tokio::net::lookup_host((host.clone(), port))
143 .await
144 .context("failed DNS lookup")?
145 .next()
146 .context("no DNS entries")?;
147
148 if url.scheme() == "http" {
149 let mut fingerprint = url.clone();
151 fingerprint.set_path("/certificate.sha256");
152 fingerprint.set_query(None);
153 fingerprint.set_fragment(None);
154
155 tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
156
157 let resp = reqwest::get(fingerprint.as_str())
158 .await
159 .context("failed to fetch fingerprint")?
160 .error_for_status()
161 .context("fingerprint request failed")?;
162
163 let fingerprint = resp.text().await.context("failed to read fingerprint")?;
164 let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
165
166 let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
167 config.dangerous().set_certificate_verifier(Arc::new(verifier));
168
169 url.set_scheme("https").expect("failed to set scheme");
170 }
171
172 let alpn = match url.scheme() {
173 "https" => web_transport::quinn::ALPN,
174 "moql" => moq_lite::ALPN,
175 _ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
176 };
177
178 config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
180 config.key_log = Arc::new(rustls::KeyLogFile::new());
181
182 let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
183 let mut config = quinn::ClientConfig::new(Arc::new(config));
184 config.transport_config(self.transport.clone());
185
186 tracing::debug!(%url, %ip, %alpn, "connecting");
187
188 let connection = self.quic.connect_with(config, ip, &host)?.await?;
189 tracing::Span::current().record("id", connection.stable_id());
190
191 let session = match url.scheme() {
192 "https" => web_transport::quinn::Session::connect(connection, url).await?,
193 moq_lite::ALPN => web_transport::quinn::Session::raw(connection, url),
194 _ => unreachable!(),
195 };
196
197 Ok(session)
198 }
199}
200
201#[derive(Debug)]
202struct NoCertificateVerification(Arc<rustls::crypto::CryptoProvider>);
203
204impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
205 fn verify_server_cert(
206 &self,
207 _end_entity: &CertificateDer<'_>,
208 _intermediates: &[CertificateDer<'_>],
209 _server_name: &ServerName<'_>,
210 _ocsp: &[u8],
211 _now: UnixTime,
212 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
213 Ok(rustls::client::danger::ServerCertVerified::assertion())
214 }
215
216 fn verify_tls12_signature(
217 &self,
218 message: &[u8],
219 cert: &CertificateDer<'_>,
220 dss: &rustls::DigitallySignedStruct,
221 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
222 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
223 }
224
225 fn verify_tls13_signature(
226 &self,
227 message: &[u8],
228 cert: &CertificateDer<'_>,
229 dss: &rustls::DigitallySignedStruct,
230 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
231 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
232 }
233
234 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
235 self.0.signature_verification_algorithms.supported_schemes()
236 }
237}
238
239#[derive(Debug)]
241struct FingerprintVerifier {
242 provider: Arc<rustls::crypto::CryptoProvider>,
243 fingerprint: Vec<u8>,
244}
245
246impl FingerprintVerifier {
247 pub fn new(provider: Arc<rustls::crypto::CryptoProvider>, fingerprint: Vec<u8>) -> Self {
248 Self { provider, fingerprint }
249 }
250}
251
252impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
253 fn verify_server_cert(
254 &self,
255 end_entity: &CertificateDer<'_>,
256 _intermediates: &[CertificateDer<'_>],
257 _server_name: &ServerName<'_>,
258 _ocsp: &[u8],
259 _now: UnixTime,
260 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
261 let fingerprint = digest(&SHA256, end_entity);
262 if fingerprint.as_ref() == self.fingerprint.as_slice() {
263 Ok(rustls::client::danger::ServerCertVerified::assertion())
264 } else {
265 Err(rustls::Error::General("fingerprint mismatch".into()))
266 }
267 }
268
269 fn verify_tls12_signature(
270 &self,
271 message: &[u8],
272 cert: &CertificateDer<'_>,
273 dss: &rustls::DigitallySignedStruct,
274 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
275 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
276 }
277
278 fn verify_tls13_signature(
279 &self,
280 message: &[u8],
281 cert: &CertificateDer<'_>,
282 dss: &rustls::DigitallySignedStruct,
283 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
284 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
285 }
286
287 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
288 self.provider.signature_verification_algorithms.supported_schemes()
289 }
290}