use std::collections::VecDeque;
use crate::deflate::PerMessageDeflate;
use crate::error::Error;
use crate::fragment_buffer::FragmentBuffer;
use crate::frame_policy::FramePolicy;
use crate::websocket_close::{CloseCode, truncate_reason};
use crate::websocket_connection_types::{
ConnectionEvent, ConnectionOutput, ConnectionState, TimerId,
};
use crate::websocket_frame::{DecodedFrame, Frame, FrameDecoder};
use crate::websocket_opcode::Opcode;
pub const DEFAULT_MAX_FRAME_SIZE: usize = 64 * 1024 * 1024;
pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
pub const DEFAULT_MAX_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
pub(crate) struct SharedConnectionState {
state: ConnectionState,
close_sent: bool,
close_received: bool,
awaiting_pong: bool,
failed: bool,
event_queue: VecDeque<ConnectionEvent>,
output_queue: VecDeque<ConnectionOutput>,
frame_decoder: FrameDecoder,
fragment_buffer: FragmentBuffer,
deflate: Option<PerMessageDeflate>,
max_frame_size: usize,
max_message_size: usize,
max_decompressed_size: usize,
ping_interval_millis: u64,
pong_timeout_millis: u64,
close_timeout_millis: u64,
}
impl SharedConnectionState {
pub(crate) fn new(
max_frame_size: usize,
max_message_size: usize,
max_decompressed_size: usize,
ping_interval_millis: u64,
pong_timeout_millis: u64,
close_timeout_millis: u64,
) -> Self {
Self {
state: ConnectionState::Disconnected,
close_sent: false,
close_received: false,
awaiting_pong: false,
failed: false,
event_queue: VecDeque::new(),
output_queue: VecDeque::new(),
frame_decoder: FrameDecoder::new(),
fragment_buffer: FragmentBuffer::new(),
deflate: None,
max_frame_size,
max_message_size,
max_decompressed_size,
ping_interval_millis,
pong_timeout_millis,
close_timeout_millis,
}
}
pub(crate) fn state(&self) -> ConnectionState {
self.state
}
pub(crate) fn is_failed(&self) -> bool {
self.failed
}
pub(crate) fn mark_failed(&mut self) {
self.failed = true;
}
pub(crate) fn enable_deflate(&mut self, deflate: PerMessageDeflate) {
self.deflate = Some(deflate);
}
pub(crate) fn emit_event(&mut self, event: ConnectionEvent) {
self.event_queue.push_back(event);
}
pub(crate) fn enqueue_output(&mut self, output: ConnectionOutput) {
self.output_queue.push_back(output);
}
pub(crate) fn set_state(&mut self, new_state: ConnectionState) -> Result<(), Error> {
if self.state == new_state {
return Ok(());
}
if !self.state.can_transition_to(new_state) {
return Err(Error::invalid_state(format!(
"invalid state transition from {:?} to {:?}",
self.state, new_state
)));
}
self.state = new_state;
self.event_queue
.push_back(ConnectionEvent::StateChanged(new_state));
Ok(())
}
pub(crate) fn check_connected(&self) -> Result<(), Error> {
if self.state != ConnectionState::Connected {
return Err(Error::invalid_state("not connected"));
}
Ok(())
}
pub(crate) fn close(
&mut self,
code: CloseCode,
reason: &str,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
if !matches!(
self.state,
ConnectionState::Connected | ConnectionState::Closing
) {
return Err(Error::invalid_state("connection is not established"));
}
if !code.is_sendable() {
return Err(Error::invalid_input(format!(
"close code {} is not sendable",
code.as_u16()
)));
}
if reason.len() > 123 {
return Err(Error::invalid_input(format!(
"close reason exceeds 123 bytes: {} bytes",
reason.len()
)));
}
self.close_internal(code, reason, policy);
Ok(())
}
pub(crate) fn close_internal(
&mut self,
code: CloseCode,
reason: &str,
policy: &mut impl FramePolicy,
) {
if !matches!(
self.state,
ConnectionState::Connected | ConnectionState::Closing
) {
return;
}
if !self.close_sent {
let truncated = truncate_reason(reason, 123);
let frame = Frame::close(Some(code.as_u16()), truncated).unwrap_or_else(|_| {
Frame::close(Some(code.as_u16()), "")
.expect("empty reason close frame must always succeed")
});
policy.encode_and_send(&frame, self);
self.close_sent = true;
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::CloseTimeout,
duration_millis: self.close_timeout_millis,
});
self.set_state(ConnectionState::Closing)
.expect("unreachable: Connected/Closing -> Closing must be valid");
}
}
pub(crate) fn send_text(
&mut self,
text: &str,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
self.check_connected()?;
self.send_data_frame(Opcode::Text, text.as_bytes().to_vec(), policy)
}
pub(crate) fn send_binary(
&mut self,
data: &[u8],
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
self.check_connected()?;
self.send_data_frame(Opcode::Binary, data.to_vec(), policy)
}
pub(crate) fn send_ping(
&mut self,
data: &[u8],
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
self.check_connected()?;
self.send_ping_internal(data, policy)
}
pub(crate) fn send_data_frame(
&mut self,
opcode: Opcode,
payload: Vec<u8>,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
let (payload, compressed) = self.compress_if_enabled(payload)?;
let mut frame = Frame::new(opcode, payload);
frame.rsv1 = compressed;
policy.encode_and_send(&frame, self);
Ok(())
}
pub(crate) fn compress_if_enabled(
&mut self,
payload: Vec<u8>,
) -> Result<(Vec<u8>, bool), Error> {
if let Some(deflate) = &mut self.deflate {
const COMPRESSION_THRESHOLD: usize = 64;
if deflate.should_compress(&payload, COMPRESSION_THRESHOLD) {
let compressed = deflate.compress(&payload)?;
Ok((compressed, true))
} else {
Ok((payload, false))
}
} else {
Ok((payload, false))
}
}
pub(crate) fn decompress_if_needed(
&mut self,
payload: Vec<u8>,
compressed: bool,
policy: &mut impl FramePolicy,
) -> Result<Vec<u8>, Error> {
if compressed {
if let Some(deflate) = &mut self.deflate {
deflate.decompress(&payload, self.max_decompressed_size)
} else {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"received compressed frame without permessage-deflate",
policy,
);
Err(Error::protocol_violation(
"received compressed frame without permessage-deflate",
))
}
} else {
Ok(payload)
}
}
pub(crate) fn process_frames(
&mut self,
buf: &[u8],
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
self.frame_decoder.feed(buf);
loop {
match self.frame_decoder.decode_with_info() {
Ok(Some(decoded)) => {
self.handle_decoded_frame(decoded, policy)?;
}
Ok(None) => break,
Err(e) => {
self.close_internal(CloseCode::PROTOCOL_ERROR, "frame decode error", policy);
return Err(e);
}
}
}
Ok(())
}
pub(crate) fn handle_decoded_frame(
&mut self,
decoded: DecodedFrame,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
if let Err(e) = policy.verify_frame_masking(decoded.masked) {
self.close_internal(CloseCode::PROTOCOL_ERROR, &e.to_string(), policy);
return Err(e);
}
self.handle_frame(decoded.frame, policy)
}
pub(crate) fn handle_frame(
&mut self,
frame: Frame,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
if !frame.opcode.is_control() && frame.payload.len() > self.max_frame_size {
self.close_internal(
CloseCode::MESSAGE_TOO_BIG,
"frame payload too large",
policy,
);
return Err(Error::protocol_violation("frame payload too large"));
}
if frame.rsv2 || frame.rsv3 {
self.close_internal(CloseCode::PROTOCOL_ERROR, "reserved bits set", policy);
return Err(Error::protocol_violation("reserved bits set"));
}
if frame.rsv1 {
if self.deflate.is_none() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"rsv1 set without permessage-deflate",
policy,
);
return Err(Error::protocol_violation(
"rsv1 set without permessage-deflate",
));
}
if frame.opcode.is_control() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"rsv1 must not be set on control frames",
policy,
);
return Err(Error::protocol_violation(
"rsv1 must not be set on control frames",
));
}
if frame.opcode == Opcode::Continuation {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"rsv1 must not be set on continuation frames",
policy,
);
return Err(Error::protocol_violation(
"rsv1 must not be set on continuation frames",
));
}
}
match frame.opcode {
Opcode::Continuation => self.handle_continuation(frame, policy)?,
Opcode::Text | Opcode::Binary => self.handle_data_frame(frame, policy)?,
Opcode::Close => self.handle_close(frame, policy)?,
Opcode::Ping => self.handle_ping(frame, policy)?,
Opcode::Pong => self.handle_pong(frame)?,
}
Ok(())
}
pub(crate) fn handle_data_frame(
&mut self,
frame: Frame,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
if !self.fragment_buffer.is_empty() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"new message started before previous completed",
policy,
);
return Err(Error::protocol_violation(
"new message started before previous completed",
));
}
if frame.fin {
let payload = self.decompress_if_needed(frame.payload, frame.rsv1, policy)?;
self.emit_message(frame.opcode, payload, policy)?;
} else {
if frame.payload.len() > self.max_message_size {
self.close_internal(CloseCode::MESSAGE_TOO_BIG, "message too large", policy);
return Err(Error::protocol_violation("message too large"));
}
self.fragment_buffer
.start(frame.opcode, frame.payload, frame.rsv1);
}
Ok(())
}
pub(crate) fn handle_continuation(
&mut self,
frame: Frame,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
if self.fragment_buffer.is_empty() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"continuation frame without initial frame",
policy,
);
return Err(Error::protocol_violation(
"continuation frame without initial frame",
));
}
self.fragment_buffer.append(&frame.payload);
if self.fragment_buffer.len() > self.max_message_size {
self.close_internal(CloseCode::MESSAGE_TOO_BIG, "message too large", policy);
return Err(Error::protocol_violation("message too large"));
}
if frame.fin {
let (opcode, payload, compressed) = self.fragment_buffer.take();
let payload = self.decompress_if_needed(payload, compressed, policy)?;
self.emit_message(opcode, payload, policy)?;
}
Ok(())
}
pub(crate) fn emit_message(
&mut self,
opcode: Opcode,
payload: Vec<u8>,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
match opcode {
Opcode::Text => match String::from_utf8(payload) {
Ok(text) => {
self.event_queue
.push_back(ConnectionEvent::TextMessage(text));
}
Err(e) => {
self.event_queue.push_back(ConnectionEvent::Error(format!(
"invalid UTF-8 in text message: {}",
e
)));
self.close_internal(CloseCode::INVALID_PAYLOAD, "invalid UTF-8", policy);
return Err(Error::protocol_violation("invalid UTF-8 in text message"));
}
},
Opcode::Binary => {
self.event_queue
.push_back(ConnectionEvent::BinaryMessage(payload));
}
_ => {}
}
Ok(())
}
pub(crate) fn handle_close(
&mut self,
frame: Frame,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
self.close_received = true;
if frame.payload.len() == 1 {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"close frame payload length must be 0 or >= 2",
policy,
);
return Err(Error::protocol_violation(
"close frame payload length must be 0 or >= 2",
));
}
let (code, reason) = if frame.payload.len() >= 2 {
let code_val = u16::from_be_bytes([frame.payload[0], frame.payload[1]]);
let close_code = CloseCode::new(code_val);
if !close_code.is_valid() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
&format!("invalid close code: {}", code_val),
policy,
);
return Err(Error::protocol_violation(format!(
"invalid close code: {}",
code_val
)));
}
let reason = match String::from_utf8(frame.payload[2..].to_vec()) {
Ok(r) => r,
Err(_) => {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"close frame reason is not valid UTF-8",
policy,
);
return Err(Error::protocol_violation(
"close frame reason is not valid UTF-8",
));
}
};
(Some(close_code), reason)
} else {
(None, String::new())
};
self.event_queue
.push_back(ConnectionEvent::Close { code, reason });
if !self.close_sent {
let reply_code = code
.filter(|c| c.is_sendable())
.map(|c| c.as_u16())
.unwrap_or(1000);
let reply_frame = Frame::close(Some(reply_code), "")
.expect("empty reason close reply frame must always succeed");
policy.encode_and_send(&reply_frame, self);
self.close_sent = true;
}
self.awaiting_pong = false;
self.output_queue.push_back(ConnectionOutput::ClearTimer {
id: TimerId::PongTimeout,
});
self.output_queue
.push_back(ConnectionOutput::ClearTimer { id: TimerId::Ping });
self.output_queue.push_back(ConnectionOutput::ClearTimer {
id: TimerId::CloseTimeout,
});
self.set_state(ConnectionState::Closed)?;
self.output_queue
.push_back(ConnectionOutput::CloseConnection);
self.frame_decoder.clear();
self.fragment_buffer.clear();
Ok(())
}
pub(crate) fn handle_ping(
&mut self,
frame: Frame,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
self.event_queue
.push_back(ConnectionEvent::Ping(frame.payload.clone()));
if !self.close_received {
let pong = Frame::pong(frame.payload)?;
policy.encode_and_send(&pong, self);
}
Ok(())
}
pub(crate) fn handle_pong(&mut self, frame: Frame) -> Result<(), Error> {
self.awaiting_pong = false;
self.output_queue.push_back(ConnectionOutput::ClearTimer {
id: TimerId::PongTimeout,
});
self.event_queue
.push_back(ConnectionEvent::Pong(frame.payload));
Ok(())
}
pub(crate) fn handle_timer(
&mut self,
timer_id: TimerId,
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
match timer_id {
TimerId::Ping => {
if self.state == ConnectionState::Connected && !self.awaiting_pong {
self.send_ping_internal(&[], policy)?;
}
if self.state == ConnectionState::Connected && self.ping_interval_millis > 0 {
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::Ping,
duration_millis: self.ping_interval_millis,
});
}
}
TimerId::PongTimeout => {
if self.awaiting_pong {
self.event_queue
.push_back(ConnectionEvent::Error("pong timeout".to_string()));
self.close_internal(CloseCode::POLICY_VIOLATION, "pong timeout", policy);
}
}
TimerId::CloseTimeout => {
if self.state == ConnectionState::Closing {
self.set_state(ConnectionState::Closed)?;
self.output_queue
.push_back(ConnectionOutput::CloseConnection);
}
}
}
Ok(())
}
pub(crate) fn send_ping_internal(
&mut self,
data: &[u8],
policy: &mut impl FramePolicy,
) -> Result<(), Error> {
let frame = Frame::ping(data.to_vec())?;
policy.encode_and_send(&frame, self);
self.awaiting_pong = true;
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::PongTimeout,
duration_millis: self.pong_timeout_millis,
});
Ok(())
}
pub(crate) fn poll_event(&mut self) -> Option<ConnectionEvent> {
self.event_queue.pop_front()
}
pub(crate) fn poll_output(&mut self) -> Option<ConnectionOutput> {
self.output_queue.pop_front()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::ErrorKind;
fn new_shared() -> SharedConnectionState {
SharedConnectionState::new(
DEFAULT_MAX_FRAME_SIZE,
DEFAULT_MAX_MESSAGE_SIZE,
DEFAULT_MAX_DECOMPRESSED_SIZE,
0,
0,
0,
)
}
#[test]
fn set_state_は許可遷移を成功させ_state_changed_を_emit_する() {
let mut shared = new_shared();
assert!(matches!(
shared.set_state(ConnectionState::Connecting),
Ok(())
));
assert_eq!(shared.state(), ConnectionState::Connecting);
let event = shared
.poll_event()
.expect("StateChanged event must be emitted");
assert_eq!(
event,
ConnectionEvent::StateChanged(ConnectionState::Connecting)
);
}
#[test]
fn set_state_は同一状態への遷移を_no_op_として_ok_を返す() {
let mut shared = new_shared();
assert!(matches!(
shared.set_state(ConnectionState::Disconnected),
Ok(())
));
assert!(shared.poll_event().is_none());
}
#[test]
fn set_state_は不正遷移を_invalid_state_で拒否し_state_を変えない() {
let mut shared = new_shared();
let err = shared
.set_state(ConnectionState::Connected)
.expect_err("invalid transition must return Err");
assert_eq!(err.kind, ErrorKind::InvalidState);
assert!(err.reason.contains("Disconnected"));
assert!(err.reason.contains("Connected"));
assert_eq!(shared.state(), ConnectionState::Disconnected);
assert!(shared.poll_event().is_none());
}
#[test]
fn set_state_は終端状態からの遷移を拒否する() {
let mut shared = new_shared();
shared
.set_state(ConnectionState::Connecting)
.expect("Disconnected -> Connecting");
shared
.set_state(ConnectionState::Connected)
.expect("Connecting -> Connected");
shared
.set_state(ConnectionState::Closing)
.expect("Connected -> Closing");
shared
.set_state(ConnectionState::Closed)
.expect("Closing -> Closed");
for next in [
ConnectionState::Disconnected,
ConnectionState::Connecting,
ConnectionState::Connected,
ConnectionState::Closing,
] {
let err = shared
.set_state(next)
.expect_err("transition from Closed must be rejected");
assert_eq!(err.kind, ErrorKind::InvalidState);
}
assert_eq!(shared.state(), ConnectionState::Closed);
}
}