1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use anyhow::Context;
5use ring::digest::{digest, SHA256};
6use rustls::crypto::ring::sign::any_supported_type;
7use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
8use rustls::server::{ClientHello, ResolvesServerCert};
9use rustls::sign::CertifiedKey;
10use std::fs;
11use std::io::{self, Cursor, Read};
12
13use futures::future::BoxFuture;
14use futures::stream::{FuturesUnordered, StreamExt};
15use futures::FutureExt;
16
17#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
18#[serde(deny_unknown_fields)]
19pub struct ServerTlsCert {
20 pub chain: PathBuf,
21 pub key: PathBuf,
22}
23
24impl ServerTlsCert {
25 pub fn parse(s: &str) -> anyhow::Result<Self> {
27 let (chain, key) = s.split_once(':').context("invalid certificate")?;
28 Ok(Self {
29 chain: PathBuf::from(chain),
30 key: PathBuf::from(key),
31 })
32 }
33}
34
35#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
36#[serde(deny_unknown_fields)]
37pub struct ServerTlsConfig {
38 #[arg(long = "tls-cert", id = "tls-cert", env = "MOQ_SERVER_TLS_CERT")]
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
41 pub cert: Vec<PathBuf>,
42
43 #[arg(long = "tls-key", id = "tls-key", env = "MOQ_SERVER_TLS_KEY")]
45 #[serde(default, skip_serializing_if = "Vec::is_empty")]
46 pub key: Vec<PathBuf>,
47
48 #[arg(
51 long = "tls-generate",
52 id = "tls-generate",
53 value_delimiter = ',',
54 env = "MOQ_SERVER_TLS_GENERATE"
55 )]
56 #[serde(default, skip_serializing_if = "Vec::is_empty")]
57 pub generate: Vec<String>,
58}
59
60#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
61#[serde(deny_unknown_fields, default)]
62pub struct ServerConfig {
63 #[arg(long, env = "MOQ_SERVER_LISTEN")]
66 pub listen: Option<net::SocketAddr>,
67
68 #[command(flatten)]
69 #[serde(default)]
70 pub tls: ServerTlsConfig,
71}
72
73impl ServerConfig {
74 pub fn init(self) -> anyhow::Result<Server> {
75 Server::new(self)
76 }
77}
78
79pub struct Server {
80 quic: quinn::Endpoint,
81 accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Request>>>,
82 fingerprints: Vec<String>,
83}
84
85impl Server {
86 pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
87 let mut transport = quinn::TransportConfig::default();
90 transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
91 transport.keep_alive_interval(Some(Duration::from_secs(4)));
92 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
95
96 let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
97 let mut serve = ServeCerts::default();
98
99 anyhow::ensure!(
101 config.tls.cert.len() == config.tls.key.len(),
102 "must provide both cert and key"
103 );
104
105 for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) {
106 serve.load(cert, key)?;
107 }
108
109 if !config.tls.generate.is_empty() {
110 serve.generate(&config.tls.generate)?;
111 }
112
113 let fingerprints = serve.fingerprints();
114
115 let mut tls = rustls::ServerConfig::builder_with_provider(provider)
116 .with_protocol_versions(&[&rustls::version::TLS13])?
117 .with_no_client_auth()
118 .with_cert_resolver(Arc::new(serve));
119
120 tls.alpn_protocols = vec![
121 web_transport_quinn::ALPN.as_bytes().to_vec(),
122 moq_lite::ALPN.as_bytes().to_vec(),
123 ];
124 tls.key_log = Arc::new(rustls::KeyLogFile::new());
125
126 let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
127 let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
128 tls.transport_config(transport.clone());
129
130 let runtime = quinn::default_runtime().context("no async runtime")?;
132 let endpoint_config = quinn::EndpointConfig::default();
133
134 let listen = config.listen.unwrap_or("[::]:443".parse().unwrap());
135 let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
136
137 let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
139 .context("failed to create QUIC endpoint")?;
140
141 Ok(Self {
142 quic: quic.clone(),
143 accept: Default::default(),
144 fingerprints,
145 })
146 }
147
148 pub fn fingerprints(&self) -> &[String] {
149 &self.fingerprints
150 }
151
152 pub async fn accept(&mut self) -> Option<web_transport_quinn::Request> {
158 loop {
159 tokio::select! {
160 res = self.quic.accept() => {
161 let conn = res?;
162 self.accept.push(Self::accept_session(conn).boxed());
163 }
164 Some(res) = self.accept.next() => {
165 if let Ok(session) = res {
166 return Some(session)
167 }
168 }
169 _ = tokio::signal::ctrl_c() => {
170 self.close();
171 tokio::time::sleep(Duration::from_millis(100)).await;
173
174 return None;
175 }
176 }
177 }
178 }
179
180 async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Request> {
181 let mut conn = conn.accept()?;
182
183 let handshake = conn
184 .handshake_data()
185 .await?
186 .downcast::<quinn::crypto::rustls::HandshakeData>()
187 .unwrap();
188
189 let alpn = handshake.protocol.context("missing ALPN")?;
190 let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
191 let host = handshake.server_name.unwrap_or_default();
192
193 tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
194
195 let conn = conn.await.context("failed to establish QUIC connection")?;
197
198 let span = tracing::Span::current();
199 span.record("id", conn.stable_id()); match alpn.as_str() {
202 web_transport_quinn::ALPN => {
203 web_transport_quinn::Request::accept(conn)
205 .await
206 .context("failed to receive WebTransport request")
207 }
208 _ => anyhow::bail!("unsupported ALPN: {}", alpn),
210 }
211 }
212
213 pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
214 self.quic.local_addr().context("failed to get local address")
215 }
216
217 pub fn close(&mut self) {
218 self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
219 }
220}
221
222#[derive(Debug, Default)]
223struct ServeCerts {
224 certs: Vec<Arc<CertifiedKey>>,
225}
226
227impl ServeCerts {
228 pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
230 let chain = fs::File::open(chain).context("failed to open cert file")?;
231 let mut chain = io::BufReader::new(chain);
232
233 let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
234 .collect::<Result<_, _>>()
235 .context("failed to read certs")?;
236
237 anyhow::ensure!(!chain.is_empty(), "could not find certificate");
238
239 let mut keys = fs::File::open(key).context("failed to open key file")?;
241
242 let mut buf = Vec::new();
244 keys.read_to_end(&mut buf)?;
245
246 let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
247 let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
248
249 self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
250
251 Ok(())
252 }
253
254 pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
255 let key_pair = rcgen::KeyPair::generate()?;
256
257 let mut params = rcgen::CertificateParams::new(hostnames)?;
258
259 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
262 params.not_after = params.not_before + time::Duration::days(14);
263
264 let cert = params.self_signed(&key_pair)?;
266
267 let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
269 let key = any_supported_type(&key.into())?;
270
271 self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
273
274 Ok(())
275 }
276
277 pub fn fingerprints(&self) -> Vec<String> {
279 self.certs
280 .iter()
281 .map(|ck| {
282 let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
283 let fingerprint = hex::encode(fingerprint.as_ref());
284 fingerprint
285 })
286 .collect()
287 }
288
289 fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
291 let server_name = client_hello.server_name()?;
292 let dns_name = webpki::DnsNameRef::try_from_ascii_str(server_name).ok()?;
293
294 for ck in &self.certs {
295 let leaf = ck.end_entity_cert().expect("missing certificate");
298 let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
299
300 if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
301 return Some(ck.clone());
302 }
303 }
304
305 None
306 }
307}
308
309impl ResolvesServerCert for ServeCerts {
310 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
311 if let Some(cert) = self.best_certificate(&client_hello) {
312 return Some(cert);
313 }
314
315 tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
318
319 self.certs.first().cloned()
320 }
321}