1use std::path::PathBuf;
2use std::{net, sync::Arc, time::Duration};
3
4use anyhow::Context;
5use ring::digest::{digest, SHA256};
6use rustls::crypto::ring::sign::any_supported_type;
7use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
8use rustls::server::{ClientHello, ResolvesServerCert};
9use rustls::sign::CertifiedKey;
10use std::fs;
11use std::io::{self, Cursor, Read};
12use url::Url;
13
14use futures::future::BoxFuture;
15use futures::stream::{FuturesUnordered, StreamExt};
16use futures::FutureExt;
17
18use web_transport::quinn as web_transport_quinn;
19
20#[derive(clap::Args, Clone, Debug, serde::Serialize, serde::Deserialize)]
21#[serde(deny_unknown_fields)]
22pub struct ServerTlsCert {
23 pub chain: PathBuf,
24 pub key: PathBuf,
25}
26
27impl ServerTlsCert {
28 pub fn parse(s: &str) -> anyhow::Result<Self> {
30 let (chain, key) = s.split_once(':').context("invalid certificate")?;
31 Ok(Self {
32 chain: PathBuf::from(chain),
33 key: PathBuf::from(key),
34 })
35 }
36}
37
38#[derive(clap::Args, Clone, Default, Debug, serde::Serialize, serde::Deserialize)]
39#[serde(deny_unknown_fields)]
40pub struct ServerTlsConfig {
41 #[arg(long = "tls-cert", value_parser = ServerTlsCert::parse)]
43 #[serde(default, skip_serializing_if = "Vec::is_empty")]
44 pub cert: Vec<ServerTlsCert>,
45
46 #[arg(long = "tls-generate")]
49 #[serde(default, skip_serializing_if = "Vec::is_empty")]
50 pub generate: Vec<String>,
51}
52
53#[derive(clap::Args, Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
54#[serde(deny_unknown_fields, default)]
55pub struct ServerConfig {
56 #[arg(long)]
59 pub listen: Option<net::SocketAddr>,
60
61 #[command(flatten)]
62 #[serde(default)]
63 pub tls: ServerTlsConfig,
64}
65
66impl ServerConfig {
67 pub fn init(self) -> anyhow::Result<Server> {
68 Server::new(self)
69 }
70}
71
72pub struct Server {
73 quic: quinn::Endpoint,
74 accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Session>>>,
75 fingerprints: Vec<String>,
76}
77
78impl Server {
79 pub fn new(config: ServerConfig) -> anyhow::Result<Self> {
80 let mut transport = quinn::TransportConfig::default();
83 transport.max_idle_timeout(Some(Duration::from_secs(10).try_into().unwrap()));
84 transport.keep_alive_interval(Some(Duration::from_secs(4)));
85 transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
86 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
88
89 let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
90 let mut serve = ServeCerts::default();
91
92 for cert in &config.tls.cert {
94 serve.load(&cert.chain, &cert.key)?;
95 }
96
97 if !config.tls.generate.is_empty() {
98 serve.generate(&config.tls.generate)?;
99 }
100
101 let fingerprints = serve.fingerprints();
102
103 let mut tls = rustls::ServerConfig::builder_with_provider(provider)
104 .with_protocol_versions(&[&rustls::version::TLS13])?
105 .with_no_client_auth()
106 .with_cert_resolver(Arc::new(serve));
107
108 tls.alpn_protocols = vec![
109 web_transport::quinn::ALPN.as_bytes().to_vec(),
110 moq_lite::ALPN.as_bytes().to_vec(),
111 ];
112 tls.key_log = Arc::new(rustls::KeyLogFile::new());
113
114 let tls: quinn::crypto::rustls::QuicServerConfig = tls.try_into()?;
115 let mut tls = quinn::ServerConfig::with_crypto(Arc::new(tls));
116 tls.transport_config(transport.clone());
117
118 let runtime = quinn::default_runtime().context("no async runtime")?;
120 let endpoint_config = quinn::EndpointConfig::default();
121
122 let listen = config.listen.unwrap_or("[::]:443".parse().unwrap());
123 let socket = std::net::UdpSocket::bind(listen).context("failed to bind UDP socket")?;
124
125 let quic = quinn::Endpoint::new(endpoint_config, Some(tls), socket, runtime)
127 .context("failed to create QUIC endpoint")?;
128
129 Ok(Self {
130 quic: quic.clone(),
131 accept: Default::default(),
132 fingerprints,
133 })
134 }
135
136 pub fn fingerprints(&self) -> &[String] {
137 &self.fingerprints
138 }
139
140 pub async fn accept(&mut self) -> Option<web_transport_quinn::Session> {
141 loop {
142 tokio::select! {
143 res = self.quic.accept() => {
144 let conn = res?;
145 self.accept.push(Self::accept_session(conn).boxed());
146 }
147 Some(res) = self.accept.next() => {
148 if let Ok(session) = res {
149 return Some(session)
150 }
151 }
152 _ = tokio::signal::ctrl_c() => {
153 self.close();
154 tokio::time::sleep(Duration::from_millis(100)).await;
156
157 return None;
158 }
159 }
160 }
161 }
162
163 async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Session> {
164 let mut conn = conn.accept()?;
165
166 let handshake = conn
167 .handshake_data()
168 .await?
169 .downcast::<quinn::crypto::rustls::HandshakeData>()
170 .unwrap();
171
172 let alpn = handshake.protocol.context("missing ALPN")?;
173 let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
174 let host = handshake.server_name.unwrap_or_default();
175
176 tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
177
178 let conn = conn.await.context("failed to establish QUIC connection")?;
180
181 let span = tracing::Span::current();
182 span.record("id", conn.stable_id()); let session = match alpn.as_str() {
185 web_transport::quinn::ALPN => {
186 let request = web_transport::quinn::Request::accept(conn)
188 .await
189 .context("failed to receive WebTransport request")?;
190
191 request
193 .ok()
194 .await
195 .context("failed to respond to WebTransport request")?
196 }
197 moq_lite::ALPN => {
199 let url = Url::parse(format!("moql://{}", host).as_str()).unwrap();
201 web_transport::quinn::Session::raw(conn, url)
202 }
203 _ => anyhow::bail!("unsupported ALPN: {}", alpn),
204 };
205
206 Ok(session)
207 }
208
209 pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
210 self.quic.local_addr().context("failed to get local address")
211 }
212
213 pub fn close(&mut self) {
214 self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
215 }
216}
217
218#[derive(Debug, Default)]
219struct ServeCerts {
220 certs: Vec<Arc<CertifiedKey>>,
221}
222
223impl ServeCerts {
224 pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
226 let chain = fs::File::open(chain).context("failed to open cert file")?;
227 let mut chain = io::BufReader::new(chain);
228
229 let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
230 .collect::<Result<_, _>>()
231 .context("failed to read certs")?;
232
233 anyhow::ensure!(!chain.is_empty(), "could not find certificate");
234
235 let mut keys = fs::File::open(key).context("failed to open key file")?;
237
238 let mut buf = Vec::new();
240 keys.read_to_end(&mut buf)?;
241
242 let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
243 let key = rustls::crypto::ring::sign::any_supported_type(&key)?;
244
245 self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
246
247 Ok(())
248 }
249
250 pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
251 let key_pair = rcgen::KeyPair::generate()?;
252
253 let mut params = rcgen::CertificateParams::new(hostnames)?;
254
255 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
258 params.not_after = params.not_before + time::Duration::days(14);
259
260 let cert = params.self_signed(&key_pair)?;
262
263 let key = PrivatePkcs8KeyDer::from(key_pair.serialized_der());
265 let key = any_supported_type(&key.into())?;
266
267 self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
269
270 Ok(())
271 }
272
273 pub fn fingerprints(&self) -> Vec<String> {
275 self.certs
276 .iter()
277 .map(|ck| {
278 let fingerprint = digest(&SHA256, ck.cert[0].as_ref());
279 let fingerprint = hex::encode(fingerprint.as_ref());
280 fingerprint
281 })
282 .collect()
283 }
284}
285
286impl ResolvesServerCert for ServeCerts {
287 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
288 let server_name = client_hello.server_name()?;
289 let dns_name = webpki::DnsNameRef::try_from_ascii_str(server_name).ok()?;
290
291 for ck in &self.certs {
292 let leaf = ck.end_entity_cert().expect("missing certificate");
295 let parsed = webpki::EndEntityCert::try_from(leaf.as_ref()).expect("failed to parse certificate");
296
297 if parsed.verify_is_valid_for_dns_name(dns_name).is_ok() {
298 return Some(ck.clone());
299 }
300 }
301
302 tracing::warn!(?server_name, "no certificate found for server name");
303
304 None
305 }
306}