use crate::error::{Error, Result};
use crate::frame::{
ContinuationFrame, DataFrame, FRAME_HEADER_SIZE, Frame, FrameFlags, FrameHeader, FrameType,
GoawayFrame, HeadersFrame, PingFrame, PriorityFrame, PriorityUpdateFrame, RstStreamFrame,
SettingsFrame, WindowUpdateFrame,
};
#[derive(Debug)]
pub struct FrameEncoder {
buf: Vec<u8>,
}
impl FrameEncoder {
#[must_use]
pub fn new() -> Self {
Self { buf: Vec::new() }
}
#[must_use]
pub fn buffer(&self) -> &[u8] {
&self.buf
}
pub fn clear(&mut self) {
self.buf.clear();
}
pub fn take(&mut self) -> Vec<u8> {
std::mem::take(&mut self.buf)
}
pub fn encode(&mut self, frame: &Frame) -> Result<()> {
match frame {
Frame::Data(f) => self.encode_data(f),
Frame::Headers(f) => self.encode_headers(f),
Frame::Priority(f) => self.encode_priority(f),
Frame::RstStream(f) => self.encode_rst_stream(f),
Frame::Settings(f) => self.encode_settings(f),
Frame::PushPromise { .. } => {
Err(Error::protocol_error(
"PUSH_PROMISE is not supported and should not be sent",
))
}
Frame::Ping(f) => self.encode_ping(f),
Frame::Goaway(f) => self.encode_goaway(f),
Frame::WindowUpdate(f) => self.encode_window_update(f),
Frame::Continuation(f) => self.encode_continuation(f),
Frame::PriorityUpdate(f) => self.encode_priority_update(f),
Frame::Unknown { header, payload } => self.encode_unknown(header, payload),
}
}
fn encode_header(&mut self, header: &FrameHeader) {
self.buf.push(((header.length >> 16) & 0xff) as u8);
self.buf.push(((header.length >> 8) & 0xff) as u8);
self.buf.push((header.length & 0xff) as u8);
self.buf.push(header.frame_type);
self.buf.push(header.flags.bits());
self.buf.push(((header.stream_id >> 24) & 0x7f) as u8);
self.buf.push(((header.stream_id >> 16) & 0xff) as u8);
self.buf.push(((header.stream_id >> 8) & 0xff) as u8);
self.buf.push((header.stream_id & 0xff) as u8);
}
fn encode_data(&mut self, frame: &DataFrame) -> Result<()> {
let mut flags = FrameFlags::empty();
if frame.end_stream {
flags = flags.set(FrameFlags::END_STREAM);
}
let (length, pad_length) = if let Some(pad_len) = frame.pad_length {
flags = flags.set(FrameFlags::PADDED);
let total = 1 + frame.data.len() as u32 + u32::from(pad_len);
(total, Some(pad_len))
} else {
(frame.data.len() as u32, None)
};
let header = FrameHeader::new(FrameType::Data, flags, frame.stream_id).with_length(length);
self.encode_header(&header);
if let Some(pad_len) = pad_length {
self.buf.push(pad_len);
}
self.buf.extend_from_slice(&frame.data);
if let Some(pad_len) = pad_length {
self.buf.extend(std::iter::repeat_n(0u8, pad_len as usize));
}
Ok(())
}
fn encode_headers(&mut self, frame: &HeadersFrame) -> Result<()> {
let mut flags = FrameFlags::empty();
if frame.end_stream {
flags = flags.set(FrameFlags::END_STREAM);
}
if frame.end_headers {
flags = flags.set(FrameFlags::END_HEADERS);
}
let (length, pad_length) = if let Some(pad_len) = frame.pad_length {
flags = flags.set(FrameFlags::PADDED);
let total = 1 + frame.header_block_fragment.len() as u32 + u32::from(pad_len);
(total, Some(pad_len))
} else {
(frame.header_block_fragment.len() as u32, None)
};
let header =
FrameHeader::new(FrameType::Headers, flags, frame.stream_id).with_length(length);
self.encode_header(&header);
if let Some(pad_len) = pad_length {
self.buf.push(pad_len);
}
self.buf.extend_from_slice(&frame.header_block_fragment);
if let Some(pad_len) = pad_length {
self.buf.extend(std::iter::repeat_n(0u8, pad_len as usize));
}
Ok(())
}
fn encode_priority(&mut self, _frame: &PriorityFrame) -> Result<()> {
Err(Error::protocol_error(
"PRIORITY frame is deprecated in RFC 9113 and should not be sent",
))
}
fn encode_rst_stream(&mut self, frame: &RstStreamFrame) -> Result<()> {
let header = FrameHeader::new(FrameType::RstStream, FrameFlags::empty(), frame.stream_id)
.with_length(4);
self.encode_header(&header);
self.buf.extend_from_slice(&frame.error_code.to_be_bytes());
Ok(())
}
fn encode_settings(&mut self, frame: &SettingsFrame) -> Result<()> {
let mut flags = FrameFlags::empty();
if frame.ack {
flags = flags.set(FrameFlags::ACK);
}
let length = if frame.ack {
0
} else {
(frame.settings.len() * 6) as u32
};
let header = FrameHeader::new(FrameType::Settings, flags, 0).with_length(length);
self.encode_header(&header);
if !frame.ack {
for setting in &frame.settings {
self.buf.extend_from_slice(&setting.id.to_be_bytes());
self.buf.extend_from_slice(&setting.value.to_be_bytes());
}
}
Ok(())
}
fn encode_ping(&mut self, frame: &PingFrame) -> Result<()> {
let mut flags = FrameFlags::empty();
if frame.ack {
flags = flags.set(FrameFlags::ACK);
}
let header = FrameHeader::new(FrameType::Ping, flags, 0).with_length(8);
self.encode_header(&header);
self.buf.extend_from_slice(&frame.opaque_data);
Ok(())
}
fn encode_goaway(&mut self, frame: &GoawayFrame) -> Result<()> {
let length = (8 + frame.debug_data.len()) as u32;
let header =
FrameHeader::new(FrameType::Goaway, FrameFlags::empty(), 0).with_length(length);
self.encode_header(&header);
self.buf.push(((frame.last_stream_id >> 24) & 0x7f) as u8);
self.buf.push(((frame.last_stream_id >> 16) & 0xff) as u8);
self.buf.push(((frame.last_stream_id >> 8) & 0xff) as u8);
self.buf.push((frame.last_stream_id & 0xff) as u8);
self.buf.extend_from_slice(&frame.error_code.to_be_bytes());
self.buf.extend_from_slice(&frame.debug_data);
Ok(())
}
fn encode_window_update(&mut self, frame: &WindowUpdateFrame) -> Result<()> {
let header = FrameHeader::new(
FrameType::WindowUpdate,
FrameFlags::empty(),
frame.stream_id,
)
.with_length(4);
self.encode_header(&header);
self.buf
.push(((frame.window_size_increment >> 24) & 0x7f) as u8);
self.buf
.push(((frame.window_size_increment >> 16) & 0xff) as u8);
self.buf
.push(((frame.window_size_increment >> 8) & 0xff) as u8);
self.buf.push((frame.window_size_increment & 0xff) as u8);
Ok(())
}
fn encode_continuation(&mut self, frame: &ContinuationFrame) -> Result<()> {
let mut flags = FrameFlags::empty();
if frame.end_headers {
flags = flags.set(FrameFlags::END_HEADERS);
}
let length = frame.header_block_fragment.len() as u32;
let header =
FrameHeader::new(FrameType::Continuation, flags, frame.stream_id).with_length(length);
self.encode_header(&header);
self.buf.extend_from_slice(&frame.header_block_fragment);
Ok(())
}
fn encode_priority_update(&mut self, frame: &PriorityUpdateFrame) -> Result<()> {
let length = (4 + frame.priority_field_value.len()) as u32;
let header =
FrameHeader::new(FrameType::PriorityUpdate, FrameFlags::empty(), 0).with_length(length);
self.encode_header(&header);
self.buf
.push(((frame.prioritized_element_id >> 24) & 0x7f) as u8);
self.buf
.push(((frame.prioritized_element_id >> 16) & 0xff) as u8);
self.buf
.push(((frame.prioritized_element_id >> 8) & 0xff) as u8);
self.buf.push((frame.prioritized_element_id & 0xff) as u8);
self.buf.extend_from_slice(&frame.priority_field_value);
Ok(())
}
fn encode_unknown(&mut self, header: &FrameHeader, payload: &[u8]) -> Result<()> {
let header = header.with_length(payload.len() as u32);
self.encode_header(&header);
self.buf.extend_from_slice(payload);
Ok(())
}
}
impl Default for FrameEncoder {
fn default() -> Self {
Self::new()
}
}
pub fn encode_header(buf: &mut [u8], header: &FrameHeader) -> Result<()> {
Error::check_buffer_size(FRAME_HEADER_SIZE, buf)?;
buf[0] = ((header.length >> 16) & 0xff) as u8;
buf[1] = ((header.length >> 8) & 0xff) as u8;
buf[2] = (header.length & 0xff) as u8;
buf[3] = header.frame_type;
buf[4] = header.flags.bits();
buf[5] = ((header.stream_id >> 24) & 0x7f) as u8;
buf[6] = ((header.stream_id >> 16) & 0xff) as u8;
buf[7] = ((header.stream_id >> 8) & 0xff) as u8;
buf[8] = (header.stream_id & 0xff) as u8;
Ok(())
}
pub fn encode_frame(buf: &mut [u8], frame: &Frame) -> Result<usize> {
let mut encoder = FrameEncoder::new();
encoder.encode(frame)?;
let encoded = encoder.buffer();
Error::check_buffer_size(encoded.len(), buf)?;
buf[..encoded.len()].copy_from_slice(encoded);
Ok(encoded.len())
}
pub fn encode_frame_to_vec(frame: &Frame) -> Result<Vec<u8>> {
let mut encoder = FrameEncoder::new();
encoder.encode(frame)?;
Ok(encoder.take())
}