httproxide_h3_quinn/
lib.rs

1//! QUIC Transport implementation with Quinn
2//!
3//! This module implements QUIC traits with Quinn.
4use std::{
5    convert::TryInto,
6    fmt::{self, Display},
7    pin::Pin,
8    sync::Arc,
9    task::{self, Poll},
10};
11
12use bytes::{Buf, Bytes};
13use futures_util::future::FutureExt as _;
14use futures_util::io::AsyncWrite as _;
15use futures_util::ready;
16use futures_util::stream::StreamExt as _;
17
18pub use quinn::{
19    self, crypto::Session, Endpoint, IncomingBiStreams, IncomingUniStreams, NewConnection, OpenBi,
20    OpenUni, VarInt, WriteError,
21};
22
23use h3::quic::{self, Error, StreamId, WriteBuf};
24
25pub struct Connection {
26    conn: quinn::Connection,
27    incoming_bi: IncomingBiStreams,
28    opening_bi: Option<OpenBi>,
29    incoming_uni: IncomingUniStreams,
30    opening_uni: Option<OpenUni>,
31}
32
33impl Connection {
34    pub fn new(new_conn: NewConnection) -> Self {
35        let NewConnection {
36            uni_streams,
37            bi_streams,
38            connection,
39            ..
40        } = new_conn;
41
42        Self {
43            conn: connection,
44            incoming_bi: bi_streams,
45            opening_bi: None,
46            incoming_uni: uni_streams,
47            opening_uni: None,
48        }
49    }
50}
51
52#[derive(Debug)]
53pub struct ConnectionError(quinn::ConnectionError);
54
55impl std::error::Error for ConnectionError {}
56
57impl fmt::Display for ConnectionError {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        self.0.fmt(f)
60    }
61}
62
63impl Error for ConnectionError {
64    fn is_timeout(&self) -> bool {
65        matches!(self.0, quinn::ConnectionError::TimedOut)
66    }
67
68    fn err_code(&self) -> Option<u64> {
69        match self.0 {
70            quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose {
71                error_code,
72                ..
73            }) => Some(error_code.into_inner()),
74            _ => None,
75        }
76    }
77}
78
79impl From<quinn::ConnectionError> for ConnectionError {
80    fn from(e: quinn::ConnectionError) -> Self {
81        Self(e)
82    }
83}
84
85impl<B> quic::Connection<B> for Connection
86where
87    B: Buf,
88{
89    type SendStream = SendStream<B>;
90    type RecvStream = RecvStream;
91    type BidiStream = BidiStream<B>;
92    type OpenStreams = OpenStreams;
93    type Error = ConnectionError;
94
95    fn poll_accept_bidi(
96        &mut self,
97        cx: &mut task::Context<'_>,
98    ) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> {
99        let (send, recv) = match ready!(self.incoming_bi.next().poll_unpin(cx)) {
100            Some(x) => x?,
101            None => return Poll::Ready(Ok(None)),
102        };
103        Poll::Ready(Ok(Some(Self::BidiStream {
104            send: Self::SendStream::new(send),
105            recv: Self::RecvStream::new(recv),
106        })))
107    }
108
109    fn poll_accept_recv(
110        &mut self,
111        cx: &mut task::Context<'_>,
112    ) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> {
113        let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) {
114            Some(x) => x?,
115            None => return Poll::Ready(Ok(None)),
116        };
117        Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
118    }
119
120    fn poll_open_bidi(
121        &mut self,
122        cx: &mut task::Context<'_>,
123    ) -> Poll<Result<Self::BidiStream, Self::Error>> {
124        if self.opening_bi.is_none() {
125            self.opening_bi = Some(self.conn.open_bi());
126        }
127
128        let (send, recv) = ready!(self.opening_bi.as_mut().unwrap().poll_unpin(cx))?;
129        Poll::Ready(Ok(Self::BidiStream {
130            send: Self::SendStream::new(send),
131            recv: Self::RecvStream::new(recv),
132        }))
133    }
134
135    fn poll_open_send(
136        &mut self,
137        cx: &mut task::Context<'_>,
138    ) -> Poll<Result<Self::SendStream, Self::Error>> {
139        if self.opening_uni.is_none() {
140            self.opening_uni = Some(self.conn.open_uni());
141        }
142
143        let send = ready!(self.opening_uni.as_mut().unwrap().poll_unpin(cx))?;
144        Poll::Ready(Ok(Self::SendStream::new(send)))
145    }
146
147    fn opener(&self) -> Self::OpenStreams {
148        OpenStreams {
149            conn: self.conn.clone(),
150            opening_bi: None,
151            opening_uni: None,
152        }
153    }
154
155    fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
156        self.conn.close(
157            VarInt::from_u64(code.value()).expect("error code VarInt"),
158            reason,
159        );
160    }
161}
162
163pub struct OpenStreams {
164    conn: quinn::Connection,
165    opening_bi: Option<OpenBi>,
166    opening_uni: Option<OpenUni>,
167}
168
169impl<B> quic::OpenStreams<B> for OpenStreams
170where
171    B: Buf,
172{
173    type RecvStream = RecvStream;
174    type SendStream = SendStream<B>;
175    type BidiStream = BidiStream<B>;
176    type Error = ConnectionError;
177
178    fn poll_open_bidi(
179        &mut self,
180        cx: &mut task::Context<'_>,
181    ) -> Poll<Result<Self::BidiStream, Self::Error>> {
182        if self.opening_bi.is_none() {
183            self.opening_bi = Some(self.conn.open_bi());
184        }
185
186        let (send, recv) = ready!(self.opening_bi.as_mut().unwrap().poll_unpin(cx))?;
187        Poll::Ready(Ok(Self::BidiStream {
188            send: Self::SendStream::new(send),
189            recv: Self::RecvStream::new(recv),
190        }))
191    }
192
193    fn poll_open_send(
194        &mut self,
195        cx: &mut task::Context<'_>,
196    ) -> Poll<Result<Self::SendStream, Self::Error>> {
197        if self.opening_uni.is_none() {
198            self.opening_uni = Some(self.conn.open_uni());
199        }
200
201        let send = ready!(self.opening_uni.as_mut().unwrap().poll_unpin(cx))?;
202        Poll::Ready(Ok(Self::SendStream::new(send)))
203    }
204
205    fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
206        self.conn.close(
207            VarInt::from_u64(code.value()).expect("error code VarInt"),
208            reason,
209        );
210    }
211}
212
213impl Clone for OpenStreams {
214    fn clone(&self) -> Self {
215        Self {
216            conn: self.conn.clone(),
217            opening_bi: None,
218            opening_uni: None,
219        }
220    }
221}
222
223pub struct BidiStream<B>
224where
225    B: Buf,
226{
227    send: SendStream<B>,
228    recv: RecvStream,
229}
230
231impl<B> quic::BidiStream<B> for BidiStream<B>
232where
233    B: Buf,
234{
235    type SendStream = SendStream<B>;
236    type RecvStream = RecvStream;
237
238    fn split(self) -> (Self::SendStream, Self::RecvStream) {
239        (self.send, self.recv)
240    }
241}
242
243impl<B> quic::RecvStream for BidiStream<B>
244where
245    B: Buf,
246{
247    type Buf = Bytes;
248    type Error = ReadError;
249
250    fn poll_data(
251        &mut self,
252        cx: &mut task::Context<'_>,
253    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
254        self.recv.poll_data(cx)
255    }
256
257    fn stop_sending(&mut self, error_code: u64) {
258        self.recv.stop_sending(error_code)
259    }
260}
261
262impl<B> quic::SendStream<B> for BidiStream<B>
263where
264    B: Buf,
265{
266    type Error = SendStreamError;
267
268    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
269        self.send.poll_ready(cx)
270    }
271
272    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
273        self.send.poll_finish(cx)
274    }
275
276    fn reset(&mut self, reset_code: u64) {
277        self.send.reset(reset_code)
278    }
279
280    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
281        self.send.send_data(data)
282    }
283
284    fn id(&self) -> StreamId {
285        self.send.id()
286    }
287}
288
289pub struct RecvStream {
290    stream: quinn::RecvStream,
291}
292
293impl RecvStream {
294    fn new(stream: quinn::RecvStream) -> Self {
295        Self { stream }
296    }
297}
298
299impl quic::RecvStream for RecvStream {
300    type Buf = Bytes;
301    type Error = ReadError;
302
303    fn poll_data(
304        &mut self,
305        cx: &mut task::Context<'_>,
306    ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
307        Poll::Ready(Ok(ready!(self
308            .stream
309            .read_chunk(usize::MAX, true)
310            .poll_unpin(cx))?
311        .map(|c| (c.bytes))))
312    }
313
314    fn stop_sending(&mut self, error_code: u64) {
315        let _ = self
316            .stream
317            .stop(VarInt::from_u64(error_code).expect("invalid error_code"));
318    }
319}
320
321#[derive(Debug)]
322pub struct ReadError(quinn::ReadError);
323
324impl std::error::Error for ReadError {}
325
326impl fmt::Display for ReadError {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        self.0.fmt(f)
329    }
330}
331
332impl From<ReadError> for Arc<dyn Error> {
333    fn from(e: ReadError) -> Self {
334        Arc::new(e)
335    }
336}
337
338impl From<quinn::ReadError> for ReadError {
339    fn from(e: quinn::ReadError) -> Self {
340        Self(e)
341    }
342}
343
344impl Error for ReadError {
345    fn is_timeout(&self) -> bool {
346        matches!(
347            self.0,
348            quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut)
349        )
350    }
351
352    fn err_code(&self) -> Option<u64> {
353        match self.0 {
354            quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed(
355                quinn_proto::ApplicationClose { error_code, .. },
356            )) => Some(error_code.into_inner()),
357            quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()),
358            _ => None,
359        }
360    }
361}
362
363pub struct SendStream<B: Buf> {
364    stream: quinn::SendStream,
365    writing: Option<WriteBuf<B>>,
366}
367
368impl<B> SendStream<B>
369where
370    B: Buf,
371{
372    fn new(stream: quinn::SendStream) -> SendStream<B> {
373        Self {
374            stream,
375            writing: None,
376        }
377    }
378}
379
380impl<B> quic::SendStream<B> for SendStream<B>
381where
382    B: Buf,
383{
384    type Error = SendStreamError;
385
386    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
387        if let Some(ref mut data) = self.writing {
388            while data.has_remaining() {
389                match ready!(Pin::new(&mut self.stream).poll_write(cx, data.chunk())) {
390                    Ok(cnt) => data.advance(cnt),
391                    Err(err) => {
392                        // We are forced to use AsyncWrite for now because we cannot store
393                        // the result of a call to:
394                        // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, S>.
395                        //
396                        // This is why we have to unpack the error from io::Error below. This should not
397                        // panic as long as quinn's AsyncWrite impl doesn't change.
398                        return Poll::Ready(Err(SendStreamError::Write(
399                            err.into_inner()
400                                .expect("write stream returned an empty error")
401                                .downcast_ref::<WriteError>()
402                                .expect(
403                                    "write stream returned an error which type is not WriteError",
404                                )
405                                .clone(),
406                        )));
407                    }
408                }
409            }
410        }
411        self.writing = None;
412        Poll::Ready(Ok(()))
413    }
414
415    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
416        self.stream.finish().poll_unpin(cx).map_err(Into::into)
417    }
418
419    fn reset(&mut self, reset_code: u64) {
420        let _ = self
421            .stream
422            .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX));
423    }
424
425    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
426        if self.writing.is_some() {
427            return Err(Self::Error::NotReady);
428        }
429        self.writing = Some(data.into());
430        Ok(())
431    }
432
433    fn id(&self) -> StreamId {
434        self.stream.id().0.try_into().expect("invalid stream id")
435    }
436}
437
438#[derive(Debug)]
439pub enum SendStreamError {
440    Write(WriteError),
441    NotReady,
442}
443
444impl std::error::Error for SendStreamError {}
445
446impl Display for SendStreamError {
447    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448        write!(f, "{:?}", self)
449    }
450}
451
452impl From<WriteError> for SendStreamError {
453    fn from(e: WriteError) -> Self {
454        Self::Write(e)
455    }
456}
457
458impl Error for SendStreamError {
459    fn is_timeout(&self) -> bool {
460        match self {
461            Self::Write(quinn::WriteError::ConnectionLost(quinn::ConnectionError::TimedOut)) => {
462                true
463            }
464            _ => false,
465        }
466    }
467
468    fn err_code(&self) -> Option<u64> {
469        match self {
470            Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()),
471            Self::Write(quinn::WriteError::ConnectionLost(
472                quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose {
473                    error_code,
474                    ..
475                }),
476            )) => Some(error_code.into_inner()),
477            _ => None,
478        }
479    }
480}
481
482impl From<SendStreamError> for Arc<dyn Error> {
483    fn from(e: SendStreamError) -> Self {
484        Arc::new(e)
485    }
486}