httproxide_h3/
stream.rs

1use std::task::{Context, Poll};
2
3use bytes::{Buf, BufMut as _, Bytes};
4use futures_util::{future, ready};
5use quic::RecvStream;
6
7use crate::{
8    buf::BufList,
9    error::Code,
10    frame::FrameStream,
11    proto::{
12        coding::{BufExt, Decode as _, Encode},
13        frame::Frame,
14        stream::StreamType,
15        varint::VarInt,
16    },
17    quic::{self, SendStream},
18    Error,
19};
20
21#[inline]
22pub(crate) async fn write<S, D, B>(stream: &mut S, data: D) -> Result<(), Error>
23where
24    S: SendStream<B>,
25    D: Into<WriteBuf<B>>,
26    B: Buf,
27{
28    stream.send_data(data)?;
29    future::poll_fn(|cx| stream.poll_ready(cx)).await?;
30
31    Ok(())
32}
33
34const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_ENCODED_SIZE;
35
36/// Wrap frames to encode their header on the stack before sending them on the wire
37///
38/// Implements `Buf` so wire data is seamlessly available for transport layer transmits:
39/// `Buf::chunk()` will yield the encoded header, then the payload. For unidirectional streams,
40/// this type makes it possible to prefix wire data with the `StreamType`.
41///
42/// Conveying frames as `Into<WriteBuf>` makes it possible to encode only when generating wire-format
43/// data is necessary (say, in `quic::SendStream::send_data`). It also has a public API ergonomy
44/// advantage: `WriteBuf` doesn't have to appear in public associated types. On the other hand,
45/// QUIC implementers have to call `into()`, which will encode the header in `Self::buf`.
46pub struct WriteBuf<B>
47where
48    B: Buf,
49{
50    buf: [u8; WRITE_BUF_ENCODE_SIZE],
51    len: usize,
52    pos: usize,
53    frame: Option<Frame<B>>,
54}
55
56impl<B> WriteBuf<B>
57where
58    B: Buf,
59{
60    fn encode_stream_type(&mut self, ty: StreamType) {
61        let mut buf_mut = &mut self.buf[self.len..];
62        ty.encode(&mut buf_mut);
63        self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
64    }
65
66    fn encode_frame_header(&mut self) {
67        if let Some(frame) = self.frame.as_ref() {
68            let mut buf_mut = &mut self.buf[self.len..];
69            frame.encode(&mut buf_mut);
70            self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
71        }
72    }
73}
74
75impl<B> From<StreamType> for WriteBuf<B>
76where
77    B: Buf,
78{
79    fn from(ty: StreamType) -> Self {
80        let mut me = Self {
81            buf: [0; WRITE_BUF_ENCODE_SIZE],
82            len: 0,
83            pos: 0,
84            frame: None,
85        };
86        me.encode_stream_type(ty);
87        me
88    }
89}
90
91impl<B> From<Frame<B>> for WriteBuf<B>
92where
93    B: Buf,
94{
95    fn from(frame: Frame<B>) -> Self {
96        let mut me = Self {
97            buf: [0; WRITE_BUF_ENCODE_SIZE],
98            len: 0,
99            pos: 0,
100            frame: Some(frame),
101        };
102        me.encode_frame_header();
103        me
104    }
105}
106
107impl<B> From<(StreamType, Frame<B>)> for WriteBuf<B>
108where
109    B: Buf,
110{
111    fn from(ty_stream: (StreamType, Frame<B>)) -> Self {
112        let (ty, frame) = ty_stream;
113        let mut me = Self {
114            buf: [0; WRITE_BUF_ENCODE_SIZE],
115            len: 0,
116            pos: 0,
117            frame: Some(frame),
118        };
119        me.encode_stream_type(ty);
120        me.encode_frame_header();
121        me
122    }
123}
124
125impl<B> Buf for WriteBuf<B>
126where
127    B: Buf,
128{
129    fn remaining(&self) -> usize {
130        self.len - self.pos
131            + self
132                .frame
133                .as_ref()
134                .and_then(|f| f.payload())
135                .map_or(0, |x| x.remaining())
136    }
137
138    fn chunk(&self) -> &[u8] {
139        if self.len - self.pos > 0 {
140            &self.buf[self.pos..self.len]
141        } else if let Some(payload) = self.frame.as_ref().and_then(|f| f.payload()) {
142            payload.chunk()
143        } else {
144            &[]
145        }
146    }
147
148    fn advance(&mut self, mut cnt: usize) {
149        let remaining_header = self.len - self.pos;
150        if remaining_header > 0 {
151            let advanced = usize::min(cnt, remaining_header);
152            self.pos += advanced;
153            cnt -= advanced;
154        }
155
156        if let Some(payload) = self.frame.as_mut().and_then(|f| f.payload_mut()) {
157            payload.advance(cnt);
158        }
159    }
160}
161
162pub(super) enum AcceptedRecvStream<S, B>
163where
164    S: quic::RecvStream,
165{
166    Control(FrameStream<S, B>),
167    Push(u64, FrameStream<S, B>),
168    Encoder(S),
169    Decoder(S),
170    Reserved,
171}
172
173pub(super) struct AcceptRecvStream<S>
174where
175    S: quic::RecvStream,
176{
177    stream: S,
178    ty: Option<StreamType>,
179    push_id: Option<u64>,
180    buf: BufList<Bytes>,
181    expected: Option<usize>,
182}
183
184impl<S> AcceptRecvStream<S>
185where
186    S: RecvStream,
187{
188    pub fn new(stream: S) -> Self {
189        Self {
190            stream,
191            ty: None,
192            push_id: None,
193            buf: BufList::new(),
194            expected: None,
195        }
196    }
197
198    pub fn into_stream<B>(self) -> Result<AcceptedRecvStream<S, B>, Error> {
199        Ok(match self.ty.expect("Stream type not resolved yet") {
200            StreamType::CONTROL => {
201                AcceptedRecvStream::Control(FrameStream::with_bufs(self.stream, self.buf))
202            }
203            StreamType::PUSH => AcceptedRecvStream::Push(
204                self.push_id.expect("Push ID not resolved yet"),
205                FrameStream::with_bufs(self.stream, self.buf),
206            ),
207            StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream),
208            StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream),
209            t if t.value() > 0x21 && (t.value() - 0x21) % 0x1f == 0 => AcceptedRecvStream::Reserved,
210            t => {
211                return Err(Code::H3_STREAM_CREATION_ERROR
212                    .with_reason(format!("unknown stream type 0x{:x}", t.value())))
213            }
214        })
215    }
216
217    pub fn poll_type(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
218        loop {
219            match (self.ty.as_ref(), self.push_id) {
220                // When accepting a Push stream, we want to parse two VarInts: [StreamType, PUSH_ID]
221                (Some(&StreamType::PUSH), Some(_)) | (Some(_), _) => return Poll::Ready(Ok(())),
222                _ => (),
223            }
224
225            match ready!(self.stream.poll_data(cx))? {
226                Some(mut b) => self.buf.push_bytes(&mut b),
227                None => {
228                    return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR
229                        .with_reason("Stream closed before type received")))
230                }
231            };
232
233            if self.expected.is_none() && self.buf.remaining() >= 1 {
234                self.expected = Some(VarInt::encoded_size(self.buf.chunk()[0]));
235            }
236
237            if let Some(expected) = self.expected {
238                if self.buf.remaining() < expected {
239                    continue;
240                }
241            } else {
242                continue;
243            }
244
245            if self.ty.is_none() {
246                // Parse StreamType
247                self.ty = Some(StreamType::decode(&mut self.buf).map_err(|_| {
248                    Code::H3_INTERNAL_ERROR.with_reason("Unexpected end parsing stream type")
249                })?);
250                // Get the next VarInt for PUSH_ID on the next iteration
251                self.expected = None;
252            } else {
253                // Parse PUSH_ID
254                self.push_id = Some(self.buf.get_var().map_err(|_| {
255                    Code::H3_INTERNAL_ERROR.with_reason("Unexpected end parsing stream type")
256                })?);
257            }
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::proto::stream::StreamId;
266
267    #[test]
268    fn write_buf_encode_streamtype() {
269        let wbuf = WriteBuf::<Bytes>::from(StreamType::ENCODER);
270
271        assert_eq!(wbuf.chunk(), b"\x02");
272        assert_eq!(wbuf.len, 1);
273    }
274
275    #[test]
276    fn write_buf_encode_frame() {
277        let wbuf = WriteBuf::<Bytes>::from(Frame::Goaway(StreamId(2)));
278
279        assert_eq!(wbuf.chunk(), b"\x07\x01\x02");
280        assert_eq!(wbuf.len, 3);
281    }
282
283    #[test]
284    fn write_buf_encode_streamtype_then_frame() {
285        let wbuf = WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Goaway(StreamId(2))));
286
287        assert_eq!(wbuf.chunk(), b"\x02\x07\x01\x02");
288    }
289
290    #[test]
291    fn write_buf_advances() {
292        let mut wbuf =
293            WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
294
295        assert_eq!(wbuf.chunk(), b"\x02\x00\x03");
296        wbuf.advance(3);
297        assert_eq!(wbuf.remaining(), 3);
298        assert_eq!(wbuf.chunk(), b"hey");
299        wbuf.advance(2);
300        assert_eq!(wbuf.chunk(), b"y");
301        wbuf.advance(1);
302        assert_eq!(wbuf.remaining(), 0);
303    }
304
305    #[test]
306    fn write_buf_advance_jumps_header_and_payload_start() {
307        let mut wbuf =
308            WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
309
310        wbuf.advance(4);
311        assert_eq!(wbuf.chunk(), b"ey");
312    }
313}