1use std::{
5 convert::TryInto,
6 fmt::{self, Display},
7 pin::Pin,
8 sync::Arc,
9 task::{self, Poll},
10};
11
12use bytes::{Buf, Bytes};
13use futures_util::future::FutureExt as _;
14use futures_util::io::AsyncWrite as _;
15use futures_util::ready;
16use futures_util::stream::StreamExt as _;
17
18pub use quinn::{
19 self, crypto::Session, Endpoint, IncomingBiStreams, IncomingUniStreams, NewConnection, OpenBi,
20 OpenUni, VarInt, WriteError,
21};
22
23use h3::quic::{self, Error, StreamId, WriteBuf};
24
25pub struct Connection {
26 conn: quinn::Connection,
27 incoming_bi: IncomingBiStreams,
28 opening_bi: Option<OpenBi>,
29 incoming_uni: IncomingUniStreams,
30 opening_uni: Option<OpenUni>,
31}
32
33impl Connection {
34 pub fn new(new_conn: NewConnection) -> Self {
35 let NewConnection {
36 uni_streams,
37 bi_streams,
38 connection,
39 ..
40 } = new_conn;
41
42 Self {
43 conn: connection,
44 incoming_bi: bi_streams,
45 opening_bi: None,
46 incoming_uni: uni_streams,
47 opening_uni: None,
48 }
49 }
50}
51
52#[derive(Debug)]
53pub struct ConnectionError(quinn::ConnectionError);
54
55impl std::error::Error for ConnectionError {}
56
57impl fmt::Display for ConnectionError {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 self.0.fmt(f)
60 }
61}
62
63impl Error for ConnectionError {
64 fn is_timeout(&self) -> bool {
65 matches!(self.0, quinn::ConnectionError::TimedOut)
66 }
67
68 fn err_code(&self) -> Option<u64> {
69 match self.0 {
70 quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose {
71 error_code,
72 ..
73 }) => Some(error_code.into_inner()),
74 _ => None,
75 }
76 }
77}
78
79impl From<quinn::ConnectionError> for ConnectionError {
80 fn from(e: quinn::ConnectionError) -> Self {
81 Self(e)
82 }
83}
84
85impl<B> quic::Connection<B> for Connection
86where
87 B: Buf,
88{
89 type SendStream = SendStream<B>;
90 type RecvStream = RecvStream;
91 type BidiStream = BidiStream<B>;
92 type OpenStreams = OpenStreams;
93 type Error = ConnectionError;
94
95 fn poll_accept_bidi(
96 &mut self,
97 cx: &mut task::Context<'_>,
98 ) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> {
99 let (send, recv) = match ready!(self.incoming_bi.next().poll_unpin(cx)) {
100 Some(x) => x?,
101 None => return Poll::Ready(Ok(None)),
102 };
103 Poll::Ready(Ok(Some(Self::BidiStream {
104 send: Self::SendStream::new(send),
105 recv: Self::RecvStream::new(recv),
106 })))
107 }
108
109 fn poll_accept_recv(
110 &mut self,
111 cx: &mut task::Context<'_>,
112 ) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> {
113 let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) {
114 Some(x) => x?,
115 None => return Poll::Ready(Ok(None)),
116 };
117 Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
118 }
119
120 fn poll_open_bidi(
121 &mut self,
122 cx: &mut task::Context<'_>,
123 ) -> Poll<Result<Self::BidiStream, Self::Error>> {
124 if self.opening_bi.is_none() {
125 self.opening_bi = Some(self.conn.open_bi());
126 }
127
128 let (send, recv) = ready!(self.opening_bi.as_mut().unwrap().poll_unpin(cx))?;
129 Poll::Ready(Ok(Self::BidiStream {
130 send: Self::SendStream::new(send),
131 recv: Self::RecvStream::new(recv),
132 }))
133 }
134
135 fn poll_open_send(
136 &mut self,
137 cx: &mut task::Context<'_>,
138 ) -> Poll<Result<Self::SendStream, Self::Error>> {
139 if self.opening_uni.is_none() {
140 self.opening_uni = Some(self.conn.open_uni());
141 }
142
143 let send = ready!(self.opening_uni.as_mut().unwrap().poll_unpin(cx))?;
144 Poll::Ready(Ok(Self::SendStream::new(send)))
145 }
146
147 fn opener(&self) -> Self::OpenStreams {
148 OpenStreams {
149 conn: self.conn.clone(),
150 opening_bi: None,
151 opening_uni: None,
152 }
153 }
154
155 fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
156 self.conn.close(
157 VarInt::from_u64(code.value()).expect("error code VarInt"),
158 reason,
159 );
160 }
161}
162
163pub struct OpenStreams {
164 conn: quinn::Connection,
165 opening_bi: Option<OpenBi>,
166 opening_uni: Option<OpenUni>,
167}
168
169impl<B> quic::OpenStreams<B> for OpenStreams
170where
171 B: Buf,
172{
173 type RecvStream = RecvStream;
174 type SendStream = SendStream<B>;
175 type BidiStream = BidiStream<B>;
176 type Error = ConnectionError;
177
178 fn poll_open_bidi(
179 &mut self,
180 cx: &mut task::Context<'_>,
181 ) -> Poll<Result<Self::BidiStream, Self::Error>> {
182 if self.opening_bi.is_none() {
183 self.opening_bi = Some(self.conn.open_bi());
184 }
185
186 let (send, recv) = ready!(self.opening_bi.as_mut().unwrap().poll_unpin(cx))?;
187 Poll::Ready(Ok(Self::BidiStream {
188 send: Self::SendStream::new(send),
189 recv: Self::RecvStream::new(recv),
190 }))
191 }
192
193 fn poll_open_send(
194 &mut self,
195 cx: &mut task::Context<'_>,
196 ) -> Poll<Result<Self::SendStream, Self::Error>> {
197 if self.opening_uni.is_none() {
198 self.opening_uni = Some(self.conn.open_uni());
199 }
200
201 let send = ready!(self.opening_uni.as_mut().unwrap().poll_unpin(cx))?;
202 Poll::Ready(Ok(Self::SendStream::new(send)))
203 }
204
205 fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
206 self.conn.close(
207 VarInt::from_u64(code.value()).expect("error code VarInt"),
208 reason,
209 );
210 }
211}
212
213impl Clone for OpenStreams {
214 fn clone(&self) -> Self {
215 Self {
216 conn: self.conn.clone(),
217 opening_bi: None,
218 opening_uni: None,
219 }
220 }
221}
222
223pub struct BidiStream<B>
224where
225 B: Buf,
226{
227 send: SendStream<B>,
228 recv: RecvStream,
229}
230
231impl<B> quic::BidiStream<B> for BidiStream<B>
232where
233 B: Buf,
234{
235 type SendStream = SendStream<B>;
236 type RecvStream = RecvStream;
237
238 fn split(self) -> (Self::SendStream, Self::RecvStream) {
239 (self.send, self.recv)
240 }
241}
242
243impl<B> quic::RecvStream for BidiStream<B>
244where
245 B: Buf,
246{
247 type Buf = Bytes;
248 type Error = ReadError;
249
250 fn poll_data(
251 &mut self,
252 cx: &mut task::Context<'_>,
253 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
254 self.recv.poll_data(cx)
255 }
256
257 fn stop_sending(&mut self, error_code: u64) {
258 self.recv.stop_sending(error_code)
259 }
260}
261
262impl<B> quic::SendStream<B> for BidiStream<B>
263where
264 B: Buf,
265{
266 type Error = SendStreamError;
267
268 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
269 self.send.poll_ready(cx)
270 }
271
272 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
273 self.send.poll_finish(cx)
274 }
275
276 fn reset(&mut self, reset_code: u64) {
277 self.send.reset(reset_code)
278 }
279
280 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
281 self.send.send_data(data)
282 }
283
284 fn id(&self) -> StreamId {
285 self.send.id()
286 }
287}
288
289pub struct RecvStream {
290 stream: quinn::RecvStream,
291}
292
293impl RecvStream {
294 fn new(stream: quinn::RecvStream) -> Self {
295 Self { stream }
296 }
297}
298
299impl quic::RecvStream for RecvStream {
300 type Buf = Bytes;
301 type Error = ReadError;
302
303 fn poll_data(
304 &mut self,
305 cx: &mut task::Context<'_>,
306 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
307 Poll::Ready(Ok(ready!(self
308 .stream
309 .read_chunk(usize::MAX, true)
310 .poll_unpin(cx))?
311 .map(|c| (c.bytes))))
312 }
313
314 fn stop_sending(&mut self, error_code: u64) {
315 let _ = self
316 .stream
317 .stop(VarInt::from_u64(error_code).expect("invalid error_code"));
318 }
319}
320
321#[derive(Debug)]
322pub struct ReadError(quinn::ReadError);
323
324impl std::error::Error for ReadError {}
325
326impl fmt::Display for ReadError {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 self.0.fmt(f)
329 }
330}
331
332impl From<ReadError> for Arc<dyn Error> {
333 fn from(e: ReadError) -> Self {
334 Arc::new(e)
335 }
336}
337
338impl From<quinn::ReadError> for ReadError {
339 fn from(e: quinn::ReadError) -> Self {
340 Self(e)
341 }
342}
343
344impl Error for ReadError {
345 fn is_timeout(&self) -> bool {
346 matches!(
347 self.0,
348 quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut)
349 )
350 }
351
352 fn err_code(&self) -> Option<u64> {
353 match self.0 {
354 quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed(
355 quinn_proto::ApplicationClose { error_code, .. },
356 )) => Some(error_code.into_inner()),
357 quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()),
358 _ => None,
359 }
360 }
361}
362
363pub struct SendStream<B: Buf> {
364 stream: quinn::SendStream,
365 writing: Option<WriteBuf<B>>,
366}
367
368impl<B> SendStream<B>
369where
370 B: Buf,
371{
372 fn new(stream: quinn::SendStream) -> SendStream<B> {
373 Self {
374 stream,
375 writing: None,
376 }
377 }
378}
379
380impl<B> quic::SendStream<B> for SendStream<B>
381where
382 B: Buf,
383{
384 type Error = SendStreamError;
385
386 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
387 if let Some(ref mut data) = self.writing {
388 while data.has_remaining() {
389 match ready!(Pin::new(&mut self.stream).poll_write(cx, data.chunk())) {
390 Ok(cnt) => data.advance(cnt),
391 Err(err) => {
392 return Poll::Ready(Err(SendStreamError::Write(
399 err.into_inner()
400 .expect("write stream returned an empty error")
401 .downcast_ref::<WriteError>()
402 .expect(
403 "write stream returned an error which type is not WriteError",
404 )
405 .clone(),
406 )));
407 }
408 }
409 }
410 }
411 self.writing = None;
412 Poll::Ready(Ok(()))
413 }
414
415 fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
416 self.stream.finish().poll_unpin(cx).map_err(Into::into)
417 }
418
419 fn reset(&mut self, reset_code: u64) {
420 let _ = self
421 .stream
422 .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX));
423 }
424
425 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
426 if self.writing.is_some() {
427 return Err(Self::Error::NotReady);
428 }
429 self.writing = Some(data.into());
430 Ok(())
431 }
432
433 fn id(&self) -> StreamId {
434 self.stream.id().0.try_into().expect("invalid stream id")
435 }
436}
437
438#[derive(Debug)]
439pub enum SendStreamError {
440 Write(WriteError),
441 NotReady,
442}
443
444impl std::error::Error for SendStreamError {}
445
446impl Display for SendStreamError {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 write!(f, "{:?}", self)
449 }
450}
451
452impl From<WriteError> for SendStreamError {
453 fn from(e: WriteError) -> Self {
454 Self::Write(e)
455 }
456}
457
458impl Error for SendStreamError {
459 fn is_timeout(&self) -> bool {
460 match self {
461 Self::Write(quinn::WriteError::ConnectionLost(quinn::ConnectionError::TimedOut)) => {
462 true
463 }
464 _ => false,
465 }
466 }
467
468 fn err_code(&self) -> Option<u64> {
469 match self {
470 Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()),
471 Self::Write(quinn::WriteError::ConnectionLost(
472 quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose {
473 error_code,
474 ..
475 }),
476 )) => Some(error_code.into_inner()),
477 _ => None,
478 }
479 }
480}
481
482impl From<SendStreamError> for Arc<dyn Error> {
483 fn from(e: SendStreamError) -> Self {
484 Arc::new(e)
485 }
486}