1use std::marker::PhantomData;
2use std::task::{Context, Poll};
3
4use bytes::{Buf, Bytes};
5
6use futures_util::ready;
7use tracing::trace;
8
9use crate::{
10 buf::BufList,
11 error::TransportError,
12 proto::{
13 frame::{self, Frame, PayloadLen},
14 stream::StreamId,
15 },
16 quic::{BidiStream, RecvStream, SendStream},
17 stream::WriteBuf,
18};
19
20pub struct FrameStream<S, B> {
21 stream: S,
22 bufs: BufList<Bytes>,
23 decoder: FrameDecoder,
24 remaining_data: usize,
25 is_eos: bool,
27 _phantom_buffer: PhantomData<B>,
28}
29
30impl<S, B> FrameStream<S, B> {
31 pub fn new(stream: S) -> Self {
32 Self::with_bufs(stream, BufList::new())
33 }
34
35 pub(crate) fn with_bufs(stream: S, bufs: BufList<Bytes>) -> Self {
36 Self {
37 stream,
38 bufs,
39 decoder: FrameDecoder::default(),
40 remaining_data: 0,
41 is_eos: false,
42 _phantom_buffer: PhantomData,
43 }
44 }
45}
46
47impl<S, B> FrameStream<S, B>
48where
49 S: RecvStream,
50{
51 pub fn poll_next(
52 &mut self,
53 cx: &mut Context<'_>,
54 ) -> Poll<Result<Option<Frame<PayloadLen>>, Error>> {
55 assert!(
56 self.remaining_data == 0,
57 "There is still data to read, please call poll_data() until it returns None."
58 );
59
60 loop {
61 let end = self.try_recv(cx)?;
62
63 return match self.decoder.decode(&mut self.bufs)? {
64 Some(Frame::Data(PayloadLen(len))) => {
65 self.remaining_data = len;
66 Poll::Ready(Ok(Some(Frame::Data(PayloadLen(len)))))
67 }
68 Some(frame) => Poll::Ready(Ok(Some(frame))),
69 None => match end {
70 Poll::Ready(false) => continue,
72 Poll::Pending => Poll::Pending,
73 Poll::Ready(true) => {
74 if self.bufs.has_remaining() {
75 Poll::Ready(Err(Error::UnexpectedEnd))
78 } else {
79 Poll::Ready(Ok(None))
80 }
81 }
82 },
83 };
84 }
85 }
86
87 pub fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<impl Buf>, Error>> {
88 if self.remaining_data == 0 {
89 return Poll::Ready(Ok(None));
90 };
91
92 let end = ready!(self.try_recv(cx))?;
93 let data = self.bufs.take_chunk(self.remaining_data as usize);
94
95 match (data, end) {
96 (None, true) => Poll::Ready(Ok(None)),
97 (None, false) => Poll::Pending,
98 (Some(d), true)
99 if d.remaining() < self.remaining_data && !self.bufs.has_remaining() =>
100 {
101 Poll::Ready(Err(Error::UnexpectedEnd))
102 }
103 (Some(d), _) => {
104 self.remaining_data -= d.remaining();
105 Poll::Ready(Ok(Some(d)))
106 }
107 }
108 }
109
110 pub(crate) fn stop_sending(&mut self, error_code: crate::error::Code) {
111 let _ = self.stream.stop_sending(error_code.into());
112 }
113
114 pub(crate) fn has_data(&self) -> bool {
115 self.remaining_data != 0
116 }
117
118 pub(crate) fn is_eos(&self) -> bool {
119 self.is_eos && !self.bufs.has_remaining()
120 }
121
122 fn try_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<bool, Error>> {
123 if self.is_eos {
124 return Poll::Ready(Ok(true));
125 }
126 match self.stream.poll_data(cx) {
127 Poll::Ready(Err(e)) => Poll::Ready(Err(Error::Quic(e.into()))),
128 Poll::Pending => Poll::Pending,
129 Poll::Ready(Ok(None)) => {
130 self.is_eos = true;
131 Poll::Ready(Ok(true))
132 }
133 Poll::Ready(Ok(Some(mut d))) => {
134 self.bufs.push_bytes(&mut d);
135 Poll::Ready(Ok(false))
136 }
137 }
138 }
139}
140
141impl<T, B> SendStream<B> for FrameStream<T, B>
142where
143 T: SendStream<B>,
144 B: Buf,
145{
146 type Error = <T as SendStream<B>>::Error;
147
148 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
149 self.stream.poll_ready(cx)
150 }
151
152 fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
153 self.stream.send_data(data)
154 }
155
156 fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
157 self.stream.poll_finish(cx)
158 }
159
160 fn reset(&mut self, reset_code: u64) {
161 self.stream.reset(reset_code)
162 }
163
164 fn id(&self) -> StreamId {
165 self.stream.id()
166 }
167}
168
169impl<S, B> FrameStream<S, B>
170where
171 S: BidiStream<B>,
172 B: Buf,
173{
174 pub(crate) fn split(self) -> (FrameStream<S::SendStream, B>, FrameStream<S::RecvStream, B>) {
175 let (send, recv) = self.stream.split();
176 (
177 FrameStream {
178 stream: send,
179 bufs: BufList::new(),
180 decoder: FrameDecoder::default(),
181 remaining_data: 0,
182 is_eos: false,
183 _phantom_buffer: PhantomData,
184 },
185 FrameStream {
186 stream: recv,
187 bufs: self.bufs,
188 decoder: self.decoder,
189 remaining_data: self.remaining_data,
190 is_eos: self.is_eos,
191 _phantom_buffer: PhantomData,
192 },
193 )
194 }
195}
196
197#[derive(Default)]
198pub struct FrameDecoder {
199 expected: Option<usize>,
200}
201
202impl FrameDecoder {
203 fn decode<B: Buf>(&mut self, src: &mut BufList<B>) -> Result<Option<Frame<PayloadLen>>, Error> {
204 loop {
207 if !src.has_remaining() {
208 return Ok(None);
209 }
210
211 if let Some(min) = self.expected {
212 if src.remaining() < min {
213 return Ok(None);
214 }
215 }
216
217 let (pos, decoded) = {
218 let mut cur = src.cursor();
219 let decoded = Frame::decode(&mut cur);
220 (cur.position() as usize, decoded)
221 };
222
223 match decoded {
224 Err(frame::Error::UnknownFrame(ty)) => {
225 trace!("ignore unknown frame type {:#x}", ty);
226 src.advance(pos);
227 self.expected = None;
228 continue;
229 }
230 Err(frame::Error::Incomplete(min)) => {
231 self.expected = Some(min);
232 return Ok(None);
233 }
234 Err(e) => return Err(e.into()),
235 Ok(frame) => {
236 src.advance(pos);
237 self.expected = None;
238 return Ok(Some(frame));
239 }
240 }
241 }
242 }
243}
244
245#[derive(Debug)]
246pub enum Error {
247 Proto(frame::Error),
248 Quic(TransportError),
249 UnexpectedEnd,
250}
251
252impl From<frame::Error> for Error {
253 fn from(err: frame::Error) -> Self {
254 Error::Proto(err)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 use assert_matches::assert_matches;
263 use bytes::{BufMut, BytesMut};
264 use futures_util::future::poll_fn;
265 use std::{collections::VecDeque, fmt, sync::Arc};
266 use tokio;
267
268 use crate::{
269 proto::{coding::Encode, frame::FrameType, varint::VarInt},
270 quic,
271 };
272
273 #[test]
276 fn one_frame() {
277 let mut buf = BytesMut::with_capacity(16);
278 Frame::headers(&b"salut"[..]).encode_with_payload(&mut buf);
279 let mut buf = BufList::from(buf);
280
281 let mut decoder = FrameDecoder::default();
282 assert_matches!(decoder.decode(&mut buf), Ok(Some(Frame::Headers(_))));
283 }
284
285 #[test]
286 fn incomplete_frame() {
287 let frame = Frame::headers(&b"salut"[..]);
288
289 let mut buf = BytesMut::with_capacity(16);
290 frame.encode(&mut buf);
291 buf.truncate(buf.len() - 1);
292 let mut buf = BufList::from(buf);
293
294 let mut decoder = FrameDecoder::default();
295 assert_matches!(decoder.decode(&mut buf), Ok(None));
296 }
297
298 #[test]
299 fn header_spread_multiple_buf() {
300 let mut buf = BytesMut::with_capacity(16);
301 Frame::headers(&b"salut"[..]).encode_with_payload(&mut buf);
302 let mut buf_list = BufList::new();
303 buf_list.push(&buf[..1]);
305 buf_list.push(&buf[1..]);
306
307 let mut decoder = FrameDecoder::default();
308 assert_matches!(decoder.decode(&mut buf_list), Ok(Some(Frame::Headers(_))));
309 }
310
311 #[test]
312 fn varint_spread_multiple_buf() {
313 let mut buf = BytesMut::with_capacity(16);
314 Frame::headers("salut".repeat(1024)).encode_with_payload(&mut buf);
315
316 let mut buf_list = BufList::new();
317 buf_list.push(&buf[..2]);
319 buf_list.push(&buf[2..]);
320
321 let mut decoder = FrameDecoder::default();
322 assert_matches!(decoder.decode(&mut buf_list), Ok(Some(Frame::Headers(_))));
323 }
324
325 #[test]
326 fn two_frames_then_incomplete() {
327 let mut buf = BytesMut::with_capacity(64);
328 Frame::headers(&b"header"[..]).encode_with_payload(&mut buf);
329 Frame::Data(&b"body"[..]).encode_with_payload(&mut buf);
330 Frame::headers(&b"trailer"[..]).encode_with_payload(&mut buf);
331
332 buf.truncate(buf.len() - 1);
333 let mut buf = BufList::from(buf);
334
335 let mut decoder = FrameDecoder::default();
336 assert_matches!(decoder.decode(&mut buf), Ok(Some(Frame::Headers(_))));
337 assert_matches!(
338 decoder.decode(&mut buf),
339 Ok(Some(Frame::Data(PayloadLen(4))))
340 );
341 assert_matches!(decoder.decode(&mut buf), Ok(None));
342 }
343
344 macro_rules! assert_poll_matches {
347 ($poll_fn:expr, $match:pat) => {
348 assert_matches!(
349 poll_fn($poll_fn).await,
350 $match
351 );
352 };
353 ($poll_fn:expr, $match:pat if $cond:expr ) => {
354 assert_matches!(
355 poll_fn($poll_fn).await,
356 $match if $cond
357 );
358 }
359 }
360
361 #[tokio::test]
362 async fn poll_full_request() {
363 let mut recv = FakeRecv::default();
364 let mut buf = BytesMut::with_capacity(64);
365
366 Frame::headers(&b"header"[..]).encode_with_payload(&mut buf);
367 Frame::Data(&b"body"[..]).encode_with_payload(&mut buf);
368 Frame::headers(&b"trailer"[..]).encode_with_payload(&mut buf);
369 recv.chunk(buf.freeze());
370
371 let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
372
373 assert_poll_matches!(
374 |mut cx| stream.poll_next(&mut cx),
375 Ok(Some(Frame::Headers(_)))
376 );
377 assert_poll_matches!(
378 |mut cx| stream.poll_next(&mut cx),
379 Ok(Some(Frame::Data(PayloadLen(4))))
380 );
381 assert_poll_matches!(
382 |mut cx| to_bytes(stream.poll_data(&mut cx)),
383 Ok(Some(b)) if b.remaining() == 4
384 );
385 assert_poll_matches!(
386 |mut cx| stream.poll_next(&mut cx),
387 Ok(Some(Frame::Headers(_)))
388 );
389 }
390
391 #[tokio::test]
392 async fn poll_next_incomplete_frame() {
393 let mut recv = FakeRecv::default();
394 let mut buf = BytesMut::with_capacity(64);
395
396 Frame::headers(&b"header"[..]).encode_with_payload(&mut buf);
397 let mut buf = buf.freeze();
398 recv.chunk(buf.split_to(buf.len() - 1));
399 let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
400
401 assert_poll_matches!(
402 |mut cx| stream.poll_next(&mut cx),
403 Err(Error::UnexpectedEnd)
404 );
405 }
406
407 #[tokio::test]
408 #[should_panic(
409 expected = "There is still data to read, please call poll_data() until it returns None"
410 )]
411 async fn poll_next_reamining_data() {
412 let mut recv = FakeRecv::default();
413 let mut buf = BytesMut::with_capacity(64);
414
415 FrameType::DATA.encode(&mut buf);
416 VarInt::from(4u32).encode(&mut buf);
417 recv.chunk(buf.freeze());
418 let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
419
420 assert_poll_matches!(
421 |mut cx| stream.poll_next(&mut cx),
422 Ok(Some(Frame::Data(PayloadLen(4))))
423 );
424
425 let _ = poll_fn(|mut cx| stream.poll_next(&mut cx)).await;
427 }
428
429 #[tokio::test]
430 async fn poll_data_split() {
431 let mut recv = FakeRecv::default();
432 let mut buf = BytesMut::with_capacity(64);
433
434 Frame::Data(Bytes::from("body")).encode_with_payload(&mut buf);
436
437 let mut buf = buf.freeze();
438 recv.chunk(buf.split_to(buf.len() - 2));
439 recv.chunk(buf);
440 let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
441
442 assert_poll_matches!(
444 |mut cx| stream.poll_next(&mut cx),
445 Ok(Some(Frame::Data(PayloadLen(4))))
446 );
447
448 assert_poll_matches!(
450 |mut cx| to_bytes(stream.poll_data(&mut cx)),
451 Ok(Some(b)) if b.remaining() == 2
452 );
453 assert_poll_matches!(
454 |mut cx| to_bytes(stream.poll_data(&mut cx)),
455 Ok(Some(b)) if b.remaining() == 2
456 );
457 }
458
459 #[tokio::test]
460 async fn poll_data_unexpected_end() {
461 let mut recv = FakeRecv::default();
462 let mut buf = BytesMut::with_capacity(64);
463
464 FrameType::DATA.encode(&mut buf);
466 VarInt::from(4u32).encode(&mut buf);
467 buf.put_slice(&b"b"[..]);
468 recv.chunk(buf.freeze());
469 let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
470
471 assert_poll_matches!(
472 |mut cx| stream.poll_next(&mut cx),
473 Ok(Some(Frame::Data(PayloadLen(4))))
474 );
475 assert_poll_matches!(
476 |mut cx| to_bytes(stream.poll_data(&mut cx)),
477 Err(Error::UnexpectedEnd)
478 );
479 }
480
481 #[tokio::test]
482 async fn poll_data_ignores_unknown_frames() {
483 use crate::proto::varint::BufMutExt as _;
484
485 let mut recv = FakeRecv::default();
486 let mut buf = BytesMut::with_capacity(64);
487
488 crate::proto::frame::FrameType::grease().encode(&mut buf);
490 buf.write_var(0);
491
492 crate::proto::frame::FrameType::grease().encode(&mut buf);
494 buf.write_var(6);
495 buf.put_slice(b"grease");
496
497 Frame::Data(Bytes::from("body")).encode_with_payload(&mut buf);
499
500 recv.chunk(buf.freeze());
501 let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
502
503 assert_poll_matches!(
504 |mut cx| stream.poll_next(&mut cx),
505 Ok(Some(Frame::Data(PayloadLen(4))))
506 );
507 assert_poll_matches!(
508 |mut cx| to_bytes(stream.poll_data(&mut cx)),
509 Ok(Some(b)) if &*b == b"body"
510 );
511 }
512
513 #[tokio::test]
514 async fn poll_data_eos_but_buffered_data() {
515 let mut recv = FakeRecv::default();
516 let mut buf = BytesMut::with_capacity(64);
517
518 FrameType::DATA.encode(&mut buf);
519 VarInt::from(4u32).encode(&mut buf);
520 buf.put_slice(&b"bo"[..]);
521 recv.chunk(buf.clone().freeze());
522
523 let mut stream: FrameStream<_, ()> = FrameStream::new(recv);
524
525 assert_poll_matches!(
526 |mut cx| stream.poll_next(&mut cx),
527 Ok(Some(Frame::Data(PayloadLen(4))))
528 );
529
530 buf.truncate(0);
531 buf.put_slice(&b"dy"[..]);
532 stream.bufs.push_bytes(&mut buf.freeze());
533
534 assert_poll_matches!(
535 |mut cx| to_bytes(stream.poll_data(&mut cx)),
536 Ok(Some(b)) if &*b == b"bo"
537 );
538
539 assert_poll_matches!(
540 |mut cx| to_bytes(stream.poll_data(&mut cx)),
541 Ok(Some(b)) if &*b == b"dy"
542 );
543 }
544
545 #[derive(Default)]
548 struct FakeRecv {
549 chunks: VecDeque<Bytes>,
550 }
551
552 impl FakeRecv {
553 fn chunk(&mut self, buf: Bytes) -> &mut Self {
554 self.chunks.push_back(buf.into());
555 self
556 }
557 }
558
559 impl RecvStream for FakeRecv {
560 type Buf = Bytes;
561 type Error = FakeError;
562
563 fn poll_data(
564 &mut self,
565 _: &mut Context<'_>,
566 ) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
567 Poll::Ready(Ok(self.chunks.pop_front()))
568 }
569
570 fn stop_sending(&mut self, _: u64) {
571 unimplemented!()
572 }
573 }
574
575 #[derive(Debug)]
576 struct FakeError;
577
578 impl quic::Error for FakeError {
579 fn is_timeout(&self) -> bool {
580 unimplemented!()
581 }
582
583 fn err_code(&self) -> Option<u64> {
584 unimplemented!()
585 }
586 }
587
588 impl std::error::Error for FakeError {}
589 impl fmt::Display for FakeError {
590 fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result {
591 unimplemented!()
592 }
593 }
594
595 impl Into<Arc<dyn quic::Error>> for FakeError {
596 fn into(self) -> Arc<dyn quic::Error> {
597 unimplemented!()
598 }
599 }
600
601 fn to_bytes(x: Poll<Result<Option<impl Buf>, Error>>) -> Poll<Result<Option<Bytes>, Error>> {
602 x.map(|b| b.map(|b| b.map(|mut b| b.copy_to_bytes(b.remaining()))))
603 }
604}