1use crate::crypto;
2use anyhow::Context;
3use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
4use rustls::RootCertStore;
5use std::path::PathBuf;
6use std::{fs, io, net, sync::Arc, time};
7use url::Url;
8
9#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
10#[serde(default, deny_unknown_fields)]
11pub struct ClientTls {
12 #[serde(skip_serializing_if = "Vec::is_empty")]
17 #[arg(long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
18 pub root: Vec<PathBuf>,
19
20 #[serde(skip_serializing_if = "Option::is_none")]
25 #[arg(long = "tls-disable-verify", env = "MOQ_CLIENT_TLS_DISABLE_VERIFY")]
26 pub disable_verify: Option<bool>,
27}
28
29#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
30#[serde(deny_unknown_fields, default)]
31pub struct ClientConfig {
32 #[arg(long, id = "client-bind", default_value = "[::]:0", env = "MOQ_CLIENT_BIND")]
34 pub bind: net::SocketAddr,
35
36 #[command(flatten)]
37 #[serde(default)]
38 pub tls: ClientTls,
39}
40
41impl Default for ClientConfig {
42 fn default() -> Self {
43 Self {
44 bind: "[::]:0".parse().unwrap(),
45 tls: ClientTls::default(),
46 }
47 }
48}
49
50impl ClientConfig {
51 pub fn init(self) -> anyhow::Result<Client> {
52 Client::new(self)
53 }
54}
55
56#[derive(Clone)]
57pub struct Client {
58 pub quic: quinn::Endpoint,
59 pub tls: rustls::ClientConfig,
60 pub transport: Arc<quinn::TransportConfig>,
61}
62
63impl Client {
64 pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
65 let provider = crypto::provider();
66
67 let mut roots = RootCertStore::empty();
69
70 if config.tls.root.is_empty() {
71 let native = rustls_native_certs::load_native_certs();
72
73 for err in native.errors {
75 tracing::warn!(%err, "failed to load root cert");
76 }
77
78 for cert in native.certs {
80 roots.add(cert).context("failed to add root cert")?;
81 }
82 } else {
83 for root in &config.tls.root {
85 let root = fs::File::open(root).context("failed to open root cert file")?;
86 let mut root = io::BufReader::new(root);
87
88 let root = rustls_pemfile::certs(&mut root)
89 .next()
90 .context("no roots found")?
91 .context("failed to read root cert")?;
92
93 roots.add(root).context("failed to add root cert")?;
94 }
95 }
96
97 let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
99 .with_protocol_versions(&[&rustls::version::TLS13])?
100 .with_root_certificates(roots)
101 .with_no_client_auth();
102
103 if config.tls.disable_verify.unwrap_or_default() {
105 tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
106
107 let noop = NoCertificateVerification(provider.clone());
108 tls.dangerous().set_certificate_verifier(Arc::new(noop));
109 }
110
111 let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
112
113 let mut transport = quinn::TransportConfig::default();
115 transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
116 transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
117 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
120
121 let runtime = quinn::default_runtime().context("no async runtime")?;
123 let endpoint_config = quinn::EndpointConfig::default();
124
125 let quic =
127 quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
128
129 Ok(Self { quic, tls, transport })
130 }
131
132 pub async fn connect(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
133 let mut config = self.tls.clone();
134
135 let host = url.host().context("invalid DNS name")?.to_string();
136 let port = url.port().unwrap_or(443);
137
138 let ip = tokio::net::lookup_host((host.clone(), port))
140 .await
141 .context("failed DNS lookup")?
142 .next()
143 .context("no DNS entries")?;
144
145 if url.scheme() == "http" {
146 let mut fingerprint = url.clone();
148 fingerprint.set_path("/certificate.sha256");
149 fingerprint.set_query(None);
150 fingerprint.set_fragment(None);
151
152 tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
153
154 let resp = reqwest::get(fingerprint.as_str())
155 .await
156 .context("failed to fetch fingerprint")?
157 .error_for_status()
158 .context("fingerprint request failed")?;
159
160 let fingerprint = resp.text().await.context("failed to read fingerprint")?;
161 let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
162
163 let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
164 config.dangerous().set_certificate_verifier(Arc::new(verifier));
165
166 url.set_scheme("https").expect("failed to set scheme");
167 }
168
169 let alpn = match url.scheme() {
170 "https" => web_transport_quinn::ALPN,
171 "moql" => moq_lite::ALPN,
172 _ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
173 };
174
175 config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
177 config.key_log = Arc::new(rustls::KeyLogFile::new());
178
179 let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
180 let mut config = quinn::ClientConfig::new(Arc::new(config));
181 config.transport_config(self.transport.clone());
182
183 tracing::debug!(%url, %ip, %alpn, "connecting");
184
185 let connection = self.quic.connect_with(config, ip, &host)?.await?;
186 tracing::Span::current().record("id", connection.stable_id());
187
188 let session = match url.scheme() {
189 "https" => web_transport_quinn::Session::connect(connection, url).await?,
190 moq_lite::ALPN => web_transport_quinn::Session::raw(connection, url),
191 _ => unreachable!(),
192 };
193
194 Ok(session)
195 }
196}
197
198#[derive(Debug)]
199struct NoCertificateVerification(crypto::Provider);
200
201impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
202 fn verify_server_cert(
203 &self,
204 _end_entity: &CertificateDer<'_>,
205 _intermediates: &[CertificateDer<'_>],
206 _server_name: &ServerName<'_>,
207 _ocsp: &[u8],
208 _now: UnixTime,
209 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
210 Ok(rustls::client::danger::ServerCertVerified::assertion())
211 }
212
213 fn verify_tls12_signature(
214 &self,
215 message: &[u8],
216 cert: &CertificateDer<'_>,
217 dss: &rustls::DigitallySignedStruct,
218 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
219 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
220 }
221
222 fn verify_tls13_signature(
223 &self,
224 message: &[u8],
225 cert: &CertificateDer<'_>,
226 dss: &rustls::DigitallySignedStruct,
227 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
228 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
229 }
230
231 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
232 self.0.signature_verification_algorithms.supported_schemes()
233 }
234}
235
236#[derive(Debug)]
238struct FingerprintVerifier {
239 provider: crypto::Provider,
240 fingerprint: Vec<u8>,
241}
242
243impl FingerprintVerifier {
244 pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
245 Self { provider, fingerprint }
246 }
247}
248
249impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
250 fn verify_server_cert(
251 &self,
252 end_entity: &CertificateDer<'_>,
253 _intermediates: &[CertificateDer<'_>],
254 _server_name: &ServerName<'_>,
255 _ocsp: &[u8],
256 _now: UnixTime,
257 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
258 let fingerprint = crypto::sha256(&self.provider, end_entity);
259 if fingerprint.as_ref() == self.fingerprint.as_slice() {
260 Ok(rustls::client::danger::ServerCertVerified::assertion())
261 } else {
262 Err(rustls::Error::General("fingerprint mismatch".into()))
263 }
264 }
265
266 fn verify_tls12_signature(
267 &self,
268 message: &[u8],
269 cert: &CertificateDer<'_>,
270 dss: &rustls::DigitallySignedStruct,
271 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
272 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
273 }
274
275 fn verify_tls13_signature(
276 &self,
277 message: &[u8],
278 cert: &CertificateDer<'_>,
279 dss: &rustls::DigitallySignedStruct,
280 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
281 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
282 }
283
284 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
285 self.provider.signature_verification_algorithms.supported_schemes()
286 }
287}