1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes};
8use futures_util::{future, ready};
9use pin_project_lite::pin_project;
10use tokio::io::ReadBuf;
11
12use crate::{
13 buf::BufList,
14 error::{Code, ErrorLevel},
15 frame::FrameStream,
16 proto::{
17 coding::{Decode as _, Encode},
18 frame::{Frame, Settings},
19 stream::StreamType,
20 varint::VarInt,
21 },
22 quic::{self, BidiStream, RecvStream, SendStream, SendStreamUnframed},
23 webtransport::SessionId,
24 Error,
25};
26
27#[inline]
28pub(crate) async fn write<S, D, B>(stream: &mut S, data: D) -> Result<(), Error>
30where
31 S: SendStream<B>,
32 D: Into<WriteBuf<B>>,
33 B: Buf,
34{
35 stream.send_data(data)?;
36 future::poll_fn(|cx| stream.poll_ready(cx)).await?;
37
38 Ok(())
39}
40
41const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_ENCODED_SIZE;
42
43pub struct WriteBuf<B> {
54 buf: [u8; WRITE_BUF_ENCODE_SIZE],
55 len: usize,
56 pos: usize,
57 frame: Option<Frame<B>>,
58}
59
60impl<B> WriteBuf<B>
61where
62 B: Buf,
63{
64 fn encode_stream_type(&mut self, ty: StreamType) {
65 let mut buf_mut = &mut self.buf[self.len..];
66
67 ty.encode(&mut buf_mut);
68 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
69 }
70
71 fn encode_value(&mut self, value: impl Encode) {
72 let mut buf_mut = &mut self.buf[self.len..];
73 value.encode(&mut buf_mut);
74 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
75 }
76
77 fn encode_frame_header(&mut self) {
78 if let Some(frame) = self.frame.as_ref() {
79 let mut buf_mut = &mut self.buf[self.len..];
80 frame.encode(&mut buf_mut);
81 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
82 }
83 }
84}
85
86impl<B> From<StreamType> for WriteBuf<B>
87where
88 B: Buf,
89{
90 fn from(ty: StreamType) -> Self {
91 let mut me = Self {
92 buf: [0; WRITE_BUF_ENCODE_SIZE],
93 len: 0,
94 pos: 0,
95 frame: None,
96 };
97 me.encode_stream_type(ty);
98 me
99 }
100}
101
102impl<B> From<UniStreamHeader> for WriteBuf<B>
103where
104 B: Buf,
105{
106 fn from(header: UniStreamHeader) -> Self {
107 let mut this = Self {
108 buf: [0; WRITE_BUF_ENCODE_SIZE],
109 len: 0,
110 pos: 0,
111 frame: None,
112 };
113
114 this.encode_value(header);
115 this
116 }
117}
118
119pub enum UniStreamHeader {
120 Control(Settings),
121 WebTransportUni(SessionId),
122}
123
124impl Encode for UniStreamHeader {
125 fn encode<B: BufMut>(&self, buf: &mut B) {
126 match self {
127 Self::Control(settings) => {
128 StreamType::CONTROL.encode(buf);
129 settings.encode(buf);
130 }
131 Self::WebTransportUni(session_id) => {
132 StreamType::WEBTRANSPORT_UNI.encode(buf);
133 session_id.encode(buf);
134 }
135 }
136 }
137}
138
139impl<B> From<BidiStreamHeader> for WriteBuf<B>
140where
141 B: Buf,
142{
143 fn from(header: BidiStreamHeader) -> Self {
144 let mut this = Self {
145 buf: [0; WRITE_BUF_ENCODE_SIZE],
146 len: 0,
147 pos: 0,
148 frame: None,
149 };
150
151 this.encode_value(header);
152 this
153 }
154}
155
156pub enum BidiStreamHeader {
157 Control(Settings),
158 WebTransportBidi(SessionId),
159}
160
161impl Encode for BidiStreamHeader {
162 fn encode<B: BufMut>(&self, buf: &mut B) {
163 match self {
164 Self::Control(settings) => {
165 StreamType::CONTROL.encode(buf);
166 settings.encode(buf);
167 }
168 Self::WebTransportBidi(session_id) => {
169 StreamType::WEBTRANSPORT_BIDI.encode(buf);
170 session_id.encode(buf);
171 }
172 }
173 }
174}
175
176impl<B> From<Frame<B>> for WriteBuf<B>
177where
178 B: Buf,
179{
180 fn from(frame: Frame<B>) -> Self {
181 let mut me = Self {
182 buf: [0; WRITE_BUF_ENCODE_SIZE],
183 len: 0,
184 pos: 0,
185 frame: Some(frame),
186 };
187 me.encode_frame_header();
188 me
189 }
190}
191
192impl<B> From<(StreamType, Frame<B>)> for WriteBuf<B>
193where
194 B: Buf,
195{
196 fn from(ty_stream: (StreamType, Frame<B>)) -> Self {
197 let (ty, frame) = ty_stream;
198 let mut me = Self {
199 buf: [0; WRITE_BUF_ENCODE_SIZE],
200 len: 0,
201 pos: 0,
202 frame: Some(frame),
203 };
204 me.encode_value(ty);
205 me.encode_frame_header();
206 me
207 }
208}
209
210impl<B> Buf for WriteBuf<B>
211where
212 B: Buf,
213{
214 fn remaining(&self) -> usize {
215 self.len - self.pos
216 + self
217 .frame
218 .as_ref()
219 .and_then(|f| f.payload())
220 .map_or(0, |x| x.remaining())
221 }
222
223 fn chunk(&self) -> &[u8] {
224 if self.len - self.pos > 0 {
225 &self.buf[self.pos..self.len]
226 } else if let Some(payload) = self.frame.as_ref().and_then(|f| f.payload()) {
227 payload.chunk()
228 } else {
229 &[]
230 }
231 }
232
233 fn advance(&mut self, mut cnt: usize) {
234 let remaining_header = self.len - self.pos;
235 if remaining_header > 0 {
236 let advanced = usize::min(cnt, remaining_header);
237 self.pos += advanced;
238 cnt -= advanced;
239 }
240
241 if let Some(payload) = self.frame.as_mut().and_then(|f| f.payload_mut()) {
242 payload.advance(cnt);
243 }
244 }
245}
246
247pub(super) enum AcceptedRecvStream<S, B>
248where
249 S: quic::RecvStream,
250 B: Buf,
251{
252 Control(FrameStream<S, B>),
253 Push(u64, FrameStream<S, B>),
254 Encoder(BufRecvStream<S, B>),
255 Decoder(BufRecvStream<S, B>),
256 WebTransportUni(SessionId, BufRecvStream<S, B>),
257 Reserved,
258}
259
260pub(super) struct AcceptRecvStream<S, B> {
262 stream: BufRecvStream<S, B>,
263 ty: Option<StreamType>,
264 id: Option<VarInt>,
266 expected: Option<usize>,
267}
268
269impl<S, B> AcceptRecvStream<S, B>
270where
271 S: RecvStream,
272 B: Buf,
273{
274 pub fn new(stream: S) -> Self {
275 Self {
276 stream: BufRecvStream::new(stream),
277 ty: None,
278 id: None,
279 expected: None,
280 }
281 }
282
283 pub fn into_stream(self) -> Result<AcceptedRecvStream<S, B>, Error> {
284 Ok(match self.ty.expect("Stream type not resolved yet") {
285 StreamType::CONTROL => AcceptedRecvStream::Control(FrameStream::new(self.stream)),
286 StreamType::PUSH => AcceptedRecvStream::Push(
287 self.id.expect("Push ID not resolved yet").into_inner(),
288 FrameStream::new(self.stream),
289 ),
290 StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream),
291 StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream),
292 StreamType::WEBTRANSPORT_UNI => AcceptedRecvStream::WebTransportUni(
293 SessionId::from_varint(self.id.expect("Session ID not resolved yet")),
294 self.stream,
295 ),
296 t if t.value() > 0x21 && (t.value() - 0x21) % 0x1f == 0 => AcceptedRecvStream::Reserved,
297
298 t => {
313 return Err(Code::H3_STREAM_CREATION_ERROR.with_reason(
314 format!("unknown stream type 0x{:x}", t.value()),
315 crate::error::ErrorLevel::ConnectionError,
316 ))
317 }
318 })
319 }
320
321 pub fn poll_type(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
322 loop {
323 match self.ty {
325 Some(StreamType::PUSH | StreamType::WEBTRANSPORT_UNI) => {
326 if self.id.is_some() {
327 return Poll::Ready(Ok(()));
328 }
329 }
330 Some(_) => return Poll::Ready(Ok(())),
331 None => (),
332 };
333
334 if ready!(self.stream.poll_read(cx))? {
335 return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR.with_reason(
336 "Stream closed before type received",
337 ErrorLevel::ConnectionError,
338 )));
339 };
340
341 let mut buf = self.stream.buf_mut();
342 if self.expected.is_none() && buf.remaining() >= 1 {
343 self.expected = Some(VarInt::encoded_size(buf.chunk()[0]));
344 }
345
346 if let Some(expected) = self.expected {
347 if buf.remaining() < expected {
349 continue;
350 }
351 } else {
352 continue;
353 }
354
355 if self.ty.is_none() {
357 self.ty = Some(StreamType::decode(&mut buf).map_err(|_| {
359 Code::H3_INTERNAL_ERROR.with_reason(
360 "Unexpected end parsing stream type",
361 ErrorLevel::ConnectionError,
362 )
363 })?);
364 self.expected = None;
366 } else {
367 self.id = Some(VarInt::decode(&mut buf).map_err(|_| {
369 Code::H3_INTERNAL_ERROR.with_reason(
370 "Unexpected end parsing push or session id",
371 ErrorLevel::ConnectionError,
372 )
373 })?);
374 }
375 }
376 }
377}
378
379pin_project! {
380 pub struct BufRecvStream<S, B> {
390 buf: BufList<Bytes>,
391 eos: bool,
395 stream: S,
396 _marker: PhantomData<B>,
397 }
398}
399
400impl<S, B> std::fmt::Debug for BufRecvStream<S, B> {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 f.debug_struct("BufRecvStream")
403 .field("buf", &self.buf)
404 .field("eos", &self.eos)
405 .field("stream", &"...")
406 .finish()
407 }
408}
409
410impl<S, B> BufRecvStream<S, B> {
411 pub fn new(stream: S) -> Self {
412 Self {
413 buf: BufList::new(),
414 eos: false,
415 stream,
416 _marker: PhantomData,
417 }
418 }
419}
420
421impl<B, S: RecvStream> BufRecvStream<S, B> {
422 pub fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool, S::Error>> {
426 let data = ready!(self.stream.poll_data(cx))?;
427
428 if let Some(mut data) = data {
429 self.buf.push_bytes(&mut data);
430 Poll::Ready(Ok(false))
431 } else {
432 self.eos = true;
433 Poll::Ready(Ok(true))
434 }
435 }
436
437 #[inline]
439 pub(crate) fn buf_mut(&mut self) -> &mut BufList<Bytes> {
440 &mut self.buf
441 }
442
443 pub fn take_chunk(&mut self, limit: usize) -> Option<Bytes> {
447 self.buf.take_chunk(limit)
448 }
449
450 pub fn has_remaining(&mut self) -> bool {
452 self.buf.has_remaining()
453 }
454
455 #[inline]
456 pub(crate) fn buf(&self) -> &BufList<Bytes> {
457 &self.buf
458 }
459
460 pub fn is_eos(&self) -> bool {
461 self.eos
462 }
463}
464
465impl<S: RecvStream, B> RecvStream for BufRecvStream<S, B> {
466 type Buf = Bytes;
467
468 type Error = S::Error;
469
470 fn poll_data(
471 &mut self,
472 cx: &mut std::task::Context<'_>,
473 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
474 if let Some(chunk) = self.buf.take_first_chunk() {
476 return Poll::Ready(Ok(Some(chunk)));
477 }
478
479 if let Some(mut data) = ready!(self.stream.poll_data(cx))? {
480 Poll::Ready(Ok(Some(data.copy_to_bytes(data.remaining()))))
481 } else {
482 self.eos = true;
483 Poll::Ready(Ok(None))
484 }
485 }
486
487 fn stop_sending(&mut self, error_code: u64) {
488 self.stream.stop_sending(error_code)
489 }
490
491 fn recv_id(&self) -> quic::StreamId {
492 self.stream.recv_id()
493 }
494}
495
496impl<S, B> SendStream<B> for BufRecvStream<S, B>
497where
498 B: Buf,
499 S: SendStream<B>,
500{
501 type Error = S::Error;
502
503 fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
504 self.stream.poll_finish(cx)
505 }
506
507 fn reset(&mut self, reset_code: u64) {
508 self.stream.reset(reset_code)
509 }
510
511 fn send_id(&self) -> quic::StreamId {
512 self.stream.send_id()
513 }
514
515 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
516 self.stream.poll_ready(cx)
517 }
518
519 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
520 self.stream.send_data(data)
521 }
522}
523
524impl<S, B> SendStreamUnframed<B> for BufRecvStream<S, B>
525where
526 B: Buf,
527 S: SendStreamUnframed<B>,
528{
529 #[inline]
530 fn poll_send<D: Buf>(
531 &mut self,
532 cx: &mut std::task::Context<'_>,
533 buf: &mut D,
534 ) -> Poll<Result<usize, Self::Error>> {
535 self.stream.poll_send(cx, buf)
536 }
537}
538
539impl<S, B> BidiStream<B> for BufRecvStream<S, B>
540where
541 B: Buf,
542 S: BidiStream<B>,
543{
544 type SendStream = BufRecvStream<S::SendStream, B>;
545
546 type RecvStream = BufRecvStream<S::RecvStream, B>;
547
548 fn split(self) -> (Self::SendStream, Self::RecvStream) {
549 let (send, recv) = self.stream.split();
550 (
551 BufRecvStream {
552 buf: BufList::new(),
554 eos: self.eos,
555 stream: send,
556 _marker: PhantomData,
557 },
558 BufRecvStream {
559 buf: self.buf,
560 eos: self.eos,
561 stream: recv,
562 _marker: PhantomData,
563 },
564 )
565 }
566}
567
568impl<S, B> futures_util::io::AsyncRead for BufRecvStream<S, B>
569where
570 B: Buf,
571 S: RecvStream,
572 S::Error: Into<std::io::Error>,
573{
574 fn poll_read(
575 mut self: Pin<&mut Self>,
576 cx: &mut Context<'_>,
577 buf: &mut [u8],
578 ) -> Poll<futures_util::io::Result<usize>> {
579 let p = &mut *self;
580 if !p.has_remaining() {
585 let eos = ready!(p.poll_read(cx).map_err(Into::into))?;
586 if eos {
587 return Poll::Ready(Ok(0));
588 }
589 }
590
591 let chunk = p.buf_mut().take_chunk(buf.len());
592 if let Some(chunk) = chunk {
593 assert!(chunk.len() <= buf.len());
594 let len = chunk.len().min(buf.len());
595 buf[..len].copy_from_slice(&chunk);
597 Poll::Ready(Ok(len))
598 } else {
599 Poll::Ready(Ok(0))
600 }
601 }
602}
603
604impl<S, B> tokio::io::AsyncRead for BufRecvStream<S, B>
605where
606 B: Buf,
607 S: RecvStream,
608 S::Error: Into<std::io::Error>,
609{
610 fn poll_read(
611 mut self: Pin<&mut Self>,
612 cx: &mut Context<'_>,
613 buf: &mut ReadBuf<'_>,
614 ) -> Poll<futures_util::io::Result<()>> {
615 let p = &mut *self;
616 if !p.has_remaining() {
621 let eos = ready!(p.poll_read(cx).map_err(Into::into))?;
622 if eos {
623 return Poll::Ready(Ok(()));
624 }
625 }
626
627 let chunk = p.buf_mut().take_chunk(buf.remaining());
628 if let Some(chunk) = chunk {
629 assert!(chunk.len() <= buf.remaining());
630 buf.put_slice(&chunk);
632 Poll::Ready(Ok(()))
633 } else {
634 Poll::Ready(Ok(()))
635 }
636 }
637}
638
639impl<S, B> futures_util::io::AsyncWrite for BufRecvStream<S, B>
640where
641 B: Buf,
642 S: SendStreamUnframed<B>,
643 S::Error: Into<std::io::Error>,
644{
645 fn poll_write(
646 mut self: Pin<&mut Self>,
647 cx: &mut Context<'_>,
648 mut buf: &[u8],
649 ) -> Poll<std::io::Result<usize>> {
650 let p = &mut *self;
651 p.poll_send(cx, &mut buf).map_err(Into::into)
652 }
653
654 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
655 Poll::Ready(Ok(()))
656 }
657
658 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
659 let p = &mut *self;
660 p.poll_finish(cx).map_err(Into::into)
661 }
662}
663
664impl<S, B> tokio::io::AsyncWrite for BufRecvStream<S, B>
665where
666 B: Buf,
667 S: SendStreamUnframed<B>,
668 S::Error: Into<std::io::Error>,
669{
670 fn poll_write(
671 mut self: Pin<&mut Self>,
672 cx: &mut Context<'_>,
673 mut buf: &[u8],
674 ) -> Poll<std::io::Result<usize>> {
675 let p = &mut *self;
676 p.poll_send(cx, &mut buf).map_err(Into::into)
677 }
678
679 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
680 Poll::Ready(Ok(()))
681 }
682
683 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
684 let p = &mut *self;
685 p.poll_finish(cx).map_err(Into::into)
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use quinn_proto::coding::BufExt;
692
693 use super::*;
694
695 #[test]
696 fn write_wt_uni_header() {
697 let mut w = WriteBuf::<Bytes>::from(UniStreamHeader::WebTransportUni(
698 SessionId::from_varint(VarInt(5)),
699 ));
700
701 let ty = w.get_var().unwrap();
702 println!("Got type: {ty} {ty:#x}");
703 assert_eq!(ty, 0x54);
704
705 let id = w.get_var().unwrap();
706 println!("Got id: {id}");
707 }
708
709 #[test]
710 fn write_buf_encode_streamtype() {
711 let wbuf = WriteBuf::<Bytes>::from(StreamType::ENCODER);
712
713 assert_eq!(wbuf.chunk(), b"\x02");
714 assert_eq!(wbuf.len, 1);
715 }
716
717 #[test]
718 fn write_buf_encode_frame() {
719 let wbuf = WriteBuf::<Bytes>::from(Frame::Goaway(VarInt(2)));
720
721 assert_eq!(wbuf.chunk(), b"\x07\x01\x02");
722 assert_eq!(wbuf.len, 3);
723 }
724
725 #[test]
726 fn write_buf_encode_streamtype_then_frame() {
727 let wbuf = WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Goaway(VarInt(2))));
728
729 assert_eq!(wbuf.chunk(), b"\x02\x07\x01\x02");
730 }
731
732 #[test]
733 fn write_buf_advances() {
734 let mut wbuf =
735 WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
736
737 assert_eq!(wbuf.chunk(), b"\x02\x00\x03");
738 wbuf.advance(3);
739 assert_eq!(wbuf.remaining(), 3);
740 assert_eq!(wbuf.chunk(), b"hey");
741 wbuf.advance(2);
742 assert_eq!(wbuf.chunk(), b"y");
743 wbuf.advance(1);
744 assert_eq!(wbuf.remaining(), 0);
745 }
746
747 #[test]
748 fn write_buf_advance_jumps_header_and_payload_start() {
749 let mut wbuf =
750 WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
751
752 wbuf.advance(4);
753 assert_eq!(wbuf.chunk(), b"ey");
754 }
755}