use std::{
borrow::Cow,
error::Error,
fmt,
future::Future,
ops::{Deref, DerefMut},
};
use ahash::AHashSet;
use http::{
header,
header::{HeaderMap, HeaderName, HeaderValue},
method::Method,
request::Parts,
status::StatusCode,
version::Version,
};
use hyper_util::rt::TokioIo;
use tokio_tungstenite::WebSocketStream;
pub use tungstenite::Message;
use tungstenite::{
handshake::derive_accept_key,
protocol::{self, WebSocketConfig},
};
use crate::{
body::Body,
context::ServerContext,
response::Response,
server::{IntoResponse, extract::FromContext},
};
const HEADERVALUE_UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
const HEADERVALUE_WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
#[must_use]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig,
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
sec_websocket_protocol: Option<HeaderValue>,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
}
impl<F> WebSocketUpgrade<F> {
pub fn write_buffer_size(mut self, size: usize) -> Self {
self.config.write_buffer_size = size;
self
}
pub fn max_write_buffer_size(mut self, max: usize) -> Self {
self.config.max_write_buffer_size = max;
self
}
pub fn max_message_size(mut self, max: Option<usize>) -> Self {
self.config.max_message_size = max;
self
}
pub fn max_frame_size(mut self, max: Option<usize>) -> Self {
self.config.max_frame_size = max;
self
}
pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
self.config.accept_unmasked_frames = accept;
self
}
fn get_protocol<I>(&mut self, protocols: I) -> Option<HeaderValue>
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
let req_protocols = self
.sec_websocket_protocol
.as_ref()?
.to_str()
.ok()?
.split(',')
.map(str::trim)
.collect::<AHashSet<_>>();
for protocol in protocols.into_iter().map(Into::into) {
if req_protocols.contains(protocol.as_ref()) {
let protocol = match protocol {
Cow::Owned(s) => HeaderValue::from_str(&s).ok()?,
Cow::Borrowed(s) => HeaderValue::from_static(s),
};
return Some(protocol);
}
}
None
}
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
self.protocol = self.get_protocol(protocols);
self
}
pub fn on_failed_upgrade<F2>(self, callback: F2) -> WebSocketUpgrade<F2>
where
F2: OnFailedUpgrade,
{
WebSocketUpgrade {
config: self.config,
protocol: self.protocol,
sec_websocket_key: self.sec_websocket_key,
sec_websocket_protocol: self.sec_websocket_protocol,
on_upgrade: self.on_upgrade,
on_failed_upgrade: callback,
}
}
pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
where
C: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
F: OnFailedUpgrade + Send + 'static,
{
let protocol = self.protocol.clone();
let fut = async move {
let upgraded = match self.on_upgrade.await {
Ok(upgraded) => upgraded,
Err(err) => {
self.on_failed_upgrade.call(WebSocketError::Upgrade(err));
return;
}
};
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(
upgraded,
protocol::Role::Server,
Some(self.config),
)
.await;
let socket = WebSocket {
inner: socket,
protocol,
};
callback(socket).await;
};
let mut resp = Response::new(Body::empty());
*resp.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
resp.headers_mut()
.insert(header::CONNECTION, HEADERVALUE_UPGRADE);
resp.headers_mut()
.insert(header::UPGRADE, HEADERVALUE_WEBSOCKET);
let Ok(accept_key) =
HeaderValue::from_str(&derive_accept_key(self.sec_websocket_key.as_bytes()))
else {
return StatusCode::BAD_REQUEST.into_response();
};
resp.headers_mut()
.insert(header::SEC_WEBSOCKET_ACCEPT, accept_key);
if let Some(protocol) = self.protocol {
if let Ok(protocol) = HeaderValue::from_bytes(protocol.as_bytes()) {
resp.headers_mut()
.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
}
tokio::spawn(fut);
resp
}
}
fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let Some(header) = headers.get(&key) else {
return false;
};
let Ok(header) = simdutf8::basic::from_utf8(header.as_bytes()) else {
return false;
};
header.to_ascii_lowercase().contains(value)
}
fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
let Some(header) = headers.get(&key) else {
return false;
};
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
}
impl FromContext for WebSocketUpgrade<DefaultOnFailedUpgrade> {
type Rejection = WebSocketUpgradeRejectionError;
async fn from_context(
_: &mut ServerContext,
parts: &mut Parts,
) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(WebSocketUpgradeRejectionError::MethodNotGet);
}
if parts.version < Version::HTTP_11 {
return Err(WebSocketUpgradeRejectionError::InvalidHttpVersion);
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader);
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader);
}
if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader);
}
let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)?
.clone();
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
let on_upgrade = parts
.extensions
.remove::<hyper::upgrade::OnUpgrade>()
.expect("`OnUpgrade` is unavailable, maybe something wrong with `hyper`");
Ok(Self {
config: Default::default(),
protocol: None,
sec_websocket_key,
sec_websocket_protocol,
on_upgrade,
on_failed_upgrade: DefaultOnFailedUpgrade,
})
}
}
pub struct WebSocket {
inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
protocol: Option<HeaderValue>,
}
impl WebSocket {
pub fn protocol(&self) -> Option<&str> {
simdutf8::basic::from_utf8(self.protocol.as_ref()?.as_bytes()).ok()
}
}
impl Deref for WebSocket {
type Target = WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for WebSocket {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[derive(Debug)]
pub enum WebSocketError {
Upgrade(hyper::Error),
}
impl fmt::Display for WebSocketError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Upgrade(err) => write!(f, "failed to upgrade: {err}"),
}
}
}
impl Error for WebSocketError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Upgrade(e) => Some(e),
}
}
}
pub trait OnFailedUpgrade {
fn call(self, error: WebSocketError);
}
impl<F> OnFailedUpgrade for F
where
F: FnOnce(WebSocketError),
{
fn call(self, error: WebSocketError) {
self(error)
}
}
#[derive(Debug)]
pub struct DefaultOnFailedUpgrade;
impl OnFailedUpgrade for DefaultOnFailedUpgrade {
fn call(self, _: WebSocketError) {}
}
#[derive(Debug)]
pub enum WebSocketUpgradeRejectionError {
MethodNotGet,
InvalidHttpVersion,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
}
impl WebSocketUpgradeRejectionError {
fn to_status_code(&self) -> StatusCode {
match self {
Self::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED,
Self::InvalidHttpVersion => StatusCode::HTTP_VERSION_NOT_SUPPORTED,
Self::InvalidConnectionHeader => StatusCode::UPGRADE_REQUIRED,
Self::InvalidUpgradeHeader => StatusCode::BAD_REQUEST,
Self::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST,
Self::WebSocketKeyHeaderMissing => StatusCode::BAD_REQUEST,
}
}
}
impl Error for WebSocketUpgradeRejectionError {}
impl fmt::Display for WebSocketUpgradeRejectionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MethodNotGet => f.write_str("Request method must be `GET`"),
Self::InvalidHttpVersion => f.write_str("HTTP version not support"),
Self::InvalidConnectionHeader => {
f.write_str("Header `Connection` does not include `upgrade`")
}
Self::InvalidUpgradeHeader => f.write_str("Header `Upgrade` is not `websocket`"),
Self::InvalidWebSocketVersionHeader => {
f.write_str("Header `Sec-WebSocket-Version` is not `13`")
}
Self::WebSocketKeyHeaderMissing => f.write_str("Header `Sec-WebSocket-Key` is missing"),
}
}
}
impl IntoResponse for WebSocketUpgradeRejectionError {
fn into_response(self) -> Response {
self.to_status_code().into_response()
}
}
#[cfg(test)]
mod websocket_tests {
use std::{
convert::Infallible,
net::{IpAddr, Ipv4Addr, SocketAddr},
str::FromStr,
};
use futures_util::{sink::SinkExt, stream::StreamExt};
use http::uri::Uri;
use motore::service::Service;
use tokio::net::TcpStream;
use tokio_tungstenite::MaybeTlsStream;
use tungstenite::ClientRequestBuilder;
use volo::net::Address;
use super::*;
use crate::{Server, request::Request, server::test_helpers};
fn simple_parts() -> Parts {
let req = Request::builder()
.method(Method::GET)
.version(Version::HTTP_11)
.header(header::HOST, "localhost")
.header(header::CONNECTION, super::HEADERVALUE_UPGRADE)
.header(header::UPGRADE, super::HEADERVALUE_WEBSOCKET)
.header(header::SEC_WEBSOCKET_KEY, "6D69KGBOr4Re+Nj6zx9aQA==")
.header(header::SEC_WEBSOCKET_VERSION, "13")
.body(())
.unwrap();
req.into_parts().0
}
async fn run_ws_handler<S>(
service: S,
sub_protocol: Option<&'static str>,
port: u16,
) -> (
WebSocketStream<MaybeTlsStream<TcpStream>>,
Response<Option<Vec<u8>>>,
)
where
S: Service<ServerContext, Request, Response = Response, Error = Infallible>
+ Send
+ Sync
+ 'static,
{
let addr = Address::Ip(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
port,
));
tokio::spawn(Server::new(service).run(addr.clone()));
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let mut req = ClientRequestBuilder::new(Uri::from_str(&format!("ws://{addr}/")).unwrap());
if let Some(sub_protocol) = sub_protocol {
req = req.with_sub_protocol(sub_protocol);
}
tokio_tungstenite::connect_async(req).await.unwrap()
}
#[tokio::test]
async fn rejection() {
{
let mut parts = simple_parts();
parts.method = Method::POST;
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::MethodNotGet)
));
}
{
let mut parts = simple_parts();
parts.version = Version::HTTP_10;
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::InvalidHttpVersion)
));
}
{
let mut parts = simple_parts();
parts.headers.remove(header::CONNECTION);
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader)
));
}
{
let mut parts = simple_parts();
parts.headers.remove(header::CONNECTION);
parts
.headers
.insert(header::CONNECTION, HeaderValue::from_static("downgrade"));
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::InvalidConnectionHeader)
));
}
{
let mut parts = simple_parts();
parts.headers.remove(header::UPGRADE);
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader)
));
}
{
let mut parts = simple_parts();
parts.headers.remove(header::UPGRADE);
parts
.headers
.insert(header::UPGRADE, HeaderValue::from_static("supersocket"));
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::InvalidUpgradeHeader)
));
}
{
let mut parts = simple_parts();
parts.headers.remove(header::SEC_WEBSOCKET_VERSION);
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader)
));
}
{
let mut parts = simple_parts();
parts.headers.remove(header::SEC_WEBSOCKET_VERSION);
parts.headers.insert(
header::SEC_WEBSOCKET_VERSION,
HeaderValue::from_static("114514"),
);
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::InvalidWebSocketVersionHeader)
));
}
{
let mut parts = simple_parts();
parts.headers.remove(header::SEC_WEBSOCKET_KEY);
let res =
WebSocketUpgrade::from_context(&mut test_helpers::empty_cx(), &mut parts).await;
assert!(matches!(
res,
Err(WebSocketUpgradeRejectionError::WebSocketKeyHeaderMissing)
));
}
}
#[tokio::test]
async fn protocol_test() {
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.protocols(["soap", "wmap", "graphql-ws", "chat"])
.on_upgrade(|_| async {})
}
let (_, resp) =
run_ws_handler(test_helpers::to_service(handler), Some("graphql-ws"), 25230).await;
assert_eq!(
resp.headers()
.get(http::header::SEC_WEBSOCKET_PROTOCOL)
.unwrap(),
"graphql-ws"
);
}
#[tokio::test]
async fn success_on_upgrade() {
async fn echo(mut socket: WebSocket) {
while let Some(Ok(msg)) = socket.next().await {
if msg.is_ping() || msg.is_pong() {
continue;
}
if socket.send(msg).await.is_err() {
break;
}
}
}
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(echo)
}
let (mut ws_stream, _) =
run_ws_handler(test_helpers::to_service(handler), None, 25231).await;
let input = Message::Text("foobar".into());
ws_stream.send(input.clone()).await.unwrap();
let output = ws_stream.next().await.unwrap().unwrap();
assert_eq!(input, output);
let input = Message::Ping("foobar".into());
ws_stream.send(input).await.unwrap();
let output = ws_stream.next().await.unwrap().unwrap();
assert_eq!(output, Message::Pong("foobar".into()));
}
}