1use crate::crypto;
2use anyhow::Context;
3use rustls::RootCertStore;
4use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{LazyLock, Mutex};
8use std::{fs, io, net, sync::Arc, time};
9use url::Url;
10#[cfg(feature = "iroh")]
11use web_transport_iroh::iroh;
12use web_transport_ws::{tokio_tungstenite, tungstenite};
13
14static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
16
17#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
19#[serde(default, deny_unknown_fields)]
20#[non_exhaustive]
21pub struct ClientTls {
22 #[serde(skip_serializing_if = "Vec::is_empty")]
27 #[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
28 pub root: Vec<PathBuf>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
34 #[arg(
35 id = "tls-disable-verify",
36 long = "tls-disable-verify",
37 env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
38 default_missing_value = "true",
39 num_args = 0..=1,
40 value_parser = clap::value_parser!(bool),
41 )]
42 pub disable_verify: Option<bool>,
43}
44
45#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
47#[serde(default, deny_unknown_fields)]
48#[non_exhaustive]
49pub struct ClientWebSocket {
50 #[arg(
52 id = "websocket-enabled",
53 long = "websocket-enabled",
54 env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
55 default_value = "true"
56 )]
57 pub enabled: bool,
58
59 #[arg(
62 id = "websocket-delay",
63 long = "websocket-delay",
64 env = "MOQ_CLIENT_WEBSOCKET_DELAY",
65 default_value = "200ms",
66 value_parser = humantime::parse_duration,
67 )]
68 #[serde(with = "humantime_serde")]
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub delay: Option<time::Duration>,
71}
72
73impl Default for ClientWebSocket {
74 fn default() -> Self {
75 Self {
76 enabled: true,
77 delay: Some(time::Duration::from_millis(200)),
78 }
79 }
80}
81
82#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
84#[serde(deny_unknown_fields, default)]
85#[non_exhaustive]
86pub struct ClientConfig {
87 #[arg(
89 id = "client-bind",
90 long = "client-bind",
91 default_value = "[::]:0",
92 env = "MOQ_CLIENT_BIND"
93 )]
94 pub bind: net::SocketAddr,
95
96 #[command(flatten)]
97 #[serde(default)]
98 pub tls: ClientTls,
99
100 #[command(flatten)]
101 #[serde(default)]
102 pub websocket: ClientWebSocket,
103}
104
105impl ClientConfig {
106 pub fn init(self) -> anyhow::Result<Client> {
107 Client::new(self)
108 }
109}
110
111impl Default for ClientConfig {
112 fn default() -> Self {
113 Self {
114 bind: "[::]:0".parse().unwrap(),
115 tls: ClientTls::default(),
116 websocket: ClientWebSocket::default(),
117 }
118 }
119}
120
121#[derive(Clone)]
125#[non_exhaustive]
126pub struct Client {
127 pub moq: moq_lite::Client,
128 pub quic: quinn::Endpoint,
129 pub tls: rustls::ClientConfig,
130 pub transport: Arc<quinn::TransportConfig>,
131 pub websocket: ClientWebSocket,
132 #[cfg(feature = "iroh")]
133 pub iroh: Option<iroh::Endpoint>,
134}
135
136impl Client {
137 pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
138 let provider = crypto::provider();
139
140 let mut roots = RootCertStore::empty();
142
143 if config.tls.root.is_empty() {
144 let native = rustls_native_certs::load_native_certs();
145
146 for err in native.errors {
148 tracing::warn!(%err, "failed to load root cert");
149 }
150
151 for cert in native.certs {
153 roots.add(cert).context("failed to add root cert")?;
154 }
155 } else {
156 for root in &config.tls.root {
158 let root = fs::File::open(root).context("failed to open root cert file")?;
159 let mut root = io::BufReader::new(root);
160
161 let root = rustls_pemfile::certs(&mut root)
162 .next()
163 .context("no roots found")?
164 .context("failed to read root cert")?;
165
166 roots.add(root).context("failed to add root cert")?;
167 }
168 }
169
170 let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
172 .with_protocol_versions(&[&rustls::version::TLS13])?
173 .with_root_certificates(roots)
174 .with_no_client_auth();
175
176 if config.tls.disable_verify.unwrap_or_default() {
178 tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
179
180 let noop = NoCertificateVerification(provider.clone());
181 tls.dangerous().set_certificate_verifier(Arc::new(noop));
182 }
183
184 let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
185
186 let mut transport = quinn::TransportConfig::default();
188 transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
189 transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
190 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
193
194 let runtime = quinn::default_runtime().context("no async runtime")?;
196 let endpoint_config = quinn::EndpointConfig::default();
197
198 let quic =
200 quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
201
202 Ok(Self {
203 moq: moq_lite::Client::new(),
204 quic,
205 tls,
206 transport,
207 websocket: config.websocket,
208 #[cfg(feature = "iroh")]
209 iroh: None,
210 })
211 }
212
213 #[cfg(feature = "iroh")]
214 pub fn with_iroh(mut self, iroh: Option<iroh::Endpoint>) -> Self {
215 self.iroh = iroh;
216 self
217 }
218
219 pub fn with_publish(mut self, publish: impl Into<Option<moq_lite::OriginConsumer>>) -> Self {
220 self.moq = self.moq.with_publish(publish);
221 self
222 }
223
224 pub fn with_consume(mut self, consume: impl Into<Option<moq_lite::OriginProducer>>) -> Self {
225 self.moq = self.moq.with_consume(consume);
226 self
227 }
228
229 pub async fn connect(&self, url: Url) -> anyhow::Result<moq_lite::Session> {
237 #[cfg(feature = "iroh")]
238 if crate::iroh::is_iroh_url(&url) {
239 let session = self.connect_iroh(url).await?;
240 let session = self.moq.connect(session).await?;
241 return Ok(session);
242 }
243
244 let quic_url = url.clone();
246 let quic_handle = async {
247 let res = self.connect_quic(quic_url).await;
248 if let Err(err) = &res {
249 tracing::warn!(%err, "QUIC connection failed");
250 }
251 res
252 };
253
254 let ws_handle = async {
255 if !self.websocket.enabled {
256 return None;
257 }
258
259 let res = self.connect_websocket(url).await;
260 if let Err(err) = &res {
261 tracing::warn!(%err, "WebSocket connection failed");
262 }
263 Some(res)
264 };
265
266 Ok(tokio::select! {
268 Ok(quic) = quic_handle => self.moq.connect(quic).await?,
269 Some(Ok(ws)) = ws_handle => self.moq.connect(ws).await?,
270 else => anyhow::bail!("failed to connect to server"),
272 })
273 }
274
275 async fn connect_quic(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
276 let mut config = self.tls.clone();
277
278 let host = url.host().context("invalid DNS name")?.to_string();
279 let port = url.port().unwrap_or(443);
280
281 let ip = tokio::net::lookup_host((host.clone(), port))
283 .await
284 .context("failed DNS lookup")?
285 .next()
286 .context("no DNS entries")?;
287
288 if url.scheme() == "http" {
289 let mut fingerprint = url.clone();
291 fingerprint.set_path("/certificate.sha256");
292 fingerprint.set_query(None);
293 fingerprint.set_fragment(None);
294
295 tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
296
297 let resp = reqwest::get(fingerprint.as_str())
298 .await
299 .context("failed to fetch fingerprint")?
300 .error_for_status()
301 .context("fingerprint request failed")?;
302
303 let fingerprint = resp.text().await.context("failed to read fingerprint")?;
304 let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
305
306 let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
307 config.dangerous().set_certificate_verifier(Arc::new(verifier));
308
309 url.set_scheme("https").expect("failed to set scheme");
310 }
311
312 let alpn = match url.scheme() {
313 "https" => web_transport_quinn::ALPN,
314 "moql" => moq_lite::lite::ALPN,
315 "moqt" => moq_lite::ietf::ALPN,
316 _ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
317 };
318
319 config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
321 config.key_log = Arc::new(rustls::KeyLogFile::new());
322
323 let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
324 let mut config = quinn::ClientConfig::new(Arc::new(config));
325 config.transport_config(self.transport.clone());
326
327 tracing::debug!(%url, %ip, %alpn, "connecting");
328
329 let connection = self.quic.connect_with(config, ip, &host)?.await?;
330 tracing::Span::current().record("id", connection.stable_id());
331
332 let session = match alpn {
333 web_transport_quinn::ALPN => web_transport_quinn::Session::connect(connection, url).await?,
334 moq_lite::lite::ALPN | moq_lite::ietf::ALPN => web_transport_quinn::Session::raw(connection, url),
335 _ => unreachable!("ALPN was checked above"),
336 };
337
338 Ok(session)
339 }
340
341 async fn connect_websocket(&self, mut url: Url) -> anyhow::Result<web_transport_ws::Session> {
342 anyhow::ensure!(self.websocket.enabled, "WebSocket support is disabled");
343
344 let host = url.host_str().context("missing hostname")?.to_string();
345 let port = url.port().unwrap_or_else(|| match url.scheme() {
346 "https" | "wss" | "moql" | "moqt" => 443,
347 "http" | "ws" => 80,
348 _ => 443,
349 });
350 let key = (host, port);
351
352 match self.websocket.delay {
356 Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
357 tokio::time::sleep(delay).await;
358 tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
359 }
360 _ => {}
361 }
362
363 let needs_tls = match url.scheme() {
365 "http" => {
366 url.set_scheme("ws").expect("failed to set scheme");
367 false
368 }
369 "https" | "moql" | "moqt" => {
370 url.set_scheme("wss").expect("failed to set scheme");
371 true
372 }
373 "ws" => false,
374 "wss" => true,
375 _ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
376 };
377
378 tracing::debug!(%url, "connecting via WebSocket");
379
380 let connector = if needs_tls {
382 Some(tokio_tungstenite::Connector::Rustls(Arc::new(self.tls.clone())))
383 } else {
384 None
385 };
386
387 let (ws_stream, _response) = tokio_tungstenite::connect_async_tls_with_config(
389 url.as_str(),
390 Some(tungstenite::protocol::WebSocketConfig {
391 max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), accept_unmasked_frames: false,
394 ..Default::default()
395 }),
396 false, connector,
398 )
399 .await
400 .context("failed to connect WebSocket")?;
401
402 let session = web_transport_ws::Session::new(ws_stream, false);
405
406 tracing::warn!(%url, "using WebSocket fallback");
407 WEBSOCKET_WON.lock().unwrap().insert(key);
408
409 Ok(session)
410 }
411
412 #[cfg(feature = "iroh")]
413 async fn connect_iroh(&self, url: Url) -> anyhow::Result<web_transport_iroh::Session> {
414 let endpoint = self.iroh.as_ref().context("Iroh support is not enabled")?;
415 let alpn = match url.scheme() {
416 "moql+iroh" | "iroh" => moq_lite::lite::ALPN,
417 "moqt+iroh" => moq_lite::ietf::ALPN,
418 "h3+iroh" => web_transport_iroh::ALPN_H3,
419 _ => anyhow::bail!("Invalid URL: unknown scheme"),
420 };
421 let host = url.host().context("Invalid URL: missing host")?.to_string();
422 let endpoint_id: iroh::EndpointId = host.parse().context("Invalid URL: host is not an iroh endpoint id")?;
423 let conn = endpoint.connect(endpoint_id, alpn.as_bytes()).await?;
424 let session = match alpn {
425 web_transport_iroh::ALPN_H3 => {
426 let url = url_set_scheme(url, "https")?;
429 web_transport_iroh::Session::connect_h3(conn, url).await?
430 }
431 _ => web_transport_iroh::Session::raw(conn),
432 };
433 Ok(session)
434 }
435}
436
437#[derive(Debug)]
438struct NoCertificateVerification(crypto::Provider);
439
440impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
441 fn verify_server_cert(
442 &self,
443 _end_entity: &CertificateDer<'_>,
444 _intermediates: &[CertificateDer<'_>],
445 _server_name: &ServerName<'_>,
446 _ocsp: &[u8],
447 _now: UnixTime,
448 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
449 Ok(rustls::client::danger::ServerCertVerified::assertion())
450 }
451
452 fn verify_tls12_signature(
453 &self,
454 message: &[u8],
455 cert: &CertificateDer<'_>,
456 dss: &rustls::DigitallySignedStruct,
457 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
458 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
459 }
460
461 fn verify_tls13_signature(
462 &self,
463 message: &[u8],
464 cert: &CertificateDer<'_>,
465 dss: &rustls::DigitallySignedStruct,
466 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
467 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
468 }
469
470 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
471 self.0.signature_verification_algorithms.supported_schemes()
472 }
473}
474
475#[derive(Debug)]
477struct FingerprintVerifier {
478 provider: crypto::Provider,
479 fingerprint: Vec<u8>,
480}
481
482impl FingerprintVerifier {
483 pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
484 Self { provider, fingerprint }
485 }
486}
487
488impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
489 fn verify_server_cert(
490 &self,
491 end_entity: &CertificateDer<'_>,
492 _intermediates: &[CertificateDer<'_>],
493 _server_name: &ServerName<'_>,
494 _ocsp: &[u8],
495 _now: UnixTime,
496 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
497 let fingerprint = crypto::sha256(&self.provider, end_entity);
498 if fingerprint.as_ref() == self.fingerprint.as_slice() {
499 Ok(rustls::client::danger::ServerCertVerified::assertion())
500 } else {
501 Err(rustls::Error::General("fingerprint mismatch".into()))
502 }
503 }
504
505 fn verify_tls12_signature(
506 &self,
507 message: &[u8],
508 cert: &CertificateDer<'_>,
509 dss: &rustls::DigitallySignedStruct,
510 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
511 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
512 }
513
514 fn verify_tls13_signature(
515 &self,
516 message: &[u8],
517 cert: &CertificateDer<'_>,
518 dss: &rustls::DigitallySignedStruct,
519 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
520 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
521 }
522
523 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
524 self.provider.signature_verification_algorithms.supported_schemes()
525 }
526}
527
528#[cfg(feature = "iroh")]
535fn url_set_scheme(url: Url, scheme: &str) -> anyhow::Result<Url> {
536 let url = format!(
537 "{}:{}",
538 scheme,
539 url.to_string().split_once(":").context("invalid URL")?.1
540 )
541 .parse()?;
542 Ok(url)
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548 use clap::Parser;
549
550 #[test]
551 fn test_toml_disable_verify_survives_update_from() {
552 let toml = r#"
553 tls.disable_verify = true
554 "#;
555
556 let mut config: ClientConfig = toml::from_str(toml).unwrap();
557 assert_eq!(config.tls.disable_verify, Some(true));
558
559 config.update_from(["test"]);
561 assert_eq!(config.tls.disable_verify, Some(true));
562 }
563
564 #[test]
565 fn test_cli_disable_verify_flag() {
566 let config = ClientConfig::parse_from(["test", "--tls-disable-verify"]);
567 assert_eq!(config.tls.disable_verify, Some(true));
568 }
569
570 #[test]
571 fn test_cli_disable_verify_explicit_false() {
572 let config = ClientConfig::parse_from(["test", "--tls-disable-verify", "false"]);
573 assert_eq!(config.tls.disable_verify, Some(false));
574 }
575
576 #[test]
577 fn test_cli_no_disable_verify() {
578 let config = ClientConfig::parse_from(["test"]);
579 assert_eq!(config.tls.disable_verify, None);
580 }
581}