h3_quinn/
lib.rs

1//! QUIC Transport implementation with Quinn
2//!
3//! This module implements QUIC traits with Quinn.
4#![deny(missing_docs)]
5
6use std::{
7    convert::TryInto,
8    fmt::{self, Display},
9    future::Future,
10    pin::Pin,
11    sync::Arc,
12    task::{self, Poll},
13};
14
15use bytes::{Buf, Bytes, BytesMut};
16
17use futures::{
18    ready,
19    stream::{self},
20    Stream, StreamExt,
21};
22
23#[cfg(feature = "datagram")]
24use h3_datagram::{datagram::Datagram, quic_traits};
25
26pub use quinn::{self, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError};
27use quinn::{ApplicationClose, ClosedStream, ReadDatagram};
28
29use h3::quic::{self, Error, StreamId, WriteBuf};
30use tokio_util::sync::ReusableBoxFuture;
31
32#[cfg(feature = "tracing")]
33use tracing::instrument;
34
35/// BoxStream with Sync trait
36type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Sync + Send + 'a>>;
37
38/// A QUIC connection backed by Quinn
39///
40/// Implements a [`quic::Connection`] backed by a [`quinn::Connection`].
41pub struct Connection {
42    conn: quinn::Connection,
43    incoming_bi: BoxStreamSync<'static, <AcceptBi<'static> as Future>::Output>,
44    opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
45    incoming_uni: BoxStreamSync<'static, <AcceptUni<'static> as Future>::Output>,
46    opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
47    datagrams: BoxStreamSync<'static, <ReadDatagram<'static> as Future>::Output>,
48}
49
50impl Connection {
51    /// Create a [`Connection`] from a [`quinn::Connection`]
52    pub fn new(conn: quinn::Connection) -> Self {
53        Self {
54            conn: conn.clone(),
55            incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async {
56                Some((conn.accept_bi().await, conn))
57            })),
58            opening_bi: None,
59            incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
60                Some((conn.accept_uni().await, conn))
61            })),
62            opening_uni: None,
63            datagrams: Box::pin(stream::unfold(conn, |conn| async {
64                Some((conn.read_datagram().await, conn))
65            })),
66        }
67    }
68}
69
70/// The error type for [`Connection`]
71///
72/// Wraps reasons a Quinn connection might be lost.
73#[derive(Debug)]
74pub struct ConnectionError(quinn::ConnectionError);
75
76impl std::error::Error for ConnectionError {}
77
78impl fmt::Display for ConnectionError {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        self.0.fmt(f)
81    }
82}
83
84impl Error for ConnectionError {
85    fn is_timeout(&self) -> bool {
86        matches!(self.0, quinn::ConnectionError::TimedOut)
87    }
88
89    fn err_code(&self) -> Option<u64> {
90        match self.0 {
91            quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) => {
92                Some(error_code.into_inner())
93            }
94            _ => None,
95        }
96    }
97}
98
99impl From<quinn::ConnectionError> for ConnectionError {
100    fn from(e: quinn::ConnectionError) -> Self {
101        Self(e)
102    }
103}
104
105/// Types of errors when sending a datagram.
106#[derive(Debug)]
107pub enum SendDatagramError {
108    /// Datagrams are not supported by the peer
109    UnsupportedByPeer,
110    /// Datagrams are locally disabled
111    Disabled,
112    /// The datagram was too large to be sent.
113    TooLarge,
114    /// Network error
115    ConnectionLost(Box<dyn Error>),
116}
117
118impl fmt::Display for SendDatagramError {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        match self {
121            SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"),
122            SendDatagramError::Disabled => write!(f, "datagram support disabled"),
123            SendDatagramError::TooLarge => write!(f, "datagram too large"),
124            SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"),
125        }
126    }
127}
128
129impl std::error::Error for SendDatagramError {}
130
131impl Error for SendDatagramError {
132    fn is_timeout(&self) -> bool {
133        false
134    }
135
136    fn err_code(&self) -> Option<u64> {
137        match self {
138            Self::ConnectionLost(err) => err.err_code(),
139            _ => None,
140        }
141    }
142}
143
144impl From<quinn::SendDatagramError> for SendDatagramError {
145    fn from(value: quinn::SendDatagramError) -> Self {
146        match value {
147            quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer,
148            quinn::SendDatagramError::Disabled => Self::Disabled,
149            quinn::SendDatagramError::TooLarge => Self::TooLarge,
150            quinn::SendDatagramError::ConnectionLost(err) => {
151                Self::ConnectionLost(ConnectionError::from(err).into())
152            }
153        }
154    }
155}
156
157impl<B> quic::Connection<B> for Connection
158where
159    B: Buf,
160{
161    type RecvStream = RecvStream;
162    type OpenStreams = OpenStreams;
163    type AcceptError = ConnectionError;
164
165    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
166    fn poll_accept_bidi(
167        &mut self,
168        cx: &mut task::Context<'_>,
169    ) -> Poll<Result<Option<Self::BidiStream>, Self::AcceptError>> {
170        let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) {
171            Some(x) => x?,
172            None => return Poll::Ready(Ok(None)),
173        };
174        Poll::Ready(Ok(Some(Self::BidiStream {
175            send: Self::SendStream::new(send),
176            recv: Self::RecvStream::new(recv),
177        })))
178    }
179
180    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
181    fn poll_accept_recv(
182        &mut self,
183        cx: &mut task::Context<'_>,
184    ) -> Poll<Result<Option<Self::RecvStream>, Self::AcceptError>> {
185        let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) {
186            Some(x) => x?,
187            None => return Poll::Ready(Ok(None)),
188        };
189        Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
190    }
191
192    fn opener(&self) -> Self::OpenStreams {
193        OpenStreams {
194            conn: self.conn.clone(),
195            opening_bi: None,
196            opening_uni: None,
197        }
198    }
199}
200
201impl<B> quic::OpenStreams<B> for Connection
202where
203    B: Buf,
204{
205    type SendStream = SendStream<B>;
206    type BidiStream = BidiStream<B>;
207    type OpenError = ConnectionError;
208
209    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
210    fn poll_open_bidi(
211        &mut self,
212        cx: &mut task::Context<'_>,
213    ) -> Poll<Result<Self::BidiStream, Self::OpenError>> {
214        if self.opening_bi.is_none() {
215            self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
216                Some((conn.clone().open_bi().await, conn))
217            })));
218        }
219
220        let (send, recv) =
221            ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
222        Poll::Ready(Ok(Self::BidiStream {
223            send: Self::SendStream::new(send),
224            recv: RecvStream::new(recv),
225        }))
226    }
227
228    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
229    fn poll_open_send(
230        &mut self,
231        cx: &mut task::Context<'_>,
232    ) -> Poll<Result<Self::SendStream, Self::OpenError>> {
233        if self.opening_uni.is_none() {
234            self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
235                Some((conn.open_uni().await, conn))
236            })));
237        }
238
239        let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
240        Poll::Ready(Ok(Self::SendStream::new(send)))
241    }
242
243    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
244    fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
245        self.conn.close(
246            VarInt::from_u64(code.value()).expect("error code VarInt"),
247            reason,
248        );
249    }
250}
251
252#[cfg(feature = "datagram")]
253impl<B> quic_traits::SendDatagramExt<B> for Connection
254where
255    B: Buf,
256{
257    type Error = SendDatagramError;
258
259    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
260    fn send_datagram(&mut self, data: Datagram<B>) -> Result<(), SendDatagramError> {
261        // TODO investigate static buffer from known max datagram size
262        let mut buf = BytesMut::new();
263        data.encode(&mut buf);
264        self.conn.send_datagram(buf.freeze())?;
265
266        Ok(())
267    }
268}
269
270#[cfg(feature = "datagram")]
271impl quic_traits::RecvDatagramExt for Connection {
272    type Buf = Bytes;
273
274    type Error = ConnectionError;
275
276    #[inline]
277    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
278    fn poll_accept_datagram(
279        &mut self,
280        cx: &mut task::Context<'_>,
281    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
282        match ready!(self.datagrams.poll_next_unpin(cx)) {
283            Some(Ok(x)) => Poll::Ready(Ok(Some(x))),
284            Some(Err(e)) => Poll::Ready(Err(e.into())),
285            None => Poll::Ready(Ok(None)),
286        }
287    }
288}
289
290/// Stream opener backed by a Quinn connection
291///
292/// Implements [`quic::OpenStreams`] using [`quinn::Connection`],
293/// [`quinn::OpenBi`], [`quinn::OpenUni`].
294pub struct OpenStreams {
295    conn: quinn::Connection,
296    opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
297    opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
298}
299
300impl<B> quic::OpenStreams<B> for OpenStreams
301where
302    B: Buf,
303{
304    type SendStream = SendStream<B>;
305    type BidiStream = BidiStream<B>;
306    type OpenError = ConnectionError;
307
308    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
309    fn poll_open_bidi(
310        &mut self,
311        cx: &mut task::Context<'_>,
312    ) -> Poll<Result<Self::BidiStream, Self::OpenError>> {
313        if self.opening_bi.is_none() {
314            self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
315                Some((conn.open_bi().await, conn))
316            })));
317        }
318
319        let (send, recv) =
320            ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
321        Poll::Ready(Ok(Self::BidiStream {
322            send: Self::SendStream::new(send),
323            recv: RecvStream::new(recv),
324        }))
325    }
326
327    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
328    fn poll_open_send(
329        &mut self,
330        cx: &mut task::Context<'_>,
331    ) -> Poll<Result<Self::SendStream, Self::OpenError>> {
332        if self.opening_uni.is_none() {
333            self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
334                Some((conn.open_uni().await, conn))
335            })));
336        }
337
338        let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
339        Poll::Ready(Ok(Self::SendStream::new(send)))
340    }
341
342    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
343    fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
344        self.conn.close(
345            VarInt::from_u64(code.value()).expect("error code VarInt"),
346            reason,
347        );
348    }
349}
350
351impl Clone for OpenStreams {
352    fn clone(&self) -> Self {
353        Self {
354            conn: self.conn.clone(),
355            opening_bi: None,
356            opening_uni: None,
357        }
358    }
359}
360
361/// Quinn-backed bidirectional stream
362///
363/// Implements [`quic::BidiStream`] which allows the stream to be split
364/// into two structs each implementing one direction.
365pub struct BidiStream<B>
366where
367    B: Buf,
368{
369    send: SendStream<B>,
370    recv: RecvStream,
371}
372
373impl<B> quic::BidiStream<B> for BidiStream<B>
374where
375    B: Buf,
376{
377    type SendStream = SendStream<B>;
378    type RecvStream = RecvStream;
379
380    fn split(self) -> (Self::SendStream, Self::RecvStream) {
381        (self.send, self.recv)
382    }
383}
384
385impl<B: Buf> quic::RecvStream for BidiStream<B> {
386    type Buf = Bytes;
387    type Error = ReadError;
388
389    fn poll_data(
390        &mut self,
391        cx: &mut task::Context<'_>,
392    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
393        self.recv.poll_data(cx)
394    }
395
396    fn stop_sending(&mut self, error_code: u64) {
397        self.recv.stop_sending(error_code)
398    }
399
400    fn recv_id(&self) -> StreamId {
401        self.recv.recv_id()
402    }
403}
404
405impl<B> quic::SendStream<B> for BidiStream<B>
406where
407    B: Buf,
408{
409    type Error = SendStreamError;
410
411    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
412        self.send.poll_ready(cx)
413    }
414
415    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
416        self.send.poll_finish(cx)
417    }
418
419    fn reset(&mut self, reset_code: u64) {
420        self.send.reset(reset_code)
421    }
422
423    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
424        self.send.send_data(data)
425    }
426
427    fn send_id(&self) -> StreamId {
428        self.send.send_id()
429    }
430}
431impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
432where
433    B: Buf,
434{
435    fn poll_send<D: Buf>(
436        &mut self,
437        cx: &mut task::Context<'_>,
438        buf: &mut D,
439    ) -> Poll<Result<usize, Self::Error>> {
440        self.send.poll_send(cx, buf)
441    }
442}
443
444/// Quinn-backed receive stream
445///
446/// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`].
447pub struct RecvStream {
448    stream: Option<quinn::RecvStream>,
449    read_chunk_fut: ReadChunkFuture,
450}
451
452type ReadChunkFuture = ReusableBoxFuture<
453    'static,
454    (
455        quinn::RecvStream,
456        Result<Option<quinn::Chunk>, quinn::ReadError>,
457    ),
458>;
459
460impl RecvStream {
461    fn new(stream: quinn::RecvStream) -> Self {
462        Self {
463            stream: Some(stream),
464            // Should only allocate once the first time it's used
465            read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }),
466        }
467    }
468}
469
470impl quic::RecvStream for RecvStream {
471    type Buf = Bytes;
472    type Error = ReadError;
473
474    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
475    fn poll_data(
476        &mut self,
477        cx: &mut task::Context<'_>,
478    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
479        if let Some(mut stream) = self.stream.take() {
480            self.read_chunk_fut.set(async move {
481                let chunk = stream.read_chunk(usize::MAX, true).await;
482                (stream, chunk)
483            })
484        };
485
486        let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx));
487        self.stream = Some(stream);
488        Poll::Ready(Ok(chunk?.map(|c| c.bytes)))
489    }
490
491    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
492    fn stop_sending(&mut self, error_code: u64) {
493        self.stream
494            .as_mut()
495            .unwrap()
496            .stop(VarInt::from_u64(error_code).expect("invalid error_code"))
497            .ok();
498    }
499
500    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
501    fn recv_id(&self) -> StreamId {
502        self.stream
503            .as_ref()
504            .unwrap()
505            .id()
506            .0
507            .try_into()
508            .expect("invalid stream id")
509    }
510}
511
512/// The error type for [`RecvStream`]
513///
514/// Wraps errors that occur when reading from a receive stream.
515#[derive(Debug)]
516pub struct ReadError(quinn::ReadError);
517
518impl From<ReadError> for std::io::Error {
519    fn from(value: ReadError) -> Self {
520        value.0.into()
521    }
522}
523
524impl std::error::Error for ReadError {
525    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
526        self.0.source()
527    }
528}
529
530impl fmt::Display for ReadError {
531    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
532        self.0.fmt(f)
533    }
534}
535
536impl From<ReadError> for Arc<dyn Error> {
537    fn from(e: ReadError) -> Self {
538        Arc::new(e)
539    }
540}
541
542impl From<quinn::ReadError> for ReadError {
543    fn from(e: quinn::ReadError) -> Self {
544        Self(e)
545    }
546}
547
548impl Error for ReadError {
549    fn is_timeout(&self) -> bool {
550        matches!(
551            self.0,
552            quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut)
553        )
554    }
555
556    fn err_code(&self) -> Option<u64> {
557        match self.0 {
558            quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed(
559                ApplicationClose { error_code, .. },
560            )) => Some(error_code.into_inner()),
561            quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()),
562            _ => None,
563        }
564    }
565}
566
567/// Quinn-backed send stream
568///
569/// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`].
570pub struct SendStream<B: Buf> {
571    stream: quinn::SendStream,
572    writing: Option<WriteBuf<B>>,
573}
574
575impl<B> SendStream<B>
576where
577    B: Buf,
578{
579    fn new(stream: quinn::SendStream) -> SendStream<B> {
580        Self {
581            stream: stream,
582            writing: None,
583        }
584    }
585}
586
587impl<B> quic::SendStream<B> for SendStream<B>
588where
589    B: Buf,
590{
591    type Error = SendStreamError;
592
593    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
594    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
595        if let Some(ref mut data) = self.writing {
596            while data.has_remaining() {
597                let stream = Pin::new(&mut self.stream);
598                let written = ready!(stream.poll_write(cx, data.chunk()))
599                    .map_err(|err| SendStreamError::Write(err))?;
600                data.advance(written);
601            }
602        }
603        // all data is written
604        self.writing = None;
605        Poll::Ready(Ok(()))
606    }
607
608    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
609    fn poll_finish(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
610        Poll::Ready(self.stream.finish().map_err(|e| e.into()))
611    }
612
613    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
614    fn reset(&mut self, reset_code: u64) {
615        let _ = self
616            .stream
617            .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX));
618    }
619
620    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
621    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
622        if self.writing.is_some() {
623            return Err(Self::Error::NotReady);
624        }
625        self.writing = Some(data.into());
626        Ok(())
627    }
628
629    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
630    fn send_id(&self) -> StreamId {
631        self.stream.id().0.try_into().expect("invalid stream id")
632    }
633}
634
635impl<B> quic::SendStreamUnframed<B> for SendStream<B>
636where
637    B: Buf,
638{
639    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
640    fn poll_send<D: Buf>(
641        &mut self,
642        cx: &mut task::Context<'_>,
643        buf: &mut D,
644    ) -> Poll<Result<usize, Self::Error>> {
645        if self.writing.is_some() {
646            // This signifies a bug in implementation
647            panic!("poll_send called while send stream is not ready")
648        }
649
650        let s = Pin::new(&mut self.stream);
651
652        let res = ready!(s.poll_write(cx, buf.chunk()));
653        match res {
654            Ok(written) => {
655                buf.advance(written);
656                Poll::Ready(Ok(written))
657            }
658            Err(err) => Poll::Ready(Err(SendStreamError::Write(err))),
659        }
660    }
661}
662
663/// The error type for [`SendStream`]
664///
665/// Wraps errors that can happen writing to or polling a send stream.
666#[derive(Debug)]
667pub enum SendStreamError {
668    /// Errors when writing, wrapping a [`quinn::WriteError`]
669    Write(WriteError),
670    /// Error when the stream is not ready, because it is still sending
671    /// data from a previous call
672    NotReady,
673    /// Error when the stream is closed
674    StreamClosed(ClosedStream),
675}
676
677impl From<SendStreamError> for std::io::Error {
678    fn from(value: SendStreamError) -> Self {
679        match value {
680            SendStreamError::Write(err) => err.into(),
681            SendStreamError::NotReady => {
682                std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready")
683            }
684            SendStreamError::StreamClosed(err) => err.into(),
685        }
686    }
687}
688
689impl std::error::Error for SendStreamError {}
690
691impl Display for SendStreamError {
692    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
693        write!(f, "{:?}", self)
694    }
695}
696
697impl From<WriteError> for SendStreamError {
698    fn from(e: WriteError) -> Self {
699        Self::Write(e)
700    }
701}
702
703impl From<ClosedStream> for SendStreamError {
704    fn from(value: ClosedStream) -> Self {
705        Self::StreamClosed(value)
706    }
707}
708
709impl Error for SendStreamError {
710    fn is_timeout(&self) -> bool {
711        matches!(
712            self,
713            Self::Write(quinn::WriteError::ConnectionLost(
714                quinn::ConnectionError::TimedOut
715            ))
716        )
717    }
718
719    fn err_code(&self) -> Option<u64> {
720        match self {
721            Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()),
722            Self::Write(quinn::WriteError::ConnectionLost(
723                quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }),
724            )) => Some(error_code.into_inner()),
725            _ => None,
726        }
727    }
728}
729
730impl From<SendStreamError> for Arc<dyn Error> {
731    fn from(e: SendStreamError) -> Self {
732        Arc::new(e)
733    }
734}