use crate::error::{Error, ErrorCode};
use crate::frame::{self, Frame, GoawayPayload, SettingsPayload};
use crate::settings::Settings;
use crate::varint::VarInt;
use super::{RecvBuffer, SendBuffer};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ControlSendState {
Initial,
Ready,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ControlRecvState {
WaitingType,
WaitingSettings,
Ready,
}
#[derive(Debug)]
pub(crate) struct ControlStreamSend {
stream_id: Option<u64>,
send_buf: SendBuffer,
send_state: ControlSendState,
}
impl ControlStreamSend {
pub fn new() -> Self {
Self {
stream_id: None,
send_buf: SendBuffer::new(),
send_state: ControlSendState::Initial,
}
}
pub fn set_stream_id(&mut self, stream_id: u64) {
self.stream_id = Some(stream_id);
}
pub fn stream_id(&self) -> Option<u64> {
self.stream_id
}
pub fn send_settings(&mut self, settings: &Settings) {
if self.send_state != ControlSendState::Initial {
return;
}
let mut buf = vec![0x00];
let payload = SettingsPayload::from_settings(settings);
let frame = Frame::Settings(payload);
let frame_len =
frame::encoded_frame_len(&frame).expect("SETTINGS frame fields fit in VarInt");
buf.resize(1 + frame_len, 0);
let written =
frame::encode_frame(&mut buf[1..], &frame).expect("encoded_frame_len validated above");
debug_assert_eq!(written, frame_len);
self.send_buf.push(&buf);
self.send_state = ControlSendState::Ready;
}
pub fn send_goaway(&mut self, id: VarInt) -> Result<(), Error> {
if self.send_state != ControlSendState::Ready {
return Err(crate::error::Error::ConnectionError(
crate::error::ErrorCode::ClosedCriticalStream,
));
}
let frame = Frame::Goaway(GoawayPayload::new(id));
let frame_len = frame::encoded_frame_len(&frame)
.expect("GOAWAY id is VarInt typed, payload always fits");
let mut buf = vec![0u8; frame_len];
let written =
frame::encode_frame(&mut buf, &frame).expect("encoded_frame_len validated above");
debug_assert_eq!(written, frame_len);
self.send_buf.push(&buf);
Ok(())
}
pub fn get_data(&self) -> &[u8] {
self.send_buf.peek()
}
pub fn consume_data(&mut self, len: usize) {
self.send_buf.consume(len);
}
pub fn has_pending(&self) -> bool {
self.send_buf.has_pending()
}
}
impl Default for ControlStreamSend {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub(crate) struct ControlStreamRecv {
stream_id: Option<u64>,
recv_buf: RecvBuffer,
recv_state: ControlRecvState,
peer_settings: Option<Settings>,
}
impl ControlStreamRecv {
pub fn new() -> Self {
Self {
stream_id: None,
recv_buf: RecvBuffer::new(),
recv_state: ControlRecvState::WaitingType,
peer_settings: None,
}
}
pub fn set_stream_id(&mut self, stream_id: u64) {
self.stream_id = Some(stream_id);
}
pub fn stream_id(&self) -> Option<u64> {
self.stream_id
}
pub fn skip_stream_type(&mut self) {
if self.recv_state == ControlRecvState::WaitingType {
self.recv_state = ControlRecvState::WaitingSettings;
}
}
pub fn receive(&mut self, data: &[u8]) {
self.recv_buf.push(data);
}
pub fn has_pending_data(&self) -> bool {
!self.recv_buf.peek().is_empty()
}
pub fn process(&mut self) -> Result<Option<Frame>, Error> {
loop {
let data = self.recv_buf.peek();
if data.is_empty() {
return Ok(None);
}
match self.recv_state {
ControlRecvState::WaitingType => {
if data.is_empty() {
return Ok(None);
}
if data[0] != 0x00 {
return Err(Error::ConnectionError(ErrorCode::StreamCreationError));
}
self.recv_buf.consume(1);
self.recv_state = ControlRecvState::WaitingSettings;
}
ControlRecvState::WaitingSettings | ControlRecvState::Ready => {
let header = match frame::decode_frame_header(data) {
Ok(h) => h,
Err(crate::error::FrameDecodeError::BufferTooShort) => return Ok(None),
Err(crate::error::FrameDecodeError::Http2Frame(_)) => {
return Err(Error::ConnectionError(ErrorCode::FrameUnexpected));
}
Err(crate::error::FrameDecodeError::InvalidLength) => {
return Err(Error::ConnectionError(ErrorCode::FrameError));
}
Err(e) => return Err(Error::FrameDecode(e)),
};
let Some(total_len) = header.total_len() else {
return Err(Error::ConnectionError(ErrorCode::FrameError));
};
if data.len() < total_len {
return Ok(None);
}
let (frame, consumed) = frame::decode_frame(data).map_err(|e| match e {
crate::error::FrameDecodeError::ServerPushNotSupported(_) => {
Error::ConnectionError(ErrorCode::FrameUnexpected)
}
crate::error::FrameDecodeError::InvalidSetting(_) => {
Error::ConnectionError(ErrorCode::SettingsError)
}
crate::error::FrameDecodeError::InvalidLength => {
Error::ConnectionError(ErrorCode::FrameError)
}
other => Error::FrameDecode(other),
})?;
self.recv_buf.consume(consumed);
if self.recv_state == ControlRecvState::WaitingSettings {
if let Frame::Settings(ref payload) = frame {
self.peer_settings = Some(Settings::from_payload(payload));
self.recv_state = ControlRecvState::Ready;
} else {
return Err(Error::ConnectionError(ErrorCode::MissingSettings));
}
} else {
match &frame {
Frame::Settings(_) => {
return Err(Error::ConnectionError(ErrorCode::FrameUnexpected));
}
Frame::Data(_) | Frame::Headers(_) => {
return Err(Error::ConnectionError(ErrorCode::FrameUnexpected));
}
_ => {}
}
}
return Ok(Some(frame));
}
}
}
}
#[cfg(test)]
pub fn is_ready(&self) -> bool {
self.recv_state == ControlRecvState::Ready
}
}
impl Default for ControlStreamRecv {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_control_stream_send() {
use crate::varint::VarInt;
let mut stream = ControlStreamSend::new();
let settings = Settings::new().max_field_section_size(VarInt::from_static(16384));
stream.send_settings(&settings);
assert!(stream.has_pending());
let data = stream.get_data();
assert!(!data.is_empty());
assert_eq!(data[0], 0x00); }
#[test]
fn test_control_stream_recv() {
let mut stream = ControlStreamRecv::new();
let data = [0x00, 0x04, 0x00];
stream.receive(&data);
let frame = stream.process().unwrap().unwrap();
assert!(matches!(frame, Frame::Settings(_)));
assert!(stream.is_ready());
}
#[test]
fn test_control_stream_recv_duplicate_settings() {
let mut stream = ControlStreamRecv::new();
let data = [0x00, 0x04, 0x00];
stream.receive(&data);
stream.process().unwrap().unwrap();
assert!(stream.is_ready());
let data = [0x04, 0x00];
stream.receive(&data);
let result = stream.process();
assert!(
matches!(
result,
Err(Error::ConnectionError(ErrorCode::FrameUnexpected))
),
"expected ConnectionError(FrameUnexpected), got {result:?}"
);
}
#[test]
fn test_control_stream_recv_data_frame_is_error() {
let mut stream = ControlStreamRecv::new();
let data = [0x00, 0x04, 0x00];
stream.receive(&data);
stream.process().unwrap().unwrap();
let data = [0x00, 0x03, 0x01, 0x02, 0x03];
stream.receive(&data);
let result = stream.process();
assert!(
matches!(
result,
Err(Error::ConnectionError(ErrorCode::FrameUnexpected))
),
"expected ConnectionError(FrameUnexpected), got {result:?}"
);
}
#[test]
fn test_control_stream_recv_http2_frame_is_error() {
let mut stream = ControlStreamRecv::new();
let data = [0x00, 0x04, 0x00];
stream.receive(&data);
stream.process().unwrap().unwrap();
let data = [0x06, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
stream.receive(&data);
let result = stream.process();
assert!(
matches!(
result,
Err(Error::ConnectionError(ErrorCode::FrameUnexpected))
),
"expected ConnectionError(FrameUnexpected), got {result:?}"
);
}
#[test]
fn test_control_stream_recv_http2_settings_id() {
let mut stream = ControlStreamRecv::new();
let data = [0x00, 0x04, 0x02, 0x02, 0x01];
stream.receive(&data);
let result = stream.process();
assert!(
matches!(
result,
Err(Error::ConnectionError(ErrorCode::SettingsError))
),
"expected ConnectionError(SettingsError), got {result:?}"
);
}
}