1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use crate::crypto;
5use anyhow::Context;
6use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
7use rustls::server::{ClientHello, ResolvesServerCert};
8use rustls::sign::CertifiedKey;
9use std::fs;
10use std::io::{self, Cursor, Read};
11use url::Url;
12use web_transport_quinn::http;
13
14use futures::future::BoxFuture;
15use futures::stream::{FuturesUnordered, StreamExt};
16use futures::FutureExt;
17
18#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
19#[serde(deny_unknown_fields)]
20pub struct ServerTlsCert {
21 pub chain: PathBuf,
22 pub key: PathBuf,
23}
24
25impl ServerTlsCert {
26 pub fn parse(s: &str) -> anyhow::Result<Self> {
28 let (chain, key) = s.split_once(':').context("invalid certificate")?;
29 Ok(Self {
30 chain: PathBuf::from(chain),
31 key: PathBuf::from(key),
32 })
33 }
34}
35
36#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
37#[serde(deny_unknown_fields)]
38pub struct ServerTlsConfig {
39 #[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
41 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub cert: Vec<PathBuf>,
43
44 #[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
46 #[serde(default, skip_serializing_if = "Vec::is_empty")]
47 pub key: Vec<PathBuf>,
48
49 #[arg(
52 long = "tls-generate",
53 id = "tls-generate",
54 value_delimiter = ',',
55 env = "MOQ_SERVER_TLS_GENERATE"
56 )]
57 #[serde(default, skip_serializing_if = "Vec::is_empty")]
58 pub generate: Vec<String>,
59}
60
61#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
62#[serde(deny_unknown_fields, default)]
63pub struct ServerConfig {
64 #[serde(alias = "listen")]
67 #[arg(id = "server-bind", long = "server-bind", alias = "listen", env = "MOQ_SERVER_BIND")]
68 pub bind: Option<net::SocketAddr>,
69
70 #[command(flatten)]
71 #[serde(default)]
72 pub tls: ServerTlsConfig,
73}
74
75impl ServerConfig {
76 pub fn init(self) -> anyhow::Result<Server> {
77 Server::new(self)
78 }
79}
80
81pub struct Server {
82 quic: quinn::Endpoint,
83 accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<Request>>>,
84 fingerprints: Vec<String>,
85}
86
87impl Server {
88 pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
89 let mut transport = quinn::TransportConfig::default();
92 transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
93 transport.keep_alive_interval(Some(Duration::from_secs(4)));
94 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
97
98 let provider = crypto::provider();
99
100 let mut serve = ServeCerts::new(provider.clone());
101
102 anyhow::ensure!(
104 config.tls.cert.len() == config.tls.key.len(),
105 "must provide both cert and key"
106 );
107
108 for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) {
109 serve.load(cert, key)?;
110 }
111
112 if !config.tls.generate.is_empty() {
113 serve.generate(&config.tls.generate)?;
114 }
115
116 let fingerprints = serve.fingerprints();
117
118 let mut tls = rustls::ServerConfig::builder_with_provider(provider)
119 .with_protocol_versions(&[&rustls::version::TLS13])?
120 .with_no_client_auth()
121 .with_cert_resolver(Arc::new(serve));
122
123 tls.alpn_protocols = vec![
124 web_transport_quinn::ALPN.as_bytes().to_vec(),
125 moq_lite::ALPN.as_bytes().to_vec(),
126 ];
127 tls.key_log = Arc::new(rustls::KeyLogFile::new());
128
129 let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
130 let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
131 tls.transport_config(transport.clone());
132
133 let runtime = quinn::default_runtime().context("no async runtime")?;
135 let endpoint_config = quinn::EndpointConfig::default();
136
137 let listen = config.bind.unwrap_or("[::]:443".parse().unwrap());
138 let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
139
140 let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
142 .context("failed to create QUIC endpoint")?;
143
144 Ok(Self {
145 quic: quic.clone(),
146 accept: Default::default(),
147 fingerprints,
148 })
149 }
150
151 pub fn fingerprints(&self) -> &[String] {
152 &self.fingerprints
153 }
154
155 pub async fn accept(&mut self) -> Option<Request> {
164 loop {
165 tokio::select! {
166 res = self.quic.accept() => {
167 let conn = res?;
168 self.accept.push(Self::accept_session(conn).boxed());
169 }
170 Some(res) = self.accept.next() => {
171 match res {
172 Ok(session) => return Some(session),
173 Err(err) => tracing::debug!(%err, "failed to accept session"),
174 }
175 }
176 _ = tokio::signal::ctrl_c() => {
177 self.close();
178 tokio::time::sleep(Duration::from_millis(100)).await;
180
181 return None;
182 }
183 }
184 }
185 }
186
187 async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<Request> {
188 let mut conn = conn.accept()?;
189
190 let handshake = conn
191 .handshake_data()
192 .await?
193 .downcast::<quinn::crypto::rustls::HandshakeData>()
194 .unwrap();
195
196 let alpn = handshake.protocol.context("missing ALPN")?;
197 let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
198 let host = handshake.server_name.unwrap_or_default();
199
200 tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
201
202 let conn = conn.await.context("failed to establish QUIC connection")?;
204
205 let span = tracing::Span::current();
206 span.record("id", conn.stable_id()); tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepted");
208
209 match alpn.as_str() {
210 web_transport_quinn::ALPN => {
211 let request = web_transport_quinn::Request::accept(conn)
213 .await
214 .context("failed to receive WebTransport request")?;
215 Ok(Request::WebTransport(request))
216 }
217 moq_lite::ALPN => Ok(Request::Quic(QuicRequest::accept(conn))),
218 _ => anyhow::bail!("unsupported ALPN: {alpn}"),
219 }
220 }
221
222 pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
223 self.quic.local_addr().context("failed to get local address")
224 }
225
226 pub fn close(&mut self) {
227 self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
228 }
229}
230
231pub enum Request {
232 WebTransport(web_transport_quinn::Request),
233 Quic(QuicRequest),
234}
235
236impl Request {
237 pub async fn close(self, status: http::StatusCode) -> Result<(), quinn::WriteError> {
239 match self {
240 Self::WebTransport(request) => request.close(status).await,
241 Self::Quic(request) => {
242 request.close(status);
243 Ok(())
244 }
245 }
246 }
247
248 pub async fn ok(self) -> Result<web_transport_quinn::Session, quinn::WriteError> {
253 match self {
254 Request::WebTransport(request) => request.ok().await,
255 Request::Quic(request) => Ok(request.ok()),
256 }
257 }
258
259 pub fn url(&self) -> &Url {
261 match self {
262 Request::WebTransport(request) => request.url(),
263 Request::Quic(request) => request.url(),
264 }
265 }
266}
267
268pub struct QuicRequest {
269 connection: quinn::Connection,
270 url: Url,
271}
272
273impl QuicRequest {
274 pub fn accept(connection: quinn::Connection) -> Self {
276 let url: Url = format!("moql://{}", connection.remote_address())
277 .parse()
278 .expect("URL is valid");
279 Self { connection, url }
280 }
281
282 pub fn ok(self) -> web_transport_quinn::Session {
284 web_transport_quinn::Session::raw(self.connection, self.url)
285 }
286
287 pub fn url(&self) -> &Url {
289 &self.url
290 }
291
292 pub fn close(self, status: http::StatusCode) {
296 self.connection
297 .close(status.as_u16().into(), status.as_str().as_bytes());
298 }
299}
300
301#[derive(Debug)]
302struct ServeCerts {
303 certs: Vec<Arc<CertifiedKey>>,
304 provider: crypto::Provider,
305}
306
307impl ServeCerts {
308 pub fn new(provider: crypto::Provider) -> Self {
309 Self {
310 certs: Vec::new(),
311 provider,
312 }
313 }
314
315 pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
317 let chain = fs::File::open(chain).context("failed to open cert file")?;
318 let mut chain = io::BufReader::new(chain);
319
320 let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
321 .collect::<Result<_, _>>()
322 .context("failed to read certs")?;
323
324 anyhow::ensure!(!chain.is_empty(), "could not find certificate");
325
326 let mut keys = fs::File::open(key).context("failed to open key file")?;
328
329 let mut buf = Vec::new();
331 keys.read_to_end(&mut buf)?;
332
333 let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
334 let key = self.provider.key_provider.load_private_key(key)?;
335
336 self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
337
338 Ok(())
339 }
340
341 pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
342 let key_pair = rcgen::KeyPair::generate()?;
343
344 let mut params = rcgen::CertificateParams::new(hostnames)?;
345
346 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
349 params.not_after = params.not_before + time::Duration::days(14);
350
351 let cert = params.self_signed(&key_pair)?;
353
354 let key_der = key_pair.serialized_der().to_vec();
356 let key_der = PrivatePkcs8KeyDer::from(key_der);
357 let key = self.provider.key_provider.load_private_key(key_der.into())?;
358
359 self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
361
362 Ok(())
363 }
364
365 pub fn fingerprints(&self) -> Vec<String> {
367 self.certs
368 .iter()
369 .map(|ck| {
370 let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
371 hex::encode(fingerprint)
372 })
373 .collect()
374 }
375
376 fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
378 let server_name = client_hello.server_name()?;
379 let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;
380
381 for ck in &self.certs {
382 let leaf: webpki::EndEntityCert = ck
383 .end_entity_cert()
384 .expect("missing certificate")
385 .try_into()
386 .expect("failed to parse certificate");
387
388 if leaf.verify_is_valid_for_subject_name(&dns_name).is_ok() {
389 return Some(ck.clone());
390 }
391 }
392
393 None
394 }
395}
396
397impl ResolvesServerCert for ServeCerts {
398 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
399 if let Some(cert) = self.best_certificate(&client_hello) {
400 return Some(cert);
401 }
402
403 tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
406
407 self.certs.first().cloned()
408 }
409}