Skip to main content

trillium_server_common/
h3.rs

1//! HTTP/3 specific exports
2
3mod priority;
4pub mod web_transport;
5use crate::{
6    ArcHandler, ArcedQuicEndpoint, BoxedBidiStream, QuicConnection, QuicTransportReceive,
7    QuicTransportSend, RuntimeTrait,
8};
9use priority::{PrioritizedStream, PriorityRegistry, transport_priority};
10use std::sync::Arc;
11use trillium::{Handler, KnownHeaderName, Listener, Upgrade};
12use trillium_http::{
13    HttpContext,
14    h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
15};
16use web_transport::{WebTransportDispatcher, WebTransportStream};
17
18/// A QUIC stream identifier
19#[derive(Clone, Copy, Debug)]
20pub struct StreamId(u64);
21impl From<StreamId> for u64 {
22    fn from(val: StreamId) -> Self {
23        val.0
24    }
25}
26
27impl From<u64> for StreamId {
28    fn from(value: u64) -> Self {
29        Self(value)
30    }
31}
32
33pub(crate) async fn run_h3(
34    quic_binding: ArcedQuicEndpoint,
35    context: Arc<HttpContext>,
36    handler: ArcHandler<impl Handler>,
37    runtime: impl RuntimeTrait,
38    listener: Option<Listener>,
39    local_alt_svc: Option<&'static str>,
40) {
41    let swansong = context.swansong();
42    while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
43        let h3 = H3Connection::new(context.clone());
44        let handler = handler.clone();
45        let runtime = runtime.clone();
46        runtime.clone().spawn(run_h3_connection(
47            connection,
48            h3,
49            handler,
50            runtime,
51            listener.clone(),
52            local_alt_svc,
53        ));
54    }
55}
56
57async fn run_h3_connection(
58    connection: QuicConnection,
59    h3: Arc<H3Connection>,
60    handler: ArcHandler<impl Handler>,
61    runtime: impl RuntimeTrait,
62    listener: Option<Listener>,
63    local_alt_svc: Option<&'static str>,
64) {
65    let wt_dispatcher = h3
66        .context()
67        .config()
68        .webtransport_enabled()
69        .then(WebTransportDispatcher::new);
70
71    log::trace!("new quic connection from {}", connection.remote_address());
72
73    let priorities = PriorityRegistry::default();
74    h3.register_priority_callback({
75        let priorities = priorities.clone();
76        move |stream_id, priority, is_update| {
77            priorities.apply(stream_id, transport_priority(priority), is_update)
78        }
79    });
80
81    spawn_outbound_control_stream(&connection, &h3, &runtime);
82    spawn_qpack_encoder_stream(&connection, &h3, &runtime);
83    spawn_qpack_decoder_stream(&connection, &h3, &runtime);
84    spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
85    handle_inbound_bidi_streams(
86        connection,
87        h3.clone(),
88        handler,
89        runtime,
90        wt_dispatcher,
91        listener,
92        local_alt_svc,
93        priorities,
94    )
95    .await;
96}
97
98#[allow(clippy::too_many_arguments)]
99async fn handle_inbound_bidi_streams(
100    connection: QuicConnection,
101    h3: Arc<H3Connection>,
102    handler: ArcHandler<impl Handler>,
103    runtime: impl RuntimeTrait,
104    wt_dispatcher: Option<WebTransportDispatcher>,
105    listener: Option<Listener>,
106    local_alt_svc: Option<&'static str>,
107    priorities: PriorityRegistry,
108) {
109    loop {
110        match h3.swansong().interrupt(connection.accept_bidi()).await {
111            None => {
112                log::trace!("H3 bidi accept loop: interrupted by swansong shutdown");
113                break;
114            }
115            Some(Err(e)) => {
116                log::debug!("H3 bidi accept loop: accept_bidi error: {e}");
117                break;
118            }
119            Some(Ok((stream_id, transport))) => {
120                handle_bidi_stream(
121                    stream_id,
122                    transport,
123                    &h3,
124                    &handler,
125                    &connection,
126                    &runtime,
127                    &wt_dispatcher,
128                    listener.clone(),
129                    local_alt_svc,
130                    &priorities,
131                );
132            }
133        }
134    }
135
136    h3.shut_down();
137}
138
139#[allow(clippy::too_many_arguments)]
140fn handle_bidi_stream(
141    stream_id: u64,
142    transport: BoxedBidiStream,
143    h3: &Arc<H3Connection>,
144    handler: &ArcHandler<impl Handler>,
145    connection: &QuicConnection,
146    runtime: &impl RuntimeTrait,
147    wt_dispatcher: &Option<WebTransportDispatcher>,
148    listener: Option<Listener>,
149    local_alt_svc: Option<&'static str>,
150    priorities: &PriorityRegistry,
151) {
152    log::trace!("H3 bidi stream {stream_id}: spawning handler task");
153    let (h3, handler, connection, wt_dispatcher, priorities) = (
154        h3.clone(),
155        handler.clone(),
156        connection.clone(),
157        wt_dispatcher.clone(),
158        priorities.clone(),
159    );
160
161    // Wrap the stream so RFC 9218 priority signals routed to its slot are applied to the QUIC
162    // send stream as it writes. trillium-http emits the initial priority and any PRIORITY_UPDATE
163    // to the connection callback, which stores into this slot.
164    let slot = priorities.register(stream_id);
165    let transport: BoxedBidiStream = Box::new(PrioritizedStream::new(transport, slot, stream_id));
166
167    runtime.spawn(async move {
168        let peer_ip = connection.remote_address().ip();
169        let quic_connection = connection.clone();
170        let wt_dispatcher = wt_dispatcher.clone();
171
172        let handler_fn = {
173            let handler = handler.clone();
174            let wt_dispatcher = wt_dispatcher.clone();
175            move |mut conn: trillium_http::Conn<_>| async move {
176                conn.set_peer_ip(Some(peer_ip));
177                conn.set_secure(true);
178
179                let state = conn.state_mut();
180                state.insert(quic_connection);
181                state.insert(StreamId(stream_id));
182                if let Some(listener) = listener {
183                    if let Some(addr) = listener.socket_addr() {
184                        state.insert(addr);
185                    }
186                    state.insert(listener);
187                }
188                if let Some(dispatcher) = wt_dispatcher {
189                    state.insert(dispatcher);
190                }
191                if let Some(alt_svc) = local_alt_svc {
192                    conn.response_headers_mut()
193                        .try_insert(KnownHeaderName::AltSvc, alt_svc);
194                }
195
196                let conn = handler.run(conn.into()).await;
197                let conn = handler.before_send(conn).await;
198
199                conn.into_inner()
200            }
201        };
202
203        let result = h3
204            .clone()
205            .process_inbound_bidi(transport, handler_fn, stream_id)
206            .with_reset(|t, code| {
207                // RFC 9114 §4.1.2: stream-level protocol errors (notably H3_MESSAGE_ERROR)
208                // MUST RST the stream. We stop the recv side and reset the send side with
209                // the same code so the peer sees the error on whichever direction it's
210                // listening on.
211                let raw = u64::from(code);
212                t.stop(raw);
213                t.reset(raw);
214            })
215            .await;
216
217        match result {
218            Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
219                let upgrade = Upgrade::from(conn);
220                if handler.has_upgrade(&upgrade) {
221                    log::debug!("upgrading h3 stream");
222                    handler.upgrade(upgrade).await;
223                } else {
224                    log::error!("h3 upgrade specified but no upgrade handler provided");
225                }
226            }
227
228            Ok(H3StreamResult::Request(_)) => {}
229
230            Ok(H3StreamResult::WebTransport {
231                session_id,
232                mut transport,
233                buffer,
234            }) => {
235                if let Some(dispatcher) = &wt_dispatcher {
236                    dispatcher.dispatch(WebTransportStream::Bidi {
237                        session_id,
238                        stream: Box::new(transport),
239                        buffer: buffer.into(),
240                    });
241                } else {
242                    transport.stop(H3ErrorCode::StreamCreationError.into());
243                    transport.reset(H3ErrorCode::StreamCreationError.into());
244                }
245            }
246
247            Err(error) => {
248                log::debug!("H3 bidi stream {stream_id}: error: {error}");
249                handle_h3_error(error, &connection, &h3);
250            }
251        }
252
253        priorities.deregister(stream_id);
254    });
255}
256
257fn spawn_inbound_uni_streams(
258    connection: &QuicConnection,
259    h3: &Arc<H3Connection>,
260    runtime: &impl RuntimeTrait,
261    wt_dispatcher: &Option<WebTransportDispatcher>,
262) {
263    let (connection, h3, runtime, wt_dispatcher) = (
264        connection.clone(),
265        h3.clone(),
266        runtime.clone(),
267        wt_dispatcher.clone(),
268    );
269    runtime.clone().spawn(async move {
270        while let Some(Ok((_stream_id, recv))) =
271            h3.swansong().interrupt(connection.accept_uni()).await
272        {
273            let (connection, h3, wt_dispatcher) =
274                (connection.clone(), h3.clone(), wt_dispatcher.clone());
275
276            runtime.spawn(async move {
277                // RFC 9114 §8.1 / RFC 9204 §6 connection-level errors must close the
278                // QUIC connection while the recv stream is still alive — otherwise
279                // quinn's RecvStream::drop sends STOP_SENDING, and the peer's malformed
280                // RESET_STREAM response can race ahead and override our app error code
281                // with FINAL_SIZE_ERROR on the wire. The closure fires inside
282                // process_inbound_uni_with_close before stream drops, so the close sets
283                // quinn's conn.error first and the drop becomes a no-op.
284                let close_connection = {
285                    let connection = connection.clone();
286                    let h3 = h3.clone();
287                    move |code: H3ErrorCode| {
288                        connection.close(code.into(), code.reason().as_bytes());
289                        h3.shut_down();
290                    }
291                };
292                let result = h3
293                    .process_inbound_uni_with_close(recv, close_connection)
294                    .await;
295
296                match result {
297                    Ok(UniStreamResult::Handled) => {}
298                    Ok(UniStreamResult::WebTransport {
299                        session_id,
300                        mut stream,
301                        buffer,
302                    }) => {
303                        if let Some(dispatcher) = &wt_dispatcher {
304                            dispatcher.dispatch(WebTransportStream::Uni {
305                                session_id,
306                                stream: Box::new(stream),
307                                buffer: buffer.into(),
308                            });
309                        } else {
310                            stream.stop(H3ErrorCode::StreamCreationError.into());
311                        }
312                    }
313
314                    Ok(UniStreamResult::Unknown { mut stream, .. }) => {
315                        stream.stop(H3ErrorCode::StreamCreationError.into());
316                    }
317
318                    Err(error) => {
319                        // Connection-level protocol errors already fired the close
320                        // callback above; this call is a no-op for the close path
321                        // (idempotent) and still useful for logging plus I/O errors.
322                        handle_h3_error(error, &connection, &h3);
323                    }
324                }
325            });
326        }
327
328        h3.shut_down();
329    });
330}
331
332fn spawn_qpack_decoder_stream(
333    connection: &QuicConnection,
334    h3: &Arc<H3Connection>,
335    runtime: &impl RuntimeTrait,
336) {
337    let (connection, h3) = (connection.clone(), h3.clone());
338
339    runtime.spawn(async move {
340        log::trace!("H3: opening outbound QPACK decoder stream");
341        let stream = match connection.open_uni().await {
342            Ok((_stream_id, stream)) => stream,
343            Err(err) => {
344                log::error!("H3: open_uni for QPACK decoder stream failed: {err:?}");
345                h3.shut_down();
346                return;
347            }
348        };
349
350        let result = h3.run_decoder(stream).await;
351
352        if let Err(error) = result {
353            handle_h3_error(error, &connection, &h3);
354        }
355
356        h3.shut_down();
357    });
358}
359
360fn spawn_qpack_encoder_stream(
361    connection: &QuicConnection,
362    h3: &Arc<H3Connection>,
363    runtime: &impl RuntimeTrait,
364) {
365    let (connection, h3) = (connection.clone(), h3.clone());
366    runtime.spawn(async move {
367        log::trace!("H3: opening outbound QPACK encoder stream");
368        let stream = match connection.open_uni().await {
369            Ok((_stream_id, stream)) => stream,
370            Err(err) => {
371                log::error!("H3: open_uni for QPACK encoder stream failed: {err:?}");
372                h3.shut_down();
373                return;
374            }
375        };
376
377        let result = h3.run_encoder(stream).await;
378
379        if let Err(error) = result {
380            handle_h3_error(error, &connection, &h3);
381        }
382
383        h3.shut_down();
384    });
385}
386
387fn spawn_outbound_control_stream(
388    connection: &QuicConnection,
389    h3: &Arc<H3Connection>,
390    runtime: &impl RuntimeTrait,
391) {
392    let (connection, h3) = (connection.clone(), h3.clone());
393    runtime.spawn(async move {
394        log::trace!("H3: opening outbound control stream");
395        let stream = match connection.open_uni().await {
396            Ok((_stream_id, stream)) => stream,
397            Err(err) => {
398                log::error!("H3: open_uni for outbound control stream failed: {err:?}");
399                h3.shut_down();
400                return;
401            }
402        };
403
404        let result = h3.run_outbound_control(stream).await;
405
406        if let Err(error) = result {
407            handle_h3_error(error, &connection, &h3);
408        }
409
410        h3.shut_down();
411    });
412}
413
414fn handle_h3_error(error: H3Error, connection: &QuicConnection, h3: &H3Connection) {
415    log::debug!("H3 error: {error}");
416    if let H3Error::Protocol(code) = error
417        && code.is_connection_error()
418    {
419        connection.close(code.into(), code.reason().as_bytes());
420        h3.shut_down();
421    }
422}