use std::collections::VecDeque;
use crate::deflate::PerMessageDeflate;
use crate::error::{Error, ErrorKind};
use crate::websocket_close::CloseCode;
use crate::websocket_extension::{Extension, PerMessageDeflateConfig};
use crate::websocket_frame::{DecodedFrame, Frame, FrameDecoder};
use crate::websocket_handshake::{
HandshakeRequestValidator, ServerHandshakeRequest, ServerHandshakeResponse,
calculate_accept_from_key,
};
use crate::websocket_opcode::Opcode;
use crate::{ConnectionEvent, ConnectionOutput, ConnectionState, TimerId};
use shiguredo_http11::Response;
const MAX_PENDING_FRAME_DATA: usize = 1024 * 1024;
pub const DEFAULT_MAX_FRAME_SIZE: usize = 64 * 1024 * 1024;
pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
pub const DEFAULT_MAX_DECOMPRESSED_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct ServerConnectionOptions {
pub protocols: Vec<String>,
pub deflate_config: Option<PerMessageDeflateConfig>,
pub additional_headers: Vec<(String, String)>,
pub ping_interval_millis: u64,
pub pong_timeout_millis: u64,
pub close_timeout_millis: u64,
pub max_frame_size: usize,
pub max_message_size: usize,
pub max_decompressed_size: usize,
}
impl Default for ServerConnectionOptions {
fn default() -> Self {
Self {
protocols: Vec::new(),
deflate_config: None,
additional_headers: Vec::new(),
ping_interval_millis: 30_000, pong_timeout_millis: 10_000, close_timeout_millis: 5_000, max_frame_size: DEFAULT_MAX_FRAME_SIZE,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
max_decompressed_size: DEFAULT_MAX_DECOMPRESSED_SIZE,
}
}
}
impl ServerConnectionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn protocol(mut self, protocol: &str) -> Self {
self.protocols.push(protocol.to_string());
self
}
pub fn deflate(mut self, config: PerMessageDeflateConfig) -> Self {
self.deflate_config = Some(config);
self
}
pub fn header(mut self, name: &str, value: &str) -> Self {
self.additional_headers
.push((name.to_string(), value.to_string()));
self
}
pub fn ping_interval(mut self, millis: u64) -> Self {
self.ping_interval_millis = millis;
self
}
pub fn max_frame_size(mut self, size: usize) -> Self {
self.max_frame_size = size;
self
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.max_message_size = size;
self
}
pub fn max_decompressed_size(mut self, size: usize) -> Self {
self.max_decompressed_size = size;
self
}
}
#[derive(Debug, Default)]
struct FragmentBuffer {
opcode: Option<Opcode>,
payload: Vec<u8>,
compressed: bool,
}
impl FragmentBuffer {
fn new() -> Self {
Self::default()
}
fn is_empty(&self) -> bool {
self.opcode.is_none()
}
fn len(&self) -> usize {
self.payload.len()
}
fn start(&mut self, opcode: Opcode, payload: Vec<u8>, compressed: bool) {
self.opcode = Some(opcode);
self.payload = payload;
self.compressed = compressed;
}
fn append(&mut self, payload: &[u8]) {
self.payload.extend_from_slice(payload);
}
fn take(&mut self) -> (Opcode, Vec<u8>, bool) {
let opcode = self.opcode.take().unwrap_or(Opcode::Binary);
let payload = std::mem::take(&mut self.payload);
let compressed = self.compressed;
self.compressed = false;
(opcode, payload, compressed)
}
fn clear(&mut self) {
self.opcode = None;
self.payload.clear();
self.compressed = false;
}
}
pub struct WebSocketServerConnection {
state: ConnectionState,
options: ServerConnectionOptions,
handshake_validator: HandshakeRequestValidator,
pending_request: Option<ServerHandshakeRequest>,
pending_frame_data: Vec<u8>,
frame_decoder: FrameDecoder,
fragment_buffer: FragmentBuffer,
negotiated_protocol: Option<String>,
negotiated_extensions: Vec<String>,
deflate: Option<PerMessageDeflate>,
close_sent: bool,
close_received: bool,
awaiting_pong: bool,
failed: bool,
event_queue: VecDeque<ConnectionEvent>,
output_queue: VecDeque<ConnectionOutput>,
}
impl WebSocketServerConnection {
pub fn new(options: ServerConnectionOptions) -> Self {
Self {
state: ConnectionState::Disconnected,
options,
handshake_validator: HandshakeRequestValidator::new(),
pending_request: None,
pending_frame_data: Vec::new(),
frame_decoder: FrameDecoder::new(),
fragment_buffer: FragmentBuffer::new(),
negotiated_protocol: None,
negotiated_extensions: Vec::new(),
deflate: None,
close_sent: false,
close_received: false,
awaiting_pong: false,
failed: false,
event_queue: VecDeque::new(),
output_queue: VecDeque::new(),
}
}
pub fn state(&self) -> ConnectionState {
self.state
}
pub fn protocol(&self) -> Option<&str> {
self.negotiated_protocol.as_deref()
}
pub fn extensions(&self) -> &[String] {
&self.negotiated_extensions
}
pub fn handshake_request(&self) -> Option<&ServerHandshakeRequest> {
self.pending_request.as_ref()
}
pub fn feed_recv_buf(&mut self, buf: &[u8]) -> Result<(), Error> {
if self.failed {
return Err(Error::invalid_state("connection has failed"));
}
let result = match self.state {
ConnectionState::Disconnected | ConnectionState::Connecting => {
self.process_handshake(buf)
}
ConnectionState::Connected | ConnectionState::Closing => self.process_frames(buf),
ConnectionState::Closed => {
return Err(Error::invalid_state("connection is closed"));
}
};
if result.is_err() {
self.failed = true;
}
result
}
pub fn accept_handshake_auto(&mut self) -> Result<(), Error> {
let request = self
.pending_request
.as_ref()
.ok_or_else(|| Error::invalid_state("handshake request not available"))?;
let mut response = ServerHandshakeResponse::new();
if let Some(protocol) = self.select_protocol(request) {
response = response.protocol(&protocol);
}
if let Some(config) = self.select_deflate(request) {
response = response.extension(&config.to_extension().encode());
}
for (name, value) in &self.options.additional_headers {
response = response.header(name, value);
}
self.accept_handshake(response)
}
pub fn accept_handshake(&mut self, response: ServerHandshakeResponse) -> Result<(), Error> {
if self.state != ConnectionState::Connecting {
return Err(Error::invalid_state("handshake is not in progress"));
}
let request = self
.pending_request
.take()
.ok_or_else(|| Error::invalid_state("handshake request not available"))?;
if let Some(protocol) = &response.protocol
&& !request.protocols.iter().any(|p| p == protocol)
{
return Err(Error::handshake_rejected(format!(
"unsupported protocol: {}",
protocol
)));
}
for extension in &response.extensions {
let parsed = Extension::parse_strict(extension).map_err(|e| {
Error::handshake_rejected(format!(
"invalid extension response '{}': {}",
extension, e
))
})?;
if parsed.is_empty() {
return Err(Error::handshake_rejected(format!(
"invalid extension response: '{}'",
extension
)));
}
let mut supported = true;
for ext in &parsed {
if request
.extensions
.iter()
.any(|req| Extension::parse(req).iter().any(|e| e.name == ext.name))
{
continue;
} else {
supported = false;
break;
}
}
if !supported {
return Err(Error::handshake_rejected(format!(
"unsupported extension: {}",
extension
)));
}
}
{
let pmce_count: usize = response
.extensions
.iter()
.flat_map(|s| Extension::parse(s))
.filter(|e| e.name == "permessage-deflate")
.count();
if pmce_count > 1 {
return Err(Error::handshake_rejected(
"response contains multiple permessage-deflate elements",
));
}
}
const RESERVED: &[&str] = &[
"upgrade",
"connection",
"sec-websocket-accept",
"sec-websocket-protocol",
"sec-websocket-extensions",
];
for (name, _) in &response.additional_headers {
if RESERVED.contains(&name.to_ascii_lowercase().as_str()) {
return Err(Error::invalid_input(format!(
"additional header '{}' conflicts with a reserved WebSocket header",
name
)));
}
}
let accept = calculate_accept_from_key(&request.key);
let mut response_builder = Response::new(101, "Switching Protocols")
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Accept", &accept);
if let Some(protocol) = &response.protocol {
response_builder = response_builder.header("Sec-WebSocket-Protocol", protocol);
}
if !response.extensions.is_empty() {
response_builder = response_builder
.header("Sec-WebSocket-Extensions", &response.extensions.join(", "));
}
for (name, value) in &response.additional_headers {
response_builder = response_builder.header(name, value);
}
let client_offered_smwb: Option<u8> = request
.extensions
.iter()
.flat_map(|s| Extension::parse(s))
.filter(|e| e.name == "permessage-deflate")
.find_map(|e| {
e.get_param("server_max_window_bits")
.and_then(|p| p.value.as_deref())
.and_then(|v| v.parse::<u8>().ok())
});
let mut deflate = None;
for ext_str in &response.extensions {
let extensions = Extension::parse(ext_str);
for ext in extensions {
if ext.name == "permessage-deflate" {
let config = PerMessageDeflateConfig::from_extension_for_client_response(&ext)
.map_err(|e| {
Error::handshake_rejected(format!(
"invalid permessage-deflate parameters: {:?}",
e
))
})?;
if let Some(smwb) = config.server_max_window_bits
&& smwb < 15
{
return Err(Error::handshake_rejected(format!(
"server_max_window_bits={} is not supported (only 15 is supported)",
smwb
)));
}
if let (Some(smwb), Some(offered)) =
(config.server_max_window_bits, client_offered_smwb)
&& smwb > offered
{
return Err(Error::handshake_rejected(format!(
"server_max_window_bits={} exceeds client offer={}",
smwb, offered
)));
}
if ext.get_param("client_max_window_bits").is_some() {
let client_offered_cmwb = request.extensions.iter().any(|req_ext_str| {
Extension::parse(req_ext_str).iter().any(|req_ext| {
req_ext.name == "permessage-deflate"
&& req_ext.get_param("client_max_window_bits").is_some()
})
});
if !client_offered_cmwb {
return Err(Error::handshake_rejected(
"client_max_window_bits included without client offer",
));
}
}
deflate = Some(PerMessageDeflate::new_server(config));
}
}
}
let encoded = response_builder.encode();
self.output_queue
.push_back(ConnectionOutput::SendData(encoded));
self.deflate = deflate;
self.negotiated_protocol = response.protocol.clone();
self.negotiated_extensions = response.extensions.clone();
self.set_state(ConnectionState::Connected);
self.event_queue.push_back(ConnectionEvent::Connected {
protocol: self.negotiated_protocol.clone(),
extensions: self.negotiated_extensions.clone(),
});
if self.options.ping_interval_millis > 0 {
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::Ping,
duration_millis: self.options.ping_interval_millis,
});
}
if !self.pending_frame_data.is_empty() {
let pending = std::mem::take(&mut self.pending_frame_data);
self.process_frames(&pending)?;
}
self.handshake_validator.reset();
Ok(())
}
pub fn reject_handshake(
&mut self,
status_code: u16,
reason: &str,
headers: &[(&str, &str)],
) -> Result<(), Error> {
if self.state != ConnectionState::Connecting {
return Err(Error::invalid_state("handshake is not in progress"));
}
self.pending_request = None;
self.pending_frame_data.clear();
self.handshake_validator.reset();
let mut response = Response::new(status_code, reason).header("Connection", "close");
for (name, value) in headers {
response = response.header(name, value);
}
self.output_queue
.push_back(ConnectionOutput::SendData(response.encode()));
self.set_state(ConnectionState::Closed);
self.output_queue
.push_back(ConnectionOutput::CloseConnection);
Ok(())
}
pub fn send_text(&mut self, text: &str) -> Result<(), Error> {
self.check_connected()?;
self.send_data_frame(Opcode::Text, text.as_bytes().to_vec())
}
pub fn send_binary(&mut self, data: &[u8]) -> Result<(), Error> {
self.check_connected()?;
self.send_data_frame(Opcode::Binary, data.to_vec())
}
fn send_data_frame(&mut self, opcode: Opcode, payload: Vec<u8>) -> Result<(), Error> {
let (payload, compressed) = self.compress_if_enabled(payload)?;
let mut frame = Frame::new(opcode, payload);
frame.rsv1 = compressed;
self.send_frame(frame);
Ok(())
}
fn compress_if_enabled(&mut self, payload: Vec<u8>) -> Result<(Vec<u8>, bool), Error> {
if let Some(deflate) = &mut self.deflate {
const COMPRESSION_THRESHOLD: usize = 64;
if deflate.should_compress(&payload, COMPRESSION_THRESHOLD) {
let compressed = deflate.compress(&payload)?;
Ok((compressed, true))
} else {
Ok((payload, false))
}
} else {
Ok((payload, false))
}
}
pub fn send_ping(&mut self, data: &[u8]) -> Result<(), Error> {
self.check_connected()?;
let frame = Frame::ping(data.to_vec())?;
self.send_frame(frame);
self.awaiting_pong = true;
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::PongTimeout,
duration_millis: self.options.pong_timeout_millis,
});
Ok(())
}
pub fn close(&mut self, code: CloseCode, reason: &str) -> Result<(), Error> {
if !matches!(
self.state,
ConnectionState::Connected | ConnectionState::Closing
) {
return Err(Error::invalid_state("connection is not established"));
}
if !code.is_sendable() {
return Err(Error::invalid_input(format!(
"close code {} is not sendable",
code.as_u16()
)));
}
if !self.close_sent {
let frame = Frame::close(Some(code.as_u16()), reason)?;
self.send_frame(frame);
self.close_sent = true;
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::CloseTimeout,
duration_millis: self.options.close_timeout_millis,
});
self.set_state(ConnectionState::Closing);
}
Ok(())
}
fn close_internal(&mut self, code: CloseCode, reason: &str) {
if self.state == ConnectionState::Disconnected || self.state == ConnectionState::Closed {
return;
}
if !self.close_sent {
let truncated_reason = if reason.len() > 123 {
&reason[..123]
} else {
reason
};
let frame = Frame::close(Some(code.as_u16()), truncated_reason)
.unwrap_or_else(|_| Frame::close(Some(code.as_u16()), "").unwrap());
self.send_frame(frame);
self.close_sent = true;
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::CloseTimeout,
duration_millis: self.options.close_timeout_millis,
});
self.set_state(ConnectionState::Closing);
}
}
pub fn handle_timer(&mut self, timer_id: TimerId) -> Result<(), Error> {
match timer_id {
TimerId::Ping => {
if self.state == ConnectionState::Connected && !self.awaiting_pong {
self.send_ping(&[])?;
}
if self.state == ConnectionState::Connected && self.options.ping_interval_millis > 0
{
self.output_queue.push_back(ConnectionOutput::SetTimer {
id: TimerId::Ping,
duration_millis: self.options.ping_interval_millis,
});
}
}
TimerId::PongTimeout => {
if self.awaiting_pong {
self.event_queue
.push_back(ConnectionEvent::Error("pong timeout".to_string()));
self.close(CloseCode::POLICY_VIOLATION, "pong timeout")?;
}
}
TimerId::CloseTimeout => {
if self.state == ConnectionState::Closing {
self.set_state(ConnectionState::Closed);
self.output_queue
.push_back(ConnectionOutput::CloseConnection);
}
}
}
Ok(())
}
pub fn poll_event(&mut self) -> Option<ConnectionEvent> {
self.event_queue.pop_front()
}
pub fn poll_output(&mut self) -> Option<ConnectionOutput> {
self.output_queue.pop_front()
}
fn set_state(&mut self, new_state: ConnectionState) {
if self.state != new_state {
self.state = new_state;
self.event_queue
.push_back(ConnectionEvent::StateChanged(new_state));
}
}
fn check_connected(&self) -> Result<(), Error> {
if self.state != ConnectionState::Connected {
return Err(Error::invalid_state("not connected"));
}
Ok(())
}
fn send_frame(&mut self, frame: Frame) {
let encoded = frame.encode_unmasked();
self.output_queue
.push_back(ConnectionOutput::SendData(encoded));
}
fn process_handshake(&mut self, buf: &[u8]) -> Result<(), Error> {
if self.pending_request.is_some() {
if self.pending_frame_data.len() + buf.len() > MAX_PENDING_FRAME_DATA {
return Err(Error::protocol_violation(
"pending frame data exceeds limit while awaiting handshake acceptance",
));
}
self.pending_frame_data.extend_from_slice(buf);
return Ok(());
}
if self.state == ConnectionState::Disconnected {
self.set_state(ConnectionState::Connecting);
}
self.handshake_validator.feed(buf);
match self.handshake_validator.validate() {
Ok(Some(request)) => {
self.pending_request = Some(request);
self.pending_frame_data
.extend_from_slice(self.handshake_validator.remaining());
Ok(())
}
Ok(None) => Ok(()),
Err(e) if e.kind == ErrorKind::VersionNotSupported => {
self.reject_handshake(426, "Upgrade Required", &[("Sec-WebSocket-Version", "13")])?;
Err(e)
}
Err(e) => {
self.reject_handshake(400, "Bad Request", &[])?;
Err(e)
}
}
}
fn process_frames(&mut self, buf: &[u8]) -> Result<(), Error> {
self.frame_decoder.feed(buf);
loop {
match self.frame_decoder.decode_with_info() {
Ok(Some(decoded)) => {
self.handle_decoded_frame(decoded)?;
}
Ok(None) => break,
Err(e) => {
self.close_internal(CloseCode::PROTOCOL_ERROR, "frame decode error");
return Err(e);
}
}
}
Ok(())
}
fn handle_decoded_frame(&mut self, decoded: DecodedFrame) -> Result<(), Error> {
if !decoded.masked {
self.close_internal(CloseCode::PROTOCOL_ERROR, "unmasked client frame");
return Err(Error::protocol_violation("unmasked client frame"));
}
self.handle_frame(decoded.frame)
}
fn handle_frame(&mut self, frame: Frame) -> Result<(), Error> {
if !frame.opcode.is_control() && frame.payload.len() > self.options.max_frame_size {
self.close_internal(CloseCode::MESSAGE_TOO_BIG, "frame payload too large");
return Err(Error::protocol_violation("frame payload too large"));
}
if frame.rsv2 || frame.rsv3 {
self.close_internal(CloseCode::PROTOCOL_ERROR, "reserved bits set");
return Err(Error::protocol_violation("reserved bits set"));
}
if frame.rsv1 {
if self.deflate.is_none() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"rsv1 set without permessage-deflate",
);
return Err(Error::protocol_violation(
"rsv1 set without permessage-deflate",
));
}
if frame.opcode.is_control() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"rsv1 must not be set on control frames",
);
return Err(Error::protocol_violation(
"rsv1 must not be set on control frames",
));
}
if frame.opcode == Opcode::Continuation {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"rsv1 must not be set on continuation frames",
);
return Err(Error::protocol_violation(
"rsv1 must not be set on continuation frames",
));
}
}
match frame.opcode {
Opcode::Continuation => self.handle_continuation(frame)?,
Opcode::Text | Opcode::Binary => self.handle_data_frame(frame)?,
Opcode::Close => self.handle_close(frame)?,
Opcode::Ping => self.handle_ping(frame)?,
Opcode::Pong => self.handle_pong(frame)?,
}
Ok(())
}
fn handle_data_frame(&mut self, frame: Frame) -> Result<(), Error> {
if !self.fragment_buffer.is_empty() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"new message started before previous completed",
);
return Err(Error::protocol_violation(
"new message started before previous completed",
));
}
if frame.fin {
let payload = self.decompress_if_needed(frame.payload, frame.rsv1)?;
self.emit_message(frame.opcode, payload)?;
} else {
if frame.payload.len() > self.options.max_message_size {
self.close_internal(CloseCode::MESSAGE_TOO_BIG, "message too large");
return Err(Error::protocol_violation("message too large"));
}
self.fragment_buffer
.start(frame.opcode, frame.payload, frame.rsv1);
}
Ok(())
}
fn handle_continuation(&mut self, frame: Frame) -> Result<(), Error> {
if self.fragment_buffer.is_empty() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"continuation frame without initial frame",
);
return Err(Error::protocol_violation(
"continuation frame without initial frame",
));
}
self.fragment_buffer.append(&frame.payload);
if self.fragment_buffer.len() > self.options.max_message_size {
self.close_internal(CloseCode::MESSAGE_TOO_BIG, "message too large");
return Err(Error::protocol_violation("message too large"));
}
if frame.fin {
let (opcode, payload, compressed) = self.fragment_buffer.take();
let payload = self.decompress_if_needed(payload, compressed)?;
self.emit_message(opcode, payload)?;
}
Ok(())
}
fn decompress_if_needed(
&mut self,
payload: Vec<u8>,
compressed: bool,
) -> Result<Vec<u8>, Error> {
if compressed {
if let Some(deflate) = &mut self.deflate {
deflate.decompress(&payload, self.options.max_decompressed_size)
} else {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"received compressed frame without permessage-deflate",
);
Err(Error::protocol_violation(
"received compressed frame without permessage-deflate",
))
}
} else {
Ok(payload)
}
}
fn emit_message(&mut self, opcode: Opcode, payload: Vec<u8>) -> Result<(), Error> {
match opcode {
Opcode::Text => match String::from_utf8(payload) {
Ok(text) => {
self.event_queue
.push_back(ConnectionEvent::TextMessage(text));
}
Err(e) => {
self.event_queue.push_back(ConnectionEvent::Error(format!(
"invalid UTF-8 in text message: {}",
e
)));
self.close(CloseCode::INVALID_PAYLOAD, "invalid UTF-8")?;
return Err(Error::protocol_violation("invalid UTF-8 in text message"));
}
},
Opcode::Binary => {
self.event_queue
.push_back(ConnectionEvent::BinaryMessage(payload));
}
_ => {}
}
Ok(())
}
fn handle_close(&mut self, frame: Frame) -> Result<(), Error> {
self.close_received = true;
if frame.payload.len() == 1 {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"close frame payload length must be 0 or >= 2",
);
return Err(Error::protocol_violation(
"close frame payload length must be 0 or >= 2",
));
}
let (code, reason) = if frame.payload.len() >= 2 {
let code_val = u16::from_be_bytes([frame.payload[0], frame.payload[1]]);
let close_code = CloseCode::new(code_val);
if !close_code.is_valid() {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
&format!("invalid close code: {}", code_val),
);
return Err(Error::protocol_violation(format!(
"invalid close code: {}",
code_val
)));
}
let reason = match String::from_utf8(frame.payload[2..].to_vec()) {
Ok(r) => r,
Err(_) => {
self.close_internal(
CloseCode::PROTOCOL_ERROR,
"close frame reason is not valid UTF-8",
);
return Err(Error::protocol_violation(
"close frame reason is not valid UTF-8",
));
}
};
(Some(close_code), reason)
} else {
(None, String::new())
};
self.event_queue
.push_back(ConnectionEvent::Close { code, reason });
if !self.close_sent {
let reply_code = code
.filter(|c| c.is_sendable())
.map(|c| c.as_u16())
.unwrap_or(1000);
let reply_frame = Frame::close(Some(reply_code), "")?;
self.send_frame(reply_frame);
self.close_sent = true;
}
self.awaiting_pong = false;
self.output_queue.push_back(ConnectionOutput::ClearTimer {
id: TimerId::PongTimeout,
});
self.output_queue
.push_back(ConnectionOutput::ClearTimer { id: TimerId::Ping });
self.output_queue.push_back(ConnectionOutput::ClearTimer {
id: TimerId::CloseTimeout,
});
self.set_state(ConnectionState::Closed);
self.output_queue
.push_back(ConnectionOutput::CloseConnection);
self.frame_decoder.clear();
self.fragment_buffer.clear();
Ok(())
}
fn handle_ping(&mut self, frame: Frame) -> Result<(), Error> {
self.event_queue
.push_back(ConnectionEvent::Ping(frame.payload.clone()));
if !self.close_received {
let pong = Frame::pong(frame.payload)?;
self.send_frame(pong);
}
Ok(())
}
fn handle_pong(&mut self, frame: Frame) -> Result<(), Error> {
self.awaiting_pong = false;
self.output_queue.push_back(ConnectionOutput::ClearTimer {
id: TimerId::PongTimeout,
});
self.event_queue
.push_back(ConnectionEvent::Pong(frame.payload));
Ok(())
}
fn select_protocol(&self, request: &ServerHandshakeRequest) -> Option<String> {
for protocol in &request.protocols {
if self.options.protocols.iter().any(|p| p == protocol) {
return Some(protocol.clone());
}
}
None
}
fn select_deflate(&self, request: &ServerHandshakeRequest) -> Option<PerMessageDeflateConfig> {
let server_config = self.options.deflate_config.clone()?;
for ext_str in &request.extensions {
for ext in Extension::parse(ext_str) {
if ext.name == "permessage-deflate" {
match PerMessageDeflateConfig::from_extension_for_server_request(&ext) {
Ok(client_request) => {
if client_request
.server_max_window_bits
.is_some_and(|v| v < 15)
{
continue;
}
return Some(PerMessageDeflateConfig::negotiate(
&client_request,
&server_config,
));
}
Err(_) => {
}
}
}
}
}
None
}
}