1use crate::crypto;
2use anyhow::Context;
3use rustls::RootCertStore;
4use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
5use std::collections::HashSet;
6use std::path::PathBuf;
7use std::sync::{LazyLock, Mutex};
8use std::{fs, io, net, sync::Arc, time};
9use url::Url;
10#[cfg(feature = "iroh")]
11use web_transport_iroh::iroh;
12use web_transport_ws::{tokio_tungstenite, tungstenite};
13
14static WEBSOCKET_WON: LazyLock<Mutex<HashSet<(String, u16)>>> = LazyLock::new(|| Mutex::new(HashSet::new()));
16
17#[derive(Clone, Default, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
19#[serde(default, deny_unknown_fields)]
20#[non_exhaustive]
21pub struct ClientTls {
22 #[serde(skip_serializing_if = "Vec::is_empty")]
27 #[arg(id = "tls-root", long = "tls-root", env = "MOQ_CLIENT_TLS_ROOT")]
28 pub root: Vec<PathBuf>,
29
30 #[serde(skip_serializing_if = "Option::is_none")]
34 #[arg(
35 id = "tls-disable-verify",
36 long = "tls-disable-verify",
37 env = "MOQ_CLIENT_TLS_DISABLE_VERIFY",
38 action = clap::ArgAction::SetTrue
39 )]
40 pub disable_verify: Option<bool>,
41}
42
43#[derive(Clone, Debug, clap::Args, serde::Serialize, serde::Deserialize)]
45#[serde(default, deny_unknown_fields)]
46#[non_exhaustive]
47pub struct ClientWebSocket {
48 #[arg(
50 id = "websocket-enabled",
51 long = "websocket-enabled",
52 env = "MOQ_CLIENT_WEBSOCKET_ENABLED",
53 default_value = "true"
54 )]
55 pub enabled: bool,
56
57 #[arg(
60 id = "websocket-delay",
61 long = "websocket-delay",
62 env = "MOQ_CLIENT_WEBSOCKET_DELAY",
63 default_value = "200ms",
64 value_parser = humantime::parse_duration,
65 )]
66 #[serde(with = "humantime_serde")]
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub delay: Option<time::Duration>,
69}
70
71impl Default for ClientWebSocket {
72 fn default() -> Self {
73 Self {
74 enabled: true,
75 delay: Some(time::Duration::from_millis(200)),
76 }
77 }
78}
79
80#[derive(Clone, Debug, clap::Parser, serde::Serialize, serde::Deserialize)]
82#[serde(deny_unknown_fields, default)]
83#[non_exhaustive]
84pub struct ClientConfig {
85 #[arg(
87 id = "client-bind",
88 long = "client-bind",
89 default_value = "[::]:0",
90 env = "MOQ_CLIENT_BIND"
91 )]
92 pub bind: net::SocketAddr,
93
94 #[command(flatten)]
95 #[serde(default)]
96 pub tls: ClientTls,
97
98 #[command(flatten)]
99 #[serde(default)]
100 pub websocket: ClientWebSocket,
101}
102
103impl ClientConfig {
104 pub fn init(self) -> anyhow::Result<Client> {
105 Client::new(self)
106 }
107}
108
109impl Default for ClientConfig {
110 fn default() -> Self {
111 Self {
112 bind: "[::]:0".parse().unwrap(),
113 tls: ClientTls::default(),
114 websocket: ClientWebSocket::default(),
115 }
116 }
117}
118
119#[derive(Clone)]
123#[non_exhaustive]
124pub struct Client {
125 pub moq: moq_lite::Client,
126 pub quic: quinn::Endpoint,
127 pub tls: rustls::ClientConfig,
128 pub transport: Arc<quinn::TransportConfig>,
129 pub websocket: ClientWebSocket,
130 #[cfg(feature = "iroh")]
131 pub iroh: Option<iroh::Endpoint>,
132}
133
134impl Client {
135 pub fn new(config: ClientConfig) -> anyhow::Result<Self> {
136 let provider = crypto::provider();
137
138 let mut roots = RootCertStore::empty();
140
141 if config.tls.root.is_empty() {
142 let native = rustls_native_certs::load_native_certs();
143
144 for err in native.errors {
146 tracing::warn!(%err, "failed to load root cert");
147 }
148
149 for cert in native.certs {
151 roots.add(cert).context("failed to add root cert")?;
152 }
153 } else {
154 for root in &config.tls.root {
156 let root = fs::File::open(root).context("failed to open root cert file")?;
157 let mut root = io::BufReader::new(root);
158
159 let root = rustls_pemfile::certs(&mut root)
160 .next()
161 .context("no roots found")?
162 .context("failed to read root cert")?;
163
164 roots.add(root).context("failed to add root cert")?;
165 }
166 }
167
168 let mut tls = rustls::ClientConfig::builder_with_provider(provider.clone())
170 .with_protocol_versions(&[&rustls::version::TLS13])?
171 .with_root_certificates(roots)
172 .with_no_client_auth();
173
174 if config.tls.disable_verify.unwrap_or_default() {
176 tracing::warn!("TLS server certificate verification is disabled; A man-in-the-middle attack is possible.");
177
178 let noop = NoCertificateVerification(provider.clone());
179 tls.dangerous().set_certificate_verifier(Arc::new(noop));
180 }
181
182 let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
183
184 let mut transport = quinn::TransportConfig::default();
186 transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
187 transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
188 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
191
192 let runtime = quinn::default_runtime().context("no async runtime")?;
194 let endpoint_config = quinn::EndpointConfig::default();
195
196 let quic =
198 quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?;
199
200 Ok(Self {
201 moq: moq_lite::Client::new(),
202 quic,
203 tls,
204 transport,
205 websocket: config.websocket,
206 #[cfg(feature = "iroh")]
207 iroh: None,
208 })
209 }
210
211 #[cfg(feature = "iroh")]
212 pub fn with_iroh(mut self, iroh: Option<iroh::Endpoint>) -> Self {
213 self.iroh = iroh;
214 self
215 }
216
217 pub fn with_publish(mut self, publish: impl Into<Option<moq_lite::OriginConsumer>>) -> Self {
218 self.moq = self.moq.with_publish(publish);
219 self
220 }
221
222 pub fn with_consume(mut self, consume: impl Into<Option<moq_lite::OriginProducer>>) -> Self {
223 self.moq = self.moq.with_consume(consume);
224 self
225 }
226
227 pub async fn connect(&self, url: Url) -> anyhow::Result<moq_lite::Session> {
235 #[cfg(feature = "iroh")]
236 if crate::iroh::is_iroh_url(&url) {
237 let session = self.connect_iroh(url).await?;
238 let session = self.moq.connect(session).await?;
239 return Ok(session);
240 }
241
242 let quic_url = url.clone();
244 let quic_handle = async {
245 let res = self.connect_quic(quic_url).await;
246 if let Err(err) = &res {
247 tracing::warn!(%err, "QUIC connection failed");
248 }
249 res
250 };
251
252 let ws_handle = async {
253 if !self.websocket.enabled {
254 return None;
255 }
256
257 let res = self.connect_websocket(url).await;
258 if let Err(err) = &res {
259 tracing::warn!(%err, "WebSocket connection failed");
260 }
261 Some(res)
262 };
263
264 Ok(tokio::select! {
266 Ok(quic) = quic_handle => self.moq.connect(quic).await?,
267 Some(Ok(ws)) = ws_handle => self.moq.connect(ws).await?,
268 else => anyhow::bail!("failed to connect to server"),
270 })
271 }
272
273 async fn connect_quic(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
274 let mut config = self.tls.clone();
275
276 let host = url.host().context("invalid DNS name")?.to_string();
277 let port = url.port().unwrap_or(443);
278
279 let ip = tokio::net::lookup_host((host.clone(), port))
281 .await
282 .context("failed DNS lookup")?
283 .next()
284 .context("no DNS entries")?;
285
286 if url.scheme() == "http" {
287 let mut fingerprint = url.clone();
289 fingerprint.set_path("/certificate.sha256");
290 fingerprint.set_query(None);
291 fingerprint.set_fragment(None);
292
293 tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
294
295 let resp = reqwest::get(fingerprint.as_str())
296 .await
297 .context("failed to fetch fingerprint")?
298 .error_for_status()
299 .context("fingerprint request failed")?;
300
301 let fingerprint = resp.text().await.context("failed to read fingerprint")?;
302 let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
303
304 let verifier = FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
305 config.dangerous().set_certificate_verifier(Arc::new(verifier));
306
307 url.set_scheme("https").expect("failed to set scheme");
308 }
309
310 let alpn = match url.scheme() {
311 "https" => web_transport_quinn::ALPN,
312 "moql" => moq_lite::lite::ALPN,
313 "moqt" => moq_lite::ietf::ALPN,
314 _ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
315 };
316
317 config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
319 config.key_log = Arc::new(rustls::KeyLogFile::new());
320
321 let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
322 let mut config = quinn::ClientConfig::new(Arc::new(config));
323 config.transport_config(self.transport.clone());
324
325 tracing::debug!(%url, %ip, %alpn, "connecting");
326
327 let connection = self.quic.connect_with(config, ip, &host)?.await?;
328 tracing::Span::current().record("id", connection.stable_id());
329
330 let session = match alpn {
331 web_transport_quinn::ALPN => web_transport_quinn::Session::connect(connection, url).await?,
332 moq_lite::lite::ALPN | moq_lite::ietf::ALPN => web_transport_quinn::Session::raw(connection, url),
333 _ => unreachable!("ALPN was checked above"),
334 };
335
336 Ok(session)
337 }
338
339 async fn connect_websocket(&self, mut url: Url) -> anyhow::Result<web_transport_ws::Session> {
340 anyhow::ensure!(self.websocket.enabled, "WebSocket support is disabled");
341
342 let host = url.host_str().context("missing hostname")?.to_string();
343 let port = url.port().unwrap_or_else(|| match url.scheme() {
344 "https" | "wss" | "moql" | "moqt" => 443,
345 "http" | "ws" => 80,
346 _ => 443,
347 });
348 let key = (host, port);
349
350 match self.websocket.delay {
354 Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => {
355 tokio::time::sleep(delay).await;
356 tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback");
357 }
358 _ => {}
359 }
360
361 let needs_tls = match url.scheme() {
363 "http" => {
364 url.set_scheme("ws").expect("failed to set scheme");
365 false
366 }
367 "https" | "moql" | "moqt" => {
368 url.set_scheme("wss").expect("failed to set scheme");
369 true
370 }
371 "ws" => false,
372 "wss" => true,
373 _ => anyhow::bail!("unsupported URL scheme for WebSocket: {}", url.scheme()),
374 };
375
376 tracing::debug!(%url, "connecting via WebSocket");
377
378 let connector = if needs_tls {
380 Some(tokio_tungstenite::Connector::Rustls(Arc::new(self.tls.clone())))
381 } else {
382 None
383 };
384
385 let (ws_stream, _response) = tokio_tungstenite::connect_async_tls_with_config(
387 url.as_str(),
388 Some(tungstenite::protocol::WebSocketConfig {
389 max_message_size: Some(64 << 20), max_frame_size: Some(16 << 20), accept_unmasked_frames: false,
392 ..Default::default()
393 }),
394 false, connector,
396 )
397 .await
398 .context("failed to connect WebSocket")?;
399
400 let session = web_transport_ws::Session::new(ws_stream, false);
403
404 tracing::warn!(%url, "using WebSocket fallback");
405 WEBSOCKET_WON.lock().unwrap().insert(key);
406
407 Ok(session)
408 }
409
410 #[cfg(feature = "iroh")]
411 async fn connect_iroh(&self, url: Url) -> anyhow::Result<web_transport_iroh::Session> {
412 let endpoint = self.iroh.as_ref().context("Iroh support is not enabled")?;
413 let alpn = match url.scheme() {
414 "moql+iroh" | "iroh" => moq_lite::lite::ALPN,
415 "moqt+iroh" => moq_lite::ietf::ALPN,
416 "h3+iroh" => web_transport_iroh::ALPN_H3,
417 _ => anyhow::bail!("Invalid URL: unknown scheme"),
418 };
419 let host = url.host().context("Invalid URL: missing host")?.to_string();
420 let endpoint_id: iroh::EndpointId = host.parse().context("Invalid URL: host is not an iroh endpoint id")?;
421 let conn = endpoint.connect(endpoint_id, alpn.as_bytes()).await?;
422 let session = match alpn {
423 web_transport_iroh::ALPN_H3 => {
424 let url = url_set_scheme(url, "https")?;
427 web_transport_iroh::Session::connect_h3(conn, url).await?
428 }
429 _ => web_transport_iroh::Session::raw(conn),
430 };
431 Ok(session)
432 }
433}
434
435#[derive(Debug)]
436struct NoCertificateVerification(crypto::Provider);
437
438impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
439 fn verify_server_cert(
440 &self,
441 _end_entity: &CertificateDer<'_>,
442 _intermediates: &[CertificateDer<'_>],
443 _server_name: &ServerName<'_>,
444 _ocsp: &[u8],
445 _now: UnixTime,
446 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
447 Ok(rustls::client::danger::ServerCertVerified::assertion())
448 }
449
450 fn verify_tls12_signature(
451 &self,
452 message: &[u8],
453 cert: &CertificateDer<'_>,
454 dss: &rustls::DigitallySignedStruct,
455 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
456 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.0.signature_verification_algorithms)
457 }
458
459 fn verify_tls13_signature(
460 &self,
461 message: &[u8],
462 cert: &CertificateDer<'_>,
463 dss: &rustls::DigitallySignedStruct,
464 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
465 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.0.signature_verification_algorithms)
466 }
467
468 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
469 self.0.signature_verification_algorithms.supported_schemes()
470 }
471}
472
473#[derive(Debug)]
475struct FingerprintVerifier {
476 provider: crypto::Provider,
477 fingerprint: Vec<u8>,
478}
479
480impl FingerprintVerifier {
481 pub fn new(provider: crypto::Provider, fingerprint: Vec<u8>) -> Self {
482 Self { provider, fingerprint }
483 }
484}
485
486impl rustls::client::danger::ServerCertVerifier for FingerprintVerifier {
487 fn verify_server_cert(
488 &self,
489 end_entity: &CertificateDer<'_>,
490 _intermediates: &[CertificateDer<'_>],
491 _server_name: &ServerName<'_>,
492 _ocsp: &[u8],
493 _now: UnixTime,
494 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
495 let fingerprint = crypto::sha256(&self.provider, end_entity);
496 if fingerprint.as_ref() == self.fingerprint.as_slice() {
497 Ok(rustls::client::danger::ServerCertVerified::assertion())
498 } else {
499 Err(rustls::Error::General("fingerprint mismatch".into()))
500 }
501 }
502
503 fn verify_tls12_signature(
504 &self,
505 message: &[u8],
506 cert: &CertificateDer<'_>,
507 dss: &rustls::DigitallySignedStruct,
508 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
509 rustls::crypto::verify_tls12_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
510 }
511
512 fn verify_tls13_signature(
513 &self,
514 message: &[u8],
515 cert: &CertificateDer<'_>,
516 dss: &rustls::DigitallySignedStruct,
517 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
518 rustls::crypto::verify_tls13_signature(message, cert, dss, &self.provider.signature_verification_algorithms)
519 }
520
521 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
522 self.provider.signature_verification_algorithms.supported_schemes()
523 }
524}
525
526#[cfg(feature = "iroh")]
533fn url_set_scheme(url: Url, scheme: &str) -> anyhow::Result<Url> {
534 let url = format!(
535 "{}:{}",
536 scheme,
537 url.to_string().split_once(":").context("invalid URL")?.1
538 )
539 .parse()?;
540 Ok(url)
541}