1mod priority;
4pub mod web_transport;
5use crate::{
6 ArcHandler, ArcedQuicEndpoint, BoxedBidiStream, QuicConnection, QuicTransportReceive,
7 QuicTransportSend, RuntimeTrait,
8};
9use priority::{PrioritizedStream, PriorityRegistry, transport_priority};
10use std::sync::Arc;
11use trillium::{Handler, KnownHeaderName, Listener, Upgrade};
12use trillium_http::{
13 HttpContext,
14 h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
15};
16use web_transport::{WebTransportDispatcher, WebTransportStream};
17
18#[derive(Clone, Copy, Debug)]
20pub struct StreamId(u64);
21impl From<StreamId> for u64 {
22 fn from(val: StreamId) -> Self {
23 val.0
24 }
25}
26
27impl From<u64> for StreamId {
28 fn from(value: u64) -> Self {
29 Self(value)
30 }
31}
32
33pub(crate) async fn run_h3(
34 quic_binding: ArcedQuicEndpoint,
35 context: Arc<HttpContext>,
36 handler: ArcHandler<impl Handler>,
37 runtime: impl RuntimeTrait,
38 listener: Option<Listener>,
39 local_alt_svc: Option<&'static str>,
40) {
41 let swansong = context.swansong();
42 while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
43 let h3 = H3Connection::new(context.clone());
44 let handler = handler.clone();
45 let runtime = runtime.clone();
46 runtime.clone().spawn(run_h3_connection(
47 connection,
48 h3,
49 handler,
50 runtime,
51 listener.clone(),
52 local_alt_svc,
53 ));
54 }
55}
56
57async fn run_h3_connection(
58 connection: QuicConnection,
59 h3: Arc<H3Connection>,
60 handler: ArcHandler<impl Handler>,
61 runtime: impl RuntimeTrait,
62 listener: Option<Listener>,
63 local_alt_svc: Option<&'static str>,
64) {
65 let wt_dispatcher = h3
66 .context()
67 .config()
68 .webtransport_enabled()
69 .then(WebTransportDispatcher::new);
70
71 log::trace!("new quic connection from {}", connection.remote_address());
72
73 let priorities = PriorityRegistry::default();
74 h3.register_priority_callback({
75 let priorities = priorities.clone();
76 move |stream_id, priority, is_update| {
77 priorities.apply(stream_id, transport_priority(priority), is_update)
78 }
79 });
80
81 spawn_outbound_control_stream(&connection, &h3, &runtime);
82 spawn_qpack_encoder_stream(&connection, &h3, &runtime);
83 spawn_qpack_decoder_stream(&connection, &h3, &runtime);
84 spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
85 handle_inbound_bidi_streams(
86 connection,
87 h3.clone(),
88 handler,
89 runtime,
90 wt_dispatcher,
91 listener,
92 local_alt_svc,
93 priorities,
94 )
95 .await;
96}
97
98#[allow(clippy::too_many_arguments)]
99async fn handle_inbound_bidi_streams(
100 connection: QuicConnection,
101 h3: Arc<H3Connection>,
102 handler: ArcHandler<impl Handler>,
103 runtime: impl RuntimeTrait,
104 wt_dispatcher: Option<WebTransportDispatcher>,
105 listener: Option<Listener>,
106 local_alt_svc: Option<&'static str>,
107 priorities: PriorityRegistry,
108) {
109 loop {
110 match h3.swansong().interrupt(connection.accept_bidi()).await {
111 None => {
112 log::trace!("H3 bidi accept loop: interrupted by swansong shutdown");
113 break;
114 }
115 Some(Err(e)) => {
116 log::debug!("H3 bidi accept loop: accept_bidi error: {e}");
117 break;
118 }
119 Some(Ok((stream_id, transport))) => {
120 handle_bidi_stream(
121 stream_id,
122 transport,
123 &h3,
124 &handler,
125 &connection,
126 &runtime,
127 &wt_dispatcher,
128 listener.clone(),
129 local_alt_svc,
130 &priorities,
131 );
132 }
133 }
134 }
135
136 h3.shut_down();
137}
138
139#[allow(clippy::too_many_arguments)]
140fn handle_bidi_stream(
141 stream_id: u64,
142 transport: BoxedBidiStream,
143 h3: &Arc<H3Connection>,
144 handler: &ArcHandler<impl Handler>,
145 connection: &QuicConnection,
146 runtime: &impl RuntimeTrait,
147 wt_dispatcher: &Option<WebTransportDispatcher>,
148 listener: Option<Listener>,
149 local_alt_svc: Option<&'static str>,
150 priorities: &PriorityRegistry,
151) {
152 log::trace!("H3 bidi stream {stream_id}: spawning handler task");
153 let (h3, handler, connection, wt_dispatcher, priorities) = (
154 h3.clone(),
155 handler.clone(),
156 connection.clone(),
157 wt_dispatcher.clone(),
158 priorities.clone(),
159 );
160
161 let slot = priorities.register(stream_id);
165 let transport: BoxedBidiStream = Box::new(PrioritizedStream::new(transport, slot, stream_id));
166
167 runtime.spawn(async move {
168 let peer_ip = connection.remote_address().ip();
169 let quic_connection = connection.clone();
170 let wt_dispatcher = wt_dispatcher.clone();
171
172 let handler_fn = {
173 let handler = handler.clone();
174 let wt_dispatcher = wt_dispatcher.clone();
175 move |mut conn: trillium_http::Conn<_>| async move {
176 conn.set_peer_ip(Some(peer_ip));
177 conn.set_secure(true);
178
179 let state = conn.state_mut();
180 state.insert(quic_connection);
181 state.insert(StreamId(stream_id));
182 if let Some(listener) = listener {
183 if let Some(addr) = listener.socket_addr() {
184 state.insert(addr);
185 }
186 state.insert(listener);
187 }
188 if let Some(dispatcher) = wt_dispatcher {
189 state.insert(dispatcher);
190 }
191 if let Some(alt_svc) = local_alt_svc {
192 conn.response_headers_mut()
193 .try_insert(KnownHeaderName::AltSvc, alt_svc);
194 }
195
196 let conn = handler.run(conn.into()).await;
197 let conn = handler.before_send(conn).await;
198
199 conn.into_inner()
200 }
201 };
202
203 let result = h3
204 .clone()
205 .process_inbound_bidi(transport, handler_fn, stream_id)
206 .with_reset(|t, code| {
207 let raw = u64::from(code);
212 t.stop(raw);
213 t.reset(raw);
214 })
215 .await;
216
217 match result {
218 Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
219 let upgrade = Upgrade::from(conn);
220 if handler.has_upgrade(&upgrade) {
221 log::debug!("upgrading h3 stream");
222 handler.upgrade(upgrade).await;
223 } else {
224 log::error!("h3 upgrade specified but no upgrade handler provided");
225 }
226 }
227
228 Ok(H3StreamResult::Request(_)) => {}
229
230 Ok(H3StreamResult::WebTransport {
231 session_id,
232 mut transport,
233 buffer,
234 }) => {
235 if let Some(dispatcher) = &wt_dispatcher {
236 dispatcher.dispatch(WebTransportStream::Bidi {
237 session_id,
238 stream: Box::new(transport),
239 buffer: buffer.into(),
240 });
241 } else {
242 transport.stop(H3ErrorCode::StreamCreationError.into());
243 transport.reset(H3ErrorCode::StreamCreationError.into());
244 }
245 }
246
247 Err(error) => {
248 log::debug!("H3 bidi stream {stream_id}: error: {error}");
249 handle_h3_error(error, &connection, &h3);
250 }
251 }
252
253 priorities.deregister(stream_id);
254 });
255}
256
257fn spawn_inbound_uni_streams(
258 connection: &QuicConnection,
259 h3: &Arc<H3Connection>,
260 runtime: &impl RuntimeTrait,
261 wt_dispatcher: &Option<WebTransportDispatcher>,
262) {
263 let (connection, h3, runtime, wt_dispatcher) = (
264 connection.clone(),
265 h3.clone(),
266 runtime.clone(),
267 wt_dispatcher.clone(),
268 );
269 runtime.clone().spawn(async move {
270 while let Some(Ok((_stream_id, recv))) =
271 h3.swansong().interrupt(connection.accept_uni()).await
272 {
273 let (connection, h3, wt_dispatcher) =
274 (connection.clone(), h3.clone(), wt_dispatcher.clone());
275
276 runtime.spawn(async move {
277 let close_connection = {
285 let connection = connection.clone();
286 let h3 = h3.clone();
287 move |code: H3ErrorCode| {
288 connection.close(code.into(), code.reason().as_bytes());
289 h3.shut_down();
290 }
291 };
292 let result = h3
293 .process_inbound_uni_with_close(recv, close_connection)
294 .await;
295
296 match result {
297 Ok(UniStreamResult::Handled) => {}
298 Ok(UniStreamResult::WebTransport {
299 session_id,
300 mut stream,
301 buffer,
302 }) => {
303 if let Some(dispatcher) = &wt_dispatcher {
304 dispatcher.dispatch(WebTransportStream::Uni {
305 session_id,
306 stream: Box::new(stream),
307 buffer: buffer.into(),
308 });
309 } else {
310 stream.stop(H3ErrorCode::StreamCreationError.into());
311 }
312 }
313
314 Ok(UniStreamResult::Unknown { mut stream, .. }) => {
315 stream.stop(H3ErrorCode::StreamCreationError.into());
316 }
317
318 Err(error) => {
319 handle_h3_error(error, &connection, &h3);
323 }
324 }
325 });
326 }
327
328 h3.shut_down();
329 });
330}
331
332fn spawn_qpack_decoder_stream(
333 connection: &QuicConnection,
334 h3: &Arc<H3Connection>,
335 runtime: &impl RuntimeTrait,
336) {
337 let (connection, h3) = (connection.clone(), h3.clone());
338
339 runtime.spawn(async move {
340 log::trace!("H3: opening outbound QPACK decoder stream");
341 let stream = match connection.open_uni().await {
342 Ok((_stream_id, stream)) => stream,
343 Err(err) => {
344 log::error!("H3: open_uni for QPACK decoder stream failed: {err:?}");
345 h3.shut_down();
346 return;
347 }
348 };
349
350 let result = h3.run_decoder(stream).await;
351
352 if let Err(error) = result {
353 handle_h3_error(error, &connection, &h3);
354 }
355
356 h3.shut_down();
357 });
358}
359
360fn spawn_qpack_encoder_stream(
361 connection: &QuicConnection,
362 h3: &Arc<H3Connection>,
363 runtime: &impl RuntimeTrait,
364) {
365 let (connection, h3) = (connection.clone(), h3.clone());
366 runtime.spawn(async move {
367 log::trace!("H3: opening outbound QPACK encoder stream");
368 let stream = match connection.open_uni().await {
369 Ok((_stream_id, stream)) => stream,
370 Err(err) => {
371 log::error!("H3: open_uni for QPACK encoder stream failed: {err:?}");
372 h3.shut_down();
373 return;
374 }
375 };
376
377 let result = h3.run_encoder(stream).await;
378
379 if let Err(error) = result {
380 handle_h3_error(error, &connection, &h3);
381 }
382
383 h3.shut_down();
384 });
385}
386
387fn spawn_outbound_control_stream(
388 connection: &QuicConnection,
389 h3: &Arc<H3Connection>,
390 runtime: &impl RuntimeTrait,
391) {
392 let (connection, h3) = (connection.clone(), h3.clone());
393 runtime.spawn(async move {
394 log::trace!("H3: opening outbound control stream");
395 let stream = match connection.open_uni().await {
396 Ok((_stream_id, stream)) => stream,
397 Err(err) => {
398 log::error!("H3: open_uni for outbound control stream failed: {err:?}");
399 h3.shut_down();
400 return;
401 }
402 };
403
404 let result = h3.run_outbound_control(stream).await;
405
406 if let Err(error) = result {
407 handle_h3_error(error, &connection, &h3);
408 }
409
410 h3.shut_down();
411 });
412}
413
414fn handle_h3_error(error: H3Error, connection: &QuicConnection, h3: &H3Connection) {
415 log::debug!("H3 error: {error}");
416 if let H3Error::Protocol(code) = error
417 && code.is_connection_error()
418 {
419 connection.close(code.into(), code.reason().as_bytes());
420 h3.shut_down();
421 }
422}