Skip to main content

h3_msquic_async/
lib.rs

1/// This file is based on the `lib.rs` from the `h3-quinn` crate.
2use bytes::Buf;
3use futures::{
4    future::poll_fn,
5    ready,
6    stream::{self},
7    Stream, StreamExt,
8};
9use h3::{
10    error::Code,
11    quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, StreamId, WriteBuf},
12};
13pub use msquic_async;
14pub use msquic_async::msquic;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{self, Poll};
18use tokio_util::sync::ReusableBoxFuture;
19#[cfg(feature = "tracing")]
20use tracing::instrument;
21
22#[cfg(feature = "datagram")]
23pub mod datagram;
24
25/// BoxStream with Sync trait
26type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Sync + Send + 'a>>;
27
28/// A QUIC connection backed by msquic-async
29///
30/// Implements a [`quic::Connection`] backed by a [`msquic_async::Connection`].
31pub struct Connection {
32    conn: msquic_async::Connection,
33    incoming: BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
34    opening: Option<
35        BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
36    >,
37    incoming_uni:
38        BoxStreamSync<'static, Result<msquic_async::ReadStream, msquic_async::StreamStartError>>,
39    opening_uni: Option<
40        BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
41    >,
42}
43
44impl Connection {
45    /// Create a [`Connection`] from a [`msquic_async::Connection`]
46    pub fn new(conn: msquic_async::Connection) -> Self {
47        Self {
48            conn: conn.clone(),
49            incoming: Box::pin(stream::unfold(conn.clone(), |conn| async {
50                Some((conn.accept_inbound_stream().await, conn))
51            })),
52            opening: None,
53            incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
54                Some((conn.accept_inbound_uni_stream().await, conn))
55            })),
56            opening_uni: None,
57        }
58    }
59}
60
61fn convert_connection_error(e: msquic_async::ConnectionError) -> ConnectionErrorIncoming {
62    match e {
63        msquic_async::ConnectionError::ShutdownByPeer(error_code) => {
64            ConnectionErrorIncoming::ApplicationClose { error_code }
65        }
66        msquic_async::ConnectionError::ShutdownByTransport(status, code) => {
67            if matches!(
68                status.try_as_status_code().unwrap(),
69                msquic::StatusCode::QUIC_STATUS_CONNECTION_TIMEOUT
70                    | msquic::StatusCode::QUIC_STATUS_CONNECTION_IDLE
71            ) {
72                ConnectionErrorIncoming::Timeout
73            } else {
74                ConnectionErrorIncoming::Undefined(Arc::new(
75                    msquic_async::ConnectionError::ShutdownByTransport(status, code),
76                ))
77            }
78        }
79
80        error @ msquic_async::ConnectionError::ShutdownByLocal
81        | error @ msquic_async::ConnectionError::ConnectionClosed
82        | error @ msquic_async::ConnectionError::SslKeyLogFileAlreadySet
83        | error @ msquic_async::ConnectionError::OtherError(_) => {
84            ConnectionErrorIncoming::Undefined(Arc::new(error))
85        }
86    }
87}
88
89fn convert_start_error(e: msquic_async::StreamStartError) -> ConnectionErrorIncoming {
90    match e {
91        msquic_async::StreamStartError::ConnectionLost(error) => convert_connection_error(error),
92
93        error @ msquic_async::StreamStartError::ConnectionNotStarted
94        | error @ msquic_async::StreamStartError::LimitReached
95        | error @ msquic_async::StreamStartError::OtherError(_) => {
96            ConnectionErrorIncoming::Undefined(Arc::new(error))
97        }
98    }
99}
100
101impl<B> quic::Connection<B> for Connection
102where
103    B: Buf,
104{
105    type RecvStream = RecvStream;
106    type OpenStreams = OpenStreams;
107
108    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
109    fn poll_accept_bidi(
110        &mut self,
111        cx: &mut task::Context<'_>,
112    ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
113        let stream = ready!(self.incoming.poll_next_unpin(cx))
114            .expect("self.incoming BoxStream never returns None")
115            .map_err(convert_start_error)?;
116        if let (Some(read), Some(write)) = stream.split() {
117            Poll::Ready(Ok(Self::BidiStream {
118                send: Self::SendStream::new(write),
119                recv: RecvStream::new(read),
120            }))
121        } else {
122            unreachable!("msquic-async should always return a bidirectional stream");
123        }
124    }
125
126    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
127    fn poll_accept_recv(
128        &mut self,
129        cx: &mut task::Context<'_>,
130    ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
131        let recv = ready!(self.incoming_uni.poll_next_unpin(cx))
132            .expect("self.incoming_uni BoxStream never returns None")
133            .map_err(convert_start_error)?;
134        Poll::Ready(Ok(Self::RecvStream::new(recv)))
135    }
136
137    fn opener(&self) -> Self::OpenStreams {
138        OpenStreams {
139            conn: self.conn.clone(),
140            opening: None,
141            opening_uni: None,
142        }
143    }
144}
145
146impl<B> quic::OpenStreams<B> for Connection
147where
148    B: Buf,
149{
150    type SendStream = SendStream<B>;
151    type BidiStream = BidiStream<B>;
152
153    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
154    fn poll_open_bidi(
155        &mut self,
156        cx: &mut task::Context<'_>,
157    ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
158        if self.opening.is_none() {
159            self.opening = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
160                Some((
161                    conn.clone()
162                        .open_outbound_stream(msquic_async::StreamType::Bidirectional, false)
163                        .await,
164                    conn,
165                ))
166            })));
167        }
168
169        let stream = ready!(self.opening.as_mut().unwrap().poll_next_unpin(cx))
170            .unwrap()
171            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
172                connection_error: convert_start_error(e),
173            })?;
174        if let (Some(read), Some(write)) = stream.split() {
175            Poll::Ready(Ok(Self::BidiStream {
176                send: Self::SendStream::new(write),
177                recv: RecvStream::new(read),
178            }))
179        } else {
180            unreachable!("msquic-async should always return a bidirectional stream");
181        }
182    }
183
184    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
185    fn poll_open_send(
186        &mut self,
187        cx: &mut task::Context<'_>,
188    ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
189        if self.opening_uni.is_none() {
190            self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
191                Some((
192                    conn.open_outbound_stream(msquic_async::StreamType::Unidirectional, false)
193                        .await,
194                    conn,
195                ))
196            })));
197        }
198
199        let stream = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx))
200            .unwrap()
201            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
202                connection_error: convert_start_error(e),
203            })?;
204        if let (None, Some(write)) = stream.split() {
205            Poll::Ready(Ok(Self::SendStream::new(write)))
206        } else {
207            unreachable!("msquic-async should always return a unidirectional stream");
208        }
209    }
210
211    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
212    fn close(&mut self, code: Code, _reason: &[u8]) {
213        self.conn.shutdown(code.value()).ok();
214    }
215}
216
217/// Stream opener backed by a msquic connection
218///
219/// Implements [`quic::OpenStreams`] using [`msquic_async::Connection`],
220/// [`msquic_async::OpenOutboundStream`].
221pub struct OpenStreams {
222    conn: msquic_async::Connection,
223    opening: Option<
224        BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
225    >,
226    opening_uni: Option<
227        BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
228    >,
229}
230
231impl<B> quic::OpenStreams<B> for OpenStreams
232where
233    B: Buf,
234{
235    type SendStream = SendStream<B>;
236    type BidiStream = BidiStream<B>;
237
238    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
239    fn poll_open_bidi(
240        &mut self,
241        cx: &mut task::Context<'_>,
242    ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
243        if self.opening.is_none() {
244            self.opening = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
245                Some((
246                    conn.open_outbound_stream(msquic_async::StreamType::Bidirectional, false)
247                        .await,
248                    conn,
249                ))
250            })));
251        }
252
253        let stream = ready!(self.opening.as_mut().unwrap().poll_next_unpin(cx))
254            .unwrap()
255            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
256                connection_error: convert_start_error(e),
257            })?;
258        if let (Some(read), Some(write)) = stream.split() {
259            Poll::Ready(Ok(Self::BidiStream {
260                send: Self::SendStream::new(write),
261                recv: RecvStream::new(read),
262            }))
263        } else {
264            unreachable!("msquic-async should always return a bidirectional stream");
265        }
266    }
267
268    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
269    fn poll_open_send(
270        &mut self,
271        cx: &mut task::Context<'_>,
272    ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
273        if self.opening_uni.is_none() {
274            self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
275                Some((
276                    conn.open_outbound_stream(msquic_async::StreamType::Unidirectional, false)
277                        .await,
278                    conn,
279                ))
280            })));
281        }
282
283        let stream = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx))
284            .unwrap()
285            .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
286                connection_error: convert_start_error(e),
287            })?;
288        if let (None, Some(write)) = stream.split() {
289            Poll::Ready(Ok(Self::SendStream::new(write)))
290        } else {
291            unreachable!("msquic-async should always return a unidirectional stream");
292        }
293    }
294
295    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
296    fn close(&mut self, code: Code, _reason: &[u8]) {
297        self.conn.shutdown(code.value()).ok();
298    }
299}
300
301impl Clone for OpenStreams {
302    fn clone(&self) -> Self {
303        Self {
304            conn: self.conn.clone(),
305            opening: None,
306            opening_uni: None,
307        }
308    }
309}
310
311/// msquic-backed bidirectional stream
312///
313/// Implements [`quic::BidiStream`] which allows the stream to be split
314/// into two structs each implementing one direction.
315pub struct BidiStream<B>
316where
317    B: Buf,
318{
319    send: SendStream<B>,
320    recv: RecvStream,
321}
322
323impl<B> quic::BidiStream<B> for BidiStream<B>
324where
325    B: Buf,
326{
327    type SendStream = SendStream<B>;
328    type RecvStream = RecvStream;
329
330    fn split(self) -> (Self::SendStream, Self::RecvStream) {
331        (self.send, self.recv)
332    }
333}
334
335impl<B: Buf> quic::RecvStream for BidiStream<B> {
336    type Buf = msquic_async::StreamRecvBuffer;
337
338    fn poll_data(
339        &mut self,
340        cx: &mut task::Context<'_>,
341    ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
342        self.recv.poll_data(cx)
343    }
344
345    fn stop_sending(&mut self, error_code: u64) {
346        self.recv.stop_sending(error_code)
347    }
348
349    fn recv_id(&self) -> StreamId {
350        self.recv.recv_id()
351    }
352}
353
354impl<B> quic::SendStream<B> for BidiStream<B>
355where
356    B: Buf,
357{
358    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
359        self.send.poll_ready(cx)
360    }
361
362    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
363        self.send.poll_finish(cx)
364    }
365
366    fn reset(&mut self, reset_code: u64) {
367        self.send.reset(reset_code)
368    }
369
370    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
371        self.send.send_data(data)
372    }
373
374    fn send_id(&self) -> StreamId {
375        self.send.send_id()
376    }
377}
378impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
379where
380    B: Buf,
381{
382    fn poll_send<D: Buf>(
383        &mut self,
384        cx: &mut task::Context<'_>,
385        buf: &mut D,
386    ) -> Poll<Result<usize, StreamErrorIncoming>> {
387        self.send.poll_send(cx, buf)
388    }
389}
390
391/// msquic-backed receive stream
392///
393/// Implements a [`quic::RecvStream`] backed by a [`msquic_async::ReadStream`].
394pub struct RecvStream {
395    stream: Option<msquic_async::ReadStream>,
396    read_chunk_fut: ReadChunkFuture,
397}
398
399type ReadChunkFuture = ReusableBoxFuture<
400    'static,
401    (
402        msquic_async::ReadStream,
403        Result<Option<msquic_async::StreamRecvBuffer>, msquic_async::ReadError>,
404    ),
405>;
406
407impl RecvStream {
408    fn new(stream: msquic_async::ReadStream) -> Self {
409        Self {
410            stream: Some(stream),
411            // Should only allocate once the first time it's used
412            read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }),
413        }
414    }
415}
416
417impl quic::RecvStream for RecvStream {
418    type Buf = msquic_async::StreamRecvBuffer;
419
420    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
421    fn poll_data(
422        &mut self,
423        cx: &mut task::Context<'_>,
424    ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
425        if let Some(stream) = self.stream.take() {
426            self.read_chunk_fut.set(async move {
427                let chunk = poll_fn(|cx| stream.poll_read_chunk(cx)).await;
428                (stream, chunk)
429            })
430        };
431
432        let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx));
433        self.stream = Some(stream);
434        let chunk = chunk
435            .map_err(convert_read_error_to_stream_error)?
436            .filter(|x| !x.is_empty() || !x.fin());
437        Poll::Ready(Ok(chunk))
438    }
439
440    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
441    fn stop_sending(&mut self, error_code: u64) {
442        self.stream.as_mut().unwrap().abort_read(error_code).ok();
443    }
444
445    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
446    fn recv_id(&self) -> StreamId {
447        self.stream
448            .as_ref()
449            .unwrap()
450            .id()
451            .expect("id")
452            .try_into()
453            .expect("invalid stream id")
454    }
455}
456
457fn convert_read_error_to_stream_error(error: msquic_async::ReadError) -> StreamErrorIncoming {
458    match error {
459        msquic_async::ReadError::Reset(error_code) => {
460            StreamErrorIncoming::StreamTerminated { error_code }
461        }
462        msquic_async::ReadError::ConnectionLost(connection_error) => {
463            StreamErrorIncoming::ConnectionErrorIncoming {
464                connection_error: convert_connection_error(connection_error),
465            }
466        }
467        error @ msquic_async::ReadError::Closed
468        | error @ msquic_async::ReadError::OtherError(_) => {
469            StreamErrorIncoming::Unknown(Box::new(error))
470        }
471    }
472}
473
474/// msquic-async-backed send stream
475///
476/// Implements a [`quic::SendStream`] backed by a [`msquic_async::WriteStream`].
477pub struct SendStream<B: Buf> {
478    stream: Option<msquic_async::WriteStream>,
479    writing: Option<WriteBuf<B>>,
480    write_fut: WriteFuture,
481}
482
483type WriteFuture = ReusableBoxFuture<
484    'static,
485    (
486        msquic_async::WriteStream,
487        Result<usize, msquic_async::WriteError>,
488    ),
489>;
490
491impl<B> SendStream<B>
492where
493    B: Buf,
494{
495    fn new(stream: msquic_async::WriteStream) -> SendStream<B> {
496        Self {
497            stream: Some(stream),
498            writing: None,
499            write_fut: ReusableBoxFuture::new(async { unreachable!() }),
500        }
501    }
502}
503
504impl<B> quic::SendStream<B> for SendStream<B>
505where
506    B: Buf,
507{
508    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
509    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
510        if let Some(ref mut data) = self.writing {
511            while data.has_remaining() {
512                if let Some(mut stream) = self.stream.take() {
513                    let chunk = data.chunk().to_owned(); // FIXME - avoid copy
514                    self.write_fut.set(async move {
515                        let ret = poll_fn(|cx| stream.poll_write(cx, &chunk, false)).await;
516                        (stream, ret)
517                    });
518                }
519
520                let (stream, res) = ready!(self.write_fut.poll(cx));
521                self.stream = Some(stream);
522                match res {
523                    Ok(cnt) => data.advance(cnt),
524                    Err(err) => {
525                        return Poll::Ready(Err(convert_write_error_to_stream_error(err)));
526                    }
527                }
528            }
529        }
530        self.writing = None;
531        Poll::Ready(Ok(()))
532    }
533
534    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
535    fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
536        self.stream
537            .as_mut()
538            .unwrap()
539            .poll_finish_write(cx)
540            .map_err(convert_write_error_to_stream_error)
541    }
542
543    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
544    fn reset(&mut self, reset_code: u64) {
545        let _ = self.stream.as_mut().unwrap().abort_write(reset_code);
546    }
547
548    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
549    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
550        if self.writing.is_some() {
551            // This can only happen if the traits are misused by h3 itself
552            // If this happens log an error and close the connection with H3_INTERNAL_ERROR
553
554            #[cfg(feature = "tracing")]
555            tracing::error!("send_data called while send stream is not ready");
556            return Err(StreamErrorIncoming::ConnectionErrorIncoming {
557                connection_error: ConnectionErrorIncoming::InternalError(
558                    "internal error in the http stack".to_string(),
559                ),
560            });
561        }
562        self.writing = Some(data.into());
563        Ok(())
564    }
565
566    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
567    fn send_id(&self) -> StreamId {
568        self.stream
569            .as_ref()
570            .unwrap()
571            .id()
572            .expect("id")
573            .try_into()
574            .expect("invalid stream id")
575    }
576}
577
578impl<B> quic::SendStreamUnframed<B> for SendStream<B>
579where
580    B: Buf,
581{
582    #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
583    fn poll_send<D: Buf>(
584        &mut self,
585        cx: &mut task::Context<'_>,
586        buf: &mut D,
587    ) -> Poll<Result<usize, StreamErrorIncoming>> {
588        if self.writing.is_some() {
589            // This signifies a bug in implementation
590            panic!("poll_send called while send stream is not ready")
591        }
592
593        let res = ready!(self
594            .stream
595            .as_mut()
596            .unwrap()
597            .poll_write(cx, buf.chunk(), false));
598        match res {
599            Ok(written) => {
600                buf.advance(written);
601                Poll::Ready(Ok(written))
602            }
603            Err(err) => Poll::Ready(Err(convert_write_error_to_stream_error(err))),
604        }
605    }
606}
607
608fn convert_write_error_to_stream_error(error: msquic_async::WriteError) -> StreamErrorIncoming {
609    match error {
610        msquic_async::WriteError::Stopped(error_code) => {
611            StreamErrorIncoming::StreamTerminated { error_code }
612        }
613        msquic_async::WriteError::ConnectionLost(connection_error) => {
614            StreamErrorIncoming::ConnectionErrorIncoming {
615                connection_error: convert_connection_error(connection_error),
616            }
617        }
618        error @ msquic_async::WriteError::Closed
619        | error @ msquic_async::WriteError::Finished
620        | error @ msquic_async::WriteError::OtherError(_) => {
621            StreamErrorIncoming::Unknown(Box::new(error))
622        }
623    }
624}