use super::{CloseCode, Frame, Opcode, WsError};
use crate::bytes::Bytes;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CloseReason {
pub code: Option<CloseCode>,
pub raw_code: Option<u16>,
pub text: Option<String>,
}
impl CloseReason {
#[must_use]
pub fn new(code: CloseCode, text: Option<&str>) -> Self {
let raw = u16::from(code);
Self {
code: Some(code),
raw_code: Some(raw),
text: text.map(String::from),
}
}
#[must_use]
pub fn empty() -> Self {
Self {
code: None,
raw_code: None,
text: None,
}
}
#[must_use]
pub fn normal() -> Self {
Self::new(CloseCode::Normal, None)
}
#[must_use]
pub fn going_away() -> Self {
Self::new(CloseCode::GoingAway, None)
}
#[must_use]
pub fn with_text(code: CloseCode, text: &str) -> Self {
Self::new(code, Some(text))
}
pub fn parse(payload: &[u8]) -> Result<Self, WsError> {
match payload.len() {
0 => Ok(Self::empty()),
1 => Err(WsError::InvalidClosePayload),
_ => {
let code_raw = u16::from_be_bytes([payload[0], payload[1]]);
if !CloseCode::is_valid_received_code(code_raw) {
return Err(WsError::InvalidClosePayload);
}
let code = CloseCode::from_u16(code_raw);
let text = if payload.len() > 2 {
let text_bytes = &payload[2..];
let text_str = std::str::from_utf8(text_bytes)
.map_err(|_| WsError::InvalidClosePayload)?;
Some(text_str.to_string())
} else {
None
};
Ok(Self {
code,
raw_code: Some(code_raw),
text,
})
}
}
}
#[must_use]
pub fn encode(&self) -> Bytes {
match self.outbound_payload_parts() {
(DropReasonText::No, code, text) => Self::encode_parts(code, text),
(DropReasonText::Yes, code, _) => Self::encode_parts(code, None),
}
}
fn encode_parts(code: Option<u16>, text: Option<&str>) -> Bytes {
match (code, text) {
(None, None) => Bytes::new(),
(None, Some(_text)) => Bytes::new(),
(Some(code_val), None) => Bytes::copy_from_slice(&code_val.to_be_bytes()),
(Some(code_val), Some(text)) => {
let mut buf = Vec::with_capacity(2 + text.len());
buf.extend_from_slice(&code_val.to_be_bytes());
buf.extend_from_slice(text.as_bytes());
Bytes::from(buf)
}
}
}
#[must_use]
pub fn to_frame(&self) -> Frame {
match self.outbound_payload_parts() {
(DropReasonText::No, code, text) => Frame::close(code, text),
(DropReasonText::Yes, code, _) => Frame::close(code, None),
}
}
#[must_use]
fn outbound_wire_code(&self) -> Option<u16> {
match self.raw_code {
Some(code) if CloseCode::is_valid_code(code) => Some(code),
Some(_) => None,
None => self
.code
.map(u16::from)
.filter(|code| CloseCode::is_valid_code(*code)),
}
}
#[must_use]
fn outbound_payload_parts(&self) -> (DropReasonText, Option<u16>, Option<&str>) {
let code = self.outbound_wire_code();
let drop_text = (self.raw_code.is_some() || self.code.is_some()) && code.is_none();
(
if drop_text {
DropReasonText::Yes
} else {
DropReasonText::No
},
code,
self.text.as_deref(),
)
}
#[must_use]
pub fn is_normal(&self) -> bool {
self.wire_code() == Some(u16::from(CloseCode::Normal))
}
#[must_use]
pub fn is_error(&self) -> bool {
matches!(
self.code,
Some(
CloseCode::ProtocolError
| CloseCode::InvalidPayload
| CloseCode::PolicyViolation
| CloseCode::InternalError
)
)
}
#[must_use]
pub const fn wire_code(&self) -> Option<u16> {
match (self.raw_code, self.code) {
(Some(code), _) => Some(code),
(None, Some(code)) => Some(code as u16),
(None, None) => None,
}
}
}
impl Default for CloseReason {
fn default() -> Self {
Self::empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DropReasonText {
No,
Yes,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CloseState {
#[default]
Open,
CloseSent,
CloseReceived,
Closed,
}
impl CloseState {
#[must_use]
pub const fn is_open(self) -> bool {
matches!(self, Self::Open)
}
#[must_use]
pub const fn is_closed(self) -> bool {
matches!(self, Self::Closed)
}
#[must_use]
pub const fn is_closing(self) -> bool {
matches!(self, Self::CloseSent | Self::CloseReceived)
}
}
#[derive(Debug, Clone)]
pub struct CloseConfig {
pub close_timeout: Duration,
pub close_on_drop: bool,
pub cancellation_code: CloseCode,
}
impl Default for CloseConfig {
fn default() -> Self {
Self {
close_timeout: Duration::from_secs(5),
close_on_drop: true,
cancellation_code: CloseCode::GoingAway,
}
}
}
impl CloseConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.close_timeout = timeout;
self
}
#[must_use]
pub fn with_close_on_drop(mut self, enabled: bool) -> Self {
self.close_on_drop = enabled;
self
}
#[must_use]
pub fn with_cancellation_code(mut self, code: CloseCode) -> Self {
self.cancellation_code = code;
self
}
}
#[derive(Debug)]
pub struct CloseHandshake {
state: CloseState,
config: CloseConfig,
our_reason: Option<CloseReason>,
peer_reason: Option<CloseReason>,
}
impl CloseHandshake {
#[must_use]
pub fn new() -> Self {
Self::with_config(CloseConfig::default())
}
#[must_use]
pub fn with_config(config: CloseConfig) -> Self {
Self {
state: CloseState::Open,
config,
our_reason: None,
peer_reason: None,
}
}
#[must_use]
pub const fn state(&self) -> CloseState {
self.state
}
#[must_use]
pub const fn is_open(&self) -> bool {
self.state.is_open()
}
#[must_use]
pub const fn is_closed(&self) -> bool {
self.state.is_closed()
}
#[must_use]
pub fn our_reason(&self) -> Option<&CloseReason> {
self.our_reason.as_ref()
}
#[must_use]
pub fn peer_reason(&self) -> Option<&CloseReason> {
self.peer_reason.as_ref()
}
#[must_use]
pub const fn close_timeout(&self) -> Duration {
self.config.close_timeout
}
pub fn initiate(&mut self, reason: CloseReason) -> Option<Frame> {
match self.state {
CloseState::Open => {
self.state = CloseState::CloseSent;
let frame = reason.to_frame();
self.our_reason = Some(reason);
Some(frame)
}
CloseState::CloseReceived => {
let frame = reason.to_frame();
self.our_reason = Some(reason);
Some(frame)
}
CloseState::CloseSent | CloseState::Closed => None,
}
}
pub fn receive_close(&mut self, frame: &Frame) -> Result<Option<Frame>, WsError> {
if frame.opcode != Opcode::Close {
return Err(WsError::InvalidOpcode(frame.opcode as u8));
}
let reason = CloseReason::parse(&frame.payload)?;
match self.state {
CloseState::Open => {
self.state = CloseState::CloseReceived;
let response_code = reason.outbound_wire_code();
self.peer_reason = Some(reason);
let response = if let Some(response_code) = response_code {
Frame::close(Some(response_code), None)
} else {
Frame::close(None, None)
};
Ok(Some(response))
}
CloseState::CloseSent => {
self.state = CloseState::Closed;
self.peer_reason = Some(reason);
Ok(None)
}
CloseState::CloseReceived | CloseState::Closed => {
Ok(None)
}
}
}
pub fn mark_response_sent(&mut self) {
if self.state == CloseState::CloseReceived {
self.state = CloseState::Closed;
}
}
pub fn force_close(&mut self, reason: CloseReason) {
self.state = CloseState::Closed;
if self.our_reason.is_none() {
self.our_reason = Some(reason);
}
}
#[cfg(test)]
pub fn reset(&mut self) {
self.state = CloseState::Open;
self.our_reason = None;
self.peer_reason = None;
}
}
impl Default for CloseHandshake {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn close_reason_parse_empty() {
let reason = CloseReason::parse(&[]).unwrap();
assert_eq!(reason.code, None);
assert_eq!(reason.raw_code, None);
assert_eq!(reason.text, None);
}
#[test]
fn close_reason_parse_code_only() {
let payload = 1000u16.to_be_bytes();
let reason = CloseReason::parse(&payload).unwrap();
assert_eq!(reason.code, Some(CloseCode::Normal));
assert_eq!(reason.raw_code, Some(1000));
assert_eq!(reason.text, None);
}
#[test]
fn close_reason_parse_code_and_text() {
let mut payload = Vec::new();
payload.extend_from_slice(&1001u16.to_be_bytes());
payload.extend_from_slice(b"Going away");
let reason = CloseReason::parse(&payload).unwrap();
assert_eq!(reason.code, Some(CloseCode::GoingAway));
assert_eq!(reason.raw_code, Some(1001));
assert_eq!(reason.text.as_deref(), Some("Going away"));
}
#[test]
fn close_reason_parse_custom_code_preserves_raw_code() {
let mut payload = Vec::new();
payload.extend_from_slice(&3001u16.to_be_bytes());
payload.extend_from_slice(b"custom");
let reason = CloseReason::parse(&payload).unwrap();
assert_eq!(reason.code, None);
assert_eq!(reason.raw_code, Some(3001));
assert_eq!(reason.wire_code(), Some(3001));
assert_eq!(reason.text.as_deref(), Some("custom"));
}
#[test]
fn close_reason_parse_invalid_single_byte() {
let result = CloseReason::parse(&[0x00]);
assert!(result.is_err());
}
#[test]
fn close_reason_parse_invalid_utf8() {
let mut payload = Vec::new();
payload.extend_from_slice(&1000u16.to_be_bytes());
payload.extend_from_slice(&[0xFF, 0xFE]);
let result = CloseReason::parse(&payload);
assert!(result.is_err());
}
#[test]
fn close_reason_parse_invalid_reserved_code() {
let payload = 1004u16.to_be_bytes();
let result = CloseReason::parse(&payload);
assert!(matches!(result, Err(WsError::InvalidClosePayload)));
}
#[test]
fn close_reason_parse_unassigned_code_accepted() {
let payload = 1016u16.to_be_bytes();
let result = CloseReason::parse(&payload).unwrap();
assert_eq!(result.code, None);
assert_eq!(result.raw_code, Some(1016));
}
#[test]
fn close_reason_parse_iana_registered_1012_accepted() {
let payload = 1012u16.to_be_bytes();
let result = CloseReason::parse(&payload);
assert!(result.is_ok(), "IANA-registered code 1012 must parse");
}
#[test]
fn close_reason_encode_empty() {
let reason = CloseReason::empty();
let encoded = reason.encode();
assert!(encoded.is_empty());
}
#[test]
fn close_reason_encode_code_only() {
let reason = CloseReason::new(CloseCode::Normal, None);
let encoded = reason.encode();
assert_eq!(encoded.as_ref(), &1000u16.to_be_bytes());
}
#[test]
fn close_reason_encode_code_and_text() {
let reason = CloseReason::with_text(CloseCode::GoingAway, "bye");
let encoded = reason.encode();
let mut expected = Vec::new();
expected.extend_from_slice(&1001u16.to_be_bytes());
expected.extend_from_slice(b"bye");
assert_eq!(encoded.as_ref(), expected.as_slice());
}
#[test]
fn close_reason_roundtrip() {
let original = CloseReason::with_text(CloseCode::Normal, "goodbye");
let encoded = original.encode();
let parsed = CloseReason::parse(&encoded).unwrap();
assert_eq!(original.code, parsed.code);
assert_eq!(original.raw_code, parsed.raw_code);
assert_eq!(original.text, parsed.text);
}
#[test]
fn close_reason_unsendable_received_code_encodes_as_empty_close() {
let payload = 2000u16.to_be_bytes();
let reason = CloseReason::parse(&payload).unwrap();
assert_eq!(reason.raw_code, Some(2000));
assert!(reason.encode().is_empty());
assert!(reason.to_frame().payload.is_empty());
}
#[test]
fn close_reason_without_code_drops_text_on_encode() {
let reason = CloseReason {
code: None,
raw_code: None,
text: Some("text without code".to_string()),
};
assert!(reason.encode().is_empty());
assert!(reason.to_frame().payload.is_empty());
}
#[test]
fn close_code_valid_ranges() {
assert!(CloseCode::is_valid_code(1000));
assert!(CloseCode::is_valid_code(1003));
assert!(CloseCode::is_valid_code(1007));
assert!(CloseCode::is_valid_code(1011));
assert!(CloseCode::is_valid_code(3000));
assert!(CloseCode::is_valid_code(4999));
assert!(!CloseCode::is_valid_code(1004)); assert!(!CloseCode::is_valid_code(1005)); assert!(!CloseCode::is_valid_code(1006)); assert!(!CloseCode::is_valid_code(999)); assert!(!CloseCode::is_valid_code(5000)); }
#[test]
fn close_state_transitions() {
assert!(CloseState::Open.is_open());
assert!(!CloseState::Open.is_closed());
assert!(!CloseState::Open.is_closing());
assert!(!CloseState::CloseSent.is_open());
assert!(!CloseState::CloseSent.is_closed());
assert!(CloseState::CloseSent.is_closing());
assert!(!CloseState::Closed.is_open());
assert!(CloseState::Closed.is_closed());
assert!(!CloseState::Closed.is_closing());
}
#[test]
fn handshake_initiate_from_open() {
let mut handshake = CloseHandshake::new();
let frame = handshake.initiate(CloseReason::normal());
assert!(frame.is_some());
assert_eq!(handshake.state(), CloseState::CloseSent);
assert!(handshake.our_reason().is_some());
}
#[test]
fn handshake_initiate_when_already_closing() {
let mut handshake = CloseHandshake::new();
handshake.initiate(CloseReason::normal());
let frame = handshake.initiate(CloseReason::normal());
assert!(frame.is_none());
}
#[test]
fn handshake_receive_close_from_open() {
let mut handshake = CloseHandshake::new();
let close_frame = Frame::close(Some(1000), Some("bye"));
let response = handshake.receive_close(&close_frame).unwrap();
assert!(response.is_some());
assert_eq!(handshake.state(), CloseState::CloseReceived);
assert!(handshake.peer_reason().is_some());
}
#[test]
fn handshake_receive_close_echoes_custom_code() {
let mut handshake = CloseHandshake::new();
let close_frame = Frame::close(Some(3001), Some("custom"));
let response = handshake.receive_close(&close_frame).unwrap().unwrap();
assert_eq!(response.opcode, Opcode::Close);
assert_eq!(&response.payload[..2], &3001u16.to_be_bytes());
assert_eq!(handshake.peer_reason().unwrap().wire_code(), Some(3001));
assert_eq!(handshake.state(), CloseState::CloseReceived);
}
#[test]
fn handshake_receive_close_rejects_rfc6455_tls_handshake_code_1015() {
let mut handshake = CloseHandshake::new();
let close_frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Close,
masked: false,
mask_key: None,
payload: Bytes::copy_from_slice(&1015u16.to_be_bytes()),
};
let err = handshake.receive_close(&close_frame).unwrap_err();
assert!(matches!(err, WsError::InvalidClosePayload));
assert_eq!(handshake.state(), CloseState::Open);
assert!(handshake.peer_reason().is_none());
}
#[test]
fn handshake_receive_empty_close_keeps_response_payload_empty() {
let mut handshake = CloseHandshake::new();
let close_frame = Frame::close(None, None);
let response = handshake.receive_close(&close_frame).unwrap().unwrap();
assert_eq!(response.opcode, Opcode::Close);
assert!(response.payload.is_empty());
assert_eq!(handshake.peer_reason(), Some(&CloseReason::empty()));
assert_eq!(handshake.state(), CloseState::CloseReceived);
}
#[test]
fn handshake_receive_close_after_sent() {
let mut handshake = CloseHandshake::new();
handshake.initiate(CloseReason::normal());
let close_frame = Frame::close(Some(1000), None);
let response = handshake.receive_close(&close_frame).unwrap();
assert!(response.is_none()); assert_eq!(handshake.state(), CloseState::Closed);
}
#[test]
fn handshake_complete_flow_initiator() {
let mut handshake = CloseHandshake::new();
let frame = handshake.initiate(CloseReason::normal());
assert!(frame.is_some());
assert_eq!(handshake.state(), CloseState::CloseSent);
let peer_close = Frame::close(Some(1000), None);
let response = handshake.receive_close(&peer_close).unwrap();
assert!(response.is_none()); assert_eq!(handshake.state(), CloseState::Closed);
}
#[test]
fn handshake_complete_flow_receiver() {
let mut handshake = CloseHandshake::new();
let peer_close = Frame::close(Some(1000), Some("goodbye"));
let response = handshake.receive_close(&peer_close).unwrap();
assert!(response.is_some());
assert_eq!(handshake.state(), CloseState::CloseReceived);
let frame = handshake.initiate(CloseReason::normal());
assert!(frame.is_some());
assert_eq!(handshake.state(), CloseState::CloseReceived);
handshake.mark_response_sent();
assert_eq!(handshake.state(), CloseState::Closed);
}
#[test]
fn handshake_response_initiation_does_not_complete_before_send() {
let mut handshake = CloseHandshake::new();
let peer_close = Frame::close(Some(1000), Some("goodbye"));
let response = handshake.receive_close(&peer_close).unwrap();
assert!(response.is_some());
assert_eq!(handshake.state(), CloseState::CloseReceived);
let frame = handshake.initiate(CloseReason::with_text(CloseCode::Normal, "ack"));
assert!(frame.is_some());
assert_eq!(handshake.state(), CloseState::CloseReceived);
assert!(!handshake.is_open());
assert!(!handshake.is_closed());
}
#[test]
fn handshake_mark_response_sent_closes_after_receiving_peer_close() {
let mut handshake = CloseHandshake::new();
let peer_close = Frame::close(Some(1000), Some("bye"));
let response = handshake.receive_close(&peer_close).unwrap();
assert!(response.is_some());
assert_eq!(handshake.state(), CloseState::CloseReceived);
handshake.mark_response_sent();
assert_eq!(handshake.state(), CloseState::Closed);
assert!(handshake.is_closed());
}
#[test]
fn handshake_force_close() {
let mut handshake = CloseHandshake::new();
handshake.force_close(CloseReason::new(CloseCode::Abnormal, None));
assert_eq!(handshake.state(), CloseState::Closed);
assert!(handshake.our_reason().is_some());
}
#[test]
fn handshake_force_close_preserves_supplied_reason() {
let mut handshake = CloseHandshake::new();
let reason = CloseReason::with_text(CloseCode::GoingAway, "cancelled by region close");
handshake.force_close(reason.clone());
assert_eq!(handshake.state(), CloseState::Closed);
assert_eq!(handshake.our_reason(), Some(&reason));
assert!(handshake.peer_reason().is_none());
}
#[test]
fn close_config_builder() {
let config = CloseConfig::new()
.with_timeout(Duration::from_secs(10))
.with_close_on_drop(false)
.with_cancellation_code(CloseCode::InternalError);
assert_eq!(config.close_timeout, Duration::from_secs(10));
assert!(!config.close_on_drop);
assert_eq!(config.cancellation_code, CloseCode::InternalError);
}
#[test]
fn close_reason_debug_clone_eq() {
let r = CloseReason::normal();
let dbg = format!("{r:?}");
assert!(dbg.contains("CloseReason"), "{dbg}");
let cloned = r.clone();
assert_eq!(r, cloned);
}
#[test]
fn close_state_debug_clone_copy_eq_default() {
let s = CloseState::default();
let dbg = format!("{s:?}");
assert!(dbg.contains("Open"), "{dbg}");
let copied = s;
let cloned = s;
assert_eq!(copied, cloned);
assert_ne!(s, CloseState::CloseSent);
}
#[test]
fn close_config_debug_clone() {
let cfg = CloseConfig::new();
let dbg = format!("{cfg:?}");
assert!(dbg.contains("CloseConfig"), "{dbg}");
let cloned = cfg.clone();
assert_eq!(cfg.close_timeout, cloned.close_timeout);
}
#[test]
fn handshake_receive_close_does_not_echo_unassigned_code() {
let mut handshake = CloseHandshake::new();
let close_frame = Frame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: Opcode::Close,
masked: false,
mask_key: None,
payload: Bytes::copy_from_slice(&1016u16.to_be_bytes()),
};
let response = handshake.receive_close(&close_frame).unwrap().unwrap();
assert_eq!(response.opcode, Opcode::Close);
assert!(response.payload.is_empty());
assert_eq!(handshake.state(), CloseState::CloseReceived);
}
#[test]
fn to_frame_with_unassigned_raw_code_does_not_panic() {
let payload = 2000u16.to_be_bytes();
let reason = CloseReason::parse(&payload).unwrap();
assert_eq!(reason.raw_code, Some(2000));
let frame = reason.to_frame();
assert_eq!(frame.opcode, Opcode::Close);
assert!(frame.payload.is_empty());
}
#[test]
fn close_reason_code_falls_back_when_raw_code_is_absent() {
let reason = CloseReason {
code: Some(CloseCode::Normal),
raw_code: None,
text: None,
};
assert_eq!(reason.wire_code(), Some(1000));
assert!(reason.is_normal());
assert_eq!(reason.encode().as_ref(), &1000u16.to_be_bytes());
assert_eq!(reason.to_frame().payload.as_ref(), &1000u16.to_be_bytes());
}
#[test]
fn typed_unsendable_close_code_fails_closed_without_panicking() {
let reason = CloseReason {
code: Some(CloseCode::Abnormal),
raw_code: None,
text: None,
};
assert_eq!(reason.wire_code(), Some(1006));
assert!(reason.encode().is_empty());
let frame = reason.to_frame();
assert_eq!(frame.opcode, Opcode::Close);
assert!(frame.payload.is_empty());
}
#[test]
fn typed_unsendable_close_code_drops_reason_text_instead_of_synthesizing_normal_close() {
let reason = CloseReason {
code: Some(CloseCode::NoStatusReceived),
raw_code: None,
text: Some("must not hit wire".to_string()),
};
assert_eq!(reason.wire_code(), Some(1005));
assert!(reason.encode().is_empty());
let frame = reason.to_frame();
assert_eq!(frame.opcode, Opcode::Close);
assert!(frame.payload.is_empty());
}
}