use bytes::{Buf, BufMut, Bytes};
use std::{convert::TryInto, fmt};
use tracing::trace;
use super::{
coding::Encode,
stream::{InvalidStreamId, StreamId},
varint::{BufExt, BufMutExt, UnexpectedEnd, VarInt},
};
#[derive(Debug, PartialEq)]
pub enum FrameError {
Malformed,
UnsupportedFrame(u64), UnknownFrame(u64), InvalidFrameValue,
Incomplete(usize),
Settings(SettingsError),
InvalidStreamId(InvalidStreamId),
}
impl std::error::Error for FrameError {}
impl fmt::Display for FrameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FrameError::Malformed => write!(f, "frame is malformed"),
FrameError::UnsupportedFrame(c) => write!(f, "frame 0x{:x} is not allowed h3", c),
FrameError::UnknownFrame(c) => write!(f, "frame 0x{:x} ignored", c),
FrameError::InvalidFrameValue => write!(f, "frame value is invalid"),
FrameError::Incomplete(x) => write!(f, "internal error: frame incomplete {}", x),
FrameError::Settings(x) => write!(f, "invalid settings: {}", x),
FrameError::InvalidStreamId(x) => write!(f, "invalid stream id: {}", x),
}
}
}
pub enum Frame<B> {
Data(B),
Headers(Bytes),
CancelPush(StreamId),
Settings(Settings),
PushPromise(PushPromise),
Goaway(StreamId),
MaxPushId(StreamId),
Grease,
}
pub struct PayloadLen(pub usize);
impl From<usize> for PayloadLen {
fn from(len: usize) -> Self {
PayloadLen(len)
}
}
impl Frame<PayloadLen> {
pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE * 3;
pub fn decode<T: Buf>(buf: &mut T) -> Result<Self, FrameError> {
let remaining = buf.remaining();
let ty = FrameType::decode(buf).map_err(|_| FrameError::Incomplete(remaining + 1))?;
let len = buf
.get_var()
.map_err(|_| FrameError::Incomplete(remaining + 1))?;
if ty == FrameType::DATA {
return Ok(Frame::Data((len as usize).into()));
}
if buf.remaining() < len as usize {
return Err(FrameError::Incomplete(2 + len as usize));
}
let mut payload = buf.take(len as usize);
let frame = match ty {
FrameType::HEADERS => Ok(Frame::Headers(payload.copy_to_bytes(len as usize))),
FrameType::SETTINGS => Ok(Frame::Settings(Settings::decode(&mut payload)?)),
FrameType::CANCEL_PUSH => Ok(Frame::CancelPush(payload.get_var()?.try_into()?)),
FrameType::PUSH_PROMISE => Ok(Frame::PushPromise(PushPromise::decode(&mut payload)?)),
FrameType::GOAWAY => Ok(Frame::Goaway(payload.get_var()?.try_into()?)),
FrameType::MAX_PUSH_ID => Ok(Frame::MaxPushId(payload.get_var()?.try_into()?)),
FrameType::H2_PRIORITY
| FrameType::H2_PING
| FrameType::H2_WINDOW_UPDATE
| FrameType::H2_CONTINUATION => Err(FrameError::UnsupportedFrame(ty.0)),
_ => {
buf.advance(len as usize);
Err(FrameError::UnknownFrame(ty.0))
}
};
if let Ok(frame) = &frame {
trace!(
"got frame {:?}, len: {}, remaining: {}",
frame,
len,
buf.remaining()
);
}
frame
}
}
impl<B> Encode for Frame<B>
where
B: Buf,
{
fn encode<T: BufMut>(&self, buf: &mut T) {
match self {
Frame::Data(b) => {
FrameType::DATA.encode(buf);
buf.write_var(b.remaining() as u64);
}
Frame::Headers(f) => {
FrameType::HEADERS.encode(buf);
buf.write_var(f.len() as u64);
}
Frame::Settings(f) => f.encode(buf),
Frame::PushPromise(f) => f.encode(buf),
Frame::CancelPush(id) => simple_frame_encode(FrameType::CANCEL_PUSH, *id, buf),
Frame::Goaway(id) => simple_frame_encode(FrameType::GOAWAY, *id, buf),
Frame::MaxPushId(id) => simple_frame_encode(FrameType::MAX_PUSH_ID, *id, buf),
Frame::Grease => {
FrameType::grease().encode(buf);
buf.write_var(6);
buf.put_slice(b"grease");
}
}
}
}
impl<B> Frame<B>
where
B: Buf,
{
pub fn payload(&self) -> Option<&dyn Buf> {
match self {
Frame::Data(f) => Some(f),
Frame::Headers(f) => Some(f),
Frame::PushPromise(f) => Some(&f.encoded),
_ => None,
}
}
pub fn payload_mut(&mut self) -> Option<&mut dyn Buf> {
match self {
Frame::Data(f) => Some(f),
Frame::Headers(f) => Some(f),
Frame::PushPromise(f) => Some(&mut f.encoded),
_ => None,
}
}
#[cfg(test)]
pub fn encode_with_payload<T: BufMut>(&mut self, buf: &mut T) {
self.encode(buf);
match self {
Frame::Data(b) => {
while b.has_remaining() {
let pos = {
let chunk = b.chunk();
buf.put_slice(chunk);
chunk.len()
};
b.advance(pos)
}
}
Frame::Headers(b) => buf.put_slice(b),
_ => (),
}
}
}
impl fmt::Debug for Frame<PayloadLen> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Frame::Data(len) => write!(f, "Data: {} bytes", len.0),
Frame::Headers(frame) => write!(f, "Headers({} entries)", frame.len()),
Frame::Settings(_) => write!(f, "Settings"),
Frame::CancelPush(id) => write!(f, "CancelPush({})", id),
Frame::PushPromise(frame) => write!(f, "PushPromise({})", frame.id),
Frame::Goaway(id) => write!(f, "GoAway({})", id),
Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id),
Frame::Grease => write!(f, "Grease()"),
}
}
}
impl<B> fmt::Debug for Frame<B>
where
B: Buf,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Frame::Data(data) => write!(f, "Data: {} bytes", data.remaining()),
Frame::Headers(frame) => write!(f, "Headers({} entries)", frame.len()),
Frame::Settings(_) => write!(f, "Settings"),
Frame::CancelPush(id) => write!(f, "CancelPush({})", id),
Frame::PushPromise(frame) => write!(f, "PushPromise({})", frame.id),
Frame::Goaway(id) => write!(f, "GoAway({})", id),
Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id),
Frame::Grease => write!(f, "Grease()"),
}
}
}
#[cfg(test)]
impl<T, U> PartialEq<Frame<T>> for Frame<U> {
fn eq(&self, other: &Frame<T>) -> bool {
match self {
Frame::Data(_) => matches!(other, Frame::Data(_)),
Frame::Settings(x) => matches!(other, Frame::Settings(y) if x == y),
Frame::Headers(x) => matches!(other, Frame::Headers(y) if x == y),
Frame::CancelPush(x) => matches!(other, Frame::CancelPush(y) if x == y),
Frame::PushPromise(x) => matches!(other, Frame::PushPromise(y) if x == y),
Frame::Goaway(x) => matches!(other, Frame::Goaway(y) if x == y),
Frame::MaxPushId(x) => matches!(other, Frame::MaxPushId(y) if x == y),
Frame::Grease => matches!(other, Frame::Grease),
}
}
}
#[cfg(test)]
impl Frame<Bytes> {
pub fn headers<T: Into<Bytes>>(block: T) -> Self {
Frame::Headers(block.into())
}
}
macro_rules! frame_types {
{$($name:ident = $val:expr,)*} => {
impl FrameType {
$(pub const $name: FrameType = FrameType($val);)*
}
}
}
frame_types! {
DATA = 0x0,
HEADERS = 0x1,
H2_PRIORITY = 0x2,
CANCEL_PUSH = 0x3,
SETTINGS = 0x4,
PUSH_PROMISE = 0x5,
H2_PING = 0x6,
GOAWAY = 0x7,
H2_WINDOW_UPDATE = 0x8,
H2_CONTINUATION = 0x9,
MAX_PUSH_ID = 0xD,
}
impl FrameType {
pub fn grease() -> Self {
FrameType(fastrand::u64(0..0x210842108421083) * 0x1f + 0x21)
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct FrameType(u64);
impl FrameType {
fn decode<B: Buf>(buf: &mut B) -> Result<Self, UnexpectedEnd> {
Ok(FrameType(buf.get_var()?))
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
buf.write_var(self.0);
}
#[cfg(test)]
pub(crate) const RESERVED: FrameType = FrameType(0x1f * 1337 + 0x21);
}
pub(crate) trait FrameHeader {
fn len(&self) -> usize;
const TYPE: FrameType;
fn encode_header<T: BufMut>(&self, buf: &mut T) {
Self::TYPE.encode(buf);
buf.write_var(self.len() as u64);
}
}
#[derive(Debug, PartialEq)]
pub struct PushPromise {
id: u64,
encoded: Bytes,
}
impl FrameHeader for PushPromise {
const TYPE: FrameType = FrameType::PUSH_PROMISE;
fn encode_header<T: BufMut>(&self, buf: &mut T) {
Self::TYPE.encode(buf);
buf.write_var(self.len() as u64);
buf.write_var(self.id);
}
fn len(&self) -> usize {
VarInt::from_u64(self.id)
.expect("PushPromise id varint overflow")
.size()
+ self.encoded.as_ref().len()
}
}
impl PushPromise {
fn decode<B: Buf>(buf: &mut B) -> Result<Self, UnexpectedEnd> {
Ok(PushPromise {
id: buf.get_var()?,
encoded: buf.copy_to_bytes(buf.remaining()),
})
}
fn encode<B: BufMut>(&self, buf: &mut B) {
self.encode_header(buf);
buf.put(self.encoded.clone());
}
}
fn simple_frame_encode<B: BufMut>(ty: FrameType, id: StreamId, buf: &mut B) {
ty.encode(buf);
buf.write_var(1);
id.encode(buf);
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct SettingId(pub u64);
impl SettingId {
const NONE: SettingId = SettingId(0);
pub fn grease() -> Self {
SettingId(fastrand::u64(0..0x210842108421083) * 0x1f + 0x21)
}
fn is_supported(self) -> bool {
matches!(
self,
SettingId::MAX_HEADER_LIST_SIZE
| SettingId::QPACK_MAX_TABLE_CAPACITY
| SettingId::QPACK_MAX_BLOCKED_STREAMS,
)
}
fn is_forbidden(&self) -> bool {
matches!(
self,
SettingId(0x00) | SettingId(0x02) | SettingId(0x03) | SettingId(0x04) | SettingId(0x05)
)
}
fn decode<B: Buf>(buf: &mut B) -> Result<Self, UnexpectedEnd> {
Ok(SettingId(buf.get_var()?))
}
fn encode<B: BufMut>(&self, buf: &mut B) {
buf.write_var(self.0);
}
}
macro_rules! setting_identifiers {
{$($name:ident = $val:expr,)*} => {
impl SettingId {
$(pub const $name: SettingId = SettingId($val);)*
}
}
}
setting_identifiers! {
QPACK_MAX_TABLE_CAPACITY = 0x1,
QPACK_MAX_BLOCKED_STREAMS = 0x7,
MAX_HEADER_LIST_SIZE = 0x6,
}
const SETTINGS_LEN: usize = 4;
#[derive(Debug, PartialEq)]
pub struct Settings {
entries: [(SettingId, u64); SETTINGS_LEN],
len: usize,
}
impl Default for Settings {
fn default() -> Self {
Self {
entries: [(SettingId::NONE, 0); SETTINGS_LEN],
len: 0,
}
}
}
impl FrameHeader for Settings {
const TYPE: FrameType = FrameType::SETTINGS;
fn len(&self) -> usize {
self.entries[..self.len].iter().fold(0, |len, (id, val)| {
len + VarInt::from_u64(id.0).unwrap().size() + VarInt::from_u64(*val).unwrap().size()
})
}
}
impl Settings {
pub const MAX_ENCODED_SIZE: usize = SETTINGS_LEN * 2 * VarInt::MAX_SIZE;
pub fn insert(&mut self, id: SettingId, value: u64) -> Result<(), SettingsError> {
if self.len >= self.entries.len() {
return Err(SettingsError::Exceeded);
}
if self.entries[..self.len].iter().any(|(i, _)| *i == id) {
return Err(SettingsError::Repeated(id));
}
self.entries[self.len] = (id, value);
self.len += 1;
Ok(())
}
pub fn get(&self, id: SettingId) -> Option<u64> {
for (entry_id, value) in self.entries.iter() {
if id == *entry_id {
return Some(*value);
}
}
None
}
pub(super) fn encode<T: BufMut>(&self, buf: &mut T) {
self.encode_header(buf);
for (id, val) in self.entries[..self.len].iter() {
id.encode(buf);
buf.write_var(*val);
}
}
pub(super) fn decode<T: Buf>(buf: &mut T) -> Result<Settings, SettingsError> {
let mut settings = Settings::default();
while buf.has_remaining() {
if buf.remaining() < 2 {
return Err(SettingsError::Malformed);
}
let identifier = SettingId::decode(buf).map_err(|_| SettingsError::Malformed)?;
let value = buf.get_var().map_err(|_| SettingsError::Malformed)?;
if identifier.is_forbidden() {
return Err(SettingsError::InvalidSettingId(identifier.0));
}
if identifier.is_supported() {
settings.insert(identifier, value)?;
}
}
Ok(settings)
}
}
#[derive(Debug, PartialEq)]
pub enum SettingsError {
Exceeded,
Malformed,
Repeated(SettingId),
InvalidSettingId(u64),
InvalidSettingValue(SettingId, u64),
}
impl std::error::Error for SettingsError {}
impl fmt::Display for SettingsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SettingsError::Exceeded => write!(
f,
"max settings number exeeded, check for duplicate entries"
),
SettingsError::Malformed => write!(f, "malformed settings frame"),
SettingsError::Repeated(id) => write!(f, "got setting 0x{:x} twice", id.0),
SettingsError::InvalidSettingId(id) => write!(f, "setting id 0x{:x} is invalid", id),
SettingsError::InvalidSettingValue(id, val) => {
write!(f, "setting 0x{:x} has invalid value {}", id.0, val)
}
}
}
}
impl From<SettingsError> for FrameError {
fn from(e: SettingsError) -> Self {
Self::Settings(e)
}
}
impl From<UnexpectedEnd> for FrameError {
fn from(e: UnexpectedEnd) -> Self {
FrameError::Incomplete(e.0)
}
}
impl From<InvalidStreamId> for FrameError {
fn from(e: InvalidStreamId) -> Self {
FrameError::InvalidStreamId(e)
}
}
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use std::io::Cursor;
#[test]
fn unknown_frame_type() {
let mut buf = Cursor::new(&[22, 4, 0, 255, 128, 0, 3, 1, 2]);
assert_matches!(Frame::decode(&mut buf), Err(FrameError::UnknownFrame(22)));
assert_matches!(Frame::decode(&mut buf), Ok(Frame::CancelPush(StreamId(2))));
}
#[test]
fn len_unexpected_end() {
let mut buf = Cursor::new(&[0, 255]);
let decoded = Frame::decode(&mut buf);
assert_matches!(decoded, Err(FrameError::Incomplete(3)));
}
#[test]
fn type_unexpected_end() {
let mut buf = Cursor::new(&[255]);
let decoded = Frame::decode(&mut buf);
assert_matches!(decoded, Err(FrameError::Incomplete(2)));
}
#[test]
fn buffer_too_short() {
let mut buf = Cursor::new(&[4, 4, 0, 255, 128]);
let decoded = Frame::decode(&mut buf);
assert_matches!(decoded, Err(FrameError::Incomplete(6)));
}
fn codec_frame_check(mut frame: Frame<Bytes>, wire: &[u8], check_frame: Frame<Bytes>) {
let mut buf = Vec::new();
frame.encode_with_payload(&mut buf);
assert_eq!(&buf, &wire);
let mut read = Cursor::new(&buf);
let decoded = Frame::decode(&mut read).unwrap();
assert_eq!(check_frame, decoded);
}
#[test]
fn settings_frame() {
codec_frame_check(
Frame::Settings(Settings {
entries: [
(SettingId::MAX_HEADER_LIST_SIZE, 0xfad1),
(SettingId::QPACK_MAX_TABLE_CAPACITY, 0xfad2),
(SettingId::QPACK_MAX_BLOCKED_STREAMS, 0xfad3),
(SettingId(95), 0),
],
len: 4,
}),
&[
4, 18, 6, 128, 0, 250, 209, 1, 128, 0, 250, 210, 7, 128, 0, 250, 211, 64, 95, 0,
],
Frame::Settings(Settings {
entries: [
(SettingId::MAX_HEADER_LIST_SIZE, 0xfad1),
(SettingId::QPACK_MAX_TABLE_CAPACITY, 0xfad2),
(SettingId::QPACK_MAX_BLOCKED_STREAMS, 0xfad3),
(SettingId(0), 0),
],
len: 3,
}),
);
}
#[test]
fn settings_frame_emtpy() {
codec_frame_check(
Frame::Settings(Settings::default()),
&[4, 0],
Frame::Settings(Settings::default()),
);
}
#[test]
fn data_frame() {
codec_frame_check(
Frame::Data(Bytes::from("1234567")),
&[0, 7, 49, 50, 51, 52, 53, 54, 55],
Frame::Data(Bytes::from("1234567")),
);
}
#[test]
fn simple_frames() {
codec_frame_check(
Frame::CancelPush(StreamId(2)),
&[3, 1, 2],
Frame::CancelPush(StreamId(2)),
);
codec_frame_check(
Frame::Goaway(StreamId(2)),
&[7, 1, 2],
Frame::Goaway(StreamId(2)),
);
codec_frame_check(
Frame::MaxPushId(StreamId(2)),
&[13, 1, 2],
Frame::MaxPushId(StreamId(2)),
);
}
#[test]
fn headers_frames() {
codec_frame_check(
Frame::headers("TODO QPACK"),
&[1, 10, 84, 79, 68, 79, 32, 81, 80, 65, 67, 75],
Frame::headers("TODO QPACK"),
);
codec_frame_check(
Frame::PushPromise(PushPromise {
id: 134,
encoded: Bytes::from("TODO QPACK"),
}),
&[5, 12, 64, 134, 84, 79, 68, 79, 32, 81, 80, 65, 67, 75],
Frame::PushPromise(PushPromise {
id: 134,
encoded: Bytes::from("TODO QPACK"),
}),
);
}
#[test]
fn reserved_frame() {
let mut raw = vec![];
VarInt::from_u32(0x21 + 2 * 0x1f).encode(&mut raw);
raw.extend(&[6, 0, 255, 128, 0, 250, 218]);
let mut buf = Cursor::new(&raw);
let decoded = Frame::decode(&mut buf);
assert_matches!(decoded, Err(FrameError::UnknownFrame(95)));
}
}