1#![deny(missing_docs)]
5
6use std::{
7 convert::TryInto,
8 future::Future,
9 pin::Pin,
10 sync::Arc,
11 task::{self, Poll},
12};
13
14use bytes::{Buf, Bytes};
15
16use futures::{
17 ready,
18 stream::{self},
19 Stream, StreamExt,
20};
21
22use quinn::ReadError;
23pub use quinn::{self, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt};
24
25use h3::{
26 error::Code,
27 quic::{self, ConnectionErrorIncoming, StreamErrorIncoming, StreamId, WriteBuf},
28};
29use tokio_util::sync::ReusableBoxFuture;
30
31#[cfg(feature = "tracing")]
32use tracing::instrument;
33
34#[cfg(feature = "datagram")]
35pub mod datagram;
36
37type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Sync + Send + 'a>>;
39
40pub struct Connection {
44 conn: quinn::Connection,
45 incoming_bi: BoxStreamSync<'static, <AcceptBi<'static> as Future>::Output>,
46 opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
47 incoming_uni: BoxStreamSync<'static, <AcceptUni<'static> as Future>::Output>,
48 opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
49}
50
51impl Connection {
52 pub fn new(conn: quinn::Connection) -> Self {
54 Self {
55 conn: conn.clone(),
56 incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async {
57 Some((conn.accept_bi().await, conn))
58 })),
59 opening_bi: None,
60 incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async {
61 Some((conn.accept_uni().await, conn))
62 })),
63 opening_uni: None,
64 }
65 }
66}
67
68impl<B> quic::Connection<B> for Connection
69where
70 B: Buf,
71{
72 type RecvStream = RecvStream;
73 type OpenStreams = OpenStreams;
74
75 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
76 fn poll_accept_bidi(
77 &mut self,
78 cx: &mut task::Context<'_>,
79 ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
80 let (send, recv) = ready!(self.incoming_bi.poll_next_unpin(cx))
81 .expect("self.incoming_bi BoxStream never returns None")
82 .map_err(|e| convert_connection_error(e))?;
83 Poll::Ready(Ok(Self::BidiStream {
84 send: Self::SendStream::new(send),
85 recv: Self::RecvStream::new(recv),
86 }))
87 }
88
89 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
90 fn poll_accept_recv(
91 &mut self,
92 cx: &mut task::Context<'_>,
93 ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
94 let recv = ready!(self.incoming_uni.poll_next_unpin(cx))
95 .expect("self.incoming_uni BoxStream never returns None")
96 .map_err(|e| convert_connection_error(e))?;
97 Poll::Ready(Ok(Self::RecvStream::new(recv)))
98 }
99
100 fn opener(&self) -> Self::OpenStreams {
101 OpenStreams {
102 conn: self.conn.clone(),
103 opening_bi: None,
104 opening_uni: None,
105 }
106 }
107}
108
109fn convert_connection_error(e: quinn::ConnectionError) -> h3::quic::ConnectionErrorIncoming {
110 match e {
111 quinn::ConnectionError::ApplicationClosed(application_close) => {
112 ConnectionErrorIncoming::ApplicationClose {
113 error_code: application_close.error_code.into(),
114 }
115 }
116 quinn::ConnectionError::TimedOut => ConnectionErrorIncoming::Timeout,
117
118 error @ quinn::ConnectionError::VersionMismatch
119 | error @ quinn::ConnectionError::Reset
120 | error @ quinn::ConnectionError::LocallyClosed
121 | error @ quinn::ConnectionError::CidsExhausted
122 | error @ quinn::ConnectionError::TransportError(_)
123 | error @ quinn::ConnectionError::ConnectionClosed(_) => {
124 ConnectionErrorIncoming::Undefined(Arc::new(error))
125 }
126 }
127}
128
129impl<B> quic::OpenStreams<B> for Connection
130where
131 B: Buf,
132{
133 type SendStream = SendStream<B>;
134 type BidiStream = BidiStream<B>;
135
136 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
137 fn poll_open_bidi(
138 &mut self,
139 cx: &mut task::Context<'_>,
140 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
141 let bi = self.opening_bi.get_or_insert_with(|| {
142 Box::pin(stream::unfold(self.conn.clone(), |conn| async {
143 Some((conn.open_bi().await, conn))
144 }))
145 });
146 let (send, recv) = ready!(bi.poll_next_unpin(cx))
147 .expect("BoxStream does not return None")
148 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
149 connection_error: convert_connection_error(e),
150 })?;
151 Poll::Ready(Ok(Self::BidiStream {
152 send: Self::SendStream::new(send),
153 recv: RecvStream::new(recv),
154 }))
155 }
156
157 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
158 fn poll_open_send(
159 &mut self,
160 cx: &mut task::Context<'_>,
161 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
162 let uni = self.opening_uni.get_or_insert_with(|| {
163 Box::pin(stream::unfold(self.conn.clone(), |conn| async {
164 Some((conn.open_uni().await, conn))
165 }))
166 });
167
168 let send = ready!(uni.poll_next_unpin(cx))
169 .expect("BoxStream does not return None")
170 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
171 connection_error: convert_connection_error(e),
172 })?;
173 Poll::Ready(Ok(Self::SendStream::new(send)))
174 }
175
176 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
177 fn close(&mut self, code: Code, reason: &[u8]) {
178 self.conn.close(
179 VarInt::from_u64(code.value()).expect("error code VarInt"),
180 reason,
181 );
182 }
183}
184
185pub struct OpenStreams {
190 conn: quinn::Connection,
191 opening_bi: Option<BoxStreamSync<'static, <OpenBi<'static> as Future>::Output>>,
192 opening_uni: Option<BoxStreamSync<'static, <OpenUni<'static> as Future>::Output>>,
193}
194
195impl<B> quic::OpenStreams<B> for OpenStreams
196where
197 B: Buf,
198{
199 type SendStream = SendStream<B>;
200 type BidiStream = BidiStream<B>;
201
202 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
203 fn poll_open_bidi(
204 &mut self,
205 cx: &mut task::Context<'_>,
206 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
207 let bi = self.opening_bi.get_or_insert_with(|| {
208 Box::pin(stream::unfold(self.conn.clone(), |conn| async {
209 Some((conn.open_bi().await, conn))
210 }))
211 });
212
213 let (send, recv) = ready!(bi.poll_next_unpin(cx))
214 .expect("BoxStream does not return None")
215 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
216 connection_error: convert_connection_error(e),
217 })?;
218 Poll::Ready(Ok(Self::BidiStream {
219 send: Self::SendStream::new(send),
220 recv: RecvStream::new(recv),
221 }))
222 }
223
224 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
225 fn poll_open_send(
226 &mut self,
227 cx: &mut task::Context<'_>,
228 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
229 let uni = self.opening_uni.get_or_insert_with(|| {
230 Box::pin(stream::unfold(self.conn.clone(), |conn| async {
231 Some((conn.open_uni().await, conn))
232 }))
233 });
234
235 let send = ready!(uni.poll_next_unpin(cx))
236 .expect("BoxStream does not return None")
237 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
238 connection_error: convert_connection_error(e),
239 })?;
240 Poll::Ready(Ok(Self::SendStream::new(send)))
241 }
242
243 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
244 fn close(&mut self, code: Code, reason: &[u8]) {
245 self.conn.close(
246 VarInt::from_u64(code.value()).expect("error code VarInt"),
247 reason,
248 );
249 }
250}
251
252impl Clone for OpenStreams {
253 fn clone(&self) -> Self {
254 Self {
255 conn: self.conn.clone(),
256 opening_bi: None,
257 opening_uni: None,
258 }
259 }
260}
261
262pub struct BidiStream<B>
267where
268 B: Buf,
269{
270 send: SendStream<B>,
271 recv: RecvStream,
272}
273
274impl<B> quic::BidiStream<B> for BidiStream<B>
275where
276 B: Buf,
277{
278 type SendStream = SendStream<B>;
279 type RecvStream = RecvStream;
280
281 fn split(self) -> (Self::SendStream, Self::RecvStream) {
282 (self.send, self.recv)
283 }
284}
285
286impl<B: Buf> quic::RecvStream for BidiStream<B> {
287 type Buf = Bytes;
288
289 fn poll_data(
290 &mut self,
291 cx: &mut task::Context<'_>,
292 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
293 self.recv.poll_data(cx)
294 }
295
296 fn stop_sending(&mut self, error_code: u64) {
297 self.recv.stop_sending(error_code)
298 }
299
300 fn recv_id(&self) -> StreamId {
301 self.recv.recv_id()
302 }
303}
304
305impl<B> quic::SendStream<B> for BidiStream<B>
306where
307 B: Buf,
308{
309 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
310 self.send.poll_ready(cx)
311 }
312
313 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
314 self.send.poll_finish(cx)
315 }
316
317 fn reset(&mut self, reset_code: u64) {
318 self.send.reset(reset_code)
319 }
320
321 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
322 self.send.send_data(data)
323 }
324
325 fn send_id(&self) -> StreamId {
326 self.send.send_id()
327 }
328}
329impl<B> quic::SendStreamUnframed<B> for BidiStream<B>
330where
331 B: Buf,
332{
333 fn poll_send<D: Buf>(
334 &mut self,
335 cx: &mut task::Context<'_>,
336 buf: &mut D,
337 ) -> Poll<Result<usize, StreamErrorIncoming>> {
338 self.send.poll_send(cx, buf)
339 }
340}
341
342pub struct RecvStream {
346 stream: Option<quinn::RecvStream>,
347 read_chunk_fut: ReadChunkFuture,
348}
349
350type ReadChunkFuture = ReusableBoxFuture<
351 'static,
352 (
353 quinn::RecvStream,
354 Result<Option<quinn::Chunk>, quinn::ReadError>,
355 ),
356>;
357
358impl RecvStream {
359 fn new(stream: quinn::RecvStream) -> Self {
360 Self {
361 stream: Some(stream),
362 read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }),
364 }
365 }
366}
367
368impl quic::RecvStream for RecvStream {
369 type Buf = Bytes;
370
371 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
372 fn poll_data(
373 &mut self,
374 cx: &mut task::Context<'_>,
375 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
376 if let Some(mut stream) = self.stream.take() {
377 self.read_chunk_fut.set(async move {
378 let chunk = stream.read_chunk(usize::MAX, true).await;
379 (stream, chunk)
380 })
381 };
382
383 let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx));
384 self.stream = Some(stream);
385 Poll::Ready(Ok(chunk
386 .map_err(|e| convert_read_error_to_stream_error(e))?
387 .map(|c| c.bytes)))
388 }
389
390 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
391 fn stop_sending(&mut self, error_code: u64) {
392 self.stream
393 .as_mut()
394 .unwrap()
395 .stop(VarInt::from_u64(error_code).expect("invalid error_code"))
396 .ok();
397 }
398
399 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
400 fn recv_id(&self) -> StreamId {
401 let num: u64 = self.stream.as_ref().unwrap().id().into();
402
403 num.try_into().expect("invalid stream id")
404 }
405}
406
407fn convert_read_error_to_stream_error(error: ReadError) -> StreamErrorIncoming {
408 match error {
409 ReadError::Reset(var_int) => StreamErrorIncoming::StreamTerminated {
410 error_code: var_int.into_inner(),
411 },
412 ReadError::ConnectionLost(connection_error) => {
413 StreamErrorIncoming::ConnectionErrorIncoming {
414 connection_error: convert_connection_error(connection_error),
415 }
416 }
417 error @ ReadError::ClosedStream => StreamErrorIncoming::Unknown(Box::new(error)),
418 ReadError::IllegalOrderedRead => panic!("h3-quinn only performs ordered reads"),
419 error @ ReadError::ZeroRttRejected => StreamErrorIncoming::Unknown(Box::new(error)),
420 }
421}
422
423fn convert_write_error_to_stream_error(error: quinn::WriteError) -> StreamErrorIncoming {
424 match error {
425 quinn::WriteError::Stopped(var_int) => StreamErrorIncoming::StreamTerminated {
426 error_code: var_int.into_inner(),
427 },
428 quinn::WriteError::ConnectionLost(connection_error) => {
429 StreamErrorIncoming::ConnectionErrorIncoming {
430 connection_error: convert_connection_error(connection_error),
431 }
432 }
433 error @ quinn::WriteError::ClosedStream | error @ quinn::WriteError::ZeroRttRejected => {
434 StreamErrorIncoming::Unknown(Box::new(error))
435 }
436 }
437}
438
439pub struct SendStream<B: Buf> {
443 stream: quinn::SendStream,
444 writing: Option<WriteBuf<B>>,
445}
446
447impl<B> SendStream<B>
448where
449 B: Buf,
450{
451 fn new(stream: quinn::SendStream) -> SendStream<B> {
452 Self {
453 stream: stream,
454 writing: None,
455 }
456 }
457}
458
459impl<B> quic::SendStream<B> for SendStream<B>
460where
461 B: Buf,
462{
463 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
464 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
465 if let Some(ref mut data) = self.writing {
466 while data.has_remaining() {
467 let stream = Pin::new(&mut self.stream);
468 let written = ready!(stream.poll_write(cx, data.chunk()))
469 .map_err(|err| convert_write_error_to_stream_error(err))?;
470 data.advance(written);
471 }
472 }
473 self.writing = None;
475 Poll::Ready(Ok(()))
476 }
477
478 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
479 fn poll_finish(
480 &mut self,
481 _cx: &mut task::Context<'_>,
482 ) -> Poll<Result<(), StreamErrorIncoming>> {
483 Poll::Ready(
484 self.stream
485 .finish()
486 .map_err(|e| StreamErrorIncoming::Unknown(Box::new(e))),
487 )
488 }
489
490 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
491 fn reset(&mut self, reset_code: u64) {
492 let _ = self
493 .stream
494 .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX));
495 }
496
497 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
498 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), StreamErrorIncoming> {
499 if self.writing.is_some() {
500 #[cfg(feature = "tracing")]
504 tracing::error!("send_data called while send stream is not ready");
505 return Err(StreamErrorIncoming::ConnectionErrorIncoming {
506 connection_error: ConnectionErrorIncoming::InternalError(
507 "internal error in the http stack".to_string(),
508 ),
509 });
510 }
511 self.writing = Some(data.into());
512 Ok(())
513 }
514
515 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
516 fn send_id(&self) -> StreamId {
517 let num: u64 = self.stream.id().into();
518 num.try_into().expect("invalid stream id")
519 }
520}
521
522impl<B> quic::SendStreamUnframed<B> for SendStream<B>
523where
524 B: Buf,
525{
526 #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
527 fn poll_send<D: Buf>(
528 &mut self,
529 cx: &mut task::Context<'_>,
530 buf: &mut D,
531 ) -> Poll<Result<usize, StreamErrorIncoming>> {
532 if self.writing.is_some() {
533 panic!("poll_send called while send stream is not ready")
535 }
536
537 let s = Pin::new(&mut self.stream);
538
539 let res = ready!(s.poll_write(cx, buf.chunk()));
540 match res {
541 Ok(written) => {
542 buf.advance(written);
543 Poll::Ready(Ok(written))
544 }
545 Err(err) => Poll::Ready(Err(convert_write_error_to_stream_error(err))),
546 }
547 }
548}