use super::types::ConstrainedError;
use std::fmt;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum ConnectionState {
#[default]
Closed,
SynSent,
SynReceived,
Established,
FinWait,
Closing,
TimeWait,
}
impl ConnectionState {
pub const fn can_send_data(&self) -> bool {
matches!(self, Self::Established | Self::FinWait)
}
pub const fn can_receive_data(&self) -> bool {
matches!(self, Self::Established | Self::FinWait | Self::Closing)
}
pub const fn is_open(&self) -> bool {
matches!(
self,
Self::SynSent | Self::SynReceived | Self::Established | Self::FinWait | Self::Closing
)
}
pub const fn is_closed(&self) -> bool {
matches!(self, Self::Closed | Self::TimeWait)
}
pub const fn is_established(&self) -> bool {
matches!(self, Self::Established)
}
pub fn timeout(&self) -> Duration {
match self {
Self::Closed => Duration::MAX, Self::SynSent => Duration::from_secs(5), Self::SynReceived => Duration::from_secs(5),
Self::Established => Duration::from_secs(300), Self::FinWait => Duration::from_secs(30), Self::Closing => Duration::from_secs(30),
Self::TimeWait => Duration::from_secs(4), }
}
}
impl fmt::Display for ConnectionState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Closed => "CLOSED",
Self::SynSent => "SYN_SENT",
Self::SynReceived => "SYN_RCVD",
Self::Established => "ESTABLISHED",
Self::FinWait => "FIN_WAIT",
Self::Closing => "CLOSING",
Self::TimeWait => "TIME_WAIT",
};
write!(f, "{}", name)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StateEvent {
Open,
RecvSyn,
RecvSynAck,
RecvAck,
RecvFin,
RecvRst,
Close,
Timeout,
}
impl fmt::Display for StateEvent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
Self::Open => "OPEN",
Self::RecvSyn => "RECV_SYN",
Self::RecvSynAck => "RECV_SYN_ACK",
Self::RecvAck => "RECV_ACK",
Self::RecvFin => "RECV_FIN",
Self::RecvRst => "RECV_RST",
Self::Close => "CLOSE",
Self::Timeout => "TIMEOUT",
};
write!(f, "{}", name)
}
}
#[derive(Debug)]
pub struct StateMachine {
state: ConnectionState,
state_entered: Instant,
history: Vec<(ConnectionState, StateEvent, ConnectionState)>,
}
impl StateMachine {
pub fn new() -> Self {
Self {
state: ConnectionState::Closed,
state_entered: Instant::now(),
history: Vec::with_capacity(8),
}
}
pub fn state(&self) -> ConnectionState {
self.state
}
pub fn time_in_state(&self) -> Duration {
self.state_entered.elapsed()
}
pub fn is_timed_out(&self) -> bool {
self.time_in_state() > self.state.timeout()
}
pub fn can_send_data(&self) -> bool {
self.state.can_send_data()
}
pub fn can_receive_data(&self) -> bool {
self.state.can_receive_data()
}
pub fn transition(&mut self, event: StateEvent) -> Result<ConnectionState, ConstrainedError> {
let old_state = self.state;
let new_state = self.next_state(event)?;
if self.history.len() >= 8 {
self.history.remove(0);
}
self.history.push((old_state, event, new_state));
self.state = new_state;
self.state_entered = Instant::now();
tracing::trace!(
from = %old_state,
event = %event,
to = %new_state,
"State transition"
);
Ok(new_state)
}
fn next_state(&self, event: StateEvent) -> Result<ConnectionState, ConstrainedError> {
use ConnectionState::*;
use StateEvent::*;
let new_state = match (self.state, event) {
(Closed, Open) => SynSent,
(Closed, RecvSyn) => SynReceived,
(SynSent, RecvSynAck) => Established,
(SynSent, RecvRst) => Closed,
(SynSent, Timeout) => Closed,
(SynSent, Close) => Closed,
(SynReceived, RecvAck) => Established,
(SynReceived, RecvRst) => Closed,
(SynReceived, Timeout) => Closed,
(SynReceived, Close) => Closed,
(Established, RecvFin) => Closing,
(Established, Close) => FinWait,
(Established, RecvRst) => Closed,
(Established, Timeout) => Closed,
(FinWait, RecvAck) => Closing,
(FinWait, RecvFin) => TimeWait,
(FinWait, RecvRst) => Closed,
(FinWait, Timeout) => Closed,
(Closing, RecvAck) => TimeWait,
(Closing, RecvFin) => TimeWait,
(Closing, RecvRst) => Closed,
(Closing, Timeout) => Closed,
(TimeWait, Timeout) => Closed,
(TimeWait, RecvRst) => Closed,
_ => {
return Err(ConstrainedError::InvalidStateTransition {
from: self.state.to_string(),
to: format!("{} -> ?", event),
});
}
};
Ok(new_state)
}
#[cfg(test)]
pub fn force_state(&mut self, state: ConnectionState) {
self.state = state;
self.state_entered = Instant::now();
}
pub fn history(&self) -> &[(ConnectionState, StateEvent, ConnectionState)] {
&self.history
}
}
impl Default for StateMachine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_state_display() {
assert_eq!(format!("{}", ConnectionState::Closed), "CLOSED");
assert_eq!(format!("{}", ConnectionState::Established), "ESTABLISHED");
assert_eq!(format!("{}", ConnectionState::SynSent), "SYN_SENT");
}
#[test]
fn test_state_properties() {
assert!(!ConnectionState::Closed.can_send_data());
assert!(ConnectionState::Established.can_send_data());
assert!(ConnectionState::FinWait.can_send_data());
assert!(ConnectionState::Closed.is_closed());
assert!(ConnectionState::TimeWait.is_closed());
assert!(!ConnectionState::Established.is_closed());
assert!(ConnectionState::Established.is_established());
assert!(!ConnectionState::SynSent.is_established());
}
#[test]
fn test_state_machine_new() {
let sm = StateMachine::new();
assert_eq!(sm.state(), ConnectionState::Closed);
}
#[test]
fn test_normal_connection_flow() {
let mut sm = StateMachine::new();
assert_eq!(
sm.transition(StateEvent::Open).unwrap(),
ConnectionState::SynSent
);
assert_eq!(
sm.transition(StateEvent::RecvSynAck).unwrap(),
ConnectionState::Established
);
assert_eq!(
sm.transition(StateEvent::Close).unwrap(),
ConnectionState::FinWait
);
assert_eq!(
sm.transition(StateEvent::RecvFin).unwrap(),
ConnectionState::TimeWait
);
assert_eq!(
sm.transition(StateEvent::Timeout).unwrap(),
ConnectionState::Closed
);
}
#[test]
fn test_responder_flow() {
let mut sm = StateMachine::new();
assert_eq!(
sm.transition(StateEvent::RecvSyn).unwrap(),
ConnectionState::SynReceived
);
assert_eq!(
sm.transition(StateEvent::RecvAck).unwrap(),
ConnectionState::Established
);
}
#[test]
fn test_reset_from_any_state() {
let mut sm = StateMachine::new();
sm.transition(StateEvent::Open).unwrap();
assert_eq!(sm.state(), ConnectionState::SynSent);
assert_eq!(
sm.transition(StateEvent::RecvRst).unwrap(),
ConnectionState::Closed
);
}
#[test]
fn test_invalid_transition() {
let mut sm = StateMachine::new();
let result = sm.transition(StateEvent::RecvSynAck);
assert!(result.is_err());
match result {
Err(ConstrainedError::InvalidStateTransition { from, .. }) => {
assert_eq!(from, "CLOSED");
}
_ => panic!("Expected InvalidStateTransition error"),
}
}
#[test]
fn test_timeout_detection() {
let sm = StateMachine::new();
assert!(!sm.is_timed_out());
}
#[test]
fn test_history_tracking() {
let mut sm = StateMachine::new();
sm.transition(StateEvent::Open).unwrap();
sm.transition(StateEvent::RecvSynAck).unwrap();
let history = sm.history();
assert_eq!(history.len(), 2);
assert_eq!(history[0].0, ConnectionState::Closed);
assert_eq!(history[0].1, StateEvent::Open);
assert_eq!(history[0].2, ConnectionState::SynSent);
}
#[test]
fn test_event_display() {
assert_eq!(format!("{}", StateEvent::Open), "OPEN");
assert_eq!(format!("{}", StateEvent::RecvSyn), "RECV_SYN");
assert_eq!(format!("{}", StateEvent::Close), "CLOSE");
}
#[test]
fn test_state_timeout_durations() {
assert!(ConnectionState::SynSent.timeout() < Duration::from_secs(60));
assert!(ConnectionState::Established.timeout() >= Duration::from_secs(60));
assert!(ConnectionState::TimeWait.timeout() < Duration::from_secs(60));
}
}