use bitflags::bitflags;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::stream::{OperationId, StreamId};
use serde::{Deserialize, Serialize};
use anyhow::Result;
pub const FRAME_HEADER_LEN: usize = 14;
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FrameFlags: u8 {
const EMPTY = 0b0000_0000;
const REQUEST = 0b0000_0001;
const RESPONSE = 0b0000_0010;
const FRAGMENTED = 0b0000_0100;
const FIN = 0b0000_1000;
const OPEN = 0b0001_0000;
const RESET = 0b0010_0000;
const CLOSE = 0b0100_0000;
}
}
impl FrameFlags {
pub fn new() -> Self {
Self::from_bits_truncate(FrameFlags::EMPTY.bits())
}
pub fn open_request() -> Self {
let mut flags = Self::new();
flags.set_flag(FrameFlags::REQUEST);
flags.set_flag(FrameFlags::OPEN);
flags
}
pub fn open_response() -> Self {
let mut flags = Self::new();
flags.set_flag(FrameFlags::RESPONSE);
flags.set_flag(FrameFlags::OPEN);
flags
}
pub fn open_reset() -> Self {
let mut flags = Self::new();
flags.set_flag(FrameFlags::RESPONSE);
flags.set_flag(FrameFlags::OPEN);
flags.set_flag(FrameFlags::RESET);
flags
}
pub fn close_request() -> Self {
let mut flags = Self::new();
flags.set_flag(FrameFlags::REQUEST);
flags.set_flag(FrameFlags::CLOSE);
flags
}
pub fn close_response() -> Self {
let mut flags = Self::new();
flags.set_flag(FrameFlags::RESPONSE);
flags.set_flag(FrameFlags::CLOSE);
flags
}
pub fn is_none(&self) -> bool {
self.is_empty()
}
pub fn is_request(&self) -> bool {
self.contains(FrameFlags::REQUEST)
}
pub fn is_response(&self) -> bool {
self.contains(FrameFlags::RESPONSE)
}
pub fn is_fragmented(&self) -> bool {
self.contains(FrameFlags::FRAGMENTED)
}
pub fn is_open(&self) -> bool {
self.contains(FrameFlags::OPEN)
}
pub fn is_fin(&self) -> bool {
self.contains(FrameFlags::FIN)
}
pub fn is_reset(&self) -> bool {
self.contains(FrameFlags::RESET)
}
pub fn is_open_request(self) -> bool {
self == (FrameFlags::REQUEST | FrameFlags::OPEN)
}
pub fn is_open_response(self) -> bool {
self == (FrameFlags::RESPONSE | FrameFlags::OPEN)
}
pub fn is_open_reset(self) -> bool {
self == (FrameFlags::RESPONSE | FrameFlags::OPEN | FrameFlags::RESET)
}
pub fn is_close_request(self) -> bool {
self == (FrameFlags::REQUEST | FrameFlags::CLOSE)
}
pub fn is_close_response(self) -> bool {
self == (FrameFlags::RESPONSE | FrameFlags::CLOSE)
}
pub fn set_flag(&mut self, flag: FrameFlags) {
self.insert(flag);
}
pub fn unset_flag(&mut self, flag: FrameFlags) {
self.remove(flag);
}
pub fn from_u8(&mut self, value: u8) {
Self::from_bits_truncate(value);
}
pub fn as_u8(&self) -> u8 {
self.bits()
}
}
impl Serialize for FrameFlags {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u8(self.bits())
}
}
impl<'de> Deserialize<'de> for FrameFlags {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let bits = u8::deserialize(deserializer)?;
FrameFlags::from_bits(bits)
.ok_or_else(|| serde::de::Error::custom("Invalid FrameFlags bits"))
}
}
#[repr(u8)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum FrameType {
Control = 0,
Data = 1,
Meta = 2,
FileMeta = 3,
FileData = 4,
ContentRequest = 5,
Reserved(u8),
}
impl From<u8> for FrameType {
fn from(value: u8) -> Self {
match value {
0 => FrameType::Control,
1 => FrameType::Data,
2 => FrameType::Meta,
3 => FrameType::FileMeta,
4 => FrameType::FileData,
5 => FrameType::ContentRequest,
other => FrameType::Reserved(other),
}
}
}
impl From<FrameType> for u8 {
fn from(ft: FrameType) -> Self {
match ft {
FrameType::Control => 0,
FrameType::Data => 1,
FrameType::Meta => 2,
FrameType::FileMeta => 3,
FrameType::FileData => 4,
FrameType::ContentRequest => 5,
FrameType::Reserved(other) => other,
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct FrameHeader {
pub stream_id: StreamId,
pub operation_id: OperationId,
pub flags: FrameFlags,
pub frame_type: FrameType,
pub payload_len: u32,
}
impl FrameHeader {
pub fn new(stream_id: StreamId, operation_id: OperationId, flags: FrameFlags, frame_type: FrameType) -> Self {
Self {
stream_id,
operation_id,
flags,
frame_type: frame_type,
payload_len: 0,
}
}
pub fn empty() -> Self {
Self {
stream_id: StreamId(0),
operation_id: OperationId(0),
flags: FrameFlags::EMPTY,
frame_type: FrameType::Data,
payload_len: 0,
}
}
pub fn set_flag(&mut self, flag: FrameFlags) {
self.flags.set_flag(flag);
}
pub fn unset_flag(&mut self, flag: FrameFlags) {
self.flags.unset_flag(flag);
}
pub fn has_flag(&self, flag: FrameFlags) -> bool {
self.flags.contains(flag)
}
pub fn is_fin(&self) -> bool {
self.flags.is_fin()
}
pub fn is_request(&self) -> bool {
self.flags.is_request()
}
pub fn is_response(&self) -> bool {
self.flags.is_response()
}
pub fn is_fragmented(&self) -> bool {
self.flags.is_fragmented()
}
pub fn set_request(&mut self) {
self.flags.set_flag(FrameFlags::REQUEST);
}
pub fn set_response(&mut self) {
self.flags.set_flag(FrameFlags::RESPONSE);
}
pub fn set_fragmented(&mut self) {
self.flags.set_flag(FrameFlags::FRAGMENTED);
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct Frame {
pub header: FrameHeader,
pub payload: Bytes,
}
impl Frame {
pub fn empty() -> Self {
Self {
header: FrameHeader::empty(),
payload: Bytes::new(),
}
}
pub fn builder() -> FrameBuilder {
FrameBuilder::new()
}
pub fn to_bytes(&self) -> Result<Bytes> {
let mut buf = BytesMut::with_capacity(FRAME_HEADER_LEN + self.payload.len());
buf.put_u32(self.header.stream_id.0);
buf.put_u32(self.header.operation_id.0);
buf.put_u8(self.header.flags.bits());
buf.put_u8(self.header.frame_type.into());
buf.put_u32(self.payload.len() as u32);
buf.extend_from_slice(&self.payload);
Ok(buf.freeze())
}
pub fn from_bytes(mut bytes: Bytes) -> Result<Self> {
if bytes.len() < FRAME_HEADER_LEN {
anyhow::bail!("Not enough bytes to read frame header");
}
let stream_id = StreamId(bytes.get_u32());
let operation_id = OperationId(bytes.get_u32());
let flags = FrameFlags::from_bits_truncate(bytes.get_u8());
let frame_type = FrameType::from(bytes.get_u8());
let payload_len = bytes.get_u32();
Ok(Self {
header: FrameHeader { stream_id, operation_id, flags, frame_type, payload_len },
payload: bytes,
})
}
pub fn to_byte_array(&self) -> Result<Vec<u8>> {
self.to_bytes()
.and_then(|bytes| Ok(bytes.to_vec()))
.map_err(|e| e)
}
pub fn from_byte_array(bytes: &[u8]) -> Result<Self> {
let bytes = Bytes::copy_from_slice(bytes);
Self::from_bytes(bytes)
.map_err(|e| e)
}
pub fn len(&self) -> usize {
FRAME_HEADER_LEN + self.payload.len()
}
pub fn payload_len(&self) -> usize {
self.payload.len()
}
}
pub struct FrameBuilder {
pub header: FrameHeader,
pub payload: Bytes,
}
impl FrameBuilder {
pub fn new() -> Self {
Self {
header: FrameHeader::empty(),
payload: Bytes::new(),
}
}
pub fn with_fin(mut self, fin: bool) -> Self {
if fin {
self.header.flags.set_flag(FrameFlags::FIN);
} else {
self.header.flags.unset_flag(FrameFlags::FIN);
}
self
}
pub fn with_stream_id(mut self, stream_id: StreamId) -> Self {
self.header.stream_id = stream_id;
self
}
pub fn with_operation_id(mut self, operation_id: OperationId) -> Self {
self.header.operation_id = operation_id;
self
}
pub fn with_flags(mut self, flags: FrameFlags) -> Self {
self.header.flags = flags;
self
}
pub fn with_frame_type(mut self, frame_type: FrameType) -> Self {
self.header.frame_type = frame_type;
self
}
pub fn as_request(mut self) -> Self {
self.header.flags.set_flag(FrameFlags::REQUEST);
self
}
pub fn as_response(mut self) -> Self {
self.header.flags.set_flag(FrameFlags::RESPONSE);
self
}
pub fn as_fragmented(mut self) -> Self {
self.header.flags.set_flag(FrameFlags::FRAGMENTED);
self
}
pub fn with_payload(mut self, payload: Bytes) -> Self {
self.payload = payload;
self
}
pub fn build(self) -> Frame {
Frame {
header: self.header,
payload: self.payload,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_frame_to_bytes_and_back() {
let frame = Frame::builder()
.with_stream_id(StreamId(123))
.with_operation_id(OperationId(456))
.with_flags(FrameFlags::FIN)
.with_frame_type(FrameType::Data)
.with_payload(Bytes::from_static(b"test-data"))
.build();
let bytes = frame.to_bytes().expect("to_bytes failed");
let parsed = Frame::from_bytes(bytes).expect("from_bytes failed");
assert_eq!(frame.header.stream_id, parsed.header.stream_id);
assert_eq!(frame.header.operation_id, parsed.header.operation_id);
assert_eq!(frame.header.flags, parsed.header.flags);
assert_eq!(frame.header.frame_type, parsed.header.frame_type);
assert_eq!(frame.payload, parsed.payload);
}
#[test]
fn test_frame_flags_open_close() {
let open_req = FrameFlags::open_request();
assert!(open_req.is_open_request());
assert!(open_req.is_open());
assert!(open_req.is_request());
let open_res = FrameFlags::open_response();
assert!(open_res.is_open_response());
assert!(open_res.is_open());
assert!(open_res.is_response());
let open_rst = FrameFlags::open_reset();
assert!(open_rst.is_open_reset());
assert!(open_rst.is_open());
assert!(open_rst.is_response());
assert!(open_rst.is_reset());
let close_req = FrameFlags::close_request();
assert!(close_req.is_close_request());
assert!(close_req.is_request());
let close_res = FrameFlags::close_response();
assert!(close_res.is_close_response());
assert!(close_res.is_response());
}
}