use crate::error::{Error, ErrorCode};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum StreamState {
#[default]
Idle,
ReservedLocal,
ReservedRemote,
Open,
HalfClosedLocal,
HalfClosedRemote,
Closed,
}
impl StreamState {
#[must_use]
pub const fn is_open(&self) -> bool {
matches!(self, Self::Open)
}
#[must_use]
pub const fn is_closed(&self) -> bool {
matches!(self, Self::Closed)
}
#[must_use]
pub const fn is_idle(&self) -> bool {
matches!(self, Self::Idle)
}
#[must_use]
pub const fn can_send(&self) -> bool {
matches!(self, Self::Open | Self::HalfClosedRemote)
}
#[must_use]
pub const fn can_recv(&self) -> bool {
matches!(self, Self::Open | Self::HalfClosedLocal)
}
}
#[derive(Debug, Clone)]
pub struct StateMachine {
state: StreamState,
}
impl StateMachine {
#[must_use]
pub const fn new() -> Self {
Self {
state: StreamState::Idle,
}
}
#[must_use]
pub const fn state(&self) -> StreamState {
self.state
}
pub fn send_headers(&mut self, end_stream: bool) -> Result<(), Error> {
self.state = match self.state {
StreamState::Idle => {
if end_stream {
StreamState::HalfClosedLocal
} else {
StreamState::Open
}
}
StreamState::ReservedLocal => {
if end_stream {
StreamState::Closed
} else {
StreamState::HalfClosedRemote
}
}
StreamState::Open => {
if end_stream {
StreamState::HalfClosedLocal
} else {
StreamState::Open
}
}
StreamState::HalfClosedRemote => {
if end_stream {
StreamState::Closed
} else {
StreamState::HalfClosedRemote
}
}
_ => {
return Err(Error::stream_error(
ErrorCode::StreamClosed,
"cannot send HEADERS in current state",
));
}
};
Ok(())
}
pub fn recv_headers(&mut self, end_stream: bool) -> Result<(), Error> {
self.state = match self.state {
StreamState::Idle => {
if end_stream {
StreamState::HalfClosedRemote
} else {
StreamState::Open
}
}
StreamState::ReservedRemote => {
if end_stream {
StreamState::Closed
} else {
StreamState::HalfClosedLocal
}
}
StreamState::Open | StreamState::HalfClosedLocal => {
if end_stream {
if self.state == StreamState::HalfClosedLocal {
StreamState::Closed
} else {
StreamState::HalfClosedRemote
}
} else {
self.state
}
}
_ => {
return Err(Error::stream_error(
ErrorCode::StreamClosed,
"cannot receive HEADERS in current state",
));
}
};
Ok(())
}
pub fn send_data(&mut self, end_stream: bool) -> Result<(), Error> {
match self.state {
StreamState::Open => {
if end_stream {
self.state = StreamState::HalfClosedLocal;
}
Ok(())
}
StreamState::HalfClosedRemote => {
if end_stream {
self.state = StreamState::Closed;
}
Ok(())
}
_ => Err(Error::stream_error(
ErrorCode::StreamClosed,
"cannot send DATA in current state",
)),
}
}
pub fn recv_data(&mut self, end_stream: bool) -> Result<(), Error> {
match self.state {
StreamState::Open => {
if end_stream {
self.state = StreamState::HalfClosedRemote;
}
Ok(())
}
StreamState::HalfClosedLocal => {
if end_stream {
self.state = StreamState::Closed;
}
Ok(())
}
_ => Err(Error::stream_error(
ErrorCode::StreamClosed,
"cannot receive DATA in current state",
)),
}
}
pub fn send_rst_stream(&mut self) {
self.state = StreamState::Closed;
}
pub fn recv_rst_stream(&mut self) {
self.state = StreamState::Closed;
}
#[must_use]
pub const fn sent_end_stream(&self) -> bool {
matches!(
self.state,
StreamState::HalfClosedLocal | StreamState::Closed
)
}
#[must_use]
pub const fn received_end_stream(&self) -> bool {
matches!(
self.state,
StreamState::HalfClosedRemote | StreamState::Closed
)
}
}
impl Default for StateMachine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_idle_to_open() {
let mut sm = StateMachine::new();
assert_eq!(sm.state(), StreamState::Idle);
sm.send_headers(false).unwrap();
assert_eq!(sm.state(), StreamState::Open);
}
#[test]
fn test_idle_to_half_closed_local() {
let mut sm = StateMachine::new();
sm.send_headers(true).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedLocal);
}
#[test]
fn test_open_to_half_closed_local() {
let mut sm = StateMachine::new();
sm.send_headers(false).unwrap();
sm.send_data(true).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedLocal);
}
#[test]
fn test_open_to_half_closed_remote() {
let mut sm = StateMachine::new();
sm.recv_headers(false).unwrap();
sm.recv_data(true).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedRemote);
}
#[test]
fn test_half_closed_to_closed() {
let mut sm = StateMachine::new();
sm.send_headers(false).unwrap();
sm.send_data(true).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedLocal);
sm.recv_data(true).unwrap();
assert_eq!(sm.state(), StreamState::Closed);
}
#[test]
fn test_rst_stream_closes() {
let mut sm = StateMachine::new();
sm.send_headers(false).unwrap();
sm.send_rst_stream();
assert_eq!(sm.state(), StreamState::Closed);
}
#[test]
fn test_server_response_from_half_closed_remote() {
let mut sm = StateMachine::new();
sm.recv_headers(true).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedRemote);
sm.send_headers(true).unwrap();
assert_eq!(sm.state(), StreamState::Closed);
}
#[test]
fn test_server_response_with_body_from_half_closed_remote() {
let mut sm = StateMachine::new();
sm.recv_headers(true).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedRemote);
sm.send_headers(false).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedRemote);
sm.send_data(true).unwrap();
assert_eq!(sm.state(), StreamState::Closed);
}
#[test]
fn test_trailer_headers_from_open() {
let mut sm = StateMachine::new();
sm.send_headers(false).unwrap();
assert_eq!(sm.state(), StreamState::Open);
sm.send_headers(true).unwrap();
assert_eq!(sm.state(), StreamState::HalfClosedLocal);
}
}