1use std::path::PathBuf;
2use std::{net, time::Duration};
3
4use crate::crypto;
5#[cfg(feature = "iroh")]
6use crate::iroh::IrohQuicRequest;
7use anyhow::Context;
8use moq_lite::Session;
9use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
10use rustls::server::{ClientHello, ResolvesServerCert};
11use rustls::sign::CertifiedKey;
12use std::fs;
13use std::io::{self, Cursor, Read};
14use std::sync::{Arc, RwLock};
15use url::Url;
16#[cfg(feature = "iroh")]
17use web_transport_iroh::iroh;
18use web_transport_quinn::http;
19
20use futures::future::BoxFuture;
21use futures::stream::{FuturesUnordered, StreamExt};
22use futures::FutureExt;
23
24#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
29#[serde(deny_unknown_fields)]
30pub struct ServerTlsConfig {
31 #[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
33 #[serde(default, skip_serializing_if = "Vec::is_empty")]
34 pub cert: Vec<PathBuf>,
35
36 #[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
38 #[serde(default, skip_serializing_if = "Vec::is_empty")]
39 pub key: Vec<PathBuf>,
40
41 #[arg(
44 long = "tls-generate",
45 id = "tls-generate",
46 value_delimiter = ',',
47 env = "MOQ_SERVER_TLS_GENERATE"
48 )]
49 #[serde(default, skip_serializing_if = "Vec::is_empty")]
50 pub generate: Vec<String>,
51}
52
53#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
55#[serde(deny_unknown_fields, default)]
56pub struct ServerConfig {
57 #[serde(alias = "listen")]
60 #[arg(id = "server-bind", long = "server-bind", alias = "listen", env = "MOQ_SERVER_BIND")]
61 pub bind: Option<net::SocketAddr>,
62
63 #[command(flatten)]
64 #[serde(default)]
65 pub tls: ServerTlsConfig,
66}
67
68impl ServerConfig {
69 pub fn init(self) -> anyhow::Result<Server> {
70 Server::new(self)
71 }
72}
73
74pub struct Server {
78 quic: quinn::Endpoint,
79 accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<Request>>>,
80 certs: Arc<ServeCerts>,
81 #[cfg(feature = "iroh")]
82 iroh: Option<iroh::Endpoint>,
83}
84
85impl Server {
86 pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
87 let mut transport = quinn::TransportConfig::default();
90 transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
91 transport.keep_alive_interval(Some(Duration::from_secs(4)));
92 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
95
96 let provider = crypto::provider();
97
98 let certs = ServeCerts::new(provider.clone());
99
100 certs.load_certs(&config.tls)?;
101
102 let certs = Arc::new(certs);
103
104 #[cfg(unix)]
105 tokio::spawn(Self::reload_certs(certs.clone(), config.tls.clone()));
106
107 let mut tls = rustls::ServerConfig::builder_with_provider(provider)
108 .with_protocol_versions(&[&rustls::version::TLS13])?
109 .with_no_client_auth()
110 .with_cert_resolver(certs.clone());
111
112 tls.alpn_protocols = vec![
113 web_transport_quinn::ALPN.as_bytes().to_vec(),
114 moq_lite::lite::ALPN.as_bytes().to_vec(),
115 moq_lite::ietf::ALPN.as_bytes().to_vec(),
116 ];
117 tls.key_log = Arc::new(rustls::KeyLogFile::new());
118
119 let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
120 let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
121 tls.transport_config(transport.clone());
122
123 let runtime = quinn::default_runtime().context("no async runtime")?;
125 let endpoint_config = quinn::EndpointConfig::default();
126
127 let listen = config.bind.unwrap_or("[::]:443".parse().unwrap());
128 let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
129
130 let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
132 .context("failed to create QUIC endpoint")?;
133
134 Ok(Self {
135 quic: quic.clone(),
136 accept: Default::default(),
137 certs,
138 #[cfg(feature = "iroh")]
139 iroh: None,
140 })
141 }
142
143 #[cfg(feature = "iroh")]
144 pub fn with_iroh(&mut self, iroh: Option<iroh::Endpoint>) -> &mut Self {
145 self.iroh = iroh;
146 self
147 }
148
149 #[cfg(unix)]
150 async fn reload_certs(certs: Arc<ServeCerts>, tls_config: ServerTlsConfig) {
151 use tokio::signal::unix::{signal, SignalKind};
152
153 let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals");
155
156 while listener.recv().await.is_some() {
157 tracing::info!("reloading server certificates");
158
159 if let Err(err) = certs.load_certs(&tls_config) {
160 tracing::warn!(%err, "failed to reload server certificates");
161 }
162 }
163 }
164
165 pub fn tls_info(&self) -> Arc<RwLock<ServerTlsInfo>> {
167 self.certs.info.clone()
168 }
169
170 pub async fn accept(&mut self) -> Option<Request> {
178 loop {
179 let iroh_accept_fut = async {
182 #[cfg(feature = "iroh")]
183 if let Some(endpoint) = self.iroh.as_ref() {
184 endpoint.accept().await
185 } else {
186 std::future::pending::<_>().await
187 }
188
189 #[cfg(not(feature = "iroh"))]
190 std::future::pending::<()>().await
191 };
192
193 tokio::select! {
194 res = self.quic.accept() => {
195 let conn = res?;
196 self.accept.push(Self::accept_session(conn).boxed());
197 }
198 res = iroh_accept_fut => {
199 #[cfg(feature = "iroh")]
200 {
201 let conn = res?;
202 self.accept.push(Self::accept_iroh_session(conn).boxed());
203 }
204 #[cfg(not(feature = "iroh"))]
205 let _: () = res;
206 }
207 Some(res) = self.accept.next() => {
208 match res {
209 Ok(session) => return Some(session),
210 Err(err) => tracing::debug!(%err, "failed to accept session"),
211 }
212 }
213 _ = tokio::signal::ctrl_c() => {
214 self.close();
215 tokio::time::sleep(Duration::from_millis(100)).await;
217
218 return None;
219 }
220 }
221 }
222 }
223
224 async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<Request> {
225 let mut conn = conn.accept()?;
226
227 let handshake = conn
228 .handshake_data()
229 .await?
230 .downcast::<quinn::crypto::rustls::HandshakeData>()
231 .unwrap();
232
233 let alpn = handshake.protocol.context("missing ALPN")?;
234 let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
235 let host = handshake.server_name.unwrap_or_default();
236
237 tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
238
239 let conn = conn.await.context("failed to establish QUIC connection")?;
241
242 let span = tracing::Span::current();
243 span.record("id", conn.stable_id()); tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepted");
245
246 match alpn.as_str() {
247 web_transport_quinn::ALPN => {
248 let request = web_transport_quinn::Request::accept(conn)
250 .await
251 .context("failed to receive WebTransport request")?;
252 Ok(Request::WebTransport(request))
253 }
254 moq_lite::lite::ALPN | moq_lite::ietf::ALPN => Ok(Request::Quic(QuicRequest::accept(conn))),
255 _ => anyhow::bail!("unsupported ALPN: {alpn}"),
256 }
257 }
258
259 #[cfg(feature = "iroh")]
260 async fn accept_iroh_session(conn: iroh::endpoint::Incoming) -> anyhow::Result<Request> {
261 let conn = conn.accept()?.await?;
262 let alpn = String::from_utf8(conn.alpn().to_vec()).context("failed to decode ALPN")?;
263 tracing::Span::current().record("id", conn.stable_id());
264 tracing::debug!(remote = %conn.remote_id().fmt_short(), %alpn, "accepted");
265 match alpn.as_str() {
266 web_transport_iroh::ALPN_H3 => {
267 let request = web_transport_iroh::H3Request::accept(conn)
268 .await
269 .context("failed to receive WebTransport request")?;
270 Ok(Request::IrohWebTransport(request))
271 }
272 moq_lite::lite::ALPN | moq_lite::ietf::ALPN => {
273 let request = IrohQuicRequest::accept(conn);
274 Ok(Request::IrohQuic(request))
275 }
276 _ => Err(anyhow::anyhow!("unsupported ALPN: {alpn}")),
277 }
278 }
279
280 #[cfg(feature = "iroh")]
281 pub fn iroh_endpoint(&self) -> Option<&iroh::Endpoint> {
282 self.iroh.as_ref()
283 }
284
285 pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
286 self.quic.local_addr().context("failed to get local address")
287 }
288
289 pub fn close(&mut self) {
290 self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
291 }
292}
293
294pub enum Request {
296 WebTransport(web_transport_quinn::Request),
297 Quic(QuicRequest),
298 #[cfg(feature = "iroh")]
299 IrohWebTransport(web_transport_iroh::H3Request),
300 #[cfg(feature = "iroh")]
301 IrohQuic(IrohQuicRequest),
302}
303
304impl Request {
305 pub async fn reject(self, status: http::StatusCode) -> anyhow::Result<()> {
307 match self {
308 Self::WebTransport(request) => request.close(status).await?,
309 Self::Quic(request) => request.close(status),
310 #[cfg(feature = "iroh")]
311 Request::IrohWebTransport(request) => request.close(status).await?,
312 #[cfg(feature = "iroh")]
313 Request::IrohQuic(request) => request.close(status),
314 }
315 Ok(())
316 }
317
318 pub async fn accept(
320 self,
321 publish: impl Into<Option<moq_lite::OriginConsumer>>,
322 subscribe: impl Into<Option<moq_lite::OriginProducer>>,
323 ) -> anyhow::Result<Session> {
324 let session = match self {
325 Request::WebTransport(request) => Session::accept(request.ok().await?, publish, subscribe).await?,
326 Request::Quic(request) => Session::accept(request.ok(), publish, subscribe).await?,
327 #[cfg(feature = "iroh")]
328 Request::IrohWebTransport(request) => Session::accept(request.ok().await?, publish, subscribe).await?,
329 #[cfg(feature = "iroh")]
330 Request::IrohQuic(request) => Session::accept(request.ok(), publish, subscribe).await?,
331 };
332 Ok(session)
333 }
334
335 pub fn url(&self) -> Option<&Url> {
337 match self {
338 Request::WebTransport(request) => Some(request.url()),
339 #[cfg(feature = "iroh")]
340 Request::IrohWebTransport(request) => Some(request.url()),
341 _ => None,
342 }
343 }
344}
345
346pub struct QuicRequest {
350 connection: quinn::Connection,
351 url: Url,
352}
353
354impl QuicRequest {
355 pub fn accept(connection: quinn::Connection) -> Self {
357 let url: Url = format!("moql://{}", connection.remote_address())
358 .parse()
359 .expect("URL is valid");
360 Self { connection, url }
361 }
362
363 pub fn ok(self) -> web_transport_quinn::Session {
365 web_transport_quinn::Session::raw(self.connection, self.url)
366 }
367
368 pub fn url(&self) -> &Url {
370 &self.url
371 }
372
373 pub fn close(self, status: http::StatusCode) {
377 self.connection
378 .close(status.as_u16().into(), status.as_str().as_bytes());
379 }
380}
381
382#[derive(Debug)]
384pub struct ServerTlsInfo {
385 pub(crate) certs: Vec<Arc<CertifiedKey>>,
386 pub fingerprints: Vec<String>,
387}
388
389#[derive(Debug)]
390struct ServeCerts {
391 info: Arc<RwLock<ServerTlsInfo>>,
392 provider: crypto::Provider,
393}
394
395impl ServeCerts {
396 pub fn new(provider: crypto::Provider) -> Self {
397 Self {
398 info: Arc::new(RwLock::new(ServerTlsInfo {
399 certs: Vec::new(),
400 fingerprints: Vec::new(),
401 })),
402 provider,
403 }
404 }
405
406 pub fn load_certs(&self, config: &ServerTlsConfig) -> anyhow::Result<()> {
407 anyhow::ensure!(config.cert.len() == config.key.len(), "must provide both cert and key");
408
409 let mut certs = Vec::new();
410
411 for (cert, key) in config.cert.iter().zip(config.key.iter()) {
413 certs.push(Arc::new(self.load(cert, key)?));
414 }
415
416 if !config.generate.is_empty() {
418 certs.push(Arc::new(self.generate(&config.generate)?));
419 }
420
421 self.set_certs(certs);
422 Ok(())
423 }
424
425 fn load(&self, chain_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<CertifiedKey> {
427 let chain = fs::File::open(chain_path).context("failed to open cert file")?;
428 let mut chain = io::BufReader::new(chain);
429
430 let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
431 .collect::<Result<_, _>>()
432 .context("failed to read certs")?;
433
434 anyhow::ensure!(!chain.is_empty(), "could not find certificate");
435
436 let mut keys = fs::File::open(key_path).context("failed to open key file")?;
438
439 let mut buf = Vec::new();
441 keys.read_to_end(&mut buf)?;
442
443 let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
444 let key = self.provider.key_provider.load_private_key(key)?;
445
446 let certified_key = CertifiedKey::new(chain, key);
447
448 certified_key.keys_match().context(format!(
449 "private key {} doesn't match certificate {}",
450 key_path.display(),
451 chain_path.display()
452 ))?;
453
454 Ok(certified_key)
455 }
456
457 fn generate(&self, hostnames: &[String]) -> anyhow::Result<CertifiedKey> {
458 let key_pair = rcgen::KeyPair::generate()?;
459
460 let mut params = rcgen::CertificateParams::new(hostnames)?;
461
462 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
465 params.not_after = params.not_before + time::Duration::days(14);
466
467 let cert = params.self_signed(&key_pair)?;
469
470 let key_der = key_pair.serialized_der().to_vec();
472 let key_der = PrivatePkcs8KeyDer::from(key_der);
473 let key = self.provider.key_provider.load_private_key(key_der.into())?;
474
475 Ok(CertifiedKey::new(vec![cert.into()], key))
477 }
478
479 pub fn set_certs(&self, certs: Vec<Arc<CertifiedKey>>) {
481 let fingerprints = certs
482 .iter()
483 .map(|ck| {
484 let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
485 hex::encode(fingerprint)
486 })
487 .collect();
488
489 let mut info = self.info.write().expect("info write lock poisoned");
490 info.certs = certs;
491 info.fingerprints = fingerprints;
492 }
493
494 fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
496 let server_name = client_hello.server_name()?;
497 let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
498
499 for ck in self.info.read().expect("info read lock poisoned").certs.iter() {
500 let leaf: webpki::EndEntityCert = ck
501 .end_entity_cert()
502 .expect("missing certificate")
503 .try_into()
504 .expect("failed to parse certificate");
505
506 if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
507 return Some(ck.clone());
508 }
509 }
510
511 None
512 }
513}
514
515impl ResolvesServerCert for ServeCerts {
516 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
517 if let Some(cert) = self.best_certificate(&client_hello) {
518 return Some(cert);
519 }
520
521 tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
524
525 self.info
526 .read()
527 .expect("info read lock poisoned")
528 .certs
529 .first()
530 .cloned()
531 }
532}