1use std::path::PathBuf;
2use std::{net, time::Duration};
3
4use crate::crypto;
5#[cfg(feature = "iroh")]
6use crate::iroh::IrohQuicRequest;
7use anyhow::Context;
8use moq_lite::Session;
9use rand::Rng;
10use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
11use rustls::server::{ClientHello, ResolvesServerCert};
12use rustls::sign::CertifiedKey;
13use std::fs;
14use std::io::{self, Cursor, Read};
15use std::sync::{Arc, RwLock};
16use url::Url;
17#[cfg(feature = "iroh")]
18use web_transport_iroh::iroh;
19use web_transport_quinn::http;
20
21use futures::FutureExt;
22use futures::future::BoxFuture;
23use futures::stream::{FuturesUnordered, StreamExt};
24
25#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
30#[serde(deny_unknown_fields)]
31#[non_exhaustive]
32pub struct ServerTlsConfig {
33 #[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
35 #[serde(default, skip_serializing_if = "Vec::is_empty")]
36 pub cert: Vec<PathBuf>,
37
38 #[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
41 pub key: Vec<PathBuf>,
42
43 #[arg(
46 long = "tls-generate",
47 id = "tls-generate",
48 value_delimiter = ',',
49 env = "MOQ_SERVER_TLS_GENERATE"
50 )]
51 #[serde(default, skip_serializing_if = "Vec::is_empty")]
52 pub generate: Vec<String>,
53}
54
55#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
57#[serde(deny_unknown_fields, default)]
58#[non_exhaustive]
59pub struct ServerConfig {
60 #[serde(alias = "listen")]
63 #[arg(id = "server-bind", long = "server-bind", alias = "listen", env = "MOQ_SERVER_BIND")]
64 pub bind: Option<net::SocketAddr>,
65
66 #[arg(id = "server-quic-lb-id", long = "server-quic-lb-id", env = "MOQ_SERVER_QUIC_LB_ID")]
69 #[serde(default, skip_serializing_if = "Option::is_none")]
70 pub quic_lb_id: Option<ServerId>,
71
72 #[arg(
75 id = "server-quic-lb-nonce",
76 long = "server-quic-lb-nonce",
77 requires = "server-quic-lb-id",
78 env = "MOQ_SERVER_QUIC_LB_NONCE"
79 )]
80 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub quic_lb_nonce: Option<usize>,
82
83 #[command(flatten)]
84 #[serde(default)]
85 pub tls: ServerTlsConfig,
86}
87
88impl ServerConfig {
89 pub fn init(self) -> anyhow::Result<Server> {
90 Server::new(self)
91 }
92}
93
94pub struct Server {
98 moq: moq_lite::Server,
99 quic: quinn::Endpoint,
100 accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<Request>>>,
101 certs: Arc<ServeCerts>,
102 #[cfg(feature = "iroh")]
103 iroh: Option<iroh::Endpoint>,
104}
105
106impl Server {
107 pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
108 let mut transport = quinn::TransportConfig::default();
111 transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
112 transport.keep_alive_interval(Some(Duration::from_secs(4)));
113 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
116
117 let provider = crypto::provider();
118
119 let certs = ServeCerts::new(provider.clone());
120
121 certs.load_certs(&config.tls)?;
122
123 let certs = Arc::new(certs);
124
125 #[cfg(unix)]
126 tokio::spawn(Self::reload_certs(certs.clone(), config.tls.clone()));
127
128 let mut tls = rustls::ServerConfig::builder_with_provider(provider)
129 .with_protocol_versions(&[&rustls::version::TLS13])?
130 .with_no_client_auth()
131 .with_cert_resolver(certs.clone());
132
133 tls.alpn_protocols = vec![
134 web_transport_quinn::ALPN.as_bytes().to_vec(),
135 moq_lite::lite::ALPN.as_bytes().to_vec(),
136 moq_lite::ietf::ALPN.as_bytes().to_vec(),
137 ];
138 tls.key_log = Arc::new(rustls::KeyLogFile::new());
139
140 let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
141 let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
142 tls.transport_config(transport.clone());
143
144 let runtime = quinn::default_runtime().context("no async runtime")?;
146
147 let mut endpoint_config = quinn::EndpointConfig::default();
149 if let Some(server_id) = config.quic_lb_id {
150 let nonce_len = config.quic_lb_nonce.unwrap_or(8);
151 anyhow::ensure!(nonce_len >= 4, "quic_lb_nonce must be at least 4");
152
153 let cid_len = 1 + server_id.len() + nonce_len;
154 anyhow::ensure!(cid_len <= 20, "connection ID length ({cid_len}) exceeds maximum of 20");
155
156 tracing::info!(
157 ?server_id,
158 nonce_len,
159 "using QUIC-LB compatible connection ID generation"
160 );
161 endpoint_config.cid_generator(move || Box::new(ServerIdGenerator::new(server_id.clone(), nonce_len)));
162 }
163
164 let listen = config.bind.unwrap_or("[::]:443".parse().unwrap());
165 let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
166
167 let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
169 .context("failed to create QUIC endpoint")?;
170
171 Ok(Self {
172 quic: quic.clone(),
173 accept: Default::default(),
174 certs,
175 moq: moq_lite::Server::new(),
176 #[cfg(feature = "iroh")]
177 iroh: None,
178 })
179 }
180
181 #[cfg(feature = "iroh")]
182 pub fn with_iroh(mut self, iroh: Option<iroh::Endpoint>) -> Self {
183 self.iroh = iroh;
184 self
185 }
186
187 pub fn with_publish(mut self, publish: impl Into<Option<moq_lite::OriginConsumer>>) -> Self {
188 self.moq = self.moq.with_publish(publish);
189 self
190 }
191
192 pub fn with_consume(mut self, consume: impl Into<Option<moq_lite::OriginProducer>>) -> Self {
193 self.moq = self.moq.with_consume(consume);
194 self
195 }
196
197 #[cfg(unix)]
204 async fn reload_certs(certs: Arc<ServeCerts>, tls_config: ServerTlsConfig) {
205 use tokio::signal::unix::{SignalKind, signal};
206
207 let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals");
209
210 while listener.recv().await.is_some() {
211 tracing::info!("reloading server certificates");
212
213 if let Err(err) = certs.load_certs(&tls_config) {
214 tracing::warn!(%err, "failed to reload server certificates");
215 }
216 }
217 }
218
219 pub fn tls_info(&self) -> Arc<RwLock<ServerTlsInfo>> {
221 self.certs.info.clone()
222 }
223
224 pub async fn accept(&mut self) -> Option<Request> {
232 loop {
233 let iroh_accept_fut = async {
236 #[cfg(feature = "iroh")]
237 if let Some(endpoint) = self.iroh.as_ref() {
238 endpoint.accept().await
239 } else {
240 std::future::pending::<_>().await
241 }
242
243 #[cfg(not(feature = "iroh"))]
244 std::future::pending::<()>().await
245 };
246
247 tokio::select! {
248 res = self.quic.accept() => {
249 let conn = res?;
250 self.accept.push(Self::accept_session(self.moq.clone(), conn).boxed());
251 }
252 res = iroh_accept_fut => {
253 #[cfg(feature = "iroh")]
254 {
255 let conn = res?;
256 self.accept.push(Self::accept_iroh_session(self.moq.clone(), conn).boxed());
257 }
258 #[cfg(not(feature = "iroh"))]
259 let _: () = res;
260 }
261 Some(res) = self.accept.next() => {
262 match res {
263 Ok(session) => return Some(session),
264 Err(err) => tracing::debug!(%err, "failed to accept session"),
265 }
266 }
267 _ = tokio::signal::ctrl_c() => {
268 self.close();
269 tokio::time::sleep(Duration::from_millis(100)).await;
271
272 return None;
273 }
274 }
275 }
276 }
277
278 async fn accept_session(server: moq_lite::Server, conn: quinn::Incoming) -> anyhow::Result<Request> {
279 let mut conn = conn.accept()?;
280
281 let handshake = conn
282 .handshake_data()
283 .await?
284 .downcast::<quinn::crypto::rustls::HandshakeData>()
285 .unwrap();
286
287 let alpn = handshake.protocol.context("missing ALPN")?;
288 let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
289 let host = handshake.server_name.unwrap_or_default();
290
291 tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
292
293 let conn = conn.await.context("failed to establish QUIC connection")?;
295
296 let span = tracing::Span::current();
297 span.record("id", conn.stable_id()); tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepted");
299
300 match alpn.as_str() {
301 web_transport_quinn::ALPN => {
302 let request = web_transport_quinn::Request::accept(conn)
304 .await
305 .context("failed to receive WebTransport request")?;
306 Ok(Request {
307 server: server.clone(),
308 kind: RequestKind::WebTransport(request),
309 })
310 }
311 moq_lite::lite::ALPN | moq_lite::ietf::ALPN => Ok(Request {
312 server: server.clone(),
313 kind: RequestKind::Quic(QuicRequest::accept(conn)),
314 }),
315 _ => anyhow::bail!("unsupported ALPN: {alpn}"),
316 }
317 }
318
319 #[cfg(feature = "iroh")]
320 async fn accept_iroh_session(server: moq_lite::Server, conn: iroh::endpoint::Incoming) -> anyhow::Result<Request> {
321 let conn = conn.accept()?.await?;
322 let alpn = String::from_utf8(conn.alpn().to_vec()).context("failed to decode ALPN")?;
323 tracing::Span::current().record("id", conn.stable_id());
324 tracing::debug!(remote = %conn.remote_id().fmt_short(), %alpn, "accepted");
325 match alpn.as_str() {
326 web_transport_iroh::ALPN_H3 => {
327 let request = web_transport_iroh::H3Request::accept(conn)
328 .await
329 .context("failed to receive WebTransport request")?;
330 Ok(Request {
331 server: server.clone(),
332 kind: RequestKind::IrohWebTransport(request),
333 })
334 }
335 moq_lite::lite::ALPN | moq_lite::ietf::ALPN => {
336 let request = IrohQuicRequest::accept(conn);
337 Ok(Request {
338 server: server.clone(),
339 kind: RequestKind::IrohQuic(request),
340 })
341 }
342 _ => Err(anyhow::anyhow!("unsupported ALPN: {alpn}")),
343 }
344 }
345
346 #[cfg(feature = "iroh")]
347 pub fn iroh_endpoint(&self) -> Option<&iroh::Endpoint> {
348 self.iroh.as_ref()
349 }
350
351 pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
352 self.quic.local_addr().context("failed to get local address")
353 }
354
355 pub fn close(&mut self) {
356 self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
357 }
358}
359
360enum RequestKind {
362 WebTransport(web_transport_quinn::Request),
363 Quic(QuicRequest),
364 #[cfg(feature = "iroh")]
365 IrohWebTransport(web_transport_iroh::H3Request),
366 #[cfg(feature = "iroh")]
367 IrohQuic(IrohQuicRequest),
368}
369
370pub struct Request {
371 server: moq_lite::Server,
372 kind: RequestKind,
373}
374
375impl Request {
376 pub async fn reject(self, status: http::StatusCode) -> anyhow::Result<()> {
378 match self.kind {
379 RequestKind::WebTransport(request) => request.close(status).await?,
380 RequestKind::Quic(request) => request.close(status),
381 #[cfg(feature = "iroh")]
382 RequestKind::IrohWebTransport(request) => request.close(status).await?,
383 #[cfg(feature = "iroh")]
384 RequestKind::IrohQuic(request) => request.close(status),
385 }
386 Ok(())
387 }
388
389 pub fn with_publish(mut self, publish: impl Into<Option<moq_lite::OriginConsumer>>) -> Self {
390 self.server = self.server.with_publish(publish);
391 self
392 }
393
394 pub fn with_consume(mut self, consume: impl Into<Option<moq_lite::OriginProducer>>) -> Self {
395 self.server = self.server.with_consume(consume);
396 self
397 }
398
399 pub async fn accept(self) -> anyhow::Result<Session> {
407 let session = match self.kind {
408 RequestKind::WebTransport(request) => self.server.accept(request.ok().await?).await?,
409 RequestKind::Quic(request) => self.server.accept(request.ok()).await?,
410 #[cfg(feature = "iroh")]
411 RequestKind::IrohWebTransport(request) => self.server.accept(request.ok().await?).await?,
412 #[cfg(feature = "iroh")]
413 RequestKind::IrohQuic(request) => self.server.accept(request.ok()).await?,
414 };
415 Ok(session)
416 }
417
418 pub fn url(&self) -> Option<&Url> {
420 match &self.kind {
421 RequestKind::WebTransport(request) => Some(request.url()),
422 #[cfg(feature = "iroh")]
423 RequestKind::IrohWebTransport(request) => Some(request.url()),
424 _ => None,
425 }
426 }
427}
428
429pub struct QuicRequest {
433 connection: quinn::Connection,
434 url: Url,
435}
436
437impl QuicRequest {
438 pub fn accept(connection: quinn::Connection) -> Self {
440 let url: Url = format!("moql://{}", connection.remote_address())
441 .parse()
442 .expect("URL is valid");
443 Self { connection, url }
444 }
445
446 pub fn ok(self) -> web_transport_quinn::Session {
448 web_transport_quinn::Session::raw(self.connection, self.url)
449 }
450
451 pub fn url(&self) -> &Url {
453 &self.url
454 }
455
456 pub fn close(self, status: http::StatusCode) {
460 self.connection
461 .close(status.as_u16().into(), status.as_str().as_bytes());
462 }
463}
464
465#[derive(Debug)]
467pub struct ServerTlsInfo {
468 pub(crate) certs: Vec<Arc<CertifiedKey>>,
469 pub fingerprints: Vec<String>,
470}
471
472#[derive(Debug)]
473struct ServeCerts {
474 info: Arc<RwLock<ServerTlsInfo>>,
475 provider: crypto::Provider,
476}
477
478impl ServeCerts {
479 pub fn new(provider: crypto::Provider) -> Self {
480 Self {
481 info: Arc::new(RwLock::new(ServerTlsInfo {
482 certs: Vec::new(),
483 fingerprints: Vec::new(),
484 })),
485 provider,
486 }
487 }
488
489 pub fn load_certs(&self, config: &ServerTlsConfig) -> anyhow::Result<()> {
490 anyhow::ensure!(config.cert.len() == config.key.len(), "must provide both cert and key");
491
492 let mut certs = Vec::new();
493
494 for (cert, key) in config.cert.iter().zip(config.key.iter()) {
496 certs.push(Arc::new(self.load(cert, key)?));
497 }
498
499 if !config.generate.is_empty() {
501 certs.push(Arc::new(self.generate(&config.generate)?));
502 }
503
504 self.set_certs(certs);
505 Ok(())
506 }
507
508 fn load(&self, chain_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<CertifiedKey> {
510 let chain = fs::File::open(chain_path).context("failed to open cert file")?;
511 let mut chain = io::BufReader::new(chain);
512
513 let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
514 .collect::<Result<_, _>>()
515 .context("failed to read certs")?;
516
517 anyhow::ensure!(!chain.is_empty(), "could not find certificate");
518
519 let mut keys = fs::File::open(key_path).context("failed to open key file")?;
521
522 let mut buf = Vec::new();
524 keys.read_to_end(&mut buf)?;
525
526 let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
527 let key = self.provider.key_provider.load_private_key(key)?;
528
529 let certified_key = CertifiedKey::new(chain, key);
530
531 certified_key.keys_match().context(format!(
532 "private key {} doesn't match certificate {}",
533 key_path.display(),
534 chain_path.display()
535 ))?;
536
537 Ok(certified_key)
538 }
539
540 fn generate(&self, hostnames: &[String]) -> anyhow::Result<CertifiedKey> {
541 let key_pair = rcgen::KeyPair::generate()?;
542
543 let mut params = rcgen::CertificateParams::new(hostnames)?;
544
545 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
548 params.not_after = params.not_before + time::Duration::days(14);
549
550 let cert = params.self_signed(&key_pair)?;
552
553 let key_der = key_pair.serialized_der().to_vec();
555 let key_der = PrivatePkcs8KeyDer::from(key_der);
556 let key = self.provider.key_provider.load_private_key(key_der.into())?;
557
558 Ok(CertifiedKey::new(vec![cert.into()], key))
560 }
561
562 pub fn set_certs(&self, certs: Vec<Arc<CertifiedKey>>) {
564 let fingerprints = certs
565 .iter()
566 .map(|ck| {
567 let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
568 hex::encode(fingerprint)
569 })
570 .collect();
571
572 let mut info = self.info.write().expect("info write lock poisoned");
573 info.certs = certs;
574 info.fingerprints = fingerprints;
575 }
576
577 fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
579 let server_name = client_hello.server_name()?;
580 let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
581
582 for ck in self.info.read().expect("info read lock poisoned").certs.iter() {
583 let leaf: webpki::EndEntityCert = ck
584 .end_entity_cert()
585 .expect("missing certificate")
586 .try_into()
587 .expect("failed to parse certificate");
588
589 if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
590 return Some(ck.clone());
591 }
592 }
593
594 None
595 }
596}
597
598impl ResolvesServerCert for ServeCerts {
599 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
600 if let Some(cert) = self.best_certificate(&client_hello) {
601 return Some(cert);
602 }
603
604 tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
607
608 self.info
609 .read()
610 .expect("info read lock poisoned")
611 .certs
612 .first()
613 .cloned()
614 }
615}
616
617#[serde_with::serde_as]
619#[derive(Clone, serde::Serialize, serde::Deserialize)]
620pub struct ServerId(#[serde_as(as = "serde_with::hex::Hex")] Vec<u8>);
621
622impl ServerId {
623 fn len(&self) -> usize {
624 self.0.len()
625 }
626}
627
628impl std::fmt::Debug for ServerId {
629 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
630 f.debug_tuple("QuicLbServerId").field(&hex::encode(&self.0)).finish()
631 }
632}
633
634impl std::str::FromStr for ServerId {
635 type Err = hex::FromHexError;
636
637 fn from_str(s: &str) -> Result<Self, Self::Err> {
638 hex::decode(s).map(Self)
639 }
640}
641
642struct ServerIdGenerator {
654 server_id: ServerId,
655 nonce_len: usize,
656}
657
658impl ServerIdGenerator {
659 fn new(server_id: ServerId, nonce_len: usize) -> Self {
660 Self { server_id, nonce_len }
661 }
662}
663
664impl quinn::ConnectionIdGenerator for ServerIdGenerator {
665 fn generate_cid(&mut self) -> quinn::ConnectionId {
666 let cid_len = self.cid_len();
667 let mut cid = Vec::with_capacity(cid_len);
668 cid.push((cid_len - 1) as u8);
670 cid.extend(self.server_id.0.iter());
671 cid.extend(rand::rng().random_iter::<u8>().take(self.nonce_len));
672 quinn::ConnectionId::new(cid.as_slice())
673 }
674
675 fn cid_len(&self) -> usize {
676 1 + self.server_id.len() + self.nonce_len
677 }
678
679 fn cid_lifetime(&self) -> Option<Duration> {
680 None
681 }
682}