use std::{fmt, io::Write, ops::RangeInclusive};
use byteorder::{BigEndian, WriteBytesExt};
use enum_repr::EnumRepr;
pub use enumflags2;
use enumflags2::{bitflags, BitFlags};
pub use nom;
use nom::{
combinator::map,
number::streaming::{be_u16, be_u24, be_u32, be_u8},
sequence::tuple,
IResult,
};
use fluke_buffet::{Piece, Roll, RollMut};
pub const PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
pub fn preface(i: Roll) -> IResult<Roll, ()> {
let (i, _) = nom::bytes::streaming::tag(PREFACE)(i)?;
Ok((i, ()))
}
pub trait IntoPiece {
fn into_piece(self, scratch: &mut RollMut) -> std::io::Result<Piece>;
}
#[EnumRepr(type = "u8")]
#[derive(Debug, Clone, Copy)]
pub enum RawFrameType {
Data = 0x00,
Headers = 0x01,
Priority = 0x02,
RstStream = 0x03,
Settings = 0x04,
PushPromise = 0x05,
Ping = 0x06,
GoAway = 0x07,
WindowUpdate = 0x08,
Continuation = 0x09,
}
#[derive(Debug, Clone, Copy)]
pub enum FrameType {
Data(BitFlags<DataFlags>),
Headers(BitFlags<HeadersFlags>),
Priority,
RstStream,
Settings(BitFlags<SettingsFlags>),
PushPromise,
Ping(BitFlags<PingFlags>),
GoAway,
WindowUpdate,
Continuation(BitFlags<ContinuationFlags>),
Unknown(EncodedFrameType),
}
impl FrameType {
pub fn into_frame(self, stream_id: StreamId) -> Frame {
Frame {
frame_type: self,
len: 0,
reserved: 0,
stream_id,
}
}
}
#[bitflags]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum DataFlags {
Padded = 0x08,
EndStream = 0x01,
}
#[bitflags]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum HeadersFlags {
Priority = 0x20,
Padded = 0x08,
EndHeaders = 0x04,
EndStream = 0x01,
}
#[bitflags]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum SettingsFlags {
Ack = 0x01,
}
#[bitflags]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum PingFlags {
Ack = 0x01,
}
#[bitflags]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ContinuationFlags {
EndHeaders = 0x04,
}
#[derive(Debug, Clone, Copy)]
pub struct EncodedFrameType {
pub ty: u8,
pub flags: u8,
}
impl EncodedFrameType {
fn parse(i: Roll) -> IResult<Roll, Self> {
let (i, (ty, flags)) = tuple((be_u8, be_u8))(i)?;
Ok((i, Self { ty, flags }))
}
}
impl From<(RawFrameType, u8)> for EncodedFrameType {
fn from((ty, flags): (RawFrameType, u8)) -> Self {
Self {
ty: ty.repr(),
flags,
}
}
}
impl FrameType {
pub(crate) fn encode(self) -> EncodedFrameType {
match self {
FrameType::Data(f) => (RawFrameType::Data, f.bits()).into(),
FrameType::Headers(f) => (RawFrameType::Headers, f.bits()).into(),
FrameType::Priority => (RawFrameType::Priority, 0).into(),
FrameType::RstStream => (RawFrameType::RstStream, 0).into(),
FrameType::Settings(f) => (RawFrameType::Settings, f.bits()).into(),
FrameType::PushPromise => (RawFrameType::PushPromise, 0).into(),
FrameType::Ping(f) => (RawFrameType::Ping, f.bits()).into(),
FrameType::GoAway => (RawFrameType::GoAway, 0).into(),
FrameType::WindowUpdate => (RawFrameType::WindowUpdate, 0).into(),
FrameType::Continuation(f) => (RawFrameType::Continuation, f.bits()).into(),
FrameType::Unknown(ft) => ft,
}
}
fn decode(ft: EncodedFrameType) -> Self {
match RawFrameType::from_repr(ft.ty) {
Some(ty) => match ty {
RawFrameType::Data => {
FrameType::Data(BitFlags::<DataFlags>::from_bits_truncate(ft.flags))
}
RawFrameType::Headers => {
FrameType::Headers(BitFlags::<HeadersFlags>::from_bits_truncate(ft.flags))
}
RawFrameType::Priority => FrameType::Priority,
RawFrameType::RstStream => FrameType::RstStream,
RawFrameType::Settings => {
FrameType::Settings(BitFlags::<SettingsFlags>::from_bits_truncate(ft.flags))
}
RawFrameType::PushPromise => FrameType::PushPromise,
RawFrameType::Ping => {
FrameType::Ping(BitFlags::<PingFlags>::from_bits_truncate(ft.flags))
}
RawFrameType::GoAway => FrameType::GoAway,
RawFrameType::WindowUpdate => FrameType::WindowUpdate,
RawFrameType::Continuation => FrameType::Continuation(
BitFlags::<ContinuationFlags>::from_bits_truncate(ft.flags),
),
},
None => FrameType::Unknown(ft),
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct StreamId(pub u32);
impl StreamId {
pub const CONNECTION: Self = Self(0);
pub fn is_server_initiated(&self) -> bool {
self.0 % 2 == 0
}
}
#[derive(Debug, thiserror::Error)]
#[error("invalid stream id: {0}")]
pub struct StreamIdOutOfRange(u32);
impl TryFrom<u32> for StreamId {
type Error = StreamIdOutOfRange;
fn try_from(value: u32) -> Result<Self, Self::Error> {
if value & 0x8000_0000 != 0 {
Err(StreamIdOutOfRange(value))
} else {
Ok(Self(value))
}
}
}
impl fmt::Debug for StreamId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.0, f)
}
}
impl fmt::Display for StreamId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
pub struct Frame {
pub frame_type: FrameType,
pub reserved: u8,
pub stream_id: StreamId,
pub len: u32,
}
impl fmt::Debug for Frame {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.stream_id.0 == 0 {
write!(f, "Conn:")?;
} else {
write!(f, "#{}:", self.stream_id.0)?;
}
let name = match &self.frame_type {
FrameType::Data(_) => "Data",
FrameType::Headers(_) => "Headers",
FrameType::Priority => "Priority",
FrameType::RstStream => "RstStream",
FrameType::Settings(_) => "Settings",
FrameType::PushPromise => "PushPromise",
FrameType::Ping(_) => "Ping",
FrameType::GoAway => "GoAway",
FrameType::WindowUpdate => "WindowUpdate",
FrameType::Continuation(_) => "Continuation",
FrameType::Unknown(EncodedFrameType { ty, flags }) => {
return write!(f, "UnknownFrame({:#x}, {:#x})", ty, flags)
}
};
let mut s = f.debug_struct(name);
if self.reserved != 0 {
s.field("reserved", &self.reserved);
}
if self.len > 0 {
s.field("len", &self.len);
}
struct DisplayDebug<'a, D: fmt::Display>(&'a D);
impl<'a, D: fmt::Display> fmt::Debug for DisplayDebug<'a, D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self.0, f)
}
}
match &self.frame_type {
FrameType::Data(flags) => {
if !flags.is_empty() {
s.field("flags", &DisplayDebug(flags));
}
}
FrameType::Headers(flags) => {
if !flags.is_empty() {
s.field("flags", &DisplayDebug(flags));
}
}
FrameType::Settings(flags) => {
if !flags.is_empty() {
s.field("flags", &DisplayDebug(flags));
}
}
FrameType::Ping(flags) => {
if !flags.is_empty() {
s.field("flags", &DisplayDebug(flags));
}
}
FrameType::Continuation(flags) => {
if !flags.is_empty() {
s.field("flags", &DisplayDebug(flags));
}
}
_ => {
}
}
s.finish()
}
}
impl Frame {
pub fn new(frame_type: FrameType, stream_id: StreamId) -> Self {
Self {
frame_type,
reserved: 0,
stream_id,
len: 0,
}
}
pub fn with_len(mut self, len: u32) -> Self {
self.len = len;
self
}
pub fn parse(i: Roll) -> IResult<Roll, Self> {
let (i, (len, frame_type, (reserved, stream_id))) = tuple((
be_u24,
EncodedFrameType::parse,
parse_reserved_and_stream_id,
))(i)?;
let frame = Frame {
frame_type: FrameType::decode(frame_type),
reserved,
stream_id,
len,
};
Ok((i, frame))
}
pub fn write_into(self, mut w: impl std::io::Write) -> std::io::Result<()> {
use byteorder::{BigEndian, WriteBytesExt};
w.write_u24::<BigEndian>(self.len as _)?;
let ft = self.frame_type.encode();
w.write_u8(ft.ty)?;
w.write_u8(ft.flags)?;
w.write_all(&pack_reserved_and_stream_id(self.reserved, self.stream_id))?;
Ok(())
}
pub fn is_ack(self) -> bool {
match self.frame_type {
FrameType::Data(_) => false,
FrameType::Headers(_) => false,
FrameType::Priority => false,
FrameType::RstStream => false,
FrameType::Settings(flags) => flags.contains(SettingsFlags::Ack),
FrameType::PushPromise => false,
FrameType::Ping(flags) => flags.contains(PingFlags::Ack),
FrameType::GoAway => false,
FrameType::WindowUpdate => false,
FrameType::Continuation(_) => false,
FrameType::Unknown(_) => false,
}
}
}
impl IntoPiece for Frame {
fn into_piece(self, mut scratch: &mut RollMut) -> std::io::Result<Piece> {
debug_assert_eq!(scratch.len(), 0);
self.write_into(&mut scratch)?;
Ok(scratch.take_all().into())
}
}
pub fn parse_reserved_and_u31(i: Roll) -> IResult<Roll, (u8, u32)> {
fn reserved(i: (Roll, usize)) -> IResult<(Roll, usize), u8> {
nom::bits::streaming::take(1_usize)(i)
}
fn stream_id(i: (Roll, usize)) -> IResult<(Roll, usize), u32> {
nom::bits::streaming::take(31_usize)(i)
}
nom::bits::bits(tuple((reserved, stream_id)))(i)
}
fn parse_reserved_and_stream_id(i: Roll) -> IResult<Roll, (u8, StreamId)> {
parse_reserved_and_u31(i).map(|(i, (reserved, stream_id))| (i, (reserved, StreamId(stream_id))))
}
fn pack_reserved_and_stream_id(reserved: u8, stream_id: StreamId) -> [u8; 4] {
let mut bytes = stream_id.0.to_be_bytes();
if reserved != 0 {
bytes[0] |= 0b1000_0000;
}
bytes
}
#[derive(Debug)]
pub struct PrioritySpec {
pub exclusive: bool,
pub stream_dependency: StreamId,
pub weight: u8,
}
impl PrioritySpec {
pub fn parse(i: Roll) -> IResult<Roll, Self> {
map(
tuple((parse_reserved_and_stream_id, be_u8)),
|((exclusive, stream_dependency), weight)| Self {
exclusive: exclusive != 0,
stream_dependency,
weight,
},
)(i)
}
}
#[derive(Clone, Copy)]
pub struct ErrorCode(u32);
impl ErrorCode {
pub fn as_repr(self) -> u32 {
self.0
}
}
impl fmt::Debug for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match KnownErrorCode::from_repr(self.0) {
Some(e) => fmt::Debug::fmt(&e, f),
None => write!(f, "ErrorCode(0x{:02x})", self.0),
}
}
}
impl From<KnownErrorCode> for ErrorCode {
fn from(e: KnownErrorCode) -> Self {
Self(e as u32)
}
}
#[EnumRepr(type = "u32")]
#[derive(Debug, Clone, Copy)]
pub enum KnownErrorCode {
NoError = 0x00,
ProtocolError = 0x01,
InternalError = 0x02,
FlowControlError = 0x03,
SettingsTimeout = 0x04,
StreamClosed = 0x05,
FrameSizeError = 0x06,
RefusedStream = 0x07,
Cancel = 0x08,
CompressionError = 0x09,
ConnectError = 0x0a,
EnhanceYourCalm = 0x0b,
InadequateSecurity = 0x0c,
Http1_1Required = 0x0d,
}
impl TryFrom<ErrorCode> for KnownErrorCode {
type Error = ();
fn try_from(e: ErrorCode) -> Result<Self, Self::Error> {
KnownErrorCode::from_repr(e.0).ok_or(())
}
}
#[derive(Clone, Copy, Debug)]
pub struct Settings {
pub header_table_size: u32,
pub enable_push: bool,
pub max_concurrent_streams: u32,
pub initial_window_size: u32,
pub max_frame_size: u32,
pub max_header_list_size: u32,
}
impl Default for Settings {
fn default() -> Self {
Self {
header_table_size: 4096,
enable_push: false,
max_concurrent_streams: 100,
initial_window_size: (1 << 16) - 1,
max_frame_size: (1 << 14),
max_header_list_size: 0,
}
}
}
#[EnumRepr(type = "u16")]
#[derive(Debug, Clone, Copy)]
enum SettingIdentifier {
HeaderTableSize = 0x01,
EnablePush = 0x02,
MaxConcurrentStreams = 0x03,
InitialWindowSize = 0x04,
MaxFrameSize = 0x05,
MaxHeaderListSize = 0x06,
}
impl Settings {
const MAX_INITIAL_WINDOW_SIZE: u32 = (1 << 31) - 1;
const MAX_FRAME_SIZE_ALLOWED_RANGE: RangeInclusive<u32> = (1 << 14)..=((1 << 24) - 1);
pub fn parse(mut i: Roll) -> IResult<Roll, Self> {
tracing::trace!("parsing settings frame, roll length: {}", i.len());
let mut settings = Self::default();
while !i.is_empty() {
let (rest, (id, value)) = tuple((be_u16, be_u32))(i)?;
tracing::trace!(%id, %value, "Got setting pair");
match SettingIdentifier::from_repr(id) {
None => {
}
Some(id) => match id {
SettingIdentifier::HeaderTableSize => {
settings.header_table_size = value;
}
SettingIdentifier::EnablePush => {
settings.enable_push = match value {
0 => false,
1 => true,
_ => {
return Err(nom::Err::Error(nom::error::Error::new(
rest,
nom::error::ErrorKind::Digit,
)));
}
}
}
SettingIdentifier::MaxConcurrentStreams => {
settings.max_concurrent_streams = value;
}
SettingIdentifier::InitialWindowSize => {
if value > Self::MAX_INITIAL_WINDOW_SIZE {
return Err(nom::Err::Error(nom::error::Error::new(
rest,
nom::error::ErrorKind::Digit,
)));
}
settings.initial_window_size = value;
}
SettingIdentifier::MaxFrameSize => {
if !Self::MAX_FRAME_SIZE_ALLOWED_RANGE.contains(&value) {
return Err(nom::Err::Error(nom::error::Error::new(
rest,
nom::error::ErrorKind::Digit,
)));
}
settings.max_frame_size = value;
}
SettingIdentifier::MaxHeaderListSize => {
settings.max_header_list_size = value;
}
},
}
i = rest;
}
Ok((i, settings))
}
pub fn pairs(&self) -> impl Iterator<Item = (u16, u32)> {
[
(
SettingIdentifier::HeaderTableSize as u16,
self.header_table_size,
),
(
SettingIdentifier::EnablePush as u16,
self.enable_push as u32,
),
(
SettingIdentifier::MaxConcurrentStreams as u16,
self.max_concurrent_streams,
),
(
SettingIdentifier::InitialWindowSize as u16,
self.initial_window_size,
),
(SettingIdentifier::MaxFrameSize as u16, self.max_frame_size),
(
SettingIdentifier::MaxHeaderListSize as u16,
self.max_header_list_size,
),
]
.into_iter()
}
pub fn write_into(self, mut w: impl std::io::Write) -> std::io::Result<()> {
use byteorder::{BigEndian, WriteBytesExt};
for (id, value) in self.pairs() {
w.write_u16::<BigEndian>(id)?;
w.write_u32::<BigEndian>(value)?;
}
Ok(())
}
}
impl IntoPiece for Settings {
fn into_piece(self, mut scratch: &mut RollMut) -> std::io::Result<Piece> {
debug_assert_eq!(scratch.len(), 0);
self.write_into(&mut scratch)?;
Ok(scratch.take_all().into())
}
}
pub struct GoAway {
pub last_stream_id: StreamId,
pub error_code: ErrorCode,
pub additional_debug_data: Piece,
}
impl IntoPiece for GoAway {
fn into_piece(self, scratch: &mut RollMut) -> std::io::Result<Piece> {
let roll = scratch
.put_to_roll(8 + self.additional_debug_data.len(), |mut slice| {
slice.write_u32::<BigEndian>(self.last_stream_id.0)?;
slice.write_u32::<BigEndian>(self.error_code.0)?;
slice.write_all(&self.additional_debug_data[..])?;
Ok(())
})
.unwrap();
Ok(roll.into())
}
}
impl GoAway {
pub fn parse(i: Roll) -> IResult<Roll, Self> {
let (rest, (last_stream_id, error_code)) = tuple((be_u32, be_u32))(i)?;
let i = Roll::empty();
Ok((
i,
Self {
last_stream_id: StreamId(last_stream_id),
error_code: ErrorCode(error_code),
additional_debug_data: rest.into(),
},
))
}
}
pub struct RstStream {
pub error_code: ErrorCode,
}
impl IntoPiece for RstStream {
fn into_piece(self, scratch: &mut RollMut) -> std::io::Result<Piece> {
let roll = scratch
.put_to_roll(4, |mut slice| {
slice.write_u32::<BigEndian>(self.error_code.0)?;
Ok(())
})
.unwrap();
Ok(roll.into())
}
}
impl RstStream {
pub fn parse(i: Roll) -> IResult<Roll, Self> {
let (rest, error_code) = be_u32(i)?;
Ok((
rest,
Self {
error_code: ErrorCode(error_code),
},
))
}
}
impl<T> IntoPiece for T
where
Piece: From<T>,
{
fn into_piece(self, _scratch: &mut RollMut) -> std::io::Result<Piece> {
Ok(self.into())
}
}