httproxide_h3/
frame.rs

1use std::marker::PhantomData;
2use std::task::{Context, Poll};
3
4use bytes::{Buf, Bytes};
5
6use futures_util::ready;
7use tracing::trace;
8
9use crate::{
10    buf::BufList,
11    error::TransportError,
12    proto::{
13        frame::{self, Frame, PayloadLen},
14        stream::StreamId,
15    },
16    quic::{BidiStream, RecvStream, SendStream},
17    stream::WriteBuf,
18};
19
20pub struct FrameStream<S, B> {
21    stream: S,
22    bufs: BufList<Bytes>,
23    decoder: FrameDecoder,
24    remaining_data: usize,
25    /// Set to true when `stream` reaches the end.
26    is_eos: bool,
27    _phantom_buffer: PhantomData<B>,
28}
29
30impl<S, B> FrameStream<S, B> {
31    pub fn new(stream: S) -> Self {
32        Self::with_bufs(stream, BufList::new())
33    }
34
35    pub(crate) fn with_bufs(stream: S, bufs: BufList<Bytes>) -> Self {
36        Self {
37            stream,
38            bufs,
39            decoder: FrameDecoder::default(),
40            remaining_data: 0,
41            is_eos: false,
42            _phantom_buffer: PhantomData,
43        }
44    }
45}
46
47impl<S, B> FrameStream<S, B>
48where
49    S: RecvStream,
50{
51    pub fn poll_next(
52        &mut self,
53        cx: &mut Context<'_>,
54    ) -> Poll<Result<Option<Frame<PayloadLen>>, Error>> {
55        assert!(
56            self.remaining_data == 0,
57            "There is still data to read, please call poll_data() until it returns None."
58        );
59
60        loop {
61            let end = self.try_recv(cx)?;
62
63            return match self.decoder.decode(&mut self.bufs)? {
64                Some(Frame::Data(PayloadLen(len))) => {
65                    self.remaining_data = len;
66                    Poll::Ready(Ok(Some(Frame::Data(PayloadLen(len)))))
67                }
68                Some(frame) => Poll::Ready(Ok(Some(frame))),
69                None => match end {
70                    // Received a chunk but frame is incomplete, poll until we get `Pending`.
71                    Poll::Ready(false) => continue,
72                    Poll::Pending => Poll::Pending,
73                    Poll::Ready(true) => {
74                        if self.bufs.has_remaining() {
75                            // Reached the end of receive stream, but there is still some data:
76                            // The frame is incomplete.
77                            Poll::Ready(Err(Error::UnexpectedEnd))
78                        } else {
79                            Poll::Ready(Ok(None))
80                        }
81                    }
82                },
83            };
84        }
85    }
86
87    pub fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<impl Buf>, Error>> {
88        if self.remaining_data == 0 {
89            return Poll::Ready(Ok(None));
90        };
91
92        let end = ready!(self.try_recv(cx))?;
93        let data = self.bufs.take_chunk(self.remaining_data as usize);
94
95        match (data, end) {
96            (None, true) => Poll::Ready(Ok(None)),
97            (None, false) => Poll::Pending,
98            (Some(d), true)
99                if d.remaining() < self.remaining_data && !self.bufs.has_remaining() =>
100            {
101                Poll::Ready(Err(Error::UnexpectedEnd))
102            }
103            (Some(d), _) => {
104                self.remaining_data -= d.remaining();
105                Poll::Ready(Ok(Some(d)))
106            }
107        }
108    }
109
110    pub(crate) fn stop_sending(&mut self, error_code: crate::error::Code) {
111        let _ = self.stream.stop_sending(error_code.into());
112    }
113
114    pub(crate) fn has_data(&self) -> bool {
115        self.remaining_data != 0
116    }
117
118    pub(crate) fn is_eos(&self) -> bool {
119        self.is_eos && !self.bufs.has_remaining()
120    }
121
122    fn try_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool, Error>> {
123        if self.is_eos {
124            return Poll::Ready(Ok(true));
125        }
126        match self.stream.poll_data(cx) {
127            Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Quic(e.into()))),
128            Poll::Pending => Poll::Pending,
129            Poll::Ready(Ok(None)) => {
130                self.is_eos = true;
131                Poll::Ready(Ok(true))
132            }
133            Poll::Ready(Ok(Some(mut d))) => {
134                self.bufs.push_bytes(&mut d);
135                Poll::Ready(Ok(false))
136            }
137        }
138    }
139}
140
141impl<T, B> SendStream<B> for FrameStream<T, B>
142where
143    T: SendStream<B>,
144    B: Buf,
145{
146    type Error = <T as SendStream<B>>::Error;
147
148    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
149        self.stream.poll_ready(cx)
150    }
151
152    fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
153        self.stream.send_data(data)
154    }
155
156    fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
157        self.stream.poll_finish(cx)
158    }
159
160    fn reset(&mut self, reset_code: u64) {
161        self.stream.reset(reset_code)
162    }
163
164    fn id(&self) -> StreamId {
165        self.stream.id()
166    }
167}
168
169impl<S, B> FrameStream<S, B>
170where
171    S: BidiStream<B>,
172    B: Buf,
173{
174    pub(crate) fn split(self) -> (FrameStream<S::SendStream, B>, FrameStream<S::RecvStream, B>) {
175        let (send, recv) = self.stream.split();
176        (
177            FrameStream {
178                stream: send,
179                bufs: BufList::new(),
180                decoder: FrameDecoder::default(),
181                remaining_data: 0,
182                is_eos: false,
183                _phantom_buffer: PhantomData,
184            },
185            FrameStream {
186                stream: recv,
187                bufs: self.bufs,
188                decoder: self.decoder,
189                remaining_data: self.remaining_data,
190                is_eos: self.is_eos,
191                _phantom_buffer: PhantomData,
192            },
193        )
194    }
195}
196
197#[derive(Default)]
198pub struct FrameDecoder {
199    expected: Option<usize>,
200}
201
202impl FrameDecoder {
203    fn decode<B: Buf>(&mut self, src: &mut BufList<B>) -> Result<Option<Frame<PayloadLen>>, Error> {
204        // Decode in a loop since we ignore unknown frames, and there may be
205        // other frames already in our BufList.
206        loop {
207            if !src.has_remaining() {
208                return Ok(None);
209            }
210
211            if let Some(min) = self.expected {
212                if src.remaining() < min {
213                    return Ok(None);
214                }
215            }
216
217            let (pos, decoded) = {
218                let mut cur = src.cursor();
219                let decoded = Frame::decode(&mut cur);
220                (cur.position() as usize, decoded)
221            };
222
223            match decoded {
224                Err(frame::Error::UnknownFrame(ty)) => {
225                    trace!("ignore unknown frame type {:#x}", ty);
226                    src.advance(pos);
227                    self.expected = None;
228                    continue;
229                }
230                Err(frame::Error::Incomplete(min)) => {
231                    self.expected = Some(min);
232                    return Ok(None);
233                }
234                Err(e) => return Err(e.into()),
235                Ok(frame) => {
236                    src.advance(pos);
237                    self.expected = None;
238                    return Ok(Some(frame));
239                }
240            }
241        }
242    }
243}
244
245#[derive(Debug)]
246pub enum Error {
247    Proto(frame::Error),
248    Quic(TransportError),
249    UnexpectedEnd,
250}
251
252impl From<frame::Error> for Error {
253    fn from(err: frame::Error) -> Self {
254        Error::Proto(err)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    use assert_matches::assert_matches;
263    use bytes::{BufMut, BytesMut};
264    use futures_util::future::poll_fn;
265    use std::{collections::VecDeque, fmt, sync::Arc};
266    use tokio;
267
268    use crate::{
269        proto::{coding::Encode, frame::FrameType, varint::VarInt},
270        quic,
271    };
272
273    // Decoder
274
275    #[test]
276    fn one_frame() {
277        let mut buf = BytesMut::with_capacity(16);
278        Frame::headers(&b"salut"[..]).encode_with_payload(&mut buf);
279        let mut buf = BufList::from(buf);
280
281        let mut decoder = FrameDecoder::default();
282        assert_matches!(decoder.decode(&mut buf), Ok(Some(Frame::Headers(_))));
283    }
284
285    #[test]
286    fn incomplete_frame() {
287        let frame = Frame::headers(&b"salut"[..]);
288
289        let mut buf = BytesMut::with_capacity(16);
290        frame.encode(&mut buf);
291        buf.truncate(buf.len() - 1);
292        let mut buf = BufList::from(buf);
293
294        let mut decoder = FrameDecoder::default();
295        assert_matches!(decoder.decode(&mut buf), Ok(None));
296    }
297
298    #[test]
299    fn header_spread_multiple_buf() {
300        let mut buf = BytesMut::with_capacity(16);
301        Frame::headers(&b"salut"[..]).encode_with_payload(&mut buf);
302        let mut buf_list = BufList::new();
303        // Cut buffer between type and length
304        buf_list.push(&buf[..1]);
305        buf_list.push(&buf[1..]);
306
307        let mut decoder = FrameDecoder::default();
308        assert_matches!(decoder.decode(&mut buf_list), Ok(Some(Frame::Headers(_))));
309    }
310
311    #[test]
312    fn varint_spread_multiple_buf() {
313        let mut buf = BytesMut::with_capacity(16);
314        Frame::headers("salut".repeat(1024)).encode_with_payload(&mut buf);
315
316        let mut buf_list = BufList::new();
317        // Cut buffer in the middle of length's varint
318        buf_list.push(&buf[..2]);
319        buf_list.push(&buf[2..]);
320
321        let mut decoder = FrameDecoder::default();
322        assert_matches!(decoder.decode(&mut buf_list), Ok(Some(Frame::Headers(_))));
323    }
324
325    #[test]
326    fn two_frames_then_incomplete() {
327        let mut buf = BytesMut::with_capacity(64);
328        Frame::headers(&b"header"[..]).encode_with_payload(&mut buf);
329        Frame::Data(&b"body"[..]).encode_with_payload(&mut buf);
330        Frame::headers(&b"trailer"[..]).encode_with_payload(&mut buf);
331
332        buf.truncate(buf.len() - 1);
333        let mut buf = BufList::from(buf);
334
335        let mut decoder = FrameDecoder::default();
336        assert_matches!(decoder.decode(&mut buf), Ok(Some(Frame::Headers(_))));
337        assert_matches!(
338            decoder.decode(&mut buf),
339            Ok(Some(Frame::Data(PayloadLen(4))))
340        );
341        assert_matches!(decoder.decode(&mut buf), Ok(None));
342    }
343
344    // FrameStream
345
346    macro_rules! assert_poll_matches {
347        ($poll_fn:expr, $match:pat) => {
348            assert_matches!(
349                poll_fn($poll_fn).await,
350                $match
351            );
352        };
353        ($poll_fn:expr, $match:pat if $cond:expr ) => {
354            assert_matches!(
355                poll_fn($poll_fn).await,
356                $match if $cond
357            );
358        }
359    }
360
361    #[tokio::test]
362    async fn poll_full_request() {
363        let mut recv = FakeRecv::default();
364        let mut buf = BytesMut::with_capacity(64);
365
366        Frame::headers(&b"header"[..]).encode_with_payload(&mut buf);
367        Frame::Data(&b"body"[..]).encode_with_payload(&mut buf);
368        Frame::headers(&b"trailer"[..]).encode_with_payload(&mut buf);
369        recv.chunk(buf.freeze());
370
371        let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
372
373        assert_poll_matches!(
374            |mut cx| stream.poll_next(&mut cx),
375            Ok(Some(Frame::Headers(_)))
376        );
377        assert_poll_matches!(
378            |mut cx| stream.poll_next(&mut cx),
379            Ok(Some(Frame::Data(PayloadLen(4))))
380        );
381        assert_poll_matches!(
382            |mut cx| to_bytes(stream.poll_data(&mut cx)),
383            Ok(Some(b)) if b.remaining() == 4
384        );
385        assert_poll_matches!(
386            |mut cx| stream.poll_next(&mut cx),
387            Ok(Some(Frame::Headers(_)))
388        );
389    }
390
391    #[tokio::test]
392    async fn poll_next_incomplete_frame() {
393        let mut recv = FakeRecv::default();
394        let mut buf = BytesMut::with_capacity(64);
395
396        Frame::headers(&b"header"[..]).encode_with_payload(&mut buf);
397        let mut buf = buf.freeze();
398        recv.chunk(buf.split_to(buf.len() - 1));
399        let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
400
401        assert_poll_matches!(
402            |mut cx| stream.poll_next(&mut cx),
403            Err(Error::UnexpectedEnd)
404        );
405    }
406
407    #[tokio::test]
408    #[should_panic(
409        expected = "There is still data to read, please call poll_data() until it returns None"
410    )]
411    async fn poll_next_reamining_data() {
412        let mut recv = FakeRecv::default();
413        let mut buf = BytesMut::with_capacity(64);
414
415        FrameType::DATA.encode(&mut buf);
416        VarInt::from(4u32).encode(&mut buf);
417        recv.chunk(buf.freeze());
418        let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
419
420        assert_poll_matches!(
421            |mut cx| stream.poll_next(&mut cx),
422            Ok(Some(Frame::Data(PayloadLen(4))))
423        );
424
425        // There is still data to consume, poll_next should panic
426        let _ = poll_fn(|mut cx| stream.poll_next(&mut cx)).await;
427    }
428
429    #[tokio::test]
430    async fn poll_data_split() {
431        let mut recv = FakeRecv::default();
432        let mut buf = BytesMut::with_capacity(64);
433
434        // Body is split into two bufs
435        Frame::Data(Bytes::from("body")).encode_with_payload(&mut buf);
436
437        let mut buf = buf.freeze();
438        recv.chunk(buf.split_to(buf.len() - 2));
439        recv.chunk(buf);
440        let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
441
442        // We get the total size of data about to be received
443        assert_poll_matches!(
444            |mut cx| stream.poll_next(&mut cx),
445            Ok(Some(Frame::Data(PayloadLen(4))))
446        );
447
448        // Then we get parts of body, chunked as they arrived
449        assert_poll_matches!(
450            |mut cx| to_bytes(stream.poll_data(&mut cx)),
451            Ok(Some(b)) if b.remaining() == 2
452        );
453        assert_poll_matches!(
454            |mut cx| to_bytes(stream.poll_data(&mut cx)),
455            Ok(Some(b)) if b.remaining() == 2
456        );
457    }
458
459    #[tokio::test]
460    async fn poll_data_unexpected_end() {
461        let mut recv = FakeRecv::default();
462        let mut buf = BytesMut::with_capacity(64);
463
464        // Truncated body
465        FrameType::DATA.encode(&mut buf);
466        VarInt::from(4u32).encode(&mut buf);
467        buf.put_slice(&b"b"[..]);
468        recv.chunk(buf.freeze());
469        let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
470
471        assert_poll_matches!(
472            |mut cx| stream.poll_next(&mut cx),
473            Ok(Some(Frame::Data(PayloadLen(4))))
474        );
475        assert_poll_matches!(
476            |mut cx| to_bytes(stream.poll_data(&mut cx)),
477            Err(Error::UnexpectedEnd)
478        );
479    }
480
481    #[tokio::test]
482    async fn poll_data_ignores_unknown_frames() {
483        use crate::proto::varint::BufMutExt as _;
484
485        let mut recv = FakeRecv::default();
486        let mut buf = BytesMut::with_capacity(64);
487
488        // grease a lil
489        crate::proto::frame::FrameType::grease().encode(&mut buf);
490        buf.write_var(0);
491
492        // grease with some data
493        crate::proto::frame::FrameType::grease().encode(&mut buf);
494        buf.write_var(6);
495        buf.put_slice(b"grease");
496
497        // Body
498        Frame::Data(Bytes::from("body")).encode_with_payload(&mut buf);
499
500        recv.chunk(buf.freeze());
501        let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
502
503        assert_poll_matches!(
504            |mut cx| stream.poll_next(&mut cx),
505            Ok(Some(Frame::Data(PayloadLen(4))))
506        );
507        assert_poll_matches!(
508            |mut cx| to_bytes(stream.poll_data(&mut cx)),
509            Ok(Some(b)) if &*b == b"body"
510        );
511    }
512
513    #[tokio::test]
514    async fn poll_data_eos_but_buffered_data() {
515        let mut recv = FakeRecv::default();
516        let mut buf = BytesMut::with_capacity(64);
517
518        FrameType::DATA.encode(&mut buf);
519        VarInt::from(4u32).encode(&mut buf);
520        buf.put_slice(&b"bo"[..]);
521        recv.chunk(buf.clone().freeze());
522
523        let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
524
525        assert_poll_matches!(
526            |mut cx| stream.poll_next(&mut cx),
527            Ok(Some(Frame::Data(PayloadLen(4))))
528        );
529
530        buf.truncate(0);
531        buf.put_slice(&b"dy"[..]);
532        stream.bufs.push_bytes(&mut buf.freeze());
533
534        assert_poll_matches!(
535            |mut cx| to_bytes(stream.poll_data(&mut cx)),
536            Ok(Some(b)) if &*b == b"bo"
537        );
538
539        assert_poll_matches!(
540            |mut cx| to_bytes(stream.poll_data(&mut cx)),
541            Ok(Some(b)) if &*b == b"dy"
542        );
543    }
544
545    // Helpers
546
547    #[derive(Default)]
548    struct FakeRecv {
549        chunks: VecDeque<Bytes>,
550    }
551
552    impl FakeRecv {
553        fn chunk(&mut self, buf: Bytes) -> &mut Self {
554            self.chunks.push_back(buf.into());
555            self
556        }
557    }
558
559    impl RecvStream for FakeRecv {
560        type Buf = Bytes;
561        type Error = FakeError;
562
563        fn poll_data(
564            &mut self,
565            _: &mut Context<'_>,
566        ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
567            Poll::Ready(Ok(self.chunks.pop_front()))
568        }
569
570        fn stop_sending(&mut self, _: u64) {
571            unimplemented!()
572        }
573    }
574
575    #[derive(Debug)]
576    struct FakeError;
577
578    impl quic::Error for FakeError {
579        fn is_timeout(&self) -> bool {
580            unimplemented!()
581        }
582
583        fn err_code(&self) -> Option<u64> {
584            unimplemented!()
585        }
586    }
587
588    impl std::error::Error for FakeError {}
589    impl fmt::Display for FakeError {
590        fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result {
591            unimplemented!()
592        }
593    }
594
595    impl Into<Arc<dyn quic::Error>> for FakeError {
596        fn into(self) -> Arc<dyn quic::Error> {
597            unimplemented!()
598        }
599    }
600
601    fn to_bytes(x: Poll<Result<Option<impl Buf>, Error>>) -> Poll<Result<Option<Bytes>, Error>> {
602        x.map(|b| b.map(|b| b.map(|mut b| b.copy_to_bytes(b.remaining()))))
603    }
604}