h3_quinn/
lib.rs

1//! QUIC Transport implementation with Quinn
2//!
3//! This module implements QUIC traits with Quinn.
4#![deny(missing_docs)]
5
6use std::{
7    convert::TryInto,
8    future::Future,
9    pin::Pin,
10    sync::Arc,
11    task::{self, Poll},
12};
13
14use bytes::{Buf, Bytes};
15
16use futures::{
17    ready,
18    stream::{self},
19    Stream, StreamExt,
20};
21
22use quinn::ReadError;
23pub use quinn::{self, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt};
24
25use h3::{
26    error::Code,
27    quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, StreamId, WriteBuf},
28};
29use tokio_util::sync::ReusableBoxFuture;
30
31#[cfg(feature = "tracing")]
32use tracing::instrument;
33
34#[cfg(feature = "datagram")]
35pub mod datagram;
36
37/// BoxStream with Sync trait
38type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Sync + Send + 'a>>;
39
40/// A QUIC connection backed by Quinn
41///
42/// Implements a [`quic::Connection`] backed by a [`quinn::Connection`].
43pub struct Connection {
44    conn: quinn::Connection,
45    incoming_bi: BoxStreamSync<'static, <AcceptBi<'static> as Future>::Output>,
46    opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
47    incoming_uni: BoxStreamSync<'static, <AcceptUni<'static> as Future>::Output>,
48    opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
49}
50
51impl Connection {
52    /// Create a [`Connection`] from a [`quinn::Connection`]
53    pub fn new(conn: quinn::Connection) -> Self {
54        Self {
55            conn: conn.clone(),
56            incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async {
57                Some((conn.accept_bi().await, conn))
58            })),
59            opening_bi: None,
60            incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
61                Some((conn.accept_uni().await, conn))
62            })),
63            opening_uni: None,
64        }
65    }
66}
67
68impl<B> quic::Connection<B> for Connection
69where
70    B: Buf,
71{
72    type RecvStream = RecvStream;
73    type OpenStreams = OpenStreams;
74
75    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
76    fn poll_accept_bidi(
77        &mut self,
78        cx: &mut task::Context<'_>,
79    ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
80        let (send, recv) = ready!(self.incoming_bi.poll_next_unpin(cx))
81            .expect("self.incoming_bi BoxStream never returns None")
82            .map_err(|e| convert_connection_error(e))?;
83        Poll::Ready(Ok(Self::BidiStream {
84            send: Self::SendStream::new(send),
85            recv: Self::RecvStream::new(recv),
86        }))
87    }
88
89    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
90    fn poll_accept_recv(
91        &mut self,
92        cx: &mut task::Context<'_>,
93    ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
94        let recv = ready!(self.incoming_uni.poll_next_unpin(cx))
95            .expect("self.incoming_uni BoxStream never returns None")
96            .map_err(|e| convert_connection_error(e))?;
97        Poll::Ready(Ok(Self::RecvStream::new(recv)))
98    }
99
100    fn opener(&self) -> Self::OpenStreams {
101        OpenStreams {
102            conn: self.conn.clone(),
103            opening_bi: None,
104            opening_uni: None,
105        }
106    }
107}
108
109fn convert_connection_error(e: quinn::ConnectionError) -> h3::quic::ConnectionErrorIncoming {
110    match e {
111        quinn::ConnectionError::ApplicationClosed(application_close) => {
112            ConnectionErrorIncoming::ApplicationClose {
113                error_code: application_close.error_code.into(),
114            }
115        }
116        quinn::ConnectionError::TimedOut => ConnectionErrorIncoming::Timeout,
117
118        error @ quinn::ConnectionError::VersionMismatch
119        | error @ quinn::ConnectionError::Reset
120        | error @ quinn::ConnectionError::LocallyClosed
121        | error @ quinn::ConnectionError::CidsExhausted
122        | error @ quinn::ConnectionError::TransportError(_)
123        | error @ quinn::ConnectionError::ConnectionClosed(_) => {
124            ConnectionErrorIncoming::Undefined(Arc::new(error))
125        }
126    }
127}
128
129impl<B> quic::OpenStreams<B> for Connection
130where
131    B: Buf,
132{
133    type SendStream = SendStream<B>;
134    type BidiStream = BidiStream<B>;
135
136    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
137    fn poll_open_bidi(
138        &mut self,
139        cx: &mut task::Context<'_>,
140    ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
141        let bi = self.opening_bi.get_or_insert_with(|| {
142            Box::pin(stream::unfold(self.conn.clone(), |conn| async {
143                Some((conn.open_bi().await, conn))
144            }))
145        });
146        let (send, recv) = ready!(bi.poll_next_unpin(cx))
147            .expect("BoxStream does not return None")
148            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
149                connection_error: convert_connection_error(e),
150            })?;
151        Poll::Ready(Ok(Self::BidiStream {
152            send: Self::SendStream::new(send),
153            recv: RecvStream::new(recv),
154        }))
155    }
156
157    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
158    fn poll_open_send(
159        &mut self,
160        cx: &mut task::Context<'_>,
161    ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
162        let uni = self.opening_uni.get_or_insert_with(|| {
163            Box::pin(stream::unfold(self.conn.clone(), |conn| async {
164                Some((conn.open_uni().await, conn))
165            }))
166        });
167
168        let send = ready!(uni.poll_next_unpin(cx))
169            .expect("BoxStream does not return None")
170            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
171                connection_error: convert_connection_error(e),
172            })?;
173        Poll::Ready(Ok(Self::SendStream::new(send)))
174    }
175
176    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
177    fn close(&mut self, code: Code, reason: &[u8]) {
178        self.conn.close(
179            VarInt::from_u64(code.value()).expect("error code VarInt"),
180            reason,
181        );
182    }
183}
184
185/// Stream opener backed by a Quinn connection
186///
187/// Implements [`quic::OpenStreams`] using [`quinn::Connection`],
188/// [`quinn::OpenBi`], [`quinn::OpenUni`].
189pub struct OpenStreams {
190    conn: quinn::Connection,
191    opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
192    opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
193}
194
195impl<B> quic::OpenStreams<B> for OpenStreams
196where
197    B: Buf,
198{
199    type SendStream = SendStream<B>;
200    type BidiStream = BidiStream<B>;
201
202    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
203    fn poll_open_bidi(
204        &mut self,
205        cx: &mut task::Context<'_>,
206    ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
207        let bi = self.opening_bi.get_or_insert_with(|| {
208            Box::pin(stream::unfold(self.conn.clone(), |conn| async {
209                Some((conn.open_bi().await, conn))
210            }))
211        });
212
213        let (send, recv) = ready!(bi.poll_next_unpin(cx))
214            .expect("BoxStream does not return None")
215            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
216                connection_error: convert_connection_error(e),
217            })?;
218        Poll::Ready(Ok(Self::BidiStream {
219            send: Self::SendStream::new(send),
220            recv: RecvStream::new(recv),
221        }))
222    }
223
224    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
225    fn poll_open_send(
226        &mut self,
227        cx: &mut task::Context<'_>,
228    ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
229        let uni = self.opening_uni.get_or_insert_with(|| {
230            Box::pin(stream::unfold(self.conn.clone(), |conn| async {
231                Some((conn.open_uni().await, conn))
232            }))
233        });
234
235        let send = ready!(uni.poll_next_unpin(cx))
236            .expect("BoxStream does not return None")
237            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
238                connection_error: convert_connection_error(e),
239            })?;
240        Poll::Ready(Ok(Self::SendStream::new(send)))
241    }
242
243    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
244    fn close(&mut self, code: Code, reason: &[u8]) {
245        self.conn.close(
246            VarInt::from_u64(code.value()).expect("error code VarInt"),
247            reason,
248        );
249    }
250}
251
252impl Clone for OpenStreams {
253    fn clone(&self) -> Self {
254        Self {
255            conn: self.conn.clone(),
256            opening_bi: None,
257            opening_uni: None,
258        }
259    }
260}
261
262/// Quinn-backed bidirectional stream
263///
264/// Implements [`quic::BidiStream`] which allows the stream to be split
265/// into two structs each implementing one direction.
266pub struct BidiStream<B>
267where
268    B: Buf,
269{
270    send: SendStream<B>,
271    recv: RecvStream,
272}
273
274impl<B> quic::BidiStream<B> for BidiStream<B>
275where
276    B: Buf,
277{
278    type SendStream = SendStream<B>;
279    type RecvStream = RecvStream;
280
281    fn split(self) -> (Self::SendStream, Self::RecvStream) {
282        (self.send, self.recv)
283    }
284}
285
286impl<B: Buf> quic::RecvStream for BidiStream<B> {
287    type Buf = Bytes;
288
289    fn poll_data(
290        &mut self,
291        cx: &mut task::Context<'_>,
292    ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
293        self.recv.poll_data(cx)
294    }
295
296    fn stop_sending(&mut self, error_code: u64) {
297        self.recv.stop_sending(error_code)
298    }
299
300    fn recv_id(&self) -> StreamId {
301        self.recv.recv_id()
302    }
303}
304
305impl<B> quic::SendStream<B> for BidiStream<B>
306where
307    B: Buf,
308{
309    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
310        self.send.poll_ready(cx)
311    }
312
313    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
314        self.send.poll_finish(cx)
315    }
316
317    fn reset(&mut self, reset_code: u64) {
318        self.send.reset(reset_code)
319    }
320
321    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
322        self.send.send_data(data)
323    }
324
325    fn send_id(&self) -> StreamId {
326        self.send.send_id()
327    }
328}
329impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
330where
331    B: Buf,
332{
333    fn poll_send<D: Buf>(
334        &mut self,
335        cx: &mut task::Context<'_>,
336        buf: &mut D,
337    ) -> Poll<Result<usize, StreamErrorIncoming>> {
338        self.send.poll_send(cx, buf)
339    }
340}
341
342/// Quinn-backed receive stream
343///
344/// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`].
345pub struct RecvStream {
346    stream: Option<quinn::RecvStream>,
347    read_chunk_fut: ReadChunkFuture,
348}
349
350type ReadChunkFuture = ReusableBoxFuture<
351    'static,
352    (
353        quinn::RecvStream,
354        Result<Option<quinn::Chunk>, quinn::ReadError>,
355    ),
356>;
357
358impl RecvStream {
359    fn new(stream: quinn::RecvStream) -> Self {
360        Self {
361            stream: Some(stream),
362            // Should only allocate once the first time it's used
363            read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }),
364        }
365    }
366}
367
368impl quic::RecvStream for RecvStream {
369    type Buf = Bytes;
370
371    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
372    fn poll_data(
373        &mut self,
374        cx: &mut task::Context<'_>,
375    ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
376        if let Some(mut stream) = self.stream.take() {
377            self.read_chunk_fut.set(async move {
378                let chunk = stream.read_chunk(usize::MAX, true).await;
379                (stream, chunk)
380            })
381        };
382
383        let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx));
384        self.stream = Some(stream);
385        Poll::Ready(Ok(chunk
386            .map_err(|e| convert_read_error_to_stream_error(e))?
387            .map(|c| c.bytes)))
388    }
389
390    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
391    fn stop_sending(&mut self, error_code: u64) {
392        self.stream
393            .as_mut()
394            .unwrap()
395            .stop(VarInt::from_u64(error_code).expect("invalid error_code"))
396            .ok();
397    }
398
399    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
400    fn recv_id(&self) -> StreamId {
401        let num: u64 = self.stream.as_ref().unwrap().id().into();
402
403        num.try_into().expect("invalid stream id")
404    }
405}
406
407fn convert_read_error_to_stream_error(error: ReadError) -> StreamErrorIncoming {
408    match error {
409        ReadError::Reset(var_int) => StreamErrorIncoming::StreamTerminated {
410            error_code: var_int.into_inner(),
411        },
412        ReadError::ConnectionLost(connection_error) => {
413            StreamErrorIncoming::ConnectionErrorIncoming {
414                connection_error: convert_connection_error(connection_error),
415            }
416        }
417        error @ ReadError::ClosedStream => StreamErrorIncoming::Unknown(Box::new(error)),
418        ReadError::IllegalOrderedRead => panic!("h3-quinn only performs ordered reads"),
419        error @ ReadError::ZeroRttRejected => StreamErrorIncoming::Unknown(Box::new(error)),
420    }
421}
422
423fn convert_write_error_to_stream_error(error: quinn::WriteError) -> StreamErrorIncoming {
424    match error {
425        quinn::WriteError::Stopped(var_int) => StreamErrorIncoming::StreamTerminated {
426            error_code: var_int.into_inner(),
427        },
428        quinn::WriteError::ConnectionLost(connection_error) => {
429            StreamErrorIncoming::ConnectionErrorIncoming {
430                connection_error: convert_connection_error(connection_error),
431            }
432        }
433        error @ quinn::WriteError::ClosedStream | error @ quinn::WriteError::ZeroRttRejected => {
434            StreamErrorIncoming::Unknown(Box::new(error))
435        }
436    }
437}
438
439/// Quinn-backed send stream
440///
441/// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`].
442pub struct SendStream<B: Buf> {
443    stream: quinn::SendStream,
444    writing: Option<WriteBuf<B>>,
445}
446
447impl<B> SendStream<B>
448where
449    B: Buf,
450{
451    fn new(stream: quinn::SendStream) -> SendStream<B> {
452        Self {
453            stream: stream,
454            writing: None,
455        }
456    }
457}
458
459impl<B> quic::SendStream<B> for SendStream<B>
460where
461    B: Buf,
462{
463    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
464    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
465        if let Some(ref mut data) = self.writing {
466            while data.has_remaining() {
467                let stream = Pin::new(&mut self.stream);
468                let written = ready!(stream.poll_write(cx, data.chunk()))
469                    .map_err(|err| convert_write_error_to_stream_error(err))?;
470                data.advance(written);
471            }
472        }
473        // all data is written
474        self.writing = None;
475        Poll::Ready(Ok(()))
476    }
477
478    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
479    fn poll_finish(
480        &mut self,
481        _cx: &mut task::Context<'_>,
482    ) -> Poll<Result<(), StreamErrorIncoming>> {
483        Poll::Ready(
484            self.stream
485                .finish()
486                .map_err(|e| StreamErrorIncoming::Unknown(Box::new(e))),
487        )
488    }
489
490    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
491    fn reset(&mut self, reset_code: u64) {
492        let _ = self
493            .stream
494            .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX));
495    }
496
497    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
498    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
499        if self.writing.is_some() {
500            // This can only happen if the traits are misused by h3 itself
501            // If this happens log an error and close the connection with H3_INTERNAL_ERROR
502
503            #[cfg(feature = "tracing")]
504            tracing::error!("send_data called while send stream is not ready");
505            return Err(StreamErrorIncoming::ConnectionErrorIncoming {
506                connection_error: ConnectionErrorIncoming::InternalError(
507                    "internal error in the http stack".to_string(),
508                ),
509            });
510        }
511        self.writing = Some(data.into());
512        Ok(())
513    }
514
515    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
516    fn send_id(&self) -> StreamId {
517        let num: u64 = self.stream.id().into();
518        num.try_into().expect("invalid stream id")
519    }
520}
521
522impl<B> quic::SendStreamUnframed<B> for SendStream<B>
523where
524    B: Buf,
525{
526    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
527    fn poll_send<D: Buf>(
528        &mut self,
529        cx: &mut task::Context<'_>,
530        buf: &mut D,
531    ) -> Poll<Result<usize, StreamErrorIncoming>> {
532        if self.writing.is_some() {
533            // This signifies a bug in implementation
534            panic!("poll_send called while send stream is not ready")
535        }
536
537        let s = Pin::new(&mut self.stream);
538
539        let res = ready!(s.poll_write(cx, buf.chunk()));
540        match res {
541            Ok(written) => {
542                buf.advance(written);
543                Poll::Ready(Ok(written))
544            }
545            Err(err) => Poll::Ready(Err(convert_write_error_to_stream_error(err))),
546        }
547    }
548}