use std::task::{Context, Poll};
use bytes::{Buf, BufMut as _, Bytes};
use futures_util::{future, ready};
use quic::RecvStream;
use crate::{
buf::BufList,
error::{Code, ErrorLevel},
frame::FrameStream,
proto::{
coding::{BufExt, Decode as _, Encode},
frame::Frame,
stream::StreamType,
varint::VarInt,
},
quic::{self, SendStream},
Error,
};
#[inline]
pub(crate) async fn write<S, D, B>(stream: &mut S, data: D) -> Result<(), Error>
where
S: SendStream<B>,
D: Into<WriteBuf<B>>,
B: Buf,
{
stream.send_data(data)?;
future::poll_fn(|cx| stream.poll_ready(cx)).await?;
Ok(())
}
const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_ENCODED_SIZE;
pub struct WriteBuf<B>
where
B: Buf,
{
buf: [u8; WRITE_BUF_ENCODE_SIZE],
len: usize,
pos: usize,
frame: Option<Frame<B>>,
}
impl<B> WriteBuf<B>
where
B: Buf,
{
fn encode_stream_type(&mut self, ty: StreamType) {
let mut buf_mut = &mut self.buf[self.len..];
ty.encode(&mut buf_mut);
self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
}
fn encode_frame_header(&mut self) {
if let Some(frame) = self.frame.as_ref() {
let mut buf_mut = &mut self.buf[self.len..];
frame.encode(&mut buf_mut);
self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut();
}
}
}
impl<B> From<StreamType> for WriteBuf<B>
where
B: Buf,
{
fn from(ty: StreamType) -> Self {
let mut me = Self {
buf: [0; WRITE_BUF_ENCODE_SIZE],
len: 0,
pos: 0,
frame: None,
};
me.encode_stream_type(ty);
me
}
}
impl<B> From<Frame<B>> for WriteBuf<B>
where
B: Buf,
{
fn from(frame: Frame<B>) -> Self {
let mut me = Self {
buf: [0; WRITE_BUF_ENCODE_SIZE],
len: 0,
pos: 0,
frame: Some(frame),
};
me.encode_frame_header();
me
}
}
impl<B> From<(StreamType, Frame<B>)> for WriteBuf<B>
where
B: Buf,
{
fn from(ty_stream: (StreamType, Frame<B>)) -> Self {
let (ty, frame) = ty_stream;
let mut me = Self {
buf: [0; WRITE_BUF_ENCODE_SIZE],
len: 0,
pos: 0,
frame: Some(frame),
};
me.encode_stream_type(ty);
me.encode_frame_header();
me
}
}
impl<B> Buf for WriteBuf<B>
where
B: Buf,
{
fn remaining(&self) -> usize {
self.len - self.pos
+ self
.frame
.as_ref()
.and_then(|f| f.payload())
.map_or(0, |x| x.remaining())
}
fn chunk(&self) -> &[u8] {
if self.len - self.pos > 0 {
&self.buf[self.pos..self.len]
} else if let Some(payload) = self.frame.as_ref().and_then(|f| f.payload()) {
payload.chunk()
} else {
&[]
}
}
fn advance(&mut self, mut cnt: usize) {
let remaining_header = self.len - self.pos;
if remaining_header > 0 {
let advanced = usize::min(cnt, remaining_header);
self.pos += advanced;
cnt -= advanced;
}
if let Some(payload) = self.frame.as_mut().and_then(|f| f.payload_mut()) {
payload.advance(cnt);
}
}
}
pub(super) enum AcceptedRecvStream<S, B>
where
S: quic::RecvStream,
{
Control(FrameStream<S, B>),
Push(u64, FrameStream<S, B>),
Encoder(S),
Decoder(S),
Reserved,
}
pub(super) struct AcceptRecvStream<S>
where
S: quic::RecvStream,
{
stream: S,
ty: Option<StreamType>,
push_id: Option<u64>,
buf: BufList<Bytes>,
expected: Option<usize>,
}
impl<S> AcceptRecvStream<S>
where
S: RecvStream,
{
pub fn new(stream: S) -> Self {
Self {
stream,
ty: None,
push_id: None,
buf: BufList::new(),
expected: None,
}
}
pub fn into_stream<B>(self) -> Result<AcceptedRecvStream<S, B>, Error> {
Ok(match self.ty.expect("Stream type not resolved yet") {
StreamType::CONTROL => {
AcceptedRecvStream::Control(FrameStream::with_bufs(self.stream, self.buf))
}
StreamType::PUSH => AcceptedRecvStream::Push(
self.push_id.expect("Push ID not resolved yet"),
FrameStream::with_bufs(self.stream, self.buf),
),
StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream),
StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream),
t if t.value() > 0x21 && (t.value() - 0x21) % 0x1f == 0 => AcceptedRecvStream::Reserved,
t => {
return Err(Code::H3_STREAM_CREATION_ERROR.with_reason(
format!("unknown stream type 0x{:x}", t.value()),
crate::error::ErrorLevel::ConnectionError,
))
}
})
}
pub fn poll_type(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
loop {
match (self.ty.as_ref(), self.push_id) {
(Some(&StreamType::PUSH), Some(_)) | (Some(_), _) => return Poll::Ready(Ok(())),
_ => (),
}
match ready!(self.stream.poll_data(cx))? {
Some(mut b) => self.buf.push_bytes(&mut b),
None => {
return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR.with_reason(
"Stream closed before type received",
ErrorLevel::ConnectionError,
)));
}
};
if self.expected.is_none() && self.buf.remaining() >= 1 {
self.expected = Some(VarInt::encoded_size(self.buf.chunk()[0]));
}
if let Some(expected) = self.expected {
if self.buf.remaining() < expected {
continue;
}
} else {
continue;
}
if self.ty.is_none() {
self.ty = Some(StreamType::decode(&mut self.buf).map_err(|_| {
Code::H3_INTERNAL_ERROR.with_reason(
"Unexpected end parsing stream type",
ErrorLevel::ConnectionError,
)
})?);
self.expected = None;
} else {
self.push_id = Some(self.buf.get_var().map_err(|_| {
Code::H3_INTERNAL_ERROR.with_reason(
"Unexpected end parsing stream type",
ErrorLevel::ConnectionError,
)
})?);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proto::stream::StreamId;
#[test]
fn write_buf_encode_streamtype() {
let wbuf = WriteBuf::<Bytes>::from(StreamType::ENCODER);
assert_eq!(wbuf.chunk(), b"\x02");
assert_eq!(wbuf.len, 1);
}
#[test]
fn write_buf_encode_frame() {
let wbuf = WriteBuf::<Bytes>::from(Frame::Goaway(StreamId(2)));
assert_eq!(wbuf.chunk(), b"\x07\x01\x02");
assert_eq!(wbuf.len, 3);
}
#[test]
fn write_buf_encode_streamtype_then_frame() {
let wbuf = WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Goaway(StreamId(2))));
assert_eq!(wbuf.chunk(), b"\x02\x07\x01\x02");
}
#[test]
fn write_buf_advances() {
let mut wbuf =
WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
assert_eq!(wbuf.chunk(), b"\x02\x00\x03");
wbuf.advance(3);
assert_eq!(wbuf.remaining(), 3);
assert_eq!(wbuf.chunk(), b"hey");
wbuf.advance(2);
assert_eq!(wbuf.chunk(), b"y");
wbuf.advance(1);
assert_eq!(wbuf.remaining(), 0);
}
#[test]
fn write_buf_advance_jumps_header_and_payload_start() {
let mut wbuf =
WriteBuf::<Bytes>::from((StreamType::ENCODER, Frame::Data(Bytes::from("hey"))));
wbuf.advance(4);
assert_eq!(wbuf.chunk(), b"ey");
}
}