1use crate::crypto;
2use anyhow::Context;
3use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
4use rustls::RootCertStore;
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{LazyLock, Mutex};
8use std::{fs, io, net, sync::Arc, time};
9use url::Url;
10#[cfg(feature = "iroh")]
11use web_transport_iroh::iroh;
12use web_transport_ws::{tokio_tungstenite, tungstenite};
13
14static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
16
17#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
19#[serde(default, deny_unknown_fields)]
20pub struct ClientTls {
21 #[serde(skip_serializing_if = "Vec::is_empty")]
26 #[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
27 pub root: Vec<PathBuf>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
33 #[arg(
34 id = "tls-disable-verify",
35 long = "tls-disable-verify",
36 env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
37 action = clap::ArgAction::SetTrue
38 )]
39 pub disable_verify: Option<bool>,
40}
41
42#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
44#[serde(default, deny_unknown_fields)]
45pub struct ClientWebSocket {
46 #[arg(
49 id = "websocket-delay",
50 long = "websocket-delay",
51 env = "MOQ_CLIENT_WEBSOCKET_DELAY",
52 default_value = "200ms",
53 value_parser = humantime::parse_duration,
54 )]
55 #[serde(with = "humantime_serde")]
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub delay: Option<time::Duration>,
58}
59
60#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
62#[serde(deny_unknown_fields, default)]
63pub struct ClientConfig {
64 #[arg(
66 id = "client-bind",
67 long = "client-bind",
68 default_value = "[::]:0",
69 env = "MOQ_CLIENT_BIND"
70 )]
71 pub bind: net::SocketAddr,
72
73 #[command(flatten)]
74 #[serde(default)]
75 pub tls: ClientTls,
76
77 #[command(flatten)]
78 #[serde(default)]
79 pub websocket: ClientWebSocket,
80}
81
82impl ClientConfig {
83 pub fn init(self) -> anyhow::Result<Client> {
84 Client::new(self)
85 }
86}
87
88impl Default for ClientConfig {
89 fn default() -> Self {
90 Self {
91 bind: "[::]:0".parse().unwrap(),
92 tls: ClientTls::default(),
93 websocket: ClientWebSocket::default(),
94 }
95 }
96}
97
98#[derive(Clone)]
102pub struct Client {
103 pub quic: quinn::Endpoint,
104 pub tls: rustls::ClientConfig,
105 pub transport: Arc<quinn::TransportConfig>,
106 pub websocket_delay: Option<time::Duration>,
107 #[cfg(feature = "iroh")]
108 pub iroh: Option<iroh::Endpoint>,
109}
110
111impl Client {
112 pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
113 let provider = crypto::provider();
114
115 let mut roots = RootCertStore::empty();
117
118 if config.tls.root.is_empty() {
119 let native = rustls_native_certs::load_native_certs();
120
121 for err in native.errors {
123 tracing::warn!(%err, "failed to load root cert");
124 }
125
126 for cert in native.certs {
128 roots.add(cert).context("failed to add root cert")?;
129 }
130 } else {
131 for root in &config.tls.root {
133 let root = fs::File::open(root).context("failed to open root cert file")?;
134 let mut root = io::BufReader::new(root);
135
136 let root = rustls_pemfile::certs(&mut root)
137 .next()
138 .context("no roots found")?
139 .context("failed to read root cert")?;
140
141 roots.add(root).context("failed to add root cert")?;
142 }
143 }
144
145 let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
147 .with_protocol_versions(&[&rustls::version::TLS13])?
148 .with_root_certificates(roots)
149 .with_no_client_auth();
150
151 if config.tls.disable_verify.unwrap_or_default() {
153 tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
154
155 let noop = NoCertificateVerification(provider.clone());
156 tls.dangerous().set_certificate_verifier(Arc::new(noop));
157 }
158
159 let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
160
161 let mut transport = quinn::TransportConfig::default();
163 transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
164 transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
165 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
168
169 let runtime = quinn::default_runtime().context("no async runtime")?;
171 let endpoint_config = quinn::EndpointConfig::default();
172
173 let quic =
175 quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
176
177 Ok(Self {
178 quic,
179 tls,
180 transport,
181 websocket_delay: config.websocket.delay,
182 #[cfg(feature = "iroh")]
183 iroh: None,
184 })
185 }
186
187 #[cfg(feature = "iroh")]
188 pub fn with_iroh(&mut self, iroh: Option<iroh::Endpoint>) -> &mut Self {
189 self.iroh = iroh;
190 self
191 }
192
193 pub async fn connect(
195 &self,
196 url: Url,
197 publish: impl Into<Option<moq_lite::OriginConsumer>>,
198 subscribe: impl Into<Option<moq_lite::OriginProducer>>,
199 ) -> anyhow::Result<moq_lite::Session> {
200 #[cfg(feature = "iroh")]
201 if crate::iroh::is_iroh_url(&url) {
202 let session = self.connect_iroh(url).await?;
203 let session = moq_lite::Session::connect(session, publish, subscribe).await?;
204 return Ok(session);
205 }
206
207 let session = self.connect_quic(url).await?;
208 let session = moq_lite::Session::connect(session, publish, subscribe).await?;
209 Ok(session)
210 }
211
212 pub async fn connect_with_fallback(
216 &self,
217 url: Url,
218 publish: impl Into<Option<moq_lite::OriginConsumer>>,
219 subscribe: impl Into<Option<moq_lite::OriginProducer>>,
220 ) -> anyhow::Result<moq_lite::Session> {
221 #[cfg(feature = "iroh")]
222 if crate::iroh::is_iroh_url(&url) {
223 let session = self.connect_iroh(url).await?;
224 let session = moq_lite::Session::connect(session, publish, subscribe).await?;
225 return Ok(session);
226 }
227
228 let quic_url = url.clone();
230 let quic_handle = async {
231 match self.connect_quic(quic_url).await {
232 Ok(session) => Some(session),
233 Err(err) => {
234 tracing::warn!(%err, "QUIC connection failed");
235 None
236 }
237 }
238 };
239
240 let ws_handle = async {
241 match self.connect_websocket(url).await {
242 Ok(session) => Some(session),
243 Err(err) => {
244 tracing::warn!(%err, "WebSocket connection failed");
245 None
246 }
247 }
248 };
249
250 Ok(tokio::select! {
252 Some(quic) = quic_handle => moq_lite::Session::connect(quic, publish, subscribe).await?,
253 Some(ws) = ws_handle => moq_lite::Session::connect(ws, publish, subscribe).await?,
254 else => anyhow::bail!("failed to connect to server"),
256 })
257 }
258
259 async fn connect_quic(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
260 let mut config = self.tls.clone();
261
262 let host = url.host().context("invalid DNS name")?.to_string();
263 let port = url.port().unwrap_or(443);
264
265 let ip = tokio::net::lookup_host((host.clone(), port))
267 .await
268 .context("failed DNS lookup")?
269 .next()
270 .context("no DNS entries")?;
271
272 if url.scheme() == "http" {
273 let mut fingerprint = url.clone();
275 fingerprint.set_path("/certificate.sha256");
276 fingerprint.set_query(None);
277 fingerprint.set_fragment(None);
278
279 tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
280
281 let resp = reqwest::get(fingerprint.as_str())
282 .await
283 .context("failed to fetch fingerprint")?
284 .error_for_status()
285 .context("fingerprint request failed")?;
286
287 let fingerprint = resp.text().await.context("failed to read fingerprint")?;
288 let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
289
290 let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
291 config.dangerous().set_certificate_verifier(Arc::new(verifier));
292
293 url.set_scheme("https").expect("failed to set scheme");
294 }
295
296 let alpn = match url.scheme() {
297 "https" => web_transport_quinn::ALPN,
298 "moql" => moq_lite::lite::ALPN,
299 "moqt" => moq_lite::ietf::ALPN,
300 _ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
301 };
302
303 config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
305 config.key_log = Arc::new(rustls::KeyLogFile::new());
306
307 let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
308 let mut config = quinn::ClientConfig::new(Arc::new(config));
309 config.transport_config(self.transport.clone());
310
311 tracing::debug!(%url, %ip, %alpn, "connecting");
312
313 let connection = self.quic.connect_with(config, ip, &host)?.await?;
314 tracing::Span::current().record("id", connection.stable_id());
315
316 let session = match alpn {
317 web_transport_quinn::ALPN => web_transport_quinn::Session::connect(connection, url).await?,
318 moq_lite::lite::ALPN | moq_lite::ietf::ALPN => web_transport_quinn::Session::raw(connection, url),
319 _ => unreachable!("ALPN was checked above"),
320 };
321
322 Ok(session)
323 }
324
325 async fn connect_websocket(&self, mut url: Url) -> anyhow::Result<web_transport_ws::Session> {
326 let host = url.host_str().context("missing hostname")?.to_string();
327 let port = url.port().unwrap_or_else(|| match url.scheme() {
328 "https" | "wss" | "moql" | "moqt" => 443,
329 "http" | "ws" => 80,
330 _ => 443,
331 });
332 let key = (host, port);
333
334 match self.websocket_delay {
338 Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
339 tokio::time::sleep(delay).await;
340 tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
341 }
342 _ => {}
343 }
344
345 match url.scheme() {
347 "http" => {
348 url.set_scheme("ws").expect("failed to set scheme");
349 }
350 "https" | "moql" | "moqt" => {
351 url.set_scheme("wss").expect("failed to set scheme");
352 }
353 "ws" | "wss" => {}
354 _ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
355 };
356
357 tracing::debug!(%url, "connecting via WebSocket");
358
359 let (ws_stream, _response) = tokio_tungstenite::connect_async_with_config(
361 url.as_str(),
362 Some(tungstenite::protocol::WebSocketConfig {
363 max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), accept_unmasked_frames: false,
366 ..Default::default()
367 }),
368 false, )
370 .await
371 .context("failed to connect WebSocket")?;
372
373 let session = web_transport_ws::Session::new(ws_stream, false);
376
377 tracing::warn!(%url, "using WebSocket fallback");
378 WEBSOCKET_WON.lock().unwrap().insert(key);
379
380 Ok(session)
381 }
382
383 #[cfg(feature = "iroh")]
384 async fn connect_iroh(&self, url: Url) -> anyhow::Result<web_transport_iroh::Session> {
385 let endpoint = self.iroh.as_ref().context("Iroh support is not enabled")?;
386 let alpn = match url.scheme() {
387 "moql+iroh" | "iroh" => moq_lite::lite::ALPN,
388 "moqt+iroh" => moq_lite::ietf::ALPN,
389 "h3+iroh" => web_transport_iroh::ALPN_H3,
390 _ => anyhow::bail!("Invalid URL: unknown scheme"),
391 };
392 let host = url.host().context("Invalid URL: missing host")?.to_string();
393 let endpoint_id: iroh::EndpointId = host.parse().context("Invalid URL: host is not an iroh endpoint id")?;
394 let conn = endpoint.connect(endpoint_id, alpn.as_bytes()).await?;
395 let session = match alpn {
396 web_transport_iroh::ALPN_H3 => {
397 let url = url_set_scheme(url, "https")?;
400 web_transport_iroh::Session::connect_h3(conn, url).await?
401 }
402 _ => web_transport_iroh::Session::raw(conn),
403 };
404 Ok(session)
405 }
406}
407
408#[derive(Debug)]
409struct NoCertificateVerification(crypto::Provider);
410
411impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
412 fn verify_server_cert(
413 &self,
414 _end_entity: &CertificateDer<'_>,
415 _intermediates: &[CertificateDer<'_>],
416 _server_name: &ServerName<'_>,
417 _ocsp: &[u8],
418 _now: UnixTime,
419 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
420 Ok(rustls::client::danger::ServerCertVerified::assertion())
421 }
422
423 fn verify_tls12_signature(
424 &self,
425 message: &[u8],
426 cert: &CertificateDer<'_>,
427 dss: &rustls::DigitallySignedStruct,
428 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
429 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
430 }
431
432 fn verify_tls13_signature(
433 &self,
434 message: &[u8],
435 cert: &CertificateDer<'_>,
436 dss: &rustls::DigitallySignedStruct,
437 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
438 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
439 }
440
441 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
442 self.0.signature_verification_algorithms.supported_schemes()
443 }
444}
445
446#[derive(Debug)]
448struct FingerprintVerifier {
449 provider: crypto::Provider,
450 fingerprint: Vec<u8>,
451}
452
453impl FingerprintVerifier {
454 pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
455 Self { provider, fingerprint }
456 }
457}
458
459impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
460 fn verify_server_cert(
461 &self,
462 end_entity: &CertificateDer<'_>,
463 _intermediates: &[CertificateDer<'_>],
464 _server_name: &ServerName<'_>,
465 _ocsp: &[u8],
466 _now: UnixTime,
467 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
468 let fingerprint = crypto::sha256(&self.provider, end_entity);
469 if fingerprint.as_ref() == self.fingerprint.as_slice() {
470 Ok(rustls::client::danger::ServerCertVerified::assertion())
471 } else {
472 Err(rustls::Error::General("fingerprint mismatch".into()))
473 }
474 }
475
476 fn verify_tls12_signature(
477 &self,
478 message: &[u8],
479 cert: &CertificateDer<'_>,
480 dss: &rustls::DigitallySignedStruct,
481 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
482 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
483 }
484
485 fn verify_tls13_signature(
486 &self,
487 message: &[u8],
488 cert: &CertificateDer<'_>,
489 dss: &rustls::DigitallySignedStruct,
490 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
491 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
492 }
493
494 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
495 self.provider.signature_verification_algorithms.supported_schemes()
496 }
497}
498
499#[cfg(feature = "iroh")]
506fn url_set_scheme(url: Url, scheme: &str) -> anyhow::Result<Url> {
507 let url = format!(
508 "{}:{}",
509 scheme,
510 url.to_string().split_once(":").context("invalid URL")?.1
511 )
512 .parse()?;
513 Ok(url)
514}