1use std::task::{Context, Poll};
2
3use bytes::{Buf, BufMut as _, Bytes};
4use futures_util::{future, ready};
5use quic::RecvStream;
6
7use crate::{
8 buf::BufList,
9 error::Code,
10 frame::FrameStream,
11 proto::{
12 coding::{BufExt, Decode as _, Encode},
13 frame::Frame,
14 stream::StreamType,
15 varint::VarInt,
16 },
17 quic::{self, SendStream},
18 Error,
19};
20
21#[inline]
22pub(crate) async fn write<S, D, B>(stream: &mut S, data: D) -> Result<(), Error>
23where
24 S: SendStream<B>,
25 D: Into<WriteBuf<B>>,
26 B: Buf,
27{
28 stream.send_data(data)?;
29 future::poll_fn(|cx| stream.poll_ready(cx)).await?;
30
31 Ok(())
32}
33
34const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_ENCODED_SIZE;
35
36pub struct WriteBuf<B>
47where
48 B: Buf,
49{
50 buf: [u8; WRITE_BUF_ENCODE_SIZE],
51 len: usize,
52 pos: usize,
53 frame: Option<Frame<B>>,
54}
55
56impl<B> WriteBuf<B>
57where
58 B: Buf,
59{
60 fn encode_stream_type(&mut self, ty: StreamType) {
61 let mut buf_mut = &mut self.buf[self.len..];
62 ty.encode(&mut buf_mut);
63 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
64 }
65
66 fn encode_frame_header(&mut self) {
67 if let Some(frame) = self.frame.as_ref() {
68 let mut buf_mut = &mut self.buf[self.len..];
69 frame.encode(&mut buf_mut);
70 self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
71 }
72 }
73}
74
75impl<B> From<StreamType> for WriteBuf<B>
76where
77 B: Buf,
78{
79 fn from(ty: StreamType) -> Self {
80 let mut me = Self {
81 buf: [0; WRITE_BUF_ENCODE_SIZE],
82 len: 0,
83 pos: 0,
84 frame: None,
85 };
86 me.encode_stream_type(ty);
87 me
88 }
89}
90
91impl<B> From<Frame<B>> for WriteBuf<B>
92where
93 B: Buf,
94{
95 fn from(frame: Frame<B>) -> Self {
96 let mut me = Self {
97 buf: [0; WRITE_BUF_ENCODE_SIZE],
98 len: 0,
99 pos: 0,
100 frame: Some(frame),
101 };
102 me.encode_frame_header();
103 me
104 }
105}
106
107impl<B> From<(StreamType, Frame<B>)> for WriteBuf<B>
108where
109 B: Buf,
110{
111 fn from(ty_stream: (StreamType, Frame<B>)) -> Self {
112 let (ty, frame) = ty_stream;
113 let mut me = Self {
114 buf: [0; WRITE_BUF_ENCODE_SIZE],
115 len: 0,
116 pos: 0,
117 frame: Some(frame),
118 };
119 me.encode_stream_type(ty);
120 me.encode_frame_header();
121 me
122 }
123}
124
125impl<B> Buf for WriteBuf<B>
126where
127 B: Buf,
128{
129 fn remaining(&self) -> usize {
130 self.len - self.pos
131 + self
132 .frame
133 .as_ref()
134 .and_then(|f| f.payload())
135 .map_or(0, |x| x.remaining())
136 }
137
138 fn chunk(&self) -> &[u8] {
139 if self.len - self.pos > 0 {
140 &self.buf[self.pos..self.len]
141 } else if let Some(payload) = self.frame.as_ref().and_then(|f| f.payload()) {
142 payload.chunk()
143 } else {
144 &[]
145 }
146 }
147
148 fn advance(&mut self, mut cnt: usize) {
149 let remaining_header = self.len - self.pos;
150 if remaining_header > 0 {
151 let advanced = usize::min(cnt, remaining_header);
152 self.pos += advanced;
153 cnt -= advanced;
154 }
155
156 if let Some(payload) = self.frame.as_mut().and_then(|f| f.payload_mut()) {
157 payload.advance(cnt);
158 }
159 }
160}
161
162pub(super) enum AcceptedRecvStream<S, B>
163where
164 S: quic::RecvStream,
165{
166 Control(FrameStream<S, B>),
167 Push(u64, FrameStream<S, B>),
168 Encoder(S),
169 Decoder(S),
170 Reserved,
171}
172
173pub(super) struct AcceptRecvStream<S>
174where
175 S: quic::RecvStream,
176{
177 stream: S,
178 ty: Option<StreamType>,
179 push_id: Option<u64>,
180 buf: BufList<Bytes>,
181 expected: Option<usize>,
182}
183
184impl<S> AcceptRecvStream<S>
185where
186 S: RecvStream,
187{
188 pub fn new(stream: S) -> Self {
189 Self {
190 stream,
191 ty: None,
192 push_id: None,
193 buf: BufList::new(),
194 expected: None,
195 }
196 }
197
198 pub fn into_stream<B>(self) -> Result<AcceptedRecvStream<S, B>, Error> {
199 Ok(match self.ty.expect("Stream type not resolved yet") {
200 StreamType::CONTROL => {
201 AcceptedRecvStream::Control(FrameStream::with_bufs(self.stream, self.buf))
202 }
203 StreamType::PUSH => AcceptedRecvStream::Push(
204 self.push_id.expect("Push ID not resolved yet"),
205 FrameStream::with_bufs(self.stream, self.buf),
206 ),
207 StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream),
208 StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream),
209 t if t.value() > 0x21 && (t.value() - 0x21) % 0x1f == 0 => AcceptedRecvStream::Reserved,
210 t => {
211 return Err(Code::H3_STREAM_CREATION_ERROR
212 .with_reason(format!("unknown stream type 0x{:x}", t.value())))
213 }
214 })
215 }
216
217 pub fn poll_type(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
218 loop {
219 match (self.ty.as_ref(), self.push_id) {
220 (Some(&StreamType::PUSH), Some(_)) | (Some(_), _) => return Poll::Ready(Ok(())),
222 _ => (),
223 }
224
225 match ready!(self.stream.poll_data(cx))? {
226 Some(mut b) => self.buf.push_bytes(&mut b),
227 None => {
228 return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR
229 .with_reason("Stream closed before type received")))
230 }
231 };
232
233 if self.expected.is_none() && self.buf.remaining() >= 1 {
234 self.expected = Some(VarInt::encoded_size(self.buf.chunk()[0]));
235 }
236
237 if let Some(expected) = self.expected {
238 if self.buf.remaining() < expected {
239 continue;
240 }
241 } else {
242 continue;
243 }
244
245 if self.ty.is_none() {
246 self.ty = Some(StreamType::decode(&mut self.buf).map_err(|_| {
248 Code::H3_INTERNAL_ERROR.with_reason("Unexpected end parsing stream type")
249 })?);
250 self.expected = None;
252 } else {
253 self.push_id = Some(self.buf.get_var().map_err(|_| {
255 Code::H3_INTERNAL_ERROR.with_reason("Unexpected end parsing stream type")
256 })?);
257 }
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::proto::stream::StreamId;
266
267 #[test]
268 fn write_buf_encode_streamtype() {
269 let wbuf = WriteBuf::<Bytes>::from(StreamType::ENCODER);
270
271 assert_eq!(wbuf.chunk(), b"\x02");
272 assert_eq!(wbuf.len, 1);
273 }
274
275 #[test]
276 fn write_buf_encode_frame() {
277 let wbuf = WriteBuf::<Bytes>::from(Frame::Goaway(StreamId(2)));
278
279 assert_eq!(wbuf.chunk(), b"\x07\x01\x02");
280 assert_eq!(wbuf.len, 3);
281 }
282
283 #[test]
284 fn write_buf_encode_streamtype_then_frame() {
285 let wbuf = WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Goaway(StreamId(2))));
286
287 assert_eq!(wbuf.chunk(), b"\x02\x07\x01\x02");
288 }
289
290 #[test]
291 fn write_buf_advances() {
292 let mut wbuf =
293 WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
294
295 assert_eq!(wbuf.chunk(), b"\x02\x00\x03");
296 wbuf.advance(3);
297 assert_eq!(wbuf.remaining(), 3);
298 assert_eq!(wbuf.chunk(), b"hey");
299 wbuf.advance(2);
300 assert_eq!(wbuf.chunk(), b"y");
301 wbuf.advance(1);
302 assert_eq!(wbuf.remaining(), 0);
303 }
304
305 #[test]
306 fn write_buf_advance_jumps_header_and_payload_start() {
307 let mut wbuf =
308 WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
309
310 wbuf.advance(4);
311 assert_eq!(wbuf.chunk(), b"ey");
312 }
313}