1use std::{net, sync::Arc, time};
2
3use anyhow::Context;
4use clap::Parser;
5use url::Url;
6
7use crate::tls;
8
9use futures::future::BoxFuture;
10use futures::stream::{FuturesUnordered, StreamExt};
11use futures::FutureExt;
12
13use web_transport::quinn as web_transport_quinn;
14
15#[derive(Parser, Clone)]
16pub struct Args {
17 #[arg(long, default_value = "[::]:0")]
19 pub bind: net::SocketAddr,
20
21 #[command(flatten)]
22 pub tls: tls::Args,
23}
24
25impl Default for Args {
26 fn default() -> Self {
27 Self {
28 bind: "[::]:0".parse().unwrap(),
29 tls: Default::default(),
30 }
31 }
32}
33
34impl Args {
35 pub fn load(&self) -> anyhow::Result<Config> {
36 let tls = self.tls.load()?;
37 Ok(Config { bind: self.bind, tls })
38 }
39}
40
41#[derive(Clone)]
42pub struct Config {
43 pub bind: net::SocketAddr,
44 pub tls: tls::Config,
45}
46
47pub struct Endpoint {
48 pub client: Client,
49 pub server: Option<Server>,
50}
51
52impl Endpoint {
53 pub fn new(config: Config) -> anyhow::Result<Self> {
54 let mut transport = quinn::TransportConfig::default();
57 transport.max_idle_timeout(Some(time::Duration::from_secs(10).try_into().unwrap()));
58 transport.keep_alive_interval(Some(time::Duration::from_secs(4)));
59 transport.congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
60 transport.mtu_discovery_config(None); let transport = Arc::new(transport);
62
63 let mut server_config = None;
64
65 if let Some(mut config) = config.tls.server {
66 config.alpn_protocols = vec![
67 web_transport::quinn::ALPN.as_bytes().to_vec(),
68 moq_lite::ALPN.as_bytes().to_vec(),
69 ];
70 config.key_log = Arc::new(rustls::KeyLogFile::new());
71
72 let config: quinn::crypto::rustls::QuicServerConfig = config.try_into()?;
73 let mut config = quinn::ServerConfig::with_crypto(Arc::new(config));
74 config.transport_config(transport.clone());
75
76 server_config = Some(config);
77 }
78
79 let runtime = quinn::default_runtime().context("no async runtime")?;
81 let endpoint_config = quinn::EndpointConfig::default();
82 let socket = std::net::UdpSocket::bind(config.bind).context("failed to bind UDP socket")?;
83
84 let quic = quinn::Endpoint::new(endpoint_config, server_config.clone(), socket, runtime)
86 .context("failed to create QUIC endpoint")?;
87
88 let server = server_config.is_some().then(|| Server {
89 quic: quic.clone(),
90 accept: Default::default(),
91 });
92
93 let client = Client {
94 quic,
95 config: config.tls.client,
96 transport,
97 };
98
99 Ok(Self { client, server })
100 }
101}
102
103pub struct Server {
104 quic: quinn::Endpoint,
105 accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<web_transport_quinn::Session>>>,
106}
107
108impl Server {
109 pub async fn accept(&mut self) -> Option<web_transport_quinn::Session> {
110 loop {
111 tokio::select! {
112 res = self.quic.accept() => {
113 let conn = res?;
114 self.accept.push(Self::accept_session(conn).boxed());
115 }
116 Some(res) = self.accept.next() => {
117 if let Ok(session) = res {
118 return Some(session)
119 }
120 }
121 _ = tokio::signal::ctrl_c() => {
122 self.close();
123 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
125
126 return None;
127 }
128 }
129 }
130 }
131
132 async fn accept_session(conn: quinn::Incoming) -> anyhow::Result<web_transport_quinn::Session> {
133 let mut conn = conn.accept()?;
134
135 let handshake = conn
136 .handshake_data()
137 .await?
138 .downcast::<quinn::crypto::rustls::HandshakeData>()
139 .unwrap();
140
141 let alpn = handshake.protocol.context("missing ALPN")?;
142 let alpn = String::from_utf8(alpn).context("failed to decode ALPN")?;
143 let host = handshake.server_name.unwrap_or_default();
144
145 tracing::debug!(%host, ip = %conn.remote_address(), %alpn, "accepting");
146
147 let conn = conn.await.context("failed to establish QUIC connection")?;
149
150 let span = tracing::Span::current();
151 span.record("id", conn.stable_id()); let session = match alpn.as_str() {
154 web_transport::quinn::ALPN => {
155 let request = web_transport::quinn::Request::accept(conn)
157 .await
158 .context("failed to receive WebTransport request")?;
159
160 request
162 .ok()
163 .await
164 .context("failed to respond to WebTransport request")?
165 }
166 moq_lite::ALPN => {
168 let url = Url::parse(format!("moql://{}", host).as_str()).unwrap();
170 web_transport::quinn::Session::raw(conn, url)
171 }
172 _ => anyhow::bail!("unsupported ALPN: {}", alpn),
173 };
174
175 Ok(session)
176 }
177
178 pub fn local_addr(&self) -> anyhow::Result<net::SocketAddr> {
179 self.quic.local_addr().context("failed to get local address")
180 }
181
182 pub fn close(&mut self) {
183 self.quic.close(quinn::VarInt::from_u32(0), b"server shutdown");
184 }
185}
186
187#[derive(Clone)]
188pub struct Client {
189 quic: quinn::Endpoint,
190 config: rustls::ClientConfig,
191 transport: Arc<quinn::TransportConfig>,
192}
193
194impl Client {
195 pub async fn connect(&self, mut url: Url) -> anyhow::Result<web_transport_quinn::Session> {
196 let mut config = self.config.clone();
197
198 let host = url.host().context("invalid DNS name")?.to_string();
199 let port = url.port().unwrap_or(443);
200
201 let ip = tokio::net::lookup_host((host.clone(), port))
203 .await
204 .context("failed DNS lookup")?
205 .next()
206 .context("no DNS entries")?;
207
208 if url.scheme() == "http" {
209 let mut fingerprint = url.clone();
211 fingerprint.set_path("/certificate.sha256");
212
213 tracing::warn!(url = %fingerprint, "performing insecure HTTP request for certificate");
214
215 let resp = reqwest::get(fingerprint.as_str())
216 .await
217 .context("failed to fetch fingerprint")?
218 .error_for_status()
219 .context("fingerprint request failed")?;
220
221 let fingerprint = resp.text().await.context("failed to read fingerprint")?;
222 let fingerprint = hex::decode(fingerprint.trim()).context("invalid fingerprint")?;
223
224 let verifier = tls::FingerprintVerifier::new(config.crypto_provider().clone(), fingerprint);
225 config.dangerous().set_certificate_verifier(Arc::new(verifier));
226
227 url.set_scheme("https").expect("failed to set scheme");
228 }
229
230 let alpn = match url.scheme() {
231 "https" => web_transport::quinn::ALPN,
232 "moql" => moq_lite::ALPN,
233 _ => anyhow::bail!("url scheme must be 'http', 'https', or 'moql'"),
234 };
235
236 config.alpn_protocols = vec![alpn.as_bytes().to_vec()];
238 config.key_log = Arc::new(rustls::KeyLogFile::new());
239
240 let config: quinn::crypto::rustls::QuicClientConfig = config.try_into()?;
241 let mut config = quinn::ClientConfig::new(Arc::new(config));
242 config.transport_config(self.transport.clone());
243
244 tracing::debug!(%url, %ip, %alpn, "connecting");
245
246 let connection = self.quic.connect_with(config, ip, &host)?.await?;
247 tracing::Span::current().record("id", connection.stable_id());
248
249 let session = match url.scheme() {
250 "https" => web_transport::quinn::Session::connect(connection, url).await?,
251 moq_lite::ALPN => web_transport::quinn::Session::raw(connection, url),
252 _ => unreachable!(),
253 };
254
255 Ok(session)
256 }
257}