1#![deny(missing_docs)]
5
6use std::{
7 convert::TryInto,
8 fmt::{self, Display},
9 future::Future,
10 pin::Pin,
11 sync::Arc,
12 task::{self, Poll},
13};
14
15use bytes::{Buf, Bytes, BytesMut};
16
17use futures::{
18 ready,
19 stream::{self},
20 Stream, StreamExt,
21};
22
23#[cfg(feature = "datagram")]
24use h3_datagram::{datagram::Datagram, quic_traits};
25
26pub use quinn::{self, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError};
27use quinn::{ApplicationClose, ClosedStream, ReadDatagram};
28
29use h3::quic::{self, Error, StreamId, WriteBuf};
30use tokio_util::sync::ReusableBoxFuture;
31
32#[cfg(feature = "tracing")]
33use tracing::instrument;
34
35type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Sync + Send + 'a>>;
37
38pub struct Connection {
42 conn: quinn::Connection,
43 incoming_bi: BoxStreamSync<'static, <AcceptBi<'static> as Future>::Output>,
44 opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
45 incoming_uni: BoxStreamSync<'static, <AcceptUni<'static> as Future>::Output>,
46 opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
47 datagrams: BoxStreamSync<'static, <ReadDatagram<'static> as Future>::Output>,
48}
49
50impl Connection {
51 pub fn new(conn: quinn::Connection) -> Self {
53 Self {
54 conn: conn.clone(),
55 incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async {
56 Some((conn.accept_bi().await, conn))
57 })),
58 opening_bi: None,
59 incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
60 Some((conn.accept_uni().await, conn))
61 })),
62 opening_uni: None,
63 datagrams: Box::pin(stream::unfold(conn, |conn| async {
64 Some((conn.read_datagram().await, conn))
65 })),
66 }
67 }
68}
69
70#[derive(Debug)]
74pub struct ConnectionError(quinn::ConnectionError);
75
76impl std::error::Error for ConnectionError {}
77
78impl fmt::Display for ConnectionError {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 self.0.fmt(f)
81 }
82}
83
84impl Error for ConnectionError {
85 fn is_timeout(&self) -> bool {
86 matches!(self.0, quinn::ConnectionError::TimedOut)
87 }
88
89 fn err_code(&self) -> Option<u64> {
90 match self.0 {
91 quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) => {
92 Some(error_code.into_inner())
93 }
94 _ => None,
95 }
96 }
97}
98
99impl From<quinn::ConnectionError> for ConnectionError {
100 fn from(e: quinn::ConnectionError) -> Self {
101 Self(e)
102 }
103}
104
105#[derive(Debug)]
107pub enum SendDatagramError {
108 UnsupportedByPeer,
110 Disabled,
112 TooLarge,
114 ConnectionLost(Box<dyn Error>),
116}
117
118impl fmt::Display for SendDatagramError {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 match self {
121 SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"),
122 SendDatagramError::Disabled => write!(f, "datagram support disabled"),
123 SendDatagramError::TooLarge => write!(f, "datagram too large"),
124 SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"),
125 }
126 }
127}
128
129impl std::error::Error for SendDatagramError {}
130
131impl Error for SendDatagramError {
132 fn is_timeout(&self) -> bool {
133 false
134 }
135
136 fn err_code(&self) -> Option<u64> {
137 match self {
138 Self::ConnectionLost(err) => err.err_code(),
139 _ => None,
140 }
141 }
142}
143
144impl From<quinn::SendDatagramError> for SendDatagramError {
145 fn from(value: quinn::SendDatagramError) -> Self {
146 match value {
147 quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer,
148 quinn::SendDatagramError::Disabled => Self::Disabled,
149 quinn::SendDatagramError::TooLarge => Self::TooLarge,
150 quinn::SendDatagramError::ConnectionLost(err) => {
151 Self::ConnectionLost(ConnectionError::from(err).into())
152 }
153 }
154 }
155}
156
157impl<B> quic::Connection<B> for Connection
158where
159 B: Buf,
160{
161 type RecvStream = RecvStream;
162 type OpenStreams = OpenStreams;
163 type AcceptError = ConnectionError;
164
165 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
166 fn poll_accept_bidi(
167 &mut self,
168 cx: &mut task::Context<'_>,
169 ) -> Poll<Result<Option<Self::BidiStream>, Self::AcceptError>> {
170 let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) {
171 Some(x) => x?,
172 None => return Poll::Ready(Ok(None)),
173 };
174 Poll::Ready(Ok(Some(Self::BidiStream {
175 send: Self::SendStream::new(send),
176 recv: Self::RecvStream::new(recv),
177 })))
178 }
179
180 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
181 fn poll_accept_recv(
182 &mut self,
183 cx: &mut task::Context<'_>,
184 ) -> Poll<Result<Option<Self::RecvStream>, Self::AcceptError>> {
185 let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) {
186 Some(x) => x?,
187 None => return Poll::Ready(Ok(None)),
188 };
189 Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
190 }
191
192 fn opener(&self) -> Self::OpenStreams {
193 OpenStreams {
194 conn: self.conn.clone(),
195 opening_bi: None,
196 opening_uni: None,
197 }
198 }
199}
200
201impl<B> quic::OpenStreams<B> for Connection
202where
203 B: Buf,
204{
205 type SendStream = SendStream<B>;
206 type BidiStream = BidiStream<B>;
207 type OpenError = ConnectionError;
208
209 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
210 fn poll_open_bidi(
211 &mut self,
212 cx: &mut task::Context<'_>,
213 ) -> Poll<Result<Self::BidiStream, Self::OpenError>> {
214 if self.opening_bi.is_none() {
215 self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
216 Some((conn.clone().open_bi().await, conn))
217 })));
218 }
219
220 let (send, recv) =
221 ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
222 Poll::Ready(Ok(Self::BidiStream {
223 send: Self::SendStream::new(send),
224 recv: RecvStream::new(recv),
225 }))
226 }
227
228 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
229 fn poll_open_send(
230 &mut self,
231 cx: &mut task::Context<'_>,
232 ) -> Poll<Result<Self::SendStream, Self::OpenError>> {
233 if self.opening_uni.is_none() {
234 self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
235 Some((conn.open_uni().await, conn))
236 })));
237 }
238
239 let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
240 Poll::Ready(Ok(Self::SendStream::new(send)))
241 }
242
243 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
244 fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
245 self.conn.close(
246 VarInt::from_u64(code.value()).expect("error code VarInt"),
247 reason,
248 );
249 }
250}
251
252#[cfg(feature = "datagram")]
253impl<B> quic_traits::SendDatagramExt<B> for Connection
254where
255 B: Buf,
256{
257 type Error = SendDatagramError;
258
259 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
260 fn send_datagram(&mut self, data: Datagram<B>) -> Result<(), SendDatagramError> {
261 let mut buf = BytesMut::new();
263 data.encode(&mut buf);
264 self.conn.send_datagram(buf.freeze())?;
265
266 Ok(())
267 }
268}
269
270#[cfg(feature = "datagram")]
271impl quic_traits::RecvDatagramExt for Connection {
272 type Buf = Bytes;
273
274 type Error = ConnectionError;
275
276 #[inline]
277 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
278 fn poll_accept_datagram(
279 &mut self,
280 cx: &mut task::Context<'_>,
281 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
282 match ready!(self.datagrams.poll_next_unpin(cx)) {
283 Some(Ok(x)) => Poll::Ready(Ok(Some(x))),
284 Some(Err(e)) => Poll::Ready(Err(e.into())),
285 None => Poll::Ready(Ok(None)),
286 }
287 }
288}
289
290pub struct OpenStreams {
295 conn: quinn::Connection,
296 opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
297 opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
298}
299
300impl<B> quic::OpenStreams<B> for OpenStreams
301where
302 B: Buf,
303{
304 type SendStream = SendStream<B>;
305 type BidiStream = BidiStream<B>;
306 type OpenError = ConnectionError;
307
308 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
309 fn poll_open_bidi(
310 &mut self,
311 cx: &mut task::Context<'_>,
312 ) -> Poll<Result<Self::BidiStream, Self::OpenError>> {
313 if self.opening_bi.is_none() {
314 self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
315 Some((conn.open_bi().await, conn))
316 })));
317 }
318
319 let (send, recv) =
320 ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
321 Poll::Ready(Ok(Self::BidiStream {
322 send: Self::SendStream::new(send),
323 recv: RecvStream::new(recv),
324 }))
325 }
326
327 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
328 fn poll_open_send(
329 &mut self,
330 cx: &mut task::Context<'_>,
331 ) -> Poll<Result<Self::SendStream, Self::OpenError>> {
332 if self.opening_uni.is_none() {
333 self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async {
334 Some((conn.open_uni().await, conn))
335 })));
336 }
337
338 let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?;
339 Poll::Ready(Ok(Self::SendStream::new(send)))
340 }
341
342 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
343 fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
344 self.conn.close(
345 VarInt::from_u64(code.value()).expect("error code VarInt"),
346 reason,
347 );
348 }
349}
350
351impl Clone for OpenStreams {
352 fn clone(&self) -> Self {
353 Self {
354 conn: self.conn.clone(),
355 opening_bi: None,
356 opening_uni: None,
357 }
358 }
359}
360
361pub struct BidiStream<B>
366where
367 B: Buf,
368{
369 send: SendStream<B>,
370 recv: RecvStream,
371}
372
373impl<B> quic::BidiStream<B> for BidiStream<B>
374where
375 B: Buf,
376{
377 type SendStream = SendStream<B>;
378 type RecvStream = RecvStream;
379
380 fn split(self) -> (Self::SendStream, Self::RecvStream) {
381 (self.send, self.recv)
382 }
383}
384
385impl<B: Buf> quic::RecvStream for BidiStream<B> {
386 type Buf = Bytes;
387 type Error = ReadError;
388
389 fn poll_data(
390 &mut self,
391 cx: &mut task::Context<'_>,
392 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
393 self.recv.poll_data(cx)
394 }
395
396 fn stop_sending(&mut self, error_code: u64) {
397 self.recv.stop_sending(error_code)
398 }
399
400 fn recv_id(&self) -> StreamId {
401 self.recv.recv_id()
402 }
403}
404
405impl<B> quic::SendStream<B> for BidiStream<B>
406where
407 B: Buf,
408{
409 type Error = SendStreamError;
410
411 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
412 self.send.poll_ready(cx)
413 }
414
415 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
416 self.send.poll_finish(cx)
417 }
418
419 fn reset(&mut self, reset_code: u64) {
420 self.send.reset(reset_code)
421 }
422
423 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
424 self.send.send_data(data)
425 }
426
427 fn send_id(&self) -> StreamId {
428 self.send.send_id()
429 }
430}
431impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
432where
433 B: Buf,
434{
435 fn poll_send<D: Buf>(
436 &mut self,
437 cx: &mut task::Context<'_>,
438 buf: &mut D,
439 ) -> Poll<Result<usize, Self::Error>> {
440 self.send.poll_send(cx, buf)
441 }
442}
443
444pub struct RecvStream {
448 stream: Option<quinn::RecvStream>,
449 read_chunk_fut: ReadChunkFuture,
450}
451
452type ReadChunkFuture = ReusableBoxFuture<
453 'static,
454 (
455 quinn::RecvStream,
456 Result<Option<quinn::Chunk>, quinn::ReadError>,
457 ),
458>;
459
460impl RecvStream {
461 fn new(stream: quinn::RecvStream) -> Self {
462 Self {
463 stream: Some(stream),
464 read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }),
466 }
467 }
468}
469
470impl quic::RecvStream for RecvStream {
471 type Buf = Bytes;
472 type Error = ReadError;
473
474 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
475 fn poll_data(
476 &mut self,
477 cx: &mut task::Context<'_>,
478 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
479 if let Some(mut stream) = self.stream.take() {
480 self.read_chunk_fut.set(async move {
481 let chunk = stream.read_chunk(usize::MAX, true).await;
482 (stream, chunk)
483 })
484 };
485
486 let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx));
487 self.stream = Some(stream);
488 Poll::Ready(Ok(chunk?.map(|c| c.bytes)))
489 }
490
491 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
492 fn stop_sending(&mut self, error_code: u64) {
493 self.stream
494 .as_mut()
495 .unwrap()
496 .stop(VarInt::from_u64(error_code).expect("invalid error_code"))
497 .ok();
498 }
499
500 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
501 fn recv_id(&self) -> StreamId {
502 let num: u64 = self.stream.as_ref().unwrap().id().into();
503
504 num.try_into().expect("invalid stream id")
505 }
506}
507
508#[derive(Debug)]
512pub struct ReadError(quinn::ReadError);
513
514impl From<ReadError> for std::io::Error {
515 fn from(value: ReadError) -> Self {
516 value.0.into()
517 }
518}
519
520impl std::error::Error for ReadError {
521 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
522 self.0.source()
523 }
524}
525
526impl fmt::Display for ReadError {
527 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
528 self.0.fmt(f)
529 }
530}
531
532impl From<ReadError> for Arc<dyn Error> {
533 fn from(e: ReadError) -> Self {
534 Arc::new(e)
535 }
536}
537
538impl From<quinn::ReadError> for ReadError {
539 fn from(e: quinn::ReadError) -> Self {
540 Self(e)
541 }
542}
543
544impl Error for ReadError {
545 fn is_timeout(&self) -> bool {
546 matches!(
547 self.0,
548 quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut)
549 )
550 }
551
552 fn err_code(&self) -> Option<u64> {
553 match self.0 {
554 quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed(
555 ApplicationClose { error_code, .. },
556 )) => Some(error_code.into_inner()),
557 quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()),
558 _ => None,
559 }
560 }
561}
562
563pub struct SendStream<B: Buf> {
567 stream: quinn::SendStream,
568 writing: Option<WriteBuf<B>>,
569}
570
571impl<B> SendStream<B>
572where
573 B: Buf,
574{
575 fn new(stream: quinn::SendStream) -> SendStream<B> {
576 Self {
577 stream: stream,
578 writing: None,
579 }
580 }
581}
582
583impl<B> quic::SendStream<B> for SendStream<B>
584where
585 B: Buf,
586{
587 type Error = SendStreamError;
588
589 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
590 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
591 if let Some(ref mut data) = self.writing {
592 while data.has_remaining() {
593 let stream = Pin::new(&mut self.stream);
594 let written = ready!(stream.poll_write(cx, data.chunk()))
595 .map_err(|err| SendStreamError::Write(err))?;
596 data.advance(written);
597 }
598 }
599 self.writing = None;
601 Poll::Ready(Ok(()))
602 }
603
604 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
605 fn poll_finish(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
606 Poll::Ready(self.stream.finish().map_err(|e| e.into()))
607 }
608
609 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
610 fn reset(&mut self, reset_code: u64) {
611 let _ = self
612 .stream
613 .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX));
614 }
615
616 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
617 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
618 if self.writing.is_some() {
619 return Err(Self::Error::NotReady);
620 }
621 self.writing = Some(data.into());
622 Ok(())
623 }
624
625 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
626 fn send_id(&self) -> StreamId {
627 let num: u64 = self.stream.id().into();
628 num.try_into().expect("invalid stream id")
629 }
630}
631
632impl<B> quic::SendStreamUnframed<B> for SendStream<B>
633where
634 B: Buf,
635{
636 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
637 fn poll_send<D: Buf>(
638 &mut self,
639 cx: &mut task::Context<'_>,
640 buf: &mut D,
641 ) -> Poll<Result<usize, Self::Error>> {
642 if self.writing.is_some() {
643 panic!("poll_send called while send stream is not ready")
645 }
646
647 let s = Pin::new(&mut self.stream);
648
649 let res = ready!(s.poll_write(cx, buf.chunk()));
650 match res {
651 Ok(written) => {
652 buf.advance(written);
653 Poll::Ready(Ok(written))
654 }
655 Err(err) => Poll::Ready(Err(SendStreamError::Write(err))),
656 }
657 }
658}
659
660#[derive(Debug)]
664pub enum SendStreamError {
665 Write(WriteError),
667 NotReady,
670 StreamClosed(ClosedStream),
672}
673
674impl From<SendStreamError> for std::io::Error {
675 fn from(value: SendStreamError) -> Self {
676 match value {
677 SendStreamError::Write(err) => err.into(),
678 SendStreamError::NotReady => {
679 std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready")
680 }
681 SendStreamError::StreamClosed(err) => err.into(),
682 }
683 }
684}
685
686impl std::error::Error for SendStreamError {}
687
688impl Display for SendStreamError {
689 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
690 write!(f, "{:?}", self)
691 }
692}
693
694impl From<WriteError> for SendStreamError {
695 fn from(e: WriteError) -> Self {
696 Self::Write(e)
697 }
698}
699
700impl From<ClosedStream> for SendStreamError {
701 fn from(value: ClosedStream) -> Self {
702 Self::StreamClosed(value)
703 }
704}
705
706impl Error for SendStreamError {
707 fn is_timeout(&self) -> bool {
708 matches!(
709 self,
710 Self::Write(quinn::WriteError::ConnectionLost(
711 quinn::ConnectionError::TimedOut
712 ))
713 )
714 }
715
716 fn err_code(&self) -> Option<u64> {
717 match self {
718 Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()),
719 Self::Write(quinn::WriteError::ConnectionLost(
720 quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }),
721 )) => Some(error_code.into_inner()),
722 _ => None,
723 }
724 }
725}
726
727impl From<SendStreamError> for Arc<dyn Error> {
728 fn from(e: SendStreamError) -> Self {
729 Arc::new(e)
730 }
731}