use crate::deflate::PerMessageDeflate;
use crate::error::{Error, ErrorKind};
use crate::frame_policy::ServerFramePolicy;
use crate::websocket_close::CloseCode;
use crate::websocket_connection_shared::{
DEFAULT_MAX_DECOMPRESSED_SIZE, DEFAULT_MAX_FRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
SharedConnectionState,
};
use crate::websocket_connection_types::{
ConnectionEvent, ConnectionOutput, ConnectionState, TimerId,
};
use crate::websocket_extension::{Extension, PerMessageDeflateConfig};
use crate::websocket_handshake::calculate_accept_from_key;
use crate::websocket_handshake_request::{HandshakeRequestValidator, ServerHandshakeRequest};
use crate::websocket_handshake_response::ServerHandshakeResponse;
use shiguredo_http11::{HeaderName, Response};
const MAX_PENDING_FRAME_DATA: usize = 1024 * 1024;
#[derive(Debug, Clone)]
pub struct ServerConnectionOptions {
protocols: Vec<String>,
deflate_config: Option<PerMessageDeflateConfig>,
additional_headers: Vec<(String, String)>,
ping_interval_millis: u64,
pong_timeout_millis: u64,
close_timeout_millis: u64,
max_frame_size: usize,
max_message_size: usize,
max_decompressed_size: usize,
}
impl Default for ServerConnectionOptions {
fn default() -> Self {
Self {
protocols: Vec::new(),
deflate_config: None,
additional_headers: Vec::new(),
ping_interval_millis: 30_000, pong_timeout_millis: 10_000, close_timeout_millis: 5_000, max_frame_size: DEFAULT_MAX_FRAME_SIZE,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
max_decompressed_size: DEFAULT_MAX_DECOMPRESSED_SIZE,
}
}
}
impl ServerConnectionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn protocol(mut self, protocol: &str) -> Self {
self.protocols.push(protocol.to_string());
self
}
pub fn deflate(mut self, config: PerMessageDeflateConfig) -> Self {
self.deflate_config = Some(config);
self
}
pub fn header(mut self, name: &str, value: &str) -> Self {
self.additional_headers
.push((name.to_string(), value.to_string()));
self
}
pub fn ping_interval(mut self, millis: u64) -> Self {
self.ping_interval_millis = millis;
self
}
pub fn pong_timeout(mut self, millis: u64) -> Self {
self.pong_timeout_millis = millis;
self
}
pub fn close_timeout(mut self, millis: u64) -> Self {
self.close_timeout_millis = millis;
self
}
pub fn max_frame_size(mut self, size: usize) -> Self {
self.max_frame_size = size;
self
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.max_message_size = size;
self
}
pub fn max_decompressed_size(mut self, size: usize) -> Self {
self.max_decompressed_size = size;
self
}
}
pub struct WebSocketServerConnection {
shared: SharedConnectionState,
policy: ServerFramePolicy,
options: ServerConnectionOptions,
handshake_validator: HandshakeRequestValidator,
pending_request: Option<ServerHandshakeRequest>,
pending_frame_data: Vec<u8>,
negotiated_protocol: Option<String>,
negotiated_extensions: Vec<String>,
}
impl WebSocketServerConnection {
pub fn new(options: ServerConnectionOptions) -> Self {
let shared = SharedConnectionState::new(
options.max_frame_size,
options.max_message_size,
options.max_decompressed_size,
options.ping_interval_millis,
options.pong_timeout_millis,
options.close_timeout_millis,
);
Self {
shared,
policy: ServerFramePolicy,
options,
handshake_validator: HandshakeRequestValidator::new(),
pending_request: None,
pending_frame_data: Vec::new(),
negotiated_protocol: None,
negotiated_extensions: Vec::new(),
}
}
pub fn state(&self) -> ConnectionState {
self.shared.state()
}
pub fn protocol(&self) -> Option<&str> {
self.negotiated_protocol.as_deref()
}
pub fn extensions(&self) -> &[String] {
&self.negotiated_extensions
}
pub fn handshake_request(&self) -> Option<&ServerHandshakeRequest> {
self.pending_request.as_ref()
}
pub fn feed_recv_buf(&mut self, buf: &[u8]) -> Result<(), Error> {
if self.shared.is_failed() {
return Err(Error::invalid_state("connection has failed"));
}
let result = match self.shared.state() {
ConnectionState::Disconnected | ConnectionState::Connecting => {
self.process_handshake(buf)
}
ConnectionState::Connected | ConnectionState::Closing => {
self.shared.process_frames(buf, &mut self.policy)
}
ConnectionState::Closed => {
return Err(Error::invalid_state("connection is closed"));
}
};
if result.is_err() {
self.shared.mark_failed();
}
result
}
pub fn accept_handshake_auto(&mut self) -> Result<(), Error> {
let request = self
.pending_request
.as_ref()
.ok_or_else(|| Error::invalid_state("handshake request not available"))?;
let mut response = ServerHandshakeResponse::new();
if let Some(protocol) = self.select_protocol(request) {
response = response.protocol(&protocol);
}
if let Some(config) = self.select_deflate(request) {
response = response.extension(&config.to_extension().encode());
}
for (name, value) in &self.options.additional_headers {
response = response.header(name, value);
}
self.accept_handshake(response)
}
pub fn accept_handshake(&mut self, response: ServerHandshakeResponse) -> Result<(), Error> {
if self.shared.state() != ConnectionState::Connecting {
return Err(Error::invalid_state("handshake is not in progress"));
}
let request = self
.pending_request
.take()
.ok_or_else(|| Error::invalid_state("handshake request not available"))?;
Self::validate_handshake_response(&request, &response)?;
let encoded = Self::build_handshake_response(&request, &response)?;
let deflate = Self::negotiate_deflate(&request, &response)?;
self.shared
.enqueue_output(ConnectionOutput::SendData(encoded));
if let Some(deflate) = deflate {
self.shared.enable_deflate(deflate);
}
self.negotiated_protocol = response.protocol.clone();
self.negotiated_extensions = response.extensions.clone();
self.shared.set_state(ConnectionState::Connected)?;
self.shared.emit_event(ConnectionEvent::Connected {
protocol: self.negotiated_protocol.clone(),
extensions: self.negotiated_extensions.clone(),
});
if self.options.ping_interval_millis > 0 {
self.shared.enqueue_output(ConnectionOutput::SetTimer {
id: TimerId::Ping,
duration_millis: self.options.ping_interval_millis,
});
}
if !self.pending_frame_data.is_empty() {
let pending = std::mem::take(&mut self.pending_frame_data);
self.shared.process_frames(&pending, &mut self.policy)?;
}
self.handshake_validator.reset();
Ok(())
}
fn validate_handshake_response(
request: &ServerHandshakeRequest,
response: &ServerHandshakeResponse,
) -> Result<(), Error> {
if let Some(protocol) = &response.protocol
&& !request.protocols.iter().any(|p| p == protocol)
{
return Err(Error::handshake_rejected(format!(
"unsupported protocol: {}",
protocol
)));
}
for extension in &response.extensions {
let parsed = Extension::parse_strict(extension).map_err(|e| {
Error::handshake_rejected(format!("invalid extension response '{extension}': {e}"))
})?;
if parsed.is_empty() {
return Err(Error::handshake_rejected(format!(
"invalid extension response: '{}'",
extension
)));
}
let mut supported = true;
for ext in &parsed {
if request
.extensions
.iter()
.any(|req| Extension::parse(req).iter().any(|e| e.name == ext.name))
{
continue;
} else {
supported = false;
break;
}
}
if !supported {
return Err(Error::handshake_rejected(format!(
"unsupported extension: {}",
extension
)));
}
}
{
let pmce_count: usize = response
.extensions
.iter()
.flat_map(|s| Extension::parse(s))
.filter(|e| e.name == "permessage-deflate")
.count();
if pmce_count > 1 {
return Err(Error::handshake_rejected(
"response contains multiple permessage-deflate elements",
));
}
}
const RESERVED: &[&str] = &[
"upgrade",
"connection",
"sec-websocket-accept",
"sec-websocket-protocol",
"sec-websocket-extensions",
];
for (name, _) in &response.additional_headers {
if RESERVED.contains(&name.to_ascii_lowercase().as_str()) {
return Err(Error::invalid_input(format!(
"additional header '{}' conflicts with a reserved WebSocket header",
name
)));
}
}
Ok(())
}
fn negotiate_deflate(
request: &ServerHandshakeRequest,
response: &ServerHandshakeResponse,
) -> Result<Option<PerMessageDeflate>, Error> {
let client_offered_smwb: Option<u8> = request
.extensions
.iter()
.flat_map(|s| Extension::parse(s))
.filter(|e| e.name == "permessage-deflate")
.find_map(|e| {
e.get_param("server_max_window_bits")
.and_then(|p| p.value.as_deref())
.and_then(|v| v.parse::<u8>().ok())
});
let mut deflate = None;
for ext_str in &response.extensions {
let extensions = Extension::parse(ext_str);
for ext in extensions {
if ext.name == "permessage-deflate" {
let config = PerMessageDeflateConfig::from_extension_for_client_response(&ext)
.map_err(|e| {
Error::handshake_rejected(format!(
"invalid permessage-deflate parameters: {:?}",
e
))
})?;
if let Some(smwb) = config.server_max_window_bits
&& smwb < 15
{
return Err(Error::handshake_rejected(format!(
"server_max_window_bits={} is not supported (only 15 is supported)",
smwb
)));
}
if let (Some(smwb), Some(offered)) =
(config.server_max_window_bits, client_offered_smwb)
&& smwb > offered
{
return Err(Error::handshake_rejected(format!(
"server_max_window_bits={} exceeds client offer={}",
smwb, offered
)));
}
if ext.get_param("client_max_window_bits").is_some() {
let client_offered_cmwb = request.extensions.iter().any(|req_ext_str| {
Extension::parse(req_ext_str).iter().any(|req_ext| {
req_ext.name == "permessage-deflate"
&& req_ext.get_param("client_max_window_bits").is_some()
})
});
if !client_offered_cmwb {
return Err(Error::handshake_rejected(
"client_max_window_bits included without client offer",
));
}
}
deflate = Some(PerMessageDeflate::new_server(config));
}
}
}
Ok(deflate)
}
fn build_handshake_response(
request: &ServerHandshakeRequest,
response: &ServerHandshakeResponse,
) -> Result<Vec<u8>, Error> {
let accept = calculate_accept_from_key(&request.key);
let mut response_builder = Response::new(101, "Switching Protocols")
.map_err(|e| Error::invalid_input(e.to_string()))?
.header("Upgrade", "websocket")
.map_err(|e| Error::invalid_input(e.to_string()))?
.header("Connection", "Upgrade")
.map_err(|e| Error::invalid_input(e.to_string()))?
.header("Sec-WebSocket-Accept", &accept)
.map_err(|e| Error::invalid_input(e.to_string()))?;
if let Some(protocol) = &response.protocol {
response_builder = response_builder
.header("Sec-WebSocket-Protocol", protocol)
.map_err(|e| Error::invalid_input(e.to_string()))?;
}
if !response.extensions.is_empty() {
response_builder = response_builder
.header("Sec-WebSocket-Extensions", response.extensions.join(", "))
.map_err(|e| Error::invalid_input(e.to_string()))?;
}
for (name, value) in &response.additional_headers {
let header_name =
HeaderName::new(name).map_err(|e| Error::invalid_input(e.to_string()))?;
response_builder = response_builder
.header(header_name, value)
.map_err(|e| Error::invalid_input(e.to_string()))?;
}
response_builder
.encode()
.map_err(|e| Error::invalid_input(e.to_string()))
}
pub fn reject_handshake(
&mut self,
status_code: u16,
reason: &str,
headers: &[(&str, &str)],
) -> Result<(), Error> {
if self.shared.state() != ConnectionState::Connecting {
return Err(Error::invalid_state("handshake is not in progress"));
}
self.pending_request = None;
self.pending_frame_data.clear();
self.handshake_validator.reset();
let mut response = Response::new(status_code, reason)
.map_err(|e| Error::invalid_input(e.to_string()))?
.header("Connection", "close")
.map_err(|e| Error::invalid_input(e.to_string()))?;
for (name, value) in headers {
let header_name =
HeaderName::new(*name).map_err(|e| Error::invalid_input(e.to_string()))?;
response = response
.header(header_name, *value)
.map_err(|e| Error::invalid_input(e.to_string()))?;
}
let encoded = response
.encode()
.map_err(|e| Error::invalid_input(e.to_string()))?;
self.shared
.enqueue_output(ConnectionOutput::SendData(encoded));
self.shared.set_state(ConnectionState::Closed)?;
self.shared
.enqueue_output(ConnectionOutput::CloseConnection);
Ok(())
}
pub fn send_text(&mut self, text: &str) -> Result<(), Error> {
self.shared.send_text(text, &mut self.policy)
}
pub fn send_binary(&mut self, data: &[u8]) -> Result<(), Error> {
self.shared.send_binary(data, &mut self.policy)
}
pub fn send_ping(&mut self, data: &[u8]) -> Result<(), Error> {
self.shared.send_ping(data, &mut self.policy)
}
pub fn close(&mut self, code: CloseCode, reason: &str) -> Result<(), Error> {
self.shared.close(code, reason, &mut self.policy)
}
pub fn handle_timer(&mut self, timer_id: TimerId) -> Result<(), Error> {
self.shared.handle_timer(timer_id, &mut self.policy)
}
pub fn poll_event(&mut self) -> Option<ConnectionEvent> {
self.shared.poll_event()
}
pub fn poll_output(&mut self) -> Option<ConnectionOutput> {
self.shared.poll_output()
}
fn process_handshake(&mut self, buf: &[u8]) -> Result<(), Error> {
if self.pending_request.is_some() {
if self.pending_frame_data.len() + buf.len() > MAX_PENDING_FRAME_DATA {
return Err(Error::protocol_violation(
"pending frame data exceeds limit while awaiting handshake acceptance",
));
}
self.pending_frame_data.extend_from_slice(buf);
return Ok(());
}
if self.shared.state() == ConnectionState::Disconnected {
self.shared.set_state(ConnectionState::Connecting)?;
}
self.handshake_validator.feed(buf);
match self.handshake_validator.validate() {
Ok(Some(request)) => {
self.pending_request = Some(request);
self.pending_frame_data
.extend_from_slice(self.handshake_validator.remaining());
Ok(())
}
Ok(None) => Ok(()),
Err(e) if e.kind == ErrorKind::VersionNotSupported => {
self.reject_handshake(426, "Upgrade Required", &[("Sec-WebSocket-Version", "13")])?;
Err(e)
}
Err(e) => {
self.reject_handshake(400, "Bad Request", &[])?;
Err(e)
}
}
}
fn select_protocol(&self, request: &ServerHandshakeRequest) -> Option<String> {
for protocol in &request.protocols {
if self.options.protocols.iter().any(|p| p == protocol) {
return Some(protocol.clone());
}
}
None
}
fn select_deflate(&self, request: &ServerHandshakeRequest) -> Option<PerMessageDeflateConfig> {
let server_config = self.options.deflate_config.clone()?;
for ext_str in &request.extensions {
for ext in Extension::parse(ext_str) {
if ext.name == "permessage-deflate" {
match PerMessageDeflateConfig::from_extension_for_server_request(&ext) {
Ok(client_request) => {
if client_request
.server_max_window_bits
.is_some_and(|v| v < 15)
{
continue;
}
return Some(PerMessageDeflateConfig::negotiate(
&client_request,
&server_config,
));
}
Err(_) => {
}
}
}
}
}
None
}
}