#![allow(missing_docs)]
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionState {
Initial,
Connecting,
Connected,
Disconnecting,
Disconnected { reason: Option<CloseReason> },
Error { message: String },
}
#[derive(Debug, Clone, PartialEq)]
pub struct CloseReason {
pub code: u16,
pub reason: String,
}
impl From<CloseFrame<'_>> for CloseReason {
fn from(frame: CloseFrame) -> Self {
Self {
code: frame.code.into(),
reason: frame.reason.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub enum StateMachineEvent {
StartConnection,
ConnectionEstablished,
PingReceived,
PongReceived,
DataReceived,
RequestDisconnect,
ConnectionClosed(Option<CloseReason>),
ErrorOccurred(String),
}
pub struct WebSocketStateMachine {
state: ConnectionState,
}
impl WebSocketStateMachine {
pub fn new() -> Self {
Self {
state: ConnectionState::Initial,
}
}
pub fn current_state(&self) -> &ConnectionState {
&self.state
}
pub fn handle_event(&mut self, event: StateMachineEvent) -> Result<(), String> {
use ConnectionState::*;
use StateMachineEvent::*;
let new_state = match (&self.state, event.clone()) {
(Initial, StartConnection) => Connecting,
(Connecting, ConnectionEstablished) => Connected,
(Connected, PingReceived) => Connected,
(Connected, PongReceived) => Connected,
(Connected, DataReceived) => Connected,
(Connected, RequestDisconnect) => Disconnecting,
(Disconnecting, ConnectionClosed(reason)) => Disconnected { reason },
(Connected, ConnectionClosed(reason)) => Disconnected { reason },
(_, ErrorOccurred(msg)) => Error { message: msg },
_ => {
return Err(format!(
"Invalid state transition from {:?} with event {:?}",
self.state, event
));
}
};
self.state = new_state;
Ok(())
}
pub fn can_send_data(&self) -> bool {
matches!(self.state, ConnectionState::Connected)
}
pub fn is_connecting(&self) -> bool {
matches!(self.state, ConnectionState::Connecting)
}
pub fn is_connected(&self) -> bool {
matches!(self.state, ConnectionState::Connected)
}
pub fn is_disconnected(&self) -> bool {
matches!(
self.state,
ConnectionState::Disconnected { .. } | ConnectionState::Error { .. }
)
}
}
impl Default for WebSocketStateMachine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(unused_imports)]
mod tests {
use super::*;
#[test]
fn test_state_transitions() {
let mut sm = WebSocketStateMachine::new();
assert_eq!(sm.current_state(), &ConnectionState::Initial);
assert!(sm.handle_event(StateMachineEvent::StartConnection).is_ok());
assert_eq!(sm.current_state(), &ConnectionState::Connecting);
assert!(sm
.handle_event(StateMachineEvent::ConnectionEstablished)
.is_ok());
assert_eq!(sm.current_state(), &ConnectionState::Connected);
assert!(sm.can_send_data());
assert!(sm.handle_event(StateMachineEvent::DataReceived).is_ok());
assert_eq!(sm.current_state(), &ConnectionState::Connected);
assert!(sm
.handle_event(StateMachineEvent::RequestDisconnect)
.is_ok());
assert_eq!(sm.current_state(), &ConnectionState::Disconnecting);
let close_reason = Some(CloseReason {
code: 1000,
reason: "Normal closure".to_string(),
});
assert!(sm
.handle_event(StateMachineEvent::ConnectionClosed(close_reason.clone()))
.is_ok());
assert_eq!(
sm.current_state(),
&ConnectionState::Disconnected {
reason: close_reason
}
);
}
#[test]
fn test_invalid_transitions() {
let mut sm = WebSocketStateMachine::new();
assert!(sm.handle_event(StateMachineEvent::DataReceived).is_err());
}
}