1use std::{
2 marker::PhantomData,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use bytes::{Buf, BufMut, Bytes};
8use futures_util::{future, ready};
9use pin_project_lite::pin_project;
10use tokio::io::ReadBuf;
11
12use crate::{
13 buf::BufList,
14 error::{internal_error::InternalConnectionError, Code},
15 frame::FrameStream,
16 proto::{
17 coding::{Decode as _, Encode},
18 frame::{Frame, Settings},
19 stream::StreamType,
20 varint::VarInt,
21 },
22 quic::{
23 self, BidiStream, ConnectionErrorIncoming, RecvStream, SendStream, SendStreamUnframed,
24 StreamErrorIncoming,
25 },
26 webtransport::SessionId,
27};
28
29#[inline]
30pub(crate) async fn write<S, D, B>(stream: &mut S, data: D) -> Result<(), StreamErrorIncoming>
32where
33 S: SendStream<B>,
34 D: Into<WriteBuf<B>>,
35 B: Buf,
36{
37 stream.send_data(data)?;
38 future::poll_fn(|cx| stream.poll_ready(cx)).await?;
39
40 Ok(())
41}
42
43const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_ENCODED_SIZE;
44
45pub struct WriteBuf<B> {
56 buf: [u8; WRITE_BUF_ENCODE_SIZE],
57 len: usize,
58 pos: usize,
59 frame: Option<Frame<B>>,
60}
61
62impl<B> WriteBuf<B>
63where
64 B: Buf,
65{
66 fn encode_stream_type(&mut self, ty: StreamType) {
67 let mut buf_mut = &mut self.buf[self.len..];
68
69 ty.encode(&mut buf_mut);
70 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
71 }
72
73 fn encode_value(&mut self, value: impl Encode) {
74 let mut buf_mut = &mut self.buf[self.len..];
75 value.encode(&mut buf_mut);
76 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
77 }
78
79 fn encode_frame_header(&mut self) {
80 if let Some(frame) = self.frame.as_ref() {
81 let mut buf_mut = &mut self.buf[self.len..];
82 frame.encode(&mut buf_mut);
83 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
84 }
85 }
86}
87
88impl<B> From<StreamType> for WriteBuf<B>
89where
90 B: Buf,
91{
92 fn from(ty: StreamType) -> Self {
93 let mut me = Self {
94 buf: [0; WRITE_BUF_ENCODE_SIZE],
95 len: 0,
96 pos: 0,
97 frame: None,
98 };
99 me.encode_stream_type(ty);
100 me
101 }
102}
103
104impl<B> From<UniStreamHeader> for WriteBuf<B>
105where
106 B: Buf,
107{
108 fn from(header: UniStreamHeader) -> Self {
109 let mut this = Self {
110 buf: [0; WRITE_BUF_ENCODE_SIZE],
111 len: 0,
112 pos: 0,
113 frame: None,
114 };
115
116 this.encode_value(header);
117 this
118 }
119}
120
121pub enum UniStreamHeader {
122 Control(Settings),
123 WebTransportUni(SessionId),
124 Encoder,
125 Decoder,
126}
127
128impl Encode for UniStreamHeader {
129 fn encode<B: BufMut>(&self, buf: &mut B) {
130 match self {
131 Self::Control(settings) => {
132 StreamType::CONTROL.encode(buf);
133 settings.encode(buf);
134 }
135 Self::WebTransportUni(session_id) => {
136 StreamType::WEBTRANSPORT_UNI.encode(buf);
137 session_id.encode(buf);
138 }
139 UniStreamHeader::Encoder => {
140 StreamType::ENCODER.encode(buf);
141 }
142 UniStreamHeader::Decoder => {
143 StreamType::DECODER.encode(buf);
144 }
145 }
146 }
147}
148
149impl<B> From<BidiStreamHeader> for WriteBuf<B>
150where
151 B: Buf,
152{
153 fn from(header: BidiStreamHeader) -> Self {
154 let mut this = Self {
155 buf: [0; WRITE_BUF_ENCODE_SIZE],
156 len: 0,
157 pos: 0,
158 frame: None,
159 };
160
161 this.encode_value(header);
162 this
163 }
164}
165
166pub enum BidiStreamHeader {
167 WebTransportBidi(SessionId),
168}
169
170impl Encode for BidiStreamHeader {
171 fn encode<B: BufMut>(&self, buf: &mut B) {
172 match self {
173 Self::WebTransportBidi(session_id) => {
174 StreamType::WEBTRANSPORT_BIDI.encode(buf);
175 session_id.encode(buf);
176 }
177 }
178 }
179}
180
181impl<B> From<Frame<B>> for WriteBuf<B>
182where
183 B: Buf,
184{
185 fn from(frame: Frame<B>) -> Self {
186 let mut me = Self {
187 buf: [0; WRITE_BUF_ENCODE_SIZE],
188 len: 0,
189 pos: 0,
190 frame: Some(frame),
191 };
192 me.encode_frame_header();
193 me
194 }
195}
196
197impl<B> From<(StreamType, Frame<B>)> for WriteBuf<B>
198where
199 B: Buf,
200{
201 fn from(ty_stream: (StreamType, Frame<B>)) -> Self {
202 let (ty, frame) = ty_stream;
203 let mut me = Self {
204 buf: [0; WRITE_BUF_ENCODE_SIZE],
205 len: 0,
206 pos: 0,
207 frame: Some(frame),
208 };
209 me.encode_value(ty);
210 me.encode_frame_header();
211 me
212 }
213}
214
215impl<B> Buf for WriteBuf<B>
216where
217 B: Buf,
218{
219 fn remaining(&self) -> usize {
220 self.len - self.pos
221 + self
222 .frame
223 .as_ref()
224 .and_then(|f| f.payload())
225 .map_or(0, |x| x.remaining())
226 }
227
228 fn chunk(&self) -> &[u8] {
229 if self.len - self.pos > 0 {
230 &self.buf[self.pos..self.len]
231 } else if let Some(payload) = self.frame.as_ref().and_then(|f| f.payload()) {
232 payload.chunk()
233 } else {
234 &[]
235 }
236 }
237
238 fn advance(&mut self, mut cnt: usize) {
239 let remaining_header = self.len - self.pos;
240 if remaining_header > 0 {
241 let advanced = usize::min(cnt, remaining_header);
242 self.pos += advanced;
243 cnt -= advanced;
244 }
245
246 if let Some(payload) = self.frame.as_mut().and_then(|f| f.payload_mut()) {
247 payload.advance(cnt);
248 }
249 }
250}
251
252pub(super) enum AcceptedRecvStream<S, B>
253where
254 S: quic::RecvStream,
255 B: Buf,
256{
257 Control(FrameStream<S, B>),
258 Push(FrameStream<S, B>),
259 Encoder(BufRecvStream<S, B>),
260 Decoder(BufRecvStream<S, B>),
261 WebTransportUni(SessionId, BufRecvStream<S, B>),
262 Unknown(BufRecvStream<S, B>),
263}
264
265pub(super) struct AcceptRecvStream<S, B> {
267 stream: BufRecvStream<S, B>,
268 ty: Option<StreamType>,
269 id: Option<VarInt>,
271 expected: Option<usize>,
272}
273
274impl<S, B> AcceptRecvStream<S, B>
275where
276 S: RecvStream,
277 B: Buf,
278{
279 pub fn new(stream: S) -> Self {
280 Self {
281 stream: BufRecvStream::new(stream),
282 ty: None,
283 id: None,
284 expected: None,
285 }
286 }
287
288 pub fn into_stream(self) -> AcceptedRecvStream<S, B> {
289 match self.ty.expect("Stream type not resolved yet") {
290 StreamType::CONTROL => AcceptedRecvStream::Control(FrameStream::new(self.stream)),
291 StreamType::PUSH => AcceptedRecvStream::Push(FrameStream::new(self.stream)),
292 StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream),
293 StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream),
294 StreamType::WEBTRANSPORT_UNI => AcceptedRecvStream::WebTransportUni(
295 SessionId::from_varint(self.id.expect("Session ID not resolved yet")),
296 self.stream,
297 ),
298 _ => AcceptedRecvStream::Unknown(self.stream),
299 }
300 }
301
302 fn poll_next_varint(
304 &mut self,
305 cx: &mut Context<'_>,
306 ) -> Poll<Result<(VarInt, Option<StreamEnd>), PollTypeError>> {
307 let mut stream_stopped = None;
309
310 loop {
311 if stream_stopped.is_some() {
312 return Poll::Ready(Err(PollTypeError::EndOfStream));
313 }
314 stream_stopped = match ready!(self.stream.poll_read(cx)) {
319 Ok(false) => None,
320 Ok(true) => Some(StreamEnd::EndOfStream),
321 Err(StreamErrorIncoming::ConnectionErrorIncoming { connection_error }) => {
322 return Poll::Ready(Err(PollTypeError::IncomingError(connection_error)));
323 }
324 Err(StreamErrorIncoming::StreamTerminated { error_code }) => {
325 Some(StreamEnd::Reset(error_code))
326 }
327 Err(StreamErrorIncoming::Unknown(err)) => {
328 #[cfg(feature = "tracing")]
329 tracing::error!("Unknown error when reading stream {}", err);
330
331 Some(StreamEnd::Other)
332 }
333 };
334
335 let mut buf = self.stream.buf_mut();
336 if self.expected.is_none() && buf.remaining() >= 1 {
337 self.expected = Some(VarInt::encoded_size(buf.chunk()[0]));
338 }
339
340 if let Some(expected) = self.expected {
341 if buf.remaining() < expected {
342 continue;
343 }
344 } else {
345 continue;
346 }
347
348 let reult = VarInt::decode(&mut buf).map_err(|_| {
349 PollTypeError::InternalError(InternalConnectionError::new(
350 Code::H3_INTERNAL_ERROR,
351 "Unexpected end parsing varint".to_string(),
352 ))
353 })?;
354
355 return Poll::Ready(Ok((reult, stream_stopped)));
356 }
357 }
358
359 pub fn poll_type(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollTypeError>> {
360 if self.ty.is_none() {
362 let (var, _) = ready!(self.poll_next_varint(cx))?;
366 let ty = StreamType::from_value(var.0);
367 self.ty = Some(ty);
368 }
369
370 if matches!(
372 self.ty,
373 Some(StreamType::PUSH | StreamType::WEBTRANSPORT_UNI)
374 ) && self.id.is_none()
375 {
376 let (var, _) = ready!(self.poll_next_varint(cx))?;
377 self.id = Some(var);
378 }
379
380 Poll::Ready(Ok(()))
381 }
382}
383
384enum StreamEnd {
385 EndOfStream,
386 Reset(u64),
387 Other,
389}
390
391pub(super) enum PollTypeError {
392 IncomingError(ConnectionErrorIncoming),
393 InternalError(InternalConnectionError),
394 EndOfStream,
397}
398
399pin_project! {
400 pub struct BufRecvStream<S, B> {
410 buf: BufList<Bytes>,
411 eos: bool,
415 stream: S,
416 _marker: PhantomData<B>,
417 }
418}
419
420impl<S, B> std::fmt::Debug for BufRecvStream<S, B> {
421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422 f.debug_struct("BufRecvStream")
423 .field("buf", &self.buf)
424 .field("eos", &self.eos)
425 .field("stream", &"...")
426 .finish()
427 }
428}
429
430impl<S, B> BufRecvStream<S, B> {
431 pub fn new(stream: S) -> Self {
432 Self {
433 buf: BufList::new(),
434 eos: false,
435 stream,
436 _marker: PhantomData,
437 }
438 }
439}
440
441impl<B, S: RecvStream> BufRecvStream<S, B> {
442 pub fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool, StreamErrorIncoming>> {
446 let data = ready!(self.stream.poll_data(cx))?;
447
448 if let Some(mut data) = data {
449 self.buf.push_bytes(&mut data);
450 Poll::Ready(Ok(false))
451 } else {
452 self.eos = true;
453 Poll::Ready(Ok(true))
454 }
455 }
456
457 #[inline]
459 pub(crate) fn buf_mut(&mut self) -> &mut BufList<Bytes> {
460 &mut self.buf
461 }
462
463 pub fn take_chunk(&mut self, limit: usize) -> Option<Bytes> {
467 self.buf.take_chunk(limit)
468 }
469
470 pub fn has_remaining(&mut self) -> bool {
472 self.buf.has_remaining()
473 }
474
475 #[inline]
476 pub(crate) fn buf(&self) -> &BufList<Bytes> {
477 &self.buf
478 }
479
480 pub fn is_eos(&self) -> bool {
481 self.eos
482 }
483}
484
485impl<S: RecvStream, B> RecvStream for BufRecvStream<S, B> {
486 type Buf = Bytes;
487
488 fn poll_data(
489 &mut self,
490 cx: &mut std::task::Context<'_>,
491 ) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
492 if let Some(chunk) = self.buf.take_first_chunk() {
494 return Poll::Ready(Ok(Some(chunk)));
495 }
496
497 if let Some(mut data) = ready!(self.stream.poll_data(cx))? {
498 Poll::Ready(Ok(Some(data.copy_to_bytes(data.remaining()))))
499 } else {
500 self.eos = true;
501 Poll::Ready(Ok(None))
502 }
503 }
504
505 fn stop_sending(&mut self, error_code: u64) {
506 self.stream.stop_sending(error_code)
507 }
508
509 fn recv_id(&self) -> quic::StreamId {
510 self.stream.recv_id()
511 }
512}
513
514impl<S, B> SendStream<B> for BufRecvStream<S, B>
515where
516 B: Buf,
517 S: SendStream<B>,
518{
519 fn poll_finish(
520 &mut self,
521 cx: &mut std::task::Context<'_>,
522 ) -> Poll<Result<(), StreamErrorIncoming>> {
523 self.stream.poll_finish(cx)
524 }
525
526 fn reset(&mut self, reset_code: u64) {
527 self.stream.reset(reset_code)
528 }
529
530 fn send_id(&self) -> quic::StreamId {
531 self.stream.send_id()
532 }
533
534 fn poll_ready(
535 &mut self,
536 cx: &mut std::task::Context<'_>,
537 ) -> Poll<Result<(), StreamErrorIncoming>> {
538 self.stream.poll_ready(cx)
539 }
540
541 fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
542 self.stream.send_data(data)
543 }
544}
545
546impl<S, B> SendStreamUnframed<B> for BufRecvStream<S, B>
547where
548 B: Buf,
549 S: SendStreamUnframed<B>,
550{
551 #[inline]
552 fn poll_send<D: Buf>(
553 &mut self,
554 cx: &mut std::task::Context<'_>,
555 buf: &mut D,
556 ) -> Poll<Result<usize, StreamErrorIncoming>> {
557 self.stream.poll_send(cx, buf)
558 }
559}
560
561impl<S, B> BidiStream<B> for BufRecvStream<S, B>
562where
563 B: Buf,
564 S: BidiStream<B>,
565{
566 type SendStream = BufRecvStream<S::SendStream, B>;
567
568 type RecvStream = BufRecvStream<S::RecvStream, B>;
569
570 fn split(self) -> (Self::SendStream, Self::RecvStream) {
571 let (send, recv) = self.stream.split();
572 (
573 BufRecvStream {
574 buf: BufList::new(),
576 eos: self.eos,
577 stream: send,
578 _marker: PhantomData,
579 },
580 BufRecvStream {
581 buf: self.buf,
582 eos: self.eos,
583 stream: recv,
584 _marker: PhantomData,
585 },
586 )
587 }
588}
589
590impl<S, B> futures_util::io::AsyncRead for BufRecvStream<S, B>
591where
592 B: Buf,
593 S: RecvStream,
594{
595 fn poll_read(
596 mut self: Pin<&mut Self>,
597 cx: &mut Context<'_>,
598 buf: &mut [u8],
599 ) -> Poll<futures_util::io::Result<usize>> {
600 let p = &mut *self;
601 if !p.has_remaining() {
606 let eos = ready!(p.poll_read(cx).map_err(|err| convert_to_std_io_error(err)))?;
607 if eos {
608 return Poll::Ready(Ok(0));
609 }
610 }
611
612 let chunk = p.buf_mut().take_chunk(buf.len());
613 if let Some(chunk) = chunk {
614 assert!(chunk.len() <= buf.len());
615 let len = chunk.len().min(buf.len());
616 buf[..len].copy_from_slice(&chunk);
618 Poll::Ready(Ok(len))
619 } else {
620 Poll::Ready(Ok(0))
621 }
622 }
623}
624
625impl<S, B> tokio::io::AsyncRead for BufRecvStream<S, B>
626where
627 B: Buf,
628 S: RecvStream,
629{
630 fn poll_read(
631 mut self: Pin<&mut Self>,
632 cx: &mut Context<'_>,
633 buf: &mut ReadBuf<'_>,
634 ) -> Poll<futures_util::io::Result<()>> {
635 let p = &mut *self;
636 if !p.has_remaining() {
641 let eos = ready!(p.poll_read(cx).map_err(|err| convert_to_std_io_error(err)))?;
642 if eos {
643 return Poll::Ready(Ok(()));
644 }
645 }
646
647 let chunk = p.buf_mut().take_chunk(buf.remaining());
648 if let Some(chunk) = chunk {
649 assert!(chunk.len() <= buf.remaining());
650 buf.put_slice(&chunk);
652 Poll::Ready(Ok(()))
653 } else {
654 Poll::Ready(Ok(()))
655 }
656 }
657}
658
659impl<S, B> futures_util::io::AsyncWrite for BufRecvStream<S, B>
660where
661 B: Buf,
662 S: SendStreamUnframed<B>,
663{
664 fn poll_write(
665 mut self: Pin<&mut Self>,
666 cx: &mut Context<'_>,
667 mut buf: &[u8],
668 ) -> Poll<std::io::Result<usize>> {
669 let p = &mut *self;
670 p.poll_send(cx, &mut buf)
671 .map_err(|err| convert_to_std_io_error(err))
672 }
673
674 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
675 Poll::Ready(Ok(()))
676 }
677
678 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
679 let p = &mut *self;
680 p.poll_finish(cx)
681 .map_err(|err| convert_to_std_io_error(err))
682 }
683}
684
685impl<S, B> tokio::io::AsyncWrite for BufRecvStream<S, B>
686where
687 B: Buf,
688 S: SendStreamUnframed<B>,
689{
690 fn poll_write(
691 mut self: Pin<&mut Self>,
692 cx: &mut Context<'_>,
693 mut buf: &[u8],
694 ) -> Poll<std::io::Result<usize>> {
695 let p = &mut *self;
696 p.poll_send(cx, &mut buf)
697 .map_err(|err| convert_to_std_io_error(err))
698 }
699
700 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<std::io::Result<()>> {
701 Poll::Ready(Ok(()))
702 }
703
704 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
705 let p = &mut *self;
706 p.poll_finish(cx)
707 .map_err(|err| convert_to_std_io_error(err))
708 }
709}
710
711fn convert_to_std_io_error(error: StreamErrorIncoming) -> std::io::Error {
712 std::io::Error::new(std::io::ErrorKind::Other, error)
713}
714
715#[cfg(test)]
716mod tests {
717 use crate::proto::coding::BufExt;
718
719 use super::*;
720
721 #[test]
722 fn write_wt_uni_header() {
723 let mut w = WriteBuf::<Bytes>::from(UniStreamHeader::WebTransportUni(
724 SessionId::from_varint(VarInt(5)),
725 ));
726
727 let ty = w.get_var().unwrap();
728 println!("Got type: {ty} {ty:#x}");
729 assert_eq!(ty, 0x54);
730
731 let id = w.get_var().unwrap();
732 println!("Got id: {id}");
733 }
734
735 #[test]
736 fn write_buf_encode_streamtype() {
737 let wbuf = WriteBuf::<Bytes>::from(StreamType::ENCODER);
738
739 assert_eq!(wbuf.chunk(), b"\x02");
740 assert_eq!(wbuf.len, 1);
741 }
742
743 #[test]
744 fn write_buf_encode_frame() {
745 let wbuf = WriteBuf::<Bytes>::from(Frame::Goaway(VarInt(2)));
746
747 assert_eq!(wbuf.chunk(), b"\x07\x01\x02");
748 assert_eq!(wbuf.len, 3);
749 }
750
751 #[test]
752 fn write_buf_encode_streamtype_then_frame() {
753 let wbuf = WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Goaway(VarInt(2))));
754
755 assert_eq!(wbuf.chunk(), b"\x02\x07\x01\x02");
756 }
757
758 #[test]
759 fn write_buf_advances() {
760 let mut wbuf =
761 WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
762
763 assert_eq!(wbuf.chunk(), b"\x02\x00\x03");
764 wbuf.advance(3);
765 assert_eq!(wbuf.remaining(), 3);
766 assert_eq!(wbuf.chunk(), b"hey");
767 wbuf.advance(2);
768 assert_eq!(wbuf.chunk(), b"y");
769 wbuf.advance(1);
770 assert_eq!(wbuf.remaining(), 0);
771 }
772
773 #[test]
774 fn write_buf_advance_jumps_header_and_payload_start() {
775 let mut wbuf =
776 WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
777
778 wbuf.advance(4);
779 assert_eq!(wbuf.chunk(), b"ey");
780 }
781}