1use bytes::Buf;
3use futures::{
4 future::poll_fn,
5 ready,
6 stream::{self},
7 Stream, StreamExt,
8};
9use h3::{
10 error::Code,
11 quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, StreamId, WriteBuf},
12};
13pub use msquic_async;
14pub use msquic_async::msquic;
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{self, Poll};
18use tokio_util::sync::ReusableBoxFuture;
19#[cfg(feature = "tracing")]
20use tracing::instrument;
21
22#[cfg(feature = "datagram")]
23pub mod datagram;
24
25type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Sync + Send + 'a>>;
27
28pub struct Connection {
32 conn: msquic_async::Connection,
33 incoming: BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
34 opening: Option<
35 BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
36 >,
37 incoming_uni:
38 BoxStreamSync<'static, Result<msquic_async::ReadStream, msquic_async::StreamStartError>>,
39 opening_uni: Option<
40 BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
41 >,
42}
43
44impl Connection {
45 pub fn new(conn: msquic_async::Connection) -> Self {
47 Self {
48 conn: conn.clone(),
49 incoming: Box::pin(stream::unfold(conn.clone(), |conn| async {
50 Some((conn.accept_inbound_stream().await, conn))
51 })),
52 opening: None,
53 incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
54 Some((conn.accept_inbound_uni_stream().await, conn))
55 })),
56 opening_uni: None,
57 }
58 }
59}
60
61fn convert_connection_error(e: msquic_async::ConnectionError) -> ConnectionErrorIncoming {
62 match e {
63 msquic_async::ConnectionError::ShutdownByPeer(error_code) => {
64 ConnectionErrorIncoming::ApplicationClose { error_code }
65 }
66 msquic_async::ConnectionError::ShutdownByTransport(status, code) => {
67 if matches!(
68 status.try_as_status_code().unwrap(),
69 msquic::StatusCode::QUIC_STATUS_CONNECTION_TIMEOUT
70 | msquic::StatusCode::QUIC_STATUS_CONNECTION_IDLE
71 ) {
72 ConnectionErrorIncoming::Timeout
73 } else {
74 ConnectionErrorIncoming::Undefined(Arc::new(
75 msquic_async::ConnectionError::ShutdownByTransport(status, code),
76 ))
77 }
78 }
79
80 error @ msquic_async::ConnectionError::ShutdownByLocal
81 | error @ msquic_async::ConnectionError::ConnectionClosed
82 | error @ msquic_async::ConnectionError::SslKeyLogFileAlreadySet
83 | error @ msquic_async::ConnectionError::OtherError(_) => {
84 ConnectionErrorIncoming::Undefined(Arc::new(error))
85 }
86 }
87}
88
89fn convert_start_error(e: msquic_async::StreamStartError) -> ConnectionErrorIncoming {
90 match e {
91 msquic_async::StreamStartError::ConnectionLost(error) => convert_connection_error(error),
92
93 error @ msquic_async::StreamStartError::ConnectionNotStarted
94 | error @ msquic_async::StreamStartError::LimitReached
95 | error @ msquic_async::StreamStartError::OtherError(_) => {
96 ConnectionErrorIncoming::Undefined(Arc::new(error))
97 }
98 }
99}
100
101impl<B> quic::Connection<B> for Connection
102where
103 B: Buf,
104{
105 type RecvStream = RecvStream;
106 type OpenStreams = OpenStreams;
107
108 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
109 fn poll_accept_bidi(
110 &mut self,
111 cx: &mut task::Context<'_>,
112 ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
113 let stream = ready!(self.incoming.poll_next_unpin(cx))
114 .expect("self.incoming BoxStream never returns None")
115 .map_err(convert_start_error)?;
116 if let (Some(read), Some(write)) = stream.split() {
117 Poll::Ready(Ok(Self::BidiStream {
118 send: Self::SendStream::new(write),
119 recv: RecvStream::new(read),
120 }))
121 } else {
122 unreachable!("msquic-async should always return a bidirectional stream");
123 }
124 }
125
126 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
127 fn poll_accept_recv(
128 &mut self,
129 cx: &mut task::Context<'_>,
130 ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
131 let recv = ready!(self.incoming_uni.poll_next_unpin(cx))
132 .expect("self.incoming_uni BoxStream never returns None")
133 .map_err(convert_start_error)?;
134 Poll::Ready(Ok(Self::RecvStream::new(recv)))
135 }
136
137 fn opener(&self) -> Self::OpenStreams {
138 OpenStreams {
139 conn: self.conn.clone(),
140 opening: None,
141 opening_uni: None,
142 }
143 }
144}
145
146impl<B> quic::OpenStreams<B> for Connection
147where
148 B: Buf,
149{
150 type SendStream = SendStream<B>;
151 type BidiStream = BidiStream<B>;
152
153 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
154 fn poll_open_bidi(
155 &mut self,
156 cx: &mut task::Context<'_>,
157 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
158 if self.opening.is_none() {
159 self.opening = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
160 Some((
161 conn.clone()
162 .open_outbound_stream(msquic_async::StreamType::Bidirectional, false)
163 .await,
164 conn,
165 ))
166 })));
167 }
168
169 let stream = ready!(self.opening.as_mut().unwrap().poll_next_unpin(cx))
170 .unwrap()
171 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
172 connection_error: convert_start_error(e),
173 })?;
174 if let (Some(read), Some(write)) = stream.split() {
175 Poll::Ready(Ok(Self::BidiStream {
176 send: Self::SendStream::new(write),
177 recv: RecvStream::new(read),
178 }))
179 } else {
180 unreachable!("msquic-async should always return a bidirectional stream");
181 }
182 }
183
184 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
185 fn poll_open_send(
186 &mut self,
187 cx: &mut task::Context<'_>,
188 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
189 if self.opening_uni.is_none() {
190 self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
191 Some((
192 conn.open_outbound_stream(msquic_async::StreamType::Unidirectional, false)
193 .await,
194 conn,
195 ))
196 })));
197 }
198
199 let stream = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx))
200 .unwrap()
201 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
202 connection_error: convert_start_error(e),
203 })?;
204 if let (None, Some(write)) = stream.split() {
205 Poll::Ready(Ok(Self::SendStream::new(write)))
206 } else {
207 unreachable!("msquic-async should always return a unidirectional stream");
208 }
209 }
210
211 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
212 fn close(&mut self, code: Code, _reason: &[u8]) {
213 self.conn.shutdown(code.value()).ok();
214 }
215}
216
217pub struct OpenStreams {
222 conn: msquic_async::Connection,
223 opening: Option<
224 BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
225 >,
226 opening_uni: Option<
227 BoxStreamSync<'static, Result<msquic_async::Stream, msquic_async::StreamStartError>>,
228 >,
229}
230
231impl<B> quic::OpenStreams<B> for OpenStreams
232where
233 B: Buf,
234{
235 type SendStream = SendStream<B>;
236 type BidiStream = BidiStream<B>;
237
238 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
239 fn poll_open_bidi(
240 &mut self,
241 cx: &mut task::Context<'_>,
242 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
243 if self.opening.is_none() {
244 self.opening = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
245 Some((
246 conn.open_outbound_stream(msquic_async::StreamType::Bidirectional, false)
247 .await,
248 conn,
249 ))
250 })));
251 }
252
253 let stream = ready!(self.opening.as_mut().unwrap().poll_next_unpin(cx))
254 .unwrap()
255 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
256 connection_error: convert_start_error(e),
257 })?;
258 if let (Some(read), Some(write)) = stream.split() {
259 Poll::Ready(Ok(Self::BidiStream {
260 send: Self::SendStream::new(write),
261 recv: RecvStream::new(read),
262 }))
263 } else {
264 unreachable!("msquic-async should always return a bidirectional stream");
265 }
266 }
267
268 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
269 fn poll_open_send(
270 &mut self,
271 cx: &mut task::Context<'_>,
272 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
273 if self.opening_uni.is_none() {
274 self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
275 Some((
276 conn.open_outbound_stream(msquic_async::StreamType::Unidirectional, false)
277 .await,
278 conn,
279 ))
280 })));
281 }
282
283 let stream = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx))
284 .unwrap()
285 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
286 connection_error: convert_start_error(e),
287 })?;
288 if let (None, Some(write)) = stream.split() {
289 Poll::Ready(Ok(Self::SendStream::new(write)))
290 } else {
291 unreachable!("msquic-async should always return a unidirectional stream");
292 }
293 }
294
295 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
296 fn close(&mut self, code: Code, _reason: &[u8]) {
297 self.conn.shutdown(code.value()).ok();
298 }
299}
300
301impl Clone for OpenStreams {
302 fn clone(&self) -> Self {
303 Self {
304 conn: self.conn.clone(),
305 opening: None,
306 opening_uni: None,
307 }
308 }
309}
310
311pub struct BidiStream<B>
316where
317 B: Buf,
318{
319 send: SendStream<B>,
320 recv: RecvStream,
321}
322
323impl<B> quic::BidiStream<B> for BidiStream<B>
324where
325 B: Buf,
326{
327 type SendStream = SendStream<B>;
328 type RecvStream = RecvStream;
329
330 fn split(self) -> (Self::SendStream, Self::RecvStream) {
331 (self.send, self.recv)
332 }
333}
334
335impl<B: Buf> quic::RecvStream for BidiStream<B> {
336 type Buf = msquic_async::StreamRecvBuffer;
337
338 fn poll_data(
339 &mut self,
340 cx: &mut task::Context<'_>,
341 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
342 self.recv.poll_data(cx)
343 }
344
345 fn stop_sending(&mut self, error_code: u64) {
346 self.recv.stop_sending(error_code)
347 }
348
349 fn recv_id(&self) -> StreamId {
350 self.recv.recv_id()
351 }
352}
353
354impl<B> quic::SendStream<B> for BidiStream<B>
355where
356 B: Buf,
357{
358 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
359 self.send.poll_ready(cx)
360 }
361
362 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
363 self.send.poll_finish(cx)
364 }
365
366 fn reset(&mut self, reset_code: u64) {
367 self.send.reset(reset_code)
368 }
369
370 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
371 self.send.send_data(data)
372 }
373
374 fn send_id(&self) -> StreamId {
375 self.send.send_id()
376 }
377}
378impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
379where
380 B: Buf,
381{
382 fn poll_send<D: Buf>(
383 &mut self,
384 cx: &mut task::Context<'_>,
385 buf: &mut D,
386 ) -> Poll<Result<usize, StreamErrorIncoming>> {
387 self.send.poll_send(cx, buf)
388 }
389}
390
391pub struct RecvStream {
395 stream: Option<msquic_async::ReadStream>,
396 read_chunk_fut: ReadChunkFuture,
397}
398
399type ReadChunkFuture = ReusableBoxFuture<
400 'static,
401 (
402 msquic_async::ReadStream,
403 Result<Option<msquic_async::StreamRecvBuffer>, msquic_async::ReadError>,
404 ),
405>;
406
407impl RecvStream {
408 fn new(stream: msquic_async::ReadStream) -> Self {
409 Self {
410 stream: Some(stream),
411 read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }),
413 }
414 }
415}
416
417impl quic::RecvStream for RecvStream {
418 type Buf = msquic_async::StreamRecvBuffer;
419
420 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
421 fn poll_data(
422 &mut self,
423 cx: &mut task::Context<'_>,
424 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
425 if let Some(stream) = self.stream.take() {
426 self.read_chunk_fut.set(async move {
427 let chunk = poll_fn(|cx| stream.poll_read_chunk(cx)).await;
428 (stream, chunk)
429 })
430 };
431
432 let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx));
433 self.stream = Some(stream);
434 let chunk = chunk
435 .map_err(convert_read_error_to_stream_error)?
436 .filter(|x| !x.is_empty() || !x.fin());
437 Poll::Ready(Ok(chunk))
438 }
439
440 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
441 fn stop_sending(&mut self, error_code: u64) {
442 self.stream.as_mut().unwrap().abort_read(error_code).ok();
443 }
444
445 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
446 fn recv_id(&self) -> StreamId {
447 self.stream
448 .as_ref()
449 .unwrap()
450 .id()
451 .expect("id")
452 .try_into()
453 .expect("invalid stream id")
454 }
455}
456
457fn convert_read_error_to_stream_error(error: msquic_async::ReadError) -> StreamErrorIncoming {
458 match error {
459 msquic_async::ReadError::Reset(error_code) => {
460 StreamErrorIncoming::StreamTerminated { error_code }
461 }
462 msquic_async::ReadError::ConnectionLost(connection_error) => {
463 StreamErrorIncoming::ConnectionErrorIncoming {
464 connection_error: convert_connection_error(connection_error),
465 }
466 }
467 error @ msquic_async::ReadError::Closed
468 | error @ msquic_async::ReadError::OtherError(_) => {
469 StreamErrorIncoming::Unknown(Box::new(error))
470 }
471 }
472}
473
474pub struct SendStream<B: Buf> {
478 stream: Option<msquic_async::WriteStream>,
479 writing: Option<WriteBuf<B>>,
480 write_fut: WriteFuture,
481}
482
483type WriteFuture = ReusableBoxFuture<
484 'static,
485 (
486 msquic_async::WriteStream,
487 Result<usize, msquic_async::WriteError>,
488 ),
489>;
490
491impl<B> SendStream<B>
492where
493 B: Buf,
494{
495 fn new(stream: msquic_async::WriteStream) -> SendStream<B> {
496 Self {
497 stream: Some(stream),
498 writing: None,
499 write_fut: ReusableBoxFuture::new(async { unreachable!() }),
500 }
501 }
502}
503
504impl<B> quic::SendStream<B> for SendStream<B>
505where
506 B: Buf,
507{
508 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
509 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
510 if let Some(ref mut data) = self.writing {
511 while data.has_remaining() {
512 if let Some(mut stream) = self.stream.take() {
513 let chunk = data.chunk().to_owned(); self.write_fut.set(async move {
515 let ret = poll_fn(|cx| stream.poll_write(cx, &chunk, false)).await;
516 (stream, ret)
517 });
518 }
519
520 let (stream, res) = ready!(self.write_fut.poll(cx));
521 self.stream = Some(stream);
522 match res {
523 Ok(cnt) => data.advance(cnt),
524 Err(err) => {
525 return Poll::Ready(Err(convert_write_error_to_stream_error(err)));
526 }
527 }
528 }
529 }
530 self.writing = None;
531 Poll::Ready(Ok(()))
532 }
533
534 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
535 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
536 self.stream
537 .as_mut()
538 .unwrap()
539 .poll_finish_write(cx)
540 .map_err(convert_write_error_to_stream_error)
541 }
542
543 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
544 fn reset(&mut self, reset_code: u64) {
545 let _ = self.stream.as_mut().unwrap().abort_write(reset_code);
546 }
547
548 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
549 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
550 if self.writing.is_some() {
551 #[cfg(feature = "tracing")]
555 tracing::error!("send_data called while send stream is not ready");
556 return Err(StreamErrorIncoming::ConnectionErrorIncoming {
557 connection_error: ConnectionErrorIncoming::InternalError(
558 "internal error in the http stack".to_string(),
559 ),
560 });
561 }
562 self.writing = Some(data.into());
563 Ok(())
564 }
565
566 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
567 fn send_id(&self) -> StreamId {
568 self.stream
569 .as_ref()
570 .unwrap()
571 .id()
572 .expect("id")
573 .try_into()
574 .expect("invalid stream id")
575 }
576}
577
578impl<B> quic::SendStreamUnframed<B> for SendStream<B>
579where
580 B: Buf,
581{
582 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
583 fn poll_send<D: Buf>(
584 &mut self,
585 cx: &mut task::Context<'_>,
586 buf: &mut D,
587 ) -> Poll<Result<usize, StreamErrorIncoming>> {
588 if self.writing.is_some() {
589 panic!("poll_send called while send stream is not ready")
591 }
592
593 let res = ready!(self
594 .stream
595 .as_mut()
596 .unwrap()
597 .poll_write(cx, buf.chunk(), false));
598 match res {
599 Ok(written) => {
600 buf.advance(written);
601 Poll::Ready(Ok(written))
602 }
603 Err(err) => Poll::Ready(Err(convert_write_error_to_stream_error(err))),
604 }
605 }
606}
607
608fn convert_write_error_to_stream_error(error: msquic_async::WriteError) -> StreamErrorIncoming {
609 match error {
610 msquic_async::WriteError::Stopped(error_code) => {
611 StreamErrorIncoming::StreamTerminated { error_code }
612 }
613 msquic_async::WriteError::ConnectionLost(connection_error) => {
614 StreamErrorIncoming::ConnectionErrorIncoming {
615 connection_error: convert_connection_error(connection_error),
616 }
617 }
618 error @ msquic_async::WriteError::Closed
619 | error @ msquic_async::WriteError::Finished
620 | error @ msquic_async::WriteError::OtherError(_) => {
621 StreamErrorIncoming::Unknown(Box::new(error))
622 }
623 }
624}