mod builder;
mod options;
mod split;
pub mod streaming;
mod upgrade;
use crate::{codec, compression, frame, streaming::Streaming, Result, WebSocketError};
use {
bytes::Bytes,
http_body_util::Empty,
hyper::{body::Incoming, header, upgrade::Upgraded, Request, Response, StatusCode},
hyper_util::rt::TokioIo,
tokio::net::TcpStream,
tokio_rustls::{rustls::pki_types::ServerName, TlsConnector},
};
use std::{
borrow::BorrowMut,
collections::VecDeque,
future::poll_fn,
io,
net::SocketAddr,
pin::{pin, Pin},
str::FromStr,
sync::Arc,
task::{ready, Context, Poll},
time::{Duration, Instant},
};
use tokio::io::{AsyncRead, AsyncWrite};
use codec::Codec;
use compression::{Compressor, Decompressor, WebSocketExtensions};
use futures::{task::AtomicWaker, SinkExt};
use tokio_rustls::rustls::{self, pki_types::TrustAnchor};
use tokio_util::codec::Framed;
use url::Url;
pub use crate::stream::MaybeTlsStream;
pub use builder::{HttpRequest, HttpRequestBuilder, WebSocketBuilder};
pub use frame::{Frame, OpCode};
pub use options::{CompressionLevel, DeflateOptions, Fragmentation, Options};
pub use split::{ReadHalf, WriteHalf};
pub use upgrade::UpgradeFut;
pub type TcpWebSocket = WebSocket<MaybeTlsStream<TcpStream>>;
pub type HttpWebSocket = WebSocket<HttpStream>;
#[cfg(feature = "axum")]
pub use upgrade::IncomingUpgrade;
pub const MAX_PAYLOAD_READ: usize = 1024 * 1024;
pub const MAX_READ_BUFFER: usize = 2 * 1024 * 1024;
pub type HttpResponse = Response<Empty<Bytes>>;
pub type UpgradeResult = Result<(HttpResponse, UpgradeFut)>;
#[derive(Debug, Default, Clone)]
pub(crate) struct Negotiation {
pub(crate) extensions: Option<WebSocketExtensions>,
pub(crate) compression_level: Option<CompressionLevel>,
pub(crate) max_payload_read: usize,
pub(crate) max_read_buffer: usize,
pub(crate) utf8: bool,
pub(crate) fragmentation: Option<options::Fragmentation>,
pub(crate) max_backpressure_write_boundary: Option<usize>,
}
impl Negotiation {
pub(crate) fn decompressor(&self, role: Role) -> Option<Decompressor> {
let config = self.extensions.as_ref()?;
log::debug!(
"Established decompressor for {role} with settings \
client_no_context_takeover={} server_no_context_takeover={} \
server_max_window_bits={:?} client_max_window_bits={:?}",
config.client_no_context_takeover,
config.client_no_context_takeover,
config.server_max_window_bits,
config.client_max_window_bits
);
Some(if role == Role::Server {
if config.client_no_context_takeover {
Decompressor::no_context_takeover()
} else {
#[cfg(feature = "zlib")]
if let Some(Some(window_bits)) = config.client_max_window_bits {
Decompressor::new_with_window_bits(window_bits.max(9))
} else {
Decompressor::new()
}
#[cfg(not(feature = "zlib"))]
Decompressor::new()
}
} else {
if config.server_no_context_takeover {
Decompressor::no_context_takeover()
} else {
#[cfg(feature = "zlib")]
if let Some(Some(window_bits)) = config.server_max_window_bits {
Decompressor::new_with_window_bits(window_bits)
} else {
Decompressor::new()
}
#[cfg(not(feature = "zlib"))]
Decompressor::new()
}
})
}
pub(crate) fn compressor(&self, role: Role) -> Option<Compressor> {
let config = self.extensions.as_ref()?;
log::debug!(
"Established compressor for {role} with settings \
client_no_context_takeover={} server_no_context_takeover={} \
server_max_window_bits={:?} client_max_window_bits={:?}",
config.client_no_context_takeover,
config.client_no_context_takeover,
config.server_max_window_bits,
config.client_max_window_bits
);
let level = self.compression_level.unwrap();
Some(if role == Role::Client {
if config.client_no_context_takeover {
Compressor::no_context_takeover(level)
} else {
#[cfg(feature = "zlib")]
if let Some(Some(window_bits)) = config.client_max_window_bits {
Compressor::new_with_window_bits(level, window_bits)
} else {
Compressor::new(level)
}
#[cfg(not(feature = "zlib"))]
Compressor::new(level)
}
} else {
if config.server_no_context_takeover {
Compressor::no_context_takeover(level)
} else {
#[cfg(feature = "zlib")]
if let Some(Some(window_bits)) = config.server_max_window_bits {
Compressor::new_with_window_bits(level, window_bits)
} else {
Compressor::new(level)
}
#[cfg(not(feature = "zlib"))]
Compressor::new(level)
}
})
}
}
#[derive(Copy, Clone, PartialEq)]
pub enum Role {
Server,
Client,
}
impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Server => write!(f, "server"),
Self::Client => write!(f, "client"),
}
}
}
#[derive(Clone, Copy)]
enum ContextKind {
Read,
Write,
}
#[derive(Default)]
struct WakeProxy {
read_waker: AtomicWaker,
write_waker: AtomicWaker,
}
impl futures::task::ArcWake for WakeProxy {
fn wake_by_ref(this: &Arc<Self>) {
this.read_waker.wake();
this.write_waker.wake();
}
}
impl WakeProxy {
#[inline]
fn set_waker(&self, kind: ContextKind, waker: &futures::task::Waker) {
match kind {
ContextKind::Read => {
self.read_waker.register(waker);
}
ContextKind::Write => {
self.write_waker.register(waker);
}
}
}
#[inline(always)]
fn with_context<F, R>(self: &Arc<Self>, f: F) -> R
where
F: FnOnce(&mut Context<'_>) -> R,
{
let waker = futures::task::waker_ref(self);
let mut cx = Context::from_waker(&waker);
f(&mut cx)
}
}
pub enum HttpStream {
#[cfg(feature = "reqwest")]
Reqwest(reqwest::Upgraded),
Hyper(TokioIo<Upgraded>),
}
impl From<TokioIo<Upgraded>> for HttpStream {
fn from(value: TokioIo<Upgraded>) -> Self {
Self::Hyper(value)
}
}
#[cfg(feature = "reqwest")]
impl From<reqwest::Upgraded> for HttpStream {
fn from(value: reqwest::Upgraded) -> Self {
Self::Reqwest(value)
}
}
impl AsyncRead for HttpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
#[cfg(feature = "reqwest")]
Self::Reqwest(stream) => pin!(stream).poll_read(cx, buf),
Self::Hyper(stream) => pin!(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for HttpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::result::Result<usize, io::Error>> {
match self.get_mut() {
#[cfg(feature = "reqwest")]
Self::Reqwest(stream) => pin!(stream).poll_write(cx, buf),
Self::Hyper(stream) => pin!(stream).poll_write(cx, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), io::Error>> {
match self.get_mut() {
#[cfg(feature = "reqwest")]
Self::Reqwest(stream) => pin!(stream).poll_flush(cx),
Self::Hyper(stream) => pin!(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), io::Error>> {
match self.get_mut() {
#[cfg(feature = "reqwest")]
Self::Reqwest(stream) => pin!(stream).poll_shutdown(cx),
Self::Hyper(stream) => pin!(stream).poll_shutdown(cx),
}
}
}
pub(super) struct FragmentationState {
started: Instant,
opcode: OpCode,
is_compressed: bool,
bytes_read: usize,
parts: VecDeque<Bytes>,
}
struct FragmentLayer {
outgoing_fragments: VecDeque<Frame>,
incoming_fragment: Option<FragmentationState>,
fragment_size: Option<usize>,
max_read_buffer: usize,
fragment_timeout: Option<Duration>,
}
impl FragmentLayer {
fn new(
fragment_size: Option<usize>,
max_read_buffer: usize,
fragment_timeout: Option<Duration>,
) -> Self {
Self {
outgoing_fragments: VecDeque::new(),
incoming_fragment: None,
fragment_size,
max_read_buffer,
fragment_timeout,
}
}
fn fragment_outgoing(&mut self, frame: Frame) {
if !frame.is_fin() && self.fragment_size.is_some() {
panic!(
"Fragment the frames yourself or use `fragment_size`, but not both. Use Streaming"
);
}
let max_fragment_size = self.fragment_size.unwrap_or(usize::MAX);
self.outgoing_fragments
.extend(frame.into_fragments(max_fragment_size));
}
#[inline(always)]
fn pop_outgoing_fragment(&mut self) -> Option<Frame> {
self.outgoing_fragments.pop_front()
}
#[inline(always)]
fn has_outgoing_fragments(&self) -> bool {
!self.outgoing_fragments.is_empty()
}
fn assemble_incoming(&mut self, mut frame: Frame) -> Result<Option<Frame>> {
use bytes::BufMut;
#[cfg(test)]
println!(
"<<Fragmentation<< OpCode={:?} fin={} len={}",
frame.opcode(),
frame.is_fin(),
frame.payload.len()
);
match frame.opcode {
OpCode::Text | OpCode::Binary => {
if self.incoming_fragment.is_some() {
return Err(WebSocketError::InvalidFragment);
}
if !frame.fin {
let fragmentation = FragmentationState {
started: Instant::now(),
opcode: frame.opcode,
is_compressed: frame.is_compressed,
bytes_read: frame.payload.len(),
parts: VecDeque::from([frame.payload]),
};
self.incoming_fragment = Some(fragmentation);
return Ok(None);
}
Ok(Some(frame))
}
OpCode::Continuation => {
let mut fragment = self
.incoming_fragment
.take()
.ok_or_else(|| WebSocketError::InvalidFragment)?;
fragment.bytes_read += frame.payload.len();
if fragment.bytes_read >= self.max_read_buffer {
return Err(WebSocketError::FrameTooLarge);
}
if let Some(timeout) = self.fragment_timeout {
if fragment.started.elapsed() > timeout {
return Err(WebSocketError::FragmentTimeout);
}
}
fragment.parts.push_back(frame.payload);
if frame.fin {
frame.opcode = fragment.opcode;
frame.is_compressed = fragment.is_compressed;
frame.payload = fragment
.parts
.into_iter()
.fold(
bytes::BytesMut::with_capacity(fragment.bytes_read),
|mut acc, b| {
acc.put(b);
acc
},
)
.freeze();
Ok(Some(frame))
} else {
self.incoming_fragment = Some(fragment);
Ok(None)
}
}
_ => {
Ok(Some(frame))
}
}
}
}
pub struct WebSocket<S> {
streaming: Streaming<S>,
check_utf8: bool,
fragment_layer: FragmentLayer,
}
impl WebSocket<MaybeTlsStream<TcpStream>> {
pub fn connect(url: Url) -> WebSocketBuilder {
WebSocketBuilder::new(url)
}
pub(crate) async fn connect_priv(
url: Url,
tcp_address: Option<SocketAddr>,
connector: Option<TlsConnector>,
options: Options,
builder: HttpRequestBuilder,
) -> Result<TcpWebSocket> {
let host = url.host().expect("hostname").to_string();
let tcp_stream = if let Some(tcp_address) = tcp_address {
TcpStream::connect(tcp_address).await?
} else {
let port = url.port_or_known_default().expect("port");
TcpStream::connect(format!("{host}:{port}")).await?
};
let _ = tcp_stream.set_nodelay(options.no_delay);
let stream = match url.scheme() {
"ws" => MaybeTlsStream::Plain(tcp_stream),
"wss" => {
let connector = connector.unwrap_or_else(tls_connector);
let domain = ServerName::try_from(host)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
MaybeTlsStream::Tls(connector.connect(domain, tcp_stream).await?)
}
_ => return Err(WebSocketError::InvalidHttpScheme),
};
WebSocket::handshake_with_request(url, stream, options, builder).await
}
}
impl<S> WebSocket<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
pub async fn handshake(url: Url, io: S, options: Options) -> Result<WebSocket<S>> {
Self::handshake_with_request(url, io, options, HttpRequest::builder()).await
}
pub async fn handshake_with_request(
url: Url,
io: S,
options: Options,
mut builder: HttpRequestBuilder,
) -> Result<WebSocket<S>> {
if !builder
.headers_ref()
.expect("header")
.contains_key(header::HOST)
{
let host = url.host().expect("hostname").to_string();
let is_port_defined = url.port().is_some();
let port = url.port_or_known_default().expect("port");
let host_header = if is_port_defined {
format!("{host}:{port}")
} else {
host
};
builder = builder.header(header::HOST, host_header.as_str());
}
let target_url = &url[url::Position::BeforePath..];
let mut req = builder
.method("GET")
.uri(target_url)
.header(header::UPGRADE, "websocket")
.header(header::CONNECTION, "upgrade")
.header(header::SEC_WEBSOCKET_KEY, generate_key())
.header(header::SEC_WEBSOCKET_VERSION, "13")
.body(Empty::<Bytes>::new())
.expect("request build");
if let Some(compression) = options.compression.as_ref() {
let extensions = WebSocketExtensions::from(compression);
let header_value = extensions.to_string().parse().unwrap();
req.headers_mut()
.insert(header::SEC_WEBSOCKET_EXTENSIONS, header_value);
}
let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(io)).await?;
#[cfg(not(feature = "smol"))]
tokio::spawn(async move {
if let Err(err) = conn.with_upgrades().await {
log::error!("upgrading connection: {:?}", err);
}
});
#[cfg(feature = "smol")]
smol::spawn(async move {
if let Err(err) = conn.with_upgrades().await {
log::error!("upgrading connection: {:?}", err);
}
})
.detach();
let mut response = sender.send_request(req).await?;
let negotiated = verify(&response, options)?;
let upgraded = hyper::upgrade::on(&mut response).await?;
let parts = upgraded.downcast::<TokioIo<S>>().unwrap();
let stream = parts.io.into_inner();
let read_buf = parts.read_buf;
Ok(WebSocket::new(Role::Client, stream, read_buf, negotiated))
}
}
impl WebSocket<HttpStream> {
#[cfg(feature = "reqwest")]
#[cfg_attr(docsrs, doc(cfg(feature = "reqwest")))]
pub async fn reqwest(
mut url: Url,
client: reqwest::Client,
options: Options,
) -> Result<WebSocket<HttpStream>> {
let host = url.host().expect("hostname").to_string();
let host_header = if let Some(port) = url.port() {
format!("{host}:{port}")
} else {
host
};
match url.scheme() {
"ws" => {
let _ = url.set_scheme("http");
}
"wss" => {
let _ = url.set_scheme("https");
}
_ => {}
}
let req = client
.get(url.as_str())
.header(reqwest::header::HOST, host_header.as_str())
.header(reqwest::header::UPGRADE, "websocket")
.header(reqwest::header::CONNECTION, "upgrade")
.header(reqwest::header::SEC_WEBSOCKET_KEY, generate_key())
.header(reqwest::header::SEC_WEBSOCKET_VERSION, "13");
let req = if let Some(compression) = options.compression.as_ref() {
let extensions = WebSocketExtensions::from(compression);
req.header(
reqwest::header::SEC_WEBSOCKET_EXTENSIONS,
extensions.to_string(),
)
} else {
req
};
let response = req.send().await?;
let negotiated = verify_reqwest(&response, options)?;
let upgraded = response.upgrade().await?;
Ok(WebSocket::new(
Role::Client,
HttpStream::from(upgraded),
Bytes::new(),
negotiated,
))
}
pub fn upgrade<B>(request: impl BorrowMut<Request<B>>) -> UpgradeResult {
Self::upgrade_with_options(request, Options::default())
}
pub fn upgrade_with_options<B>(
mut request: impl BorrowMut<Request<B>>,
options: Options,
) -> UpgradeResult {
let request = request.borrow_mut();
let key = request
.headers()
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketError::MissingSecWebSocketKey)?;
if request
.headers()
.get(header::SEC_WEBSOCKET_VERSION)
.map(|v| v.as_bytes())
!= Some(b"13")
{
return Err(WebSocketError::InvalidSecWebsocketVersion);
}
let maybe_compression = request
.headers()
.get(header::SEC_WEBSOCKET_EXTENSIONS)
.and_then(|h| h.to_str().ok())
.map(WebSocketExtensions::from_str)
.and_then(std::result::Result::ok);
let mut response = Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header(
header::SEC_WEBSOCKET_ACCEPT,
upgrade::sec_websocket_protocol(key.as_bytes()),
)
.body(Empty::new())
.expect("bug: failed to build response");
let extensions = if let Some(client_compression) = maybe_compression {
if let Some(server_compression) = options.compression.as_ref() {
let offer = server_compression.merge(&client_compression);
let header_value = offer.to_string().parse().unwrap();
response
.headers_mut()
.insert(header::SEC_WEBSOCKET_EXTENSIONS, header_value);
Some(offer)
} else {
None
}
} else {
None
};
let max_read_buffer = options.max_read_buffer.unwrap_or(
options
.max_payload_read
.map(|payload_read| payload_read * 2)
.unwrap_or(MAX_READ_BUFFER),
);
let stream = UpgradeFut {
inner: hyper::upgrade::on(request),
negotiation: Some(Negotiation {
extensions,
compression_level: options
.compression
.as_ref()
.map(|compression| compression.level),
max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
max_read_buffer,
utf8: options.check_utf8,
fragmentation: options.fragmentation.clone(),
max_backpressure_write_boundary: options.max_backpressure_write_boundary,
}),
};
Ok((response, stream))
}
}
impl<S> WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub unsafe fn split_stream(self) -> (Framed<S, Codec>, ReadHalf, WriteHalf) {
self.streaming.split_stream()
}
pub fn into_streaming(self) -> Streaming<S> {
self.streaming
}
pub fn poll_next_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Frame>> {
loop {
let frame = ready!(self.streaming.poll_next_frame(cx))?;
match self.on_frame(frame)? {
Some(ok) => break Poll::Ready(Ok(ok)),
None => continue,
}
}
}
pub async fn next_frame(&mut self) -> Result<Frame> {
poll_fn(|cx| self.poll_next_frame(cx)).await
}
pub(crate) fn new(role: Role, stream: S, read_buf: Bytes, opts: Negotiation) -> Self {
Self {
streaming: Streaming::new(role, stream, read_buf, &opts),
check_utf8: opts.utf8,
fragment_layer: FragmentLayer::new(
opts.fragmentation.as_ref().and_then(|f| f.fragment_size),
opts.max_read_buffer,
opts.fragmentation.as_ref().and_then(|f| f.timeout),
),
}
}
fn on_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
let frame = match self.fragment_layer.assemble_incoming(frame)? {
Some(frame) => frame,
None => return Ok(None), };
if frame.opcode == OpCode::Text && self.check_utf8 {
#[cfg(not(feature = "simd"))]
if std::str::from_utf8(&frame.payload).is_err() {
return Err(WebSocketError::InvalidUTF8);
}
#[cfg(feature = "simd")]
if simdutf8::basic::from_utf8(&frame.payload).is_err() {
return Err(WebSocketError::InvalidUTF8);
}
}
Ok(Some(frame))
}
}
impl<S> futures::Stream for WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Item = Frame;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
match ready!(this.poll_next_frame(cx)) {
Ok(ok) => Poll::Ready(Some(ok)),
Err(_) => Poll::Ready(None),
}
}
}
impl<S> futures::Sink<Frame> for WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Error = WebSocketError;
fn poll_ready(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
let this = self.get_mut();
this.streaming.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: Frame) -> std::result::Result<(), Self::Error> {
let this = self.get_mut();
this.fragment_layer.fragment_outgoing(item);
Ok(())
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
let this = self.get_mut();
while this.fragment_layer.has_outgoing_fragments() {
ready!(this.streaming.poll_ready_unpin(cx))?;
let fragment = this
.fragment_layer
.pop_outgoing_fragment()
.expect("fragment");
this.streaming.start_send_unpin(fragment)?;
}
this.streaming.poll_flush_unpin(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
let this = self.get_mut();
this.streaming.poll_close_unpin(cx)
}
}
#[cfg(feature = "reqwest")]
fn verify_reqwest(response: &reqwest::Response, options: Options) -> Result<Negotiation> {
if response.status() != reqwest::StatusCode::SWITCHING_PROTOCOLS {
return Err(WebSocketError::InvalidStatusCode(
response.status().as_u16(),
));
}
let compression_level = options.compression.as_ref().map(|opts| opts.level);
let headers = response.headers();
if !headers
.get(reqwest::header::UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(WebSocketError::InvalidUpgradeHeader);
}
if !headers
.get(reqwest::header::CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(WebSocketError::InvalidConnectionHeader);
}
let extensions = headers
.get(reqwest::header::SEC_WEBSOCKET_EXTENSIONS)
.and_then(|h| h.to_str().ok())
.map(WebSocketExtensions::from_str)
.and_then(std::result::Result::ok);
let max_read_buffer = options.max_read_buffer.unwrap_or(
options
.max_payload_read
.map(|payload_read| payload_read * 2)
.unwrap_or(MAX_READ_BUFFER),
);
Ok(Negotiation {
extensions,
compression_level,
max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
max_read_buffer,
utf8: options.check_utf8,
fragmentation: options.fragmentation.clone(),
max_backpressure_write_boundary: options.max_backpressure_write_boundary,
})
}
fn verify(response: &Response<Incoming>, options: Options) -> Result<Negotiation> {
if response.status().is_redirection() {
let location = response
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
return Err(WebSocketError::Redirected {
status_code: response.status().as_u16(),
location,
});
}
if response.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(WebSocketError::InvalidStatusCode(
response.status().as_u16(),
));
}
let compression_level = options.compression.as_ref().map(|opts| opts.level);
let headers = response.headers();
if !headers
.get(header::UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
return Err(WebSocketError::InvalidUpgradeHeader);
}
if !headers
.get(header::CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("Upgrade"))
.unwrap_or(false)
{
return Err(WebSocketError::InvalidConnectionHeader);
}
let extensions = headers
.get(header::SEC_WEBSOCKET_EXTENSIONS)
.and_then(|h| h.to_str().ok())
.map(WebSocketExtensions::from_str)
.and_then(std::result::Result::ok);
let max_read_buffer = options.max_read_buffer.unwrap_or(
options
.max_payload_read
.map(|payload_read| payload_read * 2)
.unwrap_or(MAX_READ_BUFFER),
);
Ok(Negotiation {
extensions,
compression_level,
max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
max_read_buffer,
utf8: options.check_utf8,
fragmentation: options.fragmentation.clone(),
max_backpressure_write_boundary: options.max_backpressure_write_boundary,
})
}
fn generate_key() -> String {
use base64::prelude::*;
let input: [u8; 16] = rand::random();
BASE64_STANDARD.encode(input)
}
fn tls_connector() -> TlsConnector {
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| TrustAnchor {
subject: ta.subject.clone(),
subject_public_key_info: ta.subject_public_key_info.clone(),
name_constraints: ta.name_constraints.clone(),
}));
let maybe_provider = rustls::crypto::CryptoProvider::get_default().cloned();
#[cfg(any(feature = "rustls-ring", feature = "rustls-aws-lc-rs"))]
let provider = maybe_provider.unwrap_or_else(|| {
#[cfg(feature = "rustls-ring")]
let provider = rustls::crypto::ring::default_provider();
#[cfg(feature = "rustls-aws-lc-rs")]
let provider = rustls::crypto::aws_lc_rs::default_provider();
Arc::new(provider)
});
#[cfg(not(any(feature = "rustls-ring", feature = "rustls-aws-lc-rs")))]
let provider = maybe_provider.expect(
r#"No Rustls crypto provider was enabled for yawc to connect to a `wss://` endpoint!
Either:
- provide a `connector` in the WebSocketBuilder options
- enable one of the following features: `rustls-ring`, `rustls-aws-lc-rs`"#,
);
let mut config = rustls::ClientConfig::builder_with_provider(provider)
.with_protocol_versions(rustls::ALL_VERSIONS)
.expect("versions")
.with_root_certificates(root_cert_store)
.with_no_client_auth();
config.alpn_protocols = vec!["http/1.1".into()];
TlsConnector::from(Arc::new(config))
}
#[cfg(test)]
mod tests {
use crate::close::{self, CloseCode};
use super::*;
use futures::SinkExt;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
struct MockStream {
inner: DuplexStream,
}
impl MockStream {
fn pair(buffer_size: usize) -> (Self, Self) {
let (a, b) = tokio::io::duplex(buffer_size);
(Self { inner: a }, Self { inner: b })
}
}
impl AsyncRead for MockStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for MockStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
fn create_websocket_pair(buffer_size: usize) -> (WebSocket<MockStream>, WebSocket<MockStream>) {
create_websocket_pair_with_config(buffer_size, None, None)
}
fn create_websocket_pair_with_config(
buffer_size: usize,
fragment_size: Option<usize>,
compression_level: Option<CompressionLevel>,
) -> (WebSocket<MockStream>, WebSocket<MockStream>) {
let (client_stream, server_stream) = MockStream::pair(buffer_size);
let extensions = compression_level.map(|_level| WebSocketExtensions {
server_max_window_bits: None,
client_max_window_bits: None,
server_no_context_takeover: false,
client_no_context_takeover: false,
});
let negotiation = Negotiation {
extensions,
compression_level,
max_payload_read: MAX_PAYLOAD_READ,
max_read_buffer: MAX_READ_BUFFER,
utf8: false,
fragmentation: fragment_size.map(|size| options::Fragmentation {
timeout: None,
fragment_size: Some(size),
}),
max_backpressure_write_boundary: None,
};
let client_ws = WebSocket::new(
Role::Client,
client_stream,
Bytes::new(),
negotiation.clone(),
);
let server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
(client_ws, server_ws)
}
#[tokio::test]
async fn test_send_and_receive_text_frame() {
let (mut client, mut server) = create_websocket_pair(1024);
let text = "Hello, WebSocket!";
client
.send(Frame::text(text))
.await
.expect("Failed to send text frame");
let frame = server.next_frame().await.expect("Failed to receive frame");
assert_eq!(frame.opcode(), OpCode::Text);
assert_eq!(frame.payload(), text.as_bytes());
assert!(frame.is_fin());
}
#[tokio::test]
async fn test_send_and_receive_binary_frame() {
let (mut client, mut server) = create_websocket_pair(1024);
let data = vec![1u8, 2, 3, 4, 5];
client
.send(Frame::binary(data.clone()))
.await
.expect("Failed to send binary frame");
let frame = server.next_frame().await.expect("Failed to receive frame");
assert_eq!(frame.opcode(), OpCode::Binary);
assert_eq!(frame.payload(), &data[..]);
assert!(frame.is_fin());
}
#[tokio::test]
async fn test_bidirectional_communication() {
let (mut client, mut server) = create_websocket_pair(2048);
client
.send(Frame::text("Client message"))
.await
.expect("Failed to send from client");
let frame = server
.next_frame()
.await
.expect("Failed to receive at server");
assert_eq!(frame.payload(), b"Client message" as &[u8]);
server
.send(Frame::text("Server response"))
.await
.expect("Failed to send from server");
let frame = client
.next_frame()
.await
.expect("Failed to receive at client");
assert_eq!(frame.payload(), b"Server response" as &[u8]);
}
#[tokio::test]
async fn test_ping_pong() {
let (mut client, mut server) = create_websocket_pair(1024);
client
.send(Frame::pong("pong_data"))
.await
.expect("Failed to send pong");
let frame = server.next_frame().await.expect("Failed to receive pong");
assert_eq!(frame.opcode(), OpCode::Pong);
assert_eq!(frame.payload(), b"pong_data" as &[u8]);
}
#[tokio::test]
async fn test_close_frame() {
let (mut client, mut server) = create_websocket_pair(1024);
client
.send(Frame::close(CloseCode::Normal, b"Goodbye"))
.await
.expect("Failed to send close frame");
let frame = server
.next_frame()
.await
.expect("Failed to receive close frame");
assert_eq!(frame.opcode(), OpCode::Close);
assert_eq!(frame.close_code(), Some(close::CloseCode::Normal));
assert_eq!(
frame.close_reason().expect("Invalid close reason"),
Some("Goodbye")
);
}
#[tokio::test]
async fn test_large_message() {
let (mut client, mut server) = create_websocket_pair(65536);
let large_data = vec![42u8; 10240];
client
.send(Frame::binary(large_data.clone()))
.await
.expect("Failed to send large message");
let frame = server
.next_frame()
.await
.expect("Failed to receive large message");
assert_eq!(frame.opcode(), OpCode::Binary);
assert_eq!(frame.payload().len(), 10240);
assert_eq!(frame.payload(), &large_data[..]);
}
#[tokio::test]
async fn test_multiple_messages() {
let (mut client, mut server) = create_websocket_pair(4096);
for i in 0..10 {
let msg = format!("Message {}", i);
client
.send(Frame::text(msg.clone()))
.await
.expect("Failed to send message");
let frame = server
.next_frame()
.await
.expect("Failed to receive message");
assert_eq!(frame.payload(), msg.as_bytes());
}
}
#[tokio::test]
async fn test_empty_payload() {
let (mut client, mut server) = create_websocket_pair(1024);
client
.send(Frame::text(Bytes::new()))
.await
.expect("Failed to send empty frame");
let frame = server
.next_frame()
.await
.expect("Failed to receive empty frame");
assert_eq!(frame.opcode(), OpCode::Text);
assert_eq!(frame.payload().len(), 0);
}
#[tokio::test]
async fn test_fragmented_message() {
let (mut client, mut server) = create_websocket_pair(2048);
let mut frame1 = Frame::text("Hello, ");
frame1.set_fin(false);
client
.send(frame1)
.await
.expect("Failed to send first fragment");
let frame2 = Frame::continuation("World!");
client
.send(frame2)
.await
.expect("Failed to send final fragment");
let received = server
.next_frame()
.await
.expect("Failed to receive message");
assert_eq!(received.opcode(), OpCode::Text);
assert!(received.is_fin());
assert_eq!(received.payload(), b"Hello, World!" as &[u8]);
}
#[tokio::test]
async fn test_concurrent_send_receive() {
let (mut client, mut server) = create_websocket_pair(4096);
let client_task = tokio::spawn(async move {
for i in 0..5 {
client
.send(Frame::text(format!("Client {}", i)))
.await
.expect("Failed to send from client");
let frame = client
.next_frame()
.await
.expect("Failed to receive at client");
assert_eq!(frame.payload(), format!("Server {}", i).as_bytes());
}
client
});
let server_task = tokio::spawn(async move {
for i in 0..5 {
let frame = server
.next_frame()
.await
.expect("Failed to receive at server");
assert_eq!(frame.payload(), format!("Client {}", i).as_bytes());
server
.send(Frame::text(format!("Server {}", i)))
.await
.expect("Failed to send from server");
}
server
});
client_task.await.expect("Client task failed");
server_task.await.expect("Server task failed");
}
#[tokio::test]
async fn test_utf8_validation() {
let (mut client, mut server) = create_websocket_pair(1024);
let valid_utf8 = "Hello, 世界! 🌍";
client
.send(Frame::text(valid_utf8))
.await
.expect("Failed to send UTF-8 text");
let frame = server
.next_frame()
.await
.expect("Failed to receive UTF-8 text");
assert_eq!(frame.opcode(), OpCode::Text);
assert!(frame.is_utf8());
assert_eq!(std::str::from_utf8(frame.payload()).unwrap(), valid_utf8);
}
#[tokio::test]
async fn test_stream_trait_implementation() {
use futures::StreamExt;
let (mut client, mut server) = create_websocket_pair(1024);
tokio::spawn(async move {
for i in 0..3 {
client
.send(Frame::text(format!("Message {}", i)))
.await
.expect("Failed to send message");
}
});
let mut count = 0;
while let Some(frame) = server.next().await {
assert_eq!(frame.opcode(), OpCode::Text);
count += 1;
if count == 3 {
break;
}
}
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_sink_trait_implementation() {
use futures::SinkExt;
let (mut client, mut server) = create_websocket_pair(1024);
client
.send(Frame::text("Sink message"))
.await
.expect("Failed to send via Sink");
client.flush().await.expect("Failed to flush");
let frame = server
.next_frame()
.await
.expect("Failed to receive message");
assert_eq!(frame.payload(), b"Sink message" as &[u8]);
}
#[tokio::test]
async fn test_rapid_small_messages() {
let (mut client, mut server) = create_websocket_pair(8192);
let count = 100;
let sender = tokio::spawn(async move {
for i in 0..count {
client
.send(Frame::text(format!("{}", i)))
.await
.expect("Failed to send");
}
client
});
for i in 0..count {
let frame = server.next_frame().await.expect("Failed to receive");
assert_eq!(frame.payload(), format!("{}", i).as_bytes());
}
sender.await.expect("Sender task failed");
}
#[tokio::test]
async fn test_interleaved_control_and_data_frames() {
let (mut client, mut server) = create_websocket_pair(2048);
client
.send(Frame::text("Data 1"))
.await
.expect("Failed to send");
client
.send(Frame::pong("pong"))
.await
.expect("Failed to send pong");
client
.send(Frame::binary(vec![1, 2, 3]))
.await
.expect("Failed to send");
let f1 = server.next_frame().await.expect("Failed to receive");
assert_eq!(f1.opcode(), OpCode::Text);
assert_eq!(f1.payload(), b"Data 1" as &[u8]);
let f2 = server.next_frame().await.expect("Failed to receive");
assert_eq!(f2.opcode(), OpCode::Pong);
let f3 = server.next_frame().await.expect("Failed to receive");
assert_eq!(f3.opcode(), OpCode::Binary);
assert_eq!(f3.payload(), &[1u8, 2, 3] as &[u8]);
}
#[tokio::test]
async fn test_client_sends_masked_frames() {
let (mut client, mut _server) = create_websocket_pair(1024);
let frame = Frame::text("test");
client.send(frame).await.expect("Failed to send");
}
#[tokio::test]
async fn test_server_sends_unmasked_frames() {
let (mut _client, mut server) = create_websocket_pair(1024);
let frame = Frame::text("test");
server.send(frame).await.expect("Failed to send");
}
#[tokio::test]
async fn test_close_code_variants() {
let (mut client, mut server) = create_websocket_pair(1024);
client
.send(Frame::close(close::CloseCode::Away, b""))
.await
.expect("Failed to send close");
let frame = server.next_frame().await.expect("Failed to receive");
assert_eq!(frame.close_code(), Some(close::CloseCode::Away));
}
#[tokio::test]
async fn test_multiple_fragments() {
let (mut client, mut server) = create_websocket_pair(4096);
for i in 0..5 {
let is_last = i == 4;
let opcode = if i == 0 {
OpCode::Text
} else {
OpCode::Continuation
};
let mut frame = Frame::from((opcode, format!("part{}", i)));
frame.set_fin(is_last);
client.send(frame).await.expect("Failed to send fragment");
}
let frame = server.next_frame().await.expect("Failed to receive");
assert_eq!(frame.opcode(), OpCode::Text);
assert!(frame.is_fin());
let expected = "part0part1part2part3part4";
assert_eq!(frame.payload(), expected.as_bytes());
}
#[tokio::test]
async fn test_automatic_fragmentation_large_messages() {
let (client_stream, server_stream) = MockStream::pair(8192);
let negotiation = Negotiation {
extensions: None,
compression_level: None,
max_payload_read: MAX_PAYLOAD_READ,
max_read_buffer: MAX_READ_BUFFER,
utf8: false,
fragmentation: Some(options::Fragmentation {
timeout: None,
fragment_size: Some(100),
}),
max_backpressure_write_boundary: None,
};
let mut client_ws = WebSocket::new(
Role::Client,
client_stream,
Bytes::new(),
negotiation.clone(),
);
let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
let large_payload = vec![b'A'; 300];
client_ws
.send(Frame::binary(large_payload.clone()))
.await
.unwrap();
let received = server_ws.next_frame().await.unwrap();
assert_eq!(received.opcode(), OpCode::Binary);
assert_eq!(received.payload(), large_payload.as_slice());
}
#[tokio::test]
async fn test_automatic_fragmentation_small_messages() {
let (client_stream, server_stream) = MockStream::pair(8192);
let negotiation = Negotiation {
extensions: None,
compression_level: None,
max_payload_read: MAX_PAYLOAD_READ,
max_read_buffer: MAX_READ_BUFFER,
utf8: false,
fragmentation: Some(options::Fragmentation {
timeout: None,
fragment_size: Some(100),
}),
max_backpressure_write_boundary: None,
};
let mut client_ws = WebSocket::new(
Role::Client,
client_stream,
Bytes::new(),
negotiation.clone(),
);
let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
let small_payload = vec![b'B'; 50];
client_ws
.send(Frame::text(small_payload.clone()))
.await
.unwrap();
let received = server_ws.next_frame().await.unwrap();
assert_eq!(received.opcode(), OpCode::Text);
assert_eq!(received.payload(), small_payload.as_slice());
}
#[tokio::test]
async fn test_no_fragmentation_when_not_configured() {
let (client_stream, server_stream) = MockStream::pair(8192);
let negotiation = Negotiation {
extensions: None,
compression_level: None,
max_payload_read: MAX_PAYLOAD_READ,
max_read_buffer: MAX_READ_BUFFER,
utf8: false,
fragmentation: None,
max_backpressure_write_boundary: None,
};
let mut client_ws = WebSocket::new(
Role::Client,
client_stream,
Bytes::new(),
negotiation.clone(),
);
let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
let large_payload = vec![b'C'; 1000];
client_ws
.send(Frame::binary(large_payload.clone()))
.await
.unwrap();
let received = server_ws.next_frame().await.unwrap();
assert_eq!(received.opcode(), OpCode::Binary);
assert_eq!(received.payload(), large_payload.as_slice());
}
#[tokio::test]
async fn test_interleave_control_frames_with_continuation_frames() {
let (mut client, mut server) = create_websocket_pair(4096);
let mut fragment1 = Frame::text("Hello, ");
fragment1.set_fin(false);
client
.send(fragment1)
.await
.expect("Failed to send first fragment");
client
.send(Frame::ping("ping during fragmentation"))
.await
.expect("Failed to send ping");
let mut fragment2 = Frame::continuation("World");
fragment2.set_fin(false);
client
.send(fragment2)
.await
.expect("Failed to send second fragment");
client
.send(Frame::pong("pong during fragmentation"))
.await
.expect("Failed to send pong");
let fragment3 = Frame::continuation("!");
client
.send(fragment3)
.await
.expect("Failed to send final fragment");
let ping_frame = server
.next_frame()
.await
.expect("Failed to receive ping frame");
assert_eq!(ping_frame.opcode(), OpCode::Ping);
assert_eq!(ping_frame.payload(), b"ping during fragmentation" as &[u8]);
let pong_frame = server
.next_frame()
.await
.expect("Failed to receive pong frame");
assert_eq!(pong_frame.opcode(), OpCode::Pong);
assert_eq!(pong_frame.payload(), b"pong during fragmentation" as &[u8]);
let message_frame = server
.next_frame()
.await
.expect("Failed to receive reassembled message");
assert_eq!(message_frame.opcode(), OpCode::Text);
assert!(message_frame.is_fin());
assert_eq!(message_frame.payload(), b"Hello, World!" as &[u8]);
}
#[tokio::test]
async fn test_large_compressed_fragmented_payload() {
const FRAGMENT_SIZE: usize = 65536;
const PAYLOAD_SIZE: usize = 1024 * 1024;
use flate2::Compression;
let (mut client, mut server) = create_websocket_pair_with_config(
256 * 1024, None, Some(Compression::best()),
);
let payload: Vec<u8> = (0..PAYLOAD_SIZE).map(|i| (i % 256) as u8).collect();
let total_fragments = PAYLOAD_SIZE.div_ceil(FRAGMENT_SIZE);
println!(
"Sending {} bytes in {} fragments of {} bytes each",
PAYLOAD_SIZE, total_fragments, FRAGMENT_SIZE
);
let server_task = tokio::spawn(async move {
server
.next_frame()
.await
.expect("Failed to receive large payload")
});
let mut offset = 0;
let mut fragment_num = 0;
while offset < PAYLOAD_SIZE {
let end = std::cmp::min(offset + FRAGMENT_SIZE, PAYLOAD_SIZE);
let chunk = payload[offset..end].to_vec();
let is_final = end == PAYLOAD_SIZE;
let mut frame = if fragment_num == 0 {
Frame::binary(chunk)
} else {
Frame::continuation(chunk)
};
frame.set_fin(is_final);
println!(
"Sending fragment {}/{}: {} bytes, OpCode={:?} FIN={}",
fragment_num + 1,
total_fragments,
frame.payload().len(),
frame.opcode(),
is_final
);
client
.send(frame)
.await
.unwrap_or_else(|_| panic!("Failed to send fragment {}", fragment_num + 1));
offset = end;
fragment_num += 1;
}
let received_frame = server_task.await.expect("Server task failed");
assert_eq!(received_frame.opcode(), OpCode::Binary);
assert!(received_frame.is_fin());
assert_eq!(received_frame.payload().len(), PAYLOAD_SIZE);
assert_eq!(received_frame.payload().as_ref(), &payload[..]);
println!(
"Successfully sent {} manual fragments, compressed, decompressed, and reassembled {} bytes",
total_fragments, PAYLOAD_SIZE
);
}
#[tokio::test]
async fn test_compressed_fragmented_with_interleaved_control() {
const FRAGMENT_SIZE: usize = 65536;
use flate2::Compression;
let (mut client, mut server) = create_websocket_pair_with_config(
128 * 1024,
Some(FRAGMENT_SIZE),
Some(Compression::best()),
);
let payload = "This is a test payload that should compress well. ".repeat(5000);
let original_payload = payload.clone();
let payload_bytes = payload.as_bytes().to_vec();
tokio::spawn(async move {
client
.send(Frame::binary(payload_bytes))
.await
.expect("Failed to send payload");
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
client
.send(Frame::ping("test"))
.await
.expect("Failed to send ping");
});
let mut received_message = None;
let mut received_ping = false;
for _ in 0..2 {
let frame = server.next_frame().await.expect("Failed to receive frame");
match frame.opcode() {
OpCode::Binary => {
assert!(frame.is_fin());
received_message = Some(frame.payload().to_vec());
}
OpCode::Ping => {
received_ping = true;
}
_ => panic!("Unexpected frame type: {:?}", frame.opcode()),
}
}
assert!(received_message.is_some(), "Message not received");
assert!(received_ping, "Ping not received");
let received = String::from_utf8(received_message.unwrap())
.expect("Invalid UTF-8 in received payload");
assert_eq!(
received, original_payload,
"Compressed fragmented payload mismatch"
);
}
}