use crate::Timestamp;
use crate::deflate::PerMessageDeflate;
use crate::error::Error;
use crate::frame_policy::ClientFramePolicy;
use crate::websocket_close::CloseCode;
use crate::websocket_connection_shared::{
DEFAULT_MAX_DECOMPRESSED_SIZE, DEFAULT_MAX_FRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
SharedConnectionState,
};
use crate::websocket_connection_types::{
ConnectionEvent, ConnectionOutput, ConnectionState, RandomSource, TimerId,
};
use crate::websocket_extension::{Extension, PerMessageDeflateConfig};
use crate::websocket_handshake_request::HandshakeRequest;
use crate::websocket_handshake_response::{HandshakeResponse, HandshakeValidator};
#[derive(Debug, Clone)]
pub struct ClientConnectionOptions {
path: String,
host: String,
origin: Option<String>,
protocols: Vec<String>,
deflate_config: Option<PerMessageDeflateConfig>,
additional_headers: Vec<(String, String)>,
ping_interval_millis: u64,
pong_timeout_millis: u64,
close_timeout_millis: u64,
max_frame_size: usize,
max_message_size: usize,
max_decompressed_size: usize,
}
impl Default for ClientConnectionOptions {
fn default() -> Self {
Self {
path: "/".to_string(),
host: "localhost".to_string(),
origin: None,
protocols: Vec::new(),
deflate_config: None,
additional_headers: Vec::new(),
ping_interval_millis: 30_000, pong_timeout_millis: 10_000, close_timeout_millis: 5_000, max_frame_size: DEFAULT_MAX_FRAME_SIZE,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
max_decompressed_size: DEFAULT_MAX_DECOMPRESSED_SIZE,
}
}
}
impl ClientConnectionOptions {
pub fn new(host: &str, path: &str) -> Self {
Self {
path: path.to_string(),
host: host.to_string(),
..Default::default()
}
}
pub fn origin(mut self, origin: &str) -> Self {
self.origin = Some(origin.to_string());
self
}
pub fn protocol(mut self, protocol: &str) -> Self {
self.protocols.push(protocol.to_string());
self
}
pub fn deflate(mut self, config: PerMessageDeflateConfig) -> Self {
self.deflate_config = Some(config);
self
}
pub fn header(mut self, name: &str, value: &str) -> Self {
self.additional_headers
.push((name.to_string(), value.to_string()));
self
}
pub fn ping_interval(mut self, millis: u64) -> Self {
self.ping_interval_millis = millis;
self
}
pub fn pong_timeout(mut self, millis: u64) -> Self {
self.pong_timeout_millis = millis;
self
}
pub fn close_timeout(mut self, millis: u64) -> Self {
self.close_timeout_millis = millis;
self
}
pub fn max_frame_size(mut self, size: usize) -> Self {
self.max_frame_size = size;
self
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.max_message_size = size;
self
}
pub fn max_decompressed_size(mut self, size: usize) -> Self {
self.max_decompressed_size = size;
self
}
}
pub struct WebSocketClientConnection<R: RandomSource> {
shared: SharedConnectionState,
policy: ClientFramePolicy<R>,
options: ClientConnectionOptions,
nonce: [u8; 16],
handshake_validator: Option<HandshakeValidator>,
negotiated_protocol: Option<String>,
negotiated_extensions: Vec<String>,
}
impl<R: RandomSource> WebSocketClientConnection<R> {
pub fn new(options: ClientConnectionOptions, random: R) -> Self {
let shared = SharedConnectionState::new(
options.max_frame_size,
options.max_message_size,
options.max_decompressed_size,
options.ping_interval_millis,
options.pong_timeout_millis,
options.close_timeout_millis,
);
Self {
shared,
policy: ClientFramePolicy::new(random),
options,
nonce: [0u8; 16], handshake_validator: None,
negotiated_protocol: None,
negotiated_extensions: Vec::new(),
}
}
pub fn state(&self) -> ConnectionState {
self.shared.state()
}
pub fn protocol(&self) -> Option<&str> {
self.negotiated_protocol.as_deref()
}
pub fn extensions(&self) -> &[String] {
&self.negotiated_extensions
}
pub fn connect(&mut self) -> Result<(), Error> {
if self.shared.state() != ConnectionState::Disconnected {
return Err(Error::invalid_state("already connected or connecting"));
}
self.nonce = self.policy.nonce();
let mut request = HandshakeRequest::new(&self.options.path, &self.options.host);
if let Some(origin) = &self.options.origin {
request = request.origin(origin);
}
for protocol in &self.options.protocols {
request = request.protocol(protocol);
}
if let Some(deflate_config) = &self.options.deflate_config {
let ext = deflate_config.to_extension();
request = request.extension(&ext.encode());
}
for (name, value) in &self.options.additional_headers {
request = request.header(name, value);
}
let encoded = request.build(self.nonce)?;
self.handshake_validator = Some(HandshakeValidator::new(self.nonce));
self.shared
.enqueue_output(ConnectionOutput::SendData(encoded));
self.shared.set_state(ConnectionState::Connecting)?;
Ok(())
}
pub fn feed_recv_buf(&mut self, buf: &[u8], now: Timestamp) -> Result<(), Error> {
if self.shared.is_failed() {
return Err(Error::invalid_state("connection has failed"));
}
let result = match self.shared.state() {
ConnectionState::Connecting => self.process_handshake(buf, now),
ConnectionState::Connected | ConnectionState::Closing => {
self.shared.process_frames(buf, &mut self.policy)
}
ConnectionState::Disconnected | ConnectionState::Closed => {
return Err(Error::invalid_state("connection is closed"));
}
};
if result.is_err() {
self.shared.mark_failed();
}
result
}
pub fn send_text(&mut self, text: &str) -> Result<(), Error> {
self.shared.send_text(text, &mut self.policy)
}
pub fn send_binary(&mut self, data: &[u8]) -> Result<(), Error> {
self.shared.send_binary(data, &mut self.policy)
}
pub fn send_ping(&mut self, data: &[u8]) -> Result<(), Error> {
self.shared.send_ping(data, &mut self.policy)
}
pub fn close(&mut self, code: CloseCode, reason: &str) -> Result<(), Error> {
self.shared.close(code, reason, &mut self.policy)
}
pub fn handle_timer(&mut self, timer_id: TimerId) -> Result<(), Error> {
self.shared.handle_timer(timer_id, &mut self.policy)
}
pub fn poll_event(&mut self) -> Option<ConnectionEvent> {
self.shared.poll_event()
}
pub fn poll_output(&mut self) -> Option<ConnectionOutput> {
self.shared.poll_output()
}
fn process_handshake(&mut self, buf: &[u8], now: Timestamp) -> Result<(), Error> {
let validator = self
.handshake_validator
.as_mut()
.ok_or_else(|| Error::invalid_state("handshake validator not initialized"))?;
validator.feed(buf);
let result = validator.validate()?;
if let Some(response) = result {
let remaining = self
.handshake_validator
.as_ref()
.map(|v| v.remaining().to_vec())
.unwrap_or_default();
self.complete_handshake(response, now)?;
self.handshake_validator = None;
if !remaining.is_empty() {
self.shared.process_frames(&remaining, &mut self.policy)?;
}
}
Ok(())
}
fn complete_handshake(
&mut self,
response: HandshakeResponse,
_now: Timestamp,
) -> Result<(), Error> {
if let Some(ref protocol) = response.protocol
&& !self.options.protocols.iter().any(|p| p == protocol)
{
return Err(Error::handshake_rejected(format!(
"server returned unsolicited protocol: {}",
protocol
)));
}
let requested_extension_names: Vec<&str> = self
.options
.deflate_config
.as_ref()
.map(|_| vec!["permessage-deflate"])
.unwrap_or_default();
for ext_str in &response.extensions {
let extensions = Extension::parse_strict(ext_str).map_err(|e| {
Error::handshake_rejected(format!("invalid Sec-WebSocket-Extensions value: {e}"))
})?;
for ext in &extensions {
if !requested_extension_names.contains(&ext.name.as_str()) {
return Err(Error::handshake_rejected(format!(
"server returned unsolicited extension: {}",
ext.name
)));
}
}
}
self.negotiated_protocol = response.protocol;
self.negotiated_extensions = response.extensions.clone();
let pmce_count = response
.extensions
.iter()
.flat_map(|s| Extension::parse_strict(s).into_iter().flatten())
.filter(|e| e.name == "permessage-deflate")
.count();
if pmce_count > 1 {
return Err(Error::handshake_rejected(
"server returned multiple permessage-deflate elements",
));
}
for ext_str in &response.extensions {
let extensions = Extension::parse_strict(ext_str).map_err(|e| {
Error::handshake_rejected(format!("invalid Sec-WebSocket-Extensions value: {e}"))
})?;
for ext in extensions {
if ext.name == "permessage-deflate" {
self.validate_deflate_negotiation(&ext)?;
}
}
}
self.shared.set_state(ConnectionState::Connected)?;
self.shared.emit_event(ConnectionEvent::Connected {
protocol: self.negotiated_protocol.clone(),
extensions: self.negotiated_extensions.clone(),
});
if self.options.ping_interval_millis > 0 {
self.shared.enqueue_output(ConnectionOutput::SetTimer {
id: TimerId::Ping,
duration_millis: self.options.ping_interval_millis,
});
}
Ok(())
}
fn validate_deflate_negotiation(&mut self, ext: &Extension) -> Result<(), Error> {
let config = match PerMessageDeflateConfig::from_extension_for_client_response(ext) {
Ok(config) => config,
Err(e) => {
if self.options.deflate_config.is_some() {
return Err(Error::handshake_rejected(format!(
"invalid permessage-deflate response: {:?}",
e
)));
}
return Ok(());
}
};
if let Some(deflate_config) = &self.options.deflate_config {
let client_offered_cmwb = deflate_config.client_max_window_bits.is_some();
let server_included_cmwb = ext.get_param("client_max_window_bits").is_some();
if server_included_cmwb && !client_offered_cmwb {
return Err(Error::handshake_rejected(
"server included client_max_window_bits without client offer",
));
}
if let (Some(client_smwb), Some(server_smwb)) = (
deflate_config.server_max_window_bits,
config.server_max_window_bits,
) && server_smwb > client_smwb
{
return Err(Error::handshake_rejected(format!(
"server_max_window_bits {} exceeds client offer {}",
server_smwb, client_smwb
)));
}
if let (Some(client_cmwb), Some(server_cmwb)) = (
deflate_config.client_max_window_bits,
config.client_max_window_bits,
) && server_cmwb > client_cmwb
{
return Err(Error::handshake_rejected(format!(
"client_max_window_bits {} exceeds client offer {}",
server_cmwb, client_cmwb
)));
}
if let Some(cmwb) = config.client_max_window_bits
&& cmwb < 15
{
return Err(Error::handshake_rejected(format!(
"client_max_window_bits={} is not supported (only 15 is supported)",
cmwb
)));
}
}
self.shared
.enable_deflate(PerMessageDeflate::new_client(config));
Ok(())
}
}