1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use anyhow::Context;
5use ring::digest::{digest, SHA256};
6use rustls::crypto::ring::sign::any_supported_type;
7use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
8use rustls::server::{ClientHello, ResolvesServerCert};
9use rustls::sign::CertifiedKey;
10use std::fs;
11use std::io::{self, Cursor, Read};
12
13use futures::future::BoxFuture;
14use futures::stream::{FuturesUnordered, StreamExt};
15use futures::FutureExt;
16
17use web_transport::quinn as web_transport_quinn;
18
19#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
20#[serde(deny_unknown_fields)]
21pub struct ServerTlsCert {
22 pub chain: PathBuf,
23 pub key: PathBuf,
24}
25
26impl ServerTlsCert {
27 pub fn parse(s: &str) -> anyhow::Result<Self> {
29 let (chain, key) = s.split_once(':').context("invalid certificate")?;
30 Ok(Self {
31 chain: PathBuf::from(chain),
32 key: PathBuf::from(key),
33 })
34 }
35}
36
37#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct ServerTlsConfig {
40 #[arg(long = "tls-cert", value_parser = ServerTlsCert::parse, value_delimiter = ',', env = "MOQ_SERVER_TLS_CERT")]
42 #[serde(default, skip_serializing_if = "Vec::is_empty")]
43 pub cert: Vec<ServerTlsCert>,
44
45 #[arg(long = "tls-generate", value_delimiter = ',', env = "MOQ_SERVER_TLS_GENERATE")]
48 #[serde(default, skip_serializing_if = "Vec::is_empty")]
49 pub generate: Vec<String>,
50}
51
52#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
53#[serde(deny_unknown_fields, default)]
54pub struct ServerConfig {
55 #[arg(long, env = "MOQ_SERVER_LISTEN")]
58 pub listen: Option<net::SocketAddr>,
59
60 #[command(flatten)]
61 #[serde(default)]
62 pub tls: ServerTlsConfig,
63}
64
65impl ServerConfig {
66 pub fn init(self) -> anyhow::Result<Server> {
67 Server::new(self)
68 }
69}
70
71pub struct Server {
72 quic: quinn::Endpoint,
73 accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Request>>>,
74 fingerprints: Vec<String>,
75}
76
77impl Server {
78 pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
79 let mut transport = quinn::TransportConfig::default();
82 transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
83 transport.keep_alive_interval(Some(Duration::from_secs(4)));
84 transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
85 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
87
88 let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
89 let mut serve = ServeCerts::default();
90
91 for cert in &config.tls.cert {
93 serve.load(&cert.chain, &cert.key)?;
94 }
95
96 if !config.tls.generate.is_empty() {
97 serve.generate(&config.tls.generate)?;
98 }
99
100 let fingerprints = serve.fingerprints();
101
102 let mut tls = rustls::ServerConfig::builder_with_provider(provider)
103 .with_protocol_versions(&[&rustls::version::TLS13])?
104 .with_no_client_auth()
105 .with_cert_resolver(Arc::new(serve));
106
107 tls.alpn_protocols = vec![
108 web_transport::quinn::ALPN.as_bytes().to_vec(),
109 moq_lite::ALPN.as_bytes().to_vec(),
110 ];
111 tls.key_log = Arc::new(rustls::KeyLogFile::new());
112
113 let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
114 let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
115 tls.transport_config(transport.clone());
116
117 let runtime = quinn::default_runtime().context("no async runtime")?;
119 let endpoint_config = quinn::EndpointConfig::default();
120
121 let listen = config.listen.unwrap_or("[::]:443".parse().unwrap());
122 let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
123
124 let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
126 .context("failed to create QUIC endpoint")?;
127
128 Ok(Self {
129 quic: quic.clone(),
130 accept: Default::default(),
131 fingerprints,
132 })
133 }
134
135 pub fn fingerprints(&self) -> &[String] {
136 &self.fingerprints
137 }
138
139 pub async fn accept(&mut self) -> Option<web_transport_quinn::Request> {
145 loop {
146 tokio::select! {
147 res = self.quic.accept() => {
148 let conn = res?;
149 self.accept.push(Self::accept_session(conn).boxed());
150 }
151 Some(res) = self.accept.next() => {
152 if let Ok(session) = res {
153 return Some(session)
154 }
155 }
156 _ = tokio::signal::ctrl_c() => {
157 self.close();
158 tokio::time::sleep(Duration::from_millis(100)).await;
160
161 return None;
162 }
163 }
164 }
165 }
166
167 async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Request> {
168 let mut conn = conn.accept()?;
169
170 let handshake = conn
171 .handshake_data()
172 .await?
173 .downcast::<quinn::crypto::rustls::HandshakeData>()
174 .unwrap();
175
176 let alpn = handshake.protocol.context("missing ALPN")?;
177 let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
178 let host = handshake.server_name.unwrap_or_default();
179
180 tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
181
182 let conn = conn.await.context("failed to establish QUIC connection")?;
184
185 let span = tracing::Span::current();
186 span.record("id", conn.stable_id()); match alpn.as_str() {
189 web_transport::quinn::ALPN => {
190 web_transport::quinn::Request::accept(conn)
192 .await
193 .context("failed to receive WebTransport request")
194 }
195 _ => anyhow::bail!("unsupported ALPN: {}", alpn),
197 }
198 }
199
200 pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
201 self.quic.local_addr().context("failed to get local address")
202 }
203
204 pub fn close(&mut self) {
205 self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
206 }
207}
208
209#[derive(Debug, Default)]
210struct ServeCerts {
211 certs: Vec<Arc<CertifiedKey>>,
212}
213
214impl ServeCerts {
215 pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
217 let chain = fs::File::open(chain).context("failed to open cert file")?;
218 let mut chain = io::BufReader::new(chain);
219
220 let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
221 .collect::<Result<_, _>>()
222 .context("failed to read certs")?;
223
224 anyhow::ensure!(!chain.is_empty(), "could not find certificate");
225
226 let mut keys = fs::File::open(key).context("failed to open key file")?;
228
229 let mut buf = Vec::new();
231 keys.read_to_end(&mut buf)?;
232
233 let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
234 let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
235
236 self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
237
238 Ok(())
239 }
240
241 pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
242 let key_pair = rcgen::KeyPair::generate()?;
243
244 let mut params = rcgen::CertificateParams::new(hostnames)?;
245
246 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
249 params.not_after = params.not_before + time::Duration::days(14);
250
251 let cert = params.self_signed(&key_pair)?;
253
254 let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
256 let key = any_supported_type(&key.into())?;
257
258 self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
260
261 Ok(())
262 }
263
264 pub fn fingerprints(&self) -> Vec<String> {
266 self.certs
267 .iter()
268 .map(|ck| {
269 let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
270 let fingerprint = hex::encode(fingerprint.as_ref());
271 fingerprint
272 })
273 .collect()
274 }
275
276 fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
278 let server_name = client_hello.server_name()?;
279 let dns_name = webpki::DnsNameRef::try_from_ascii_str(server_name).ok()?;
280
281 for ck in &self.certs {
282 let leaf = ck.end_entity_cert().expect("missing certificate");
285 let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
286
287 if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
288 return Some(ck.clone());
289 }
290 }
291
292 None
293 }
294}
295
296impl ResolvesServerCert for ServeCerts {
297 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
298 if let Some(cert) = self.best_certificate(&client_hello) {
299 return Some(cert);
300 }
301
302 tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");
305
306 self.certs.first().cloned()
307 }
308}