use std::{
borrow::Cow,
fmt,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
use base64::Engine;
use fastwebsockets::{Frame, Role};
use futures_util::Stream;
use http::{
HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri, Version, header, uri::Scheme,
};
#[cfg(feature = "http2")]
use http2::ext::Protocol;
use sha1::{Digest, Sha1};
use super::message::{CloseCode, Message, Utf8Bytes};
use crate::{
EmulationFactory, Error, RequestBuilder, Response, Upgraded, header::OrigHeaderMap,
proxy::Proxy,
};
type WebSocketStream = fastwebsockets::WebSocket<Upgraded>;
pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
#[derive(Debug, Clone, Copy)]
pub struct WebSocketConfig {
pub max_message_size: Option<usize>,
pub auto_close: bool,
pub auto_pong: bool,
}
impl Default for WebSocketConfig {
fn default() -> Self {
Self {
max_message_size: Some(DEFAULT_MAX_MESSAGE_SIZE),
auto_close: true,
auto_pong: true,
}
}
}
pub struct WebSocketRequestBuilder {
inner: RequestBuilder,
accept_key: Option<Cow<'static, str>>,
protocols: Option<Vec<Cow<'static, str>>>,
config: WebSocketConfig,
}
impl WebSocketRequestBuilder {
pub fn new(inner: RequestBuilder) -> Self {
Self {
inner: inner.version(Version::HTTP_11),
accept_key: None,
protocols: None,
config: WebSocketConfig::default(),
}
}
#[inline]
pub fn accept_key<K>(mut self, key: K) -> Self
where
K: Into<Cow<'static, str>>,
{
self.accept_key = Some(key.into());
self
}
#[inline]
pub fn force_http2(mut self) -> Self {
self.inner = self.inner.version(Version::HTTP_2);
self
}
#[inline]
pub fn protocols<P>(mut self, protocols: P) -> Self
where
P: IntoIterator,
P::Item: Into<Cow<'static, str>>,
{
let protocols = protocols.into_iter().map(Into::into).collect();
self.protocols = Some(protocols);
self
}
#[inline]
pub fn max_message_size(mut self, max_message_size: usize) -> Self {
self.config.max_message_size = Some(max_message_size);
self
}
#[inline]
pub fn auto_close(mut self, auto_close: bool) -> Self {
self.config.auto_close = auto_close;
self
}
#[inline]
pub fn auto_pong(mut self, auto_pong: bool) -> Self {
self.config.auto_pong = auto_pong;
self
}
#[inline]
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.inner = self.inner.header(key, value);
self
}
#[inline]
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.inner = self.inner.headers(headers);
self
}
#[inline]
pub fn orig_headers(mut self, orig_headers: OrigHeaderMap) -> Self {
self.inner = self.inner.orig_headers(orig_headers);
self
}
pub fn default_headers(mut self, enable: bool) -> Self {
self.inner = self.inner.default_headers(enable);
self
}
#[inline]
pub fn auth<V>(mut self, value: V) -> Self
where
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.inner = self.inner.auth(value);
self
}
#[inline]
pub fn basic_auth<U, P>(mut self, username: U, password: Option<P>) -> Self
where
U: fmt::Display,
P: fmt::Display,
{
self.inner = self.inner.basic_auth(username, password);
self
}
#[inline]
pub fn bearer_auth<T>(mut self, token: T) -> Self
where
T: fmt::Display,
{
self.inner = self.inner.bearer_auth(token);
self
}
#[inline]
#[cfg(feature = "query")]
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
pub fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
self.inner = self.inner.query(query);
self
}
#[inline]
pub fn proxy(mut self, proxy: Proxy) -> Self {
self.inner = self.inner.proxy(proxy);
self
}
#[inline]
pub fn local_address<V>(mut self, local_address: V) -> Self
where
V: Into<Option<IpAddr>>,
{
self.inner = self.inner.local_address(local_address);
self
}
#[inline]
pub fn local_addresses<V4, V6>(mut self, ipv4: V4, ipv6: V6) -> Self
where
V4: Into<Option<Ipv4Addr>>,
V6: Into<Option<Ipv6Addr>>,
{
self.inner = self.inner.local_addresses(ipv4, ipv6);
self
}
#[inline]
#[cfg(any(
target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "solaris",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos",
))]
#[cfg_attr(
docsrs,
doc(cfg(any(
target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "solaris",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos",
)))
)]
pub fn interface<I>(mut self, interface: I) -> Self
where
I: Into<std::borrow::Cow<'static, str>>,
{
self.inner = self.inner.interface(interface);
self
}
#[inline]
pub fn emulation<P>(mut self, factory: P) -> Self
where
P: EmulationFactory,
{
self.inner = self.inner.emulation(factory);
self
}
pub async fn send(self) -> Result<WebSocketResponse, Error> {
let (client, request) = self.inner.build_split();
let mut request = request?;
let uri = request.uri_mut();
let scheme = match uri.scheme_str() {
Some("ws") => Some(Scheme::HTTP),
Some("wss") => Some(Scheme::HTTPS),
_ => None,
};
if scheme.is_some() {
let mut parts = uri.clone().into_parts();
parts.scheme = scheme;
*uri = Uri::from_parts(parts).map_err(Error::builder)?;
}
let version = request.version();
let headers = request.headers_mut();
headers.insert(
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("13"),
);
let accept_key = match version {
Some(Version::HTTP_10 | Version::HTTP_11) => {
let nonce = self
.accept_key
.unwrap_or_else(|| Cow::Owned(fastwebsockets::handshake::generate_key()));
headers.insert(header::UPGRADE, HeaderValue::from_static("websocket"));
headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
headers.insert(
header::SEC_WEBSOCKET_KEY,
HeaderValue::from_str(&nonce).map_err(Error::builder)?,
);
*request.method_mut() = Method::GET;
*request.version_mut() = Some(Version::HTTP_11);
Some(nonce)
}
Some(Version::HTTP_2) => {
#[cfg(feature = "http2")]
{
*request.method_mut() = Method::CONNECT;
*request.version_mut() = Some(Version::HTTP_2);
request
.extensions_mut()
.insert(Protocol::from_static("websocket"));
None
}
#[cfg(not(feature = "http2"))]
{
return Err(Error::upgrade("HTTP/2 WebSockets require 'http2' feature"));
}
}
unsupported => {
return Err(Error::upgrade(format!(
"unsupported version: {unsupported:?}"
)));
}
};
if let Some(ref protocols) = self.protocols {
if !protocols.is_empty() {
let subprotocols = protocols
.iter()
.map(|s| s.as_ref())
.collect::<Vec<&str>>()
.join(", ");
request.headers_mut().insert(
header::SEC_WEBSOCKET_PROTOCOL,
subprotocols.parse().map_err(Error::builder)?,
);
}
}
client
.execute(request)
.await
.map(|inner| WebSocketResponse {
inner,
accept_key,
protocols: self.protocols,
config: self.config,
})
}
}
#[derive(Debug)]
pub struct WebSocketResponse {
inner: Response,
accept_key: Option<Cow<'static, str>>,
protocols: Option<Vec<Cow<'static, str>>>,
config: WebSocketConfig,
}
impl Deref for WebSocketResponse {
type Target = Response;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for WebSocketResponse {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl WebSocketResponse {
pub async fn into_websocket(self) -> Result<WebSocket, Error> {
let (inner, protocol) = {
let status = self.inner.status();
let headers = self.inner.headers();
match self.inner.version() {
Version::HTTP_10 | Version::HTTP_11 => {
if status != StatusCode::SWITCHING_PROTOCOLS {
return Err(Error::upgrade(format!("unexpected status code: {status}")));
}
if !header_contains(self.inner.headers(), header::CONNECTION, "upgrade") {
return Err(Error::upgrade("missing connection header"));
}
if !header_eq(self.inner.headers(), header::UPGRADE, "websocket") {
return Err(Error::upgrade("invalid upgrade header"));
}
match self
.accept_key
.zip(headers.get(header::SEC_WEBSOCKET_ACCEPT))
{
Some((nonce, header)) => {
if !header
.to_str()
.is_ok_and(|s| s == derive_accept_key(nonce.as_bytes()))
{
return Err(Error::upgrade(format!(
"invalid accept key: {header:?}"
)));
}
}
None => {
return Err(Error::upgrade("missing accept key"));
}
}
}
Version::HTTP_2 => {
if status != StatusCode::OK {
return Err(Error::upgrade(format!("unexpected status code: {status}")));
}
}
_ => {
return Err(Error::upgrade(format!(
"unsupported version: {:?}",
self.inner.version()
)));
}
}
let protocol = headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
let requested = self.protocols.as_ref().filter(|p| !p.is_empty());
let replied = protocol.as_ref().and_then(|v| v.to_str().ok());
match (requested, replied) {
(Some(req), Some(rep)) => {
if !req.contains(&Cow::Borrowed(rep)) {
return Err(Error::upgrade(format!("invalid protocol: {rep}")));
}
}
(Some(_), None) => {
return Err(Error::upgrade(format!(
"missing protocol: {:?}",
self.protocols
)));
}
(None, Some(_)) => {
return Err(Error::upgrade(format!("invalid protocol: {protocol:?}")));
}
(None, None) => {}
};
let upgraded = self.inner.upgrade().await?;
let mut ws = fastwebsockets::WebSocket::after_handshake(upgraded, Role::Client);
if let Some(size) = self.config.max_message_size {
ws.set_max_message_size(size);
}
ws.set_auto_close(self.config.auto_close);
ws.set_auto_pong(self.config.auto_pong);
(ws, protocol)
};
Ok(WebSocket { inner, protocol })
}
}
fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = headers.get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = headers.get(&key) {
header
} else {
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}
fn derive_accept_key(key: &[u8]) -> String {
let mut hasher = Sha1::new();
hasher.update(key);
hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
base64::engine::general_purpose::STANDARD.encode(hasher.finalize())
}
pub struct WebSocket {
inner: WebSocketStream,
protocol: Option<HeaderValue>,
}
impl WebSocket {
#[inline]
pub fn protocol(&self) -> Option<&HeaderValue> {
self.protocol.as_ref()
}
#[inline]
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
match self.inner.read_frame().await {
Ok(frame) => Some(Ok(Message::from_frame(frame))),
Err(e) => Some(Err(Error::websocket(e))),
}
}
#[inline]
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
let frame = msg.into_frame();
self.inner
.write_frame(frame)
.await
.map_err(Error::websocket)
}
pub async fn close<C, R>(mut self, code: C, reason: R) -> Result<(), Error>
where
C: Into<CloseCode>,
R: Into<Utf8Bytes>,
{
let code = code.into();
let reason = reason.into();
let frame = Frame::close(code.0, reason.as_str().as_bytes());
self.inner
.write_frame(frame)
.await
.map_err(Error::websocket)
}
pub fn split(self) -> (WebSocketWrite, WebSocketRead) {
let (r, w) = self.inner.split(tokio::io::split);
(WebSocketWrite { inner: w }, WebSocketRead { inner: r })
}
}
pub struct WebSocketRead {
inner: fastwebsockets::WebSocketRead<tokio::io::ReadHalf<Upgraded>>,
}
impl Stream for WebSocketRead {
type Item = Result<Message, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut send_fn = |_| async { Ok::<(), Error>(()) };
let fut = self.inner.read_frame(&mut send_fn);
tokio::pin!(fut);
match fut.poll(cx) {
Poll::Ready(Ok(frame)) => Poll::Ready(Some(Ok(Message::from_frame(frame)))),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(Error::websocket(e)))),
Poll::Pending => Poll::Pending,
}
}
}
impl WebSocketRead {
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
match self
.inner
.read_frame(&mut |_| async { Ok::<(), Error>(()) })
.await
{
Ok(frame) => Some(Ok(Message::from_frame(frame))),
Err(e) => Some(Err(Error::websocket(e))),
}
}
}
pub struct WebSocketWrite {
inner: fastwebsockets::WebSocketWrite<tokio::io::WriteHalf<Upgraded>>,
}
impl WebSocketWrite {
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
let frame = msg.into_frame();
self.inner
.write_frame(frame)
.await
.map_err(Error::websocket)
}
pub async fn close<C, R>(mut self, code: C, reason: R) -> Result<(), Error>
where
C: Into<CloseCode>,
R: Into<Utf8Bytes>,
{
let code = code.into();
let reason = reason.into();
let frame = Frame::close(code.0, reason.as_str().as_bytes());
self.inner
.write_frame(frame)
.await
.map_err(Error::websocket)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_sets_a_message_size_limit() {
assert_eq!(
WebSocketConfig::default().max_message_size,
Some(DEFAULT_MAX_MESSAGE_SIZE)
);
}
}