use super::{WebSocket, WebSocketError};
use base64::{Engine as _, engine::general_purpose::STANDARD};
use futures_util::future::{Ready, ready};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
use std::future::Future;
use hyper::{
http::{Method, Version, request::Parts},
upgrade::OnUpgrade,
};
use crate::{
HttpResult,
error::{
Error,
handler::{ErrorArgsSlot, extract_error_args},
},
headers::{
CONNECTION, HeaderValue, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL,
SEC_WEBSOCKET_VERSION, UPGRADE,
},
http::{
endpoints::args::{FromPayload, Payload, Source},
request_scope::HttpRequestScope,
},
ok, status,
};
use tokio_tungstenite::{
WebSocketStream,
tungstenite::protocol::{Role, WebSocketConfig},
};
pub struct WebSocketConnection {
config: WebSocketConfig,
error_args: ErrorArgsSlot,
on_upgrade: OnUpgrade,
protocol: Option<HeaderValue>,
sec_websocket_key: Option<HeaderValue>,
sec_websocket_protocol: Option<HeaderValue>,
}
impl std::fmt::Debug for WebSocketConnection {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("WebSocketConnection(..)")
}
}
impl WebSocketConnection {
pub fn with_read_buffer_size(mut self, size: usize) -> Self {
self.config.read_buffer_size = size;
self
}
pub fn with_write_buffer_size(mut self, size: usize) -> Self {
self.config.write_buffer_size = size;
self
}
pub fn with_max_write_buffer_size(mut self, max: usize) -> Self {
self.config.max_write_buffer_size = max;
self
}
pub fn with_max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
pub fn with_max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
pub fn with_accept_unmasked_frames(mut self, accept: bool) -> Self {
self.config.accept_unmasked_frames = accept;
self
}
pub fn with_protocols<const N: usize>(mut self, known: [&'static str; N]) -> Self {
if let Some(sec_websocket_protocol) = self
.sec_websocket_protocol
.as_ref()
.and_then(|p| p.to_str().ok())
{
let mut split = sec_websocket_protocol.split(',').map(str::trim);
self.protocol = known
.iter()
.find(|&&proto| split.any(|req_proto| req_proto == proto))
.map(|&protocol| HeaderValue::from_static(protocol));
}
self
}
pub fn on<F, Fut>(self, func: F) -> HttpResult
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let WebSocketConnection {
config,
error_args,
protocol,
on_upgrade,
sec_websocket_key,
sec_websocket_protocol: _,
} = self;
let response_protocol = protocol.clone();
tokio::spawn(async move {
let upgraded = match on_upgrade.await {
Ok(upgraded) => TokioIo::new(upgraded),
Err(err) => {
_ = error_args.call(Error::server_error(err)).await;
return;
}
};
let stream =
WebSocketStream::from_raw_socket(upgraded, Role::Server, Some(config)).await;
let socket = WebSocket::new(stream, protocol);
func(socket).await;
});
let http_response = if let Some(sec_websocket_key) = &sec_websocket_key {
let accept_key = Self::generate_websocket_accept_key(sec_websocket_key.as_bytes());
status!(101; [
(UPGRADE, super::WEBSOCKET),
(CONNECTION, super::UPGRADE),
(SEC_WEBSOCKET_ACCEPT, accept_key)
])
} else {
ok!()
};
match (http_response, response_protocol) {
(Ok(response), None) => Ok(response),
(Err(err), _) => Err(err),
(Ok(mut response), Some(protocol)) => {
response
.headers_mut()
.insert(SEC_WEBSOCKET_PROTOCOL, protocol);
Ok(response)
}
}
}
#[inline]
fn generate_websocket_accept_key(key: &[u8]) -> String {
let mut hasher = Sha1::new();
hasher.update(key);
hasher.update(super::WEBSOCKET_GUID.as_bytes());
STANDARD.encode(hasher.finalize())
}
}
#[inline]
fn header_contains_token(value: &HeaderValue, token: &str) -> bool {
let bytes = value.as_bytes();
let token = token.as_bytes();
let mut start = 0;
while start < bytes.len() {
let mut end = start;
while end < bytes.len() && bytes[end] != b',' {
end += 1;
}
let mut slice = &bytes[start..end];
while slice.first().is_some_and(|b| b.is_ascii_whitespace()) {
slice = &slice[1..];
}
while slice.last().is_some_and(|b| b.is_ascii_whitespace()) {
slice = &slice[..slice.len() - 1];
}
if slice.len() == token.len()
&& slice
.iter()
.zip(token.iter())
.all(|(a, b)| a.eq_ignore_ascii_case(b))
{
return true;
}
start = end + 1;
}
false
}
impl TryFrom<&Parts> for WebSocketConnection {
type Error = Error;
fn try_from(parts: &Parts) -> Result<Self, Self::Error> {
let sec_websocket_key = if parts.version <= Version::HTTP_11 {
if parts.method != Method::GET {
return Err(WebSocketError::invalid_method());
}
if !matches!(parts.headers.get(&UPGRADE), Some(upgrade) if header_contains_token(upgrade, super::WEBSOCKET))
{
return Err(WebSocketError::invalid_upgrade_header());
}
if !matches!(parts.headers.get(&CONNECTION), Some(conn) if header_contains_token(conn, super::UPGRADE))
{
return Err(WebSocketError::invalid_connection_header());
}
if !matches!(parts.headers.get(&SEC_WEBSOCKET_VERSION), Some(version) if version == super::VERSION)
{
return Err(WebSocketError::invalid_version_header());
}
let key = parts
.headers
.get(&SEC_WEBSOCKET_KEY)
.ok_or(WebSocketError::websocket_key_missing())?
.clone();
Some(key)
} else {
if parts.method != Method::CONNECT {
return Err(WebSocketError::invalid_method());
}
let protocol = parts
.extensions
.get::<hyper::ext::Protocol>()
.ok_or(WebSocketError::invalid_connect_protocol())?;
if !protocol.as_str().eq_ignore_ascii_case(super::WEBSOCKET) {
return Err(WebSocketError::invalid_connect_protocol());
}
None
};
let on_upgrade = parts
.extensions
.get::<OnUpgrade>()
.ok_or(WebSocketError::not_upgradable_connection())?
.clone();
let error_handler = parts
.extensions
.get::<HttpRequestScope>()
.map(|s| &s.error_handler)
.ok_or(Error::server_error(
"Server error: error handler is missing",
))?;
let error_args = extract_error_args(error_handler, parts);
let sec_websocket_protocol = parts.headers.get(&SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self {
config: Default::default(),
error_args,
protocol: None,
on_upgrade,
sec_websocket_key,
sec_websocket_protocol,
})
}
}
impl FromPayload for WebSocketConnection {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Parts;
#[inline]
fn from_payload(payload: Payload<'_>) -> Self::Future {
let Payload::Parts(parts) = payload else {
unreachable!()
};
ready(parts.try_into())
}
}
#[cfg(test)]
mod tests {
use super::WebSocketConnection;
use crate::error::ErrorFunc;
use crate::error::handler::PipelineErrorHandler;
use crate::headers::SEC_WEBSOCKET_PROTOCOL;
use crate::http::{
endpoints::args::{FromPayload, Payload},
request_scope::HttpRequestScope,
};
use hyper::http::HeaderValue;
use hyper::{Request, Version};
#[tokio::test]
async fn it_creates_ws_connection_from_payload() {
let mut req = Request::get("/ws")
.version(Version::HTTP_11)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "123abc")
.body(())
.unwrap();
let error_handler = PipelineErrorHandler::from(ErrorFunc::new(|_| async move {}));
let u = hyper::upgrade::on(&mut req);
req.extensions_mut().insert(u);
req.extensions_mut().insert(HttpRequestScope {
error_handler: error_handler.clone(),
..HttpRequestScope::default()
});
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::from_payload(Payload::Parts(&parts))
.await
.unwrap();
assert_eq!(conn.protocol, None);
assert_eq!(
conn.sec_websocket_key,
parts.headers.get("Sec-WebSocket-Key").cloned()
);
}
#[tokio::test]
async fn it_tries_to_create_not_upgradable_ws_connection_from_payload() {
let mut req = Request::get("/ws")
.version(Version::HTTP_11)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "123abc")
.body(())
.unwrap();
let error_handler = PipelineErrorHandler::from(ErrorFunc::new(|_| async move {}));
req.extensions_mut().insert(HttpRequestScope {
error_handler: error_handler.clone(),
..HttpRequestScope::default()
});
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::from_payload(Payload::Parts(&parts)).await;
assert!(conn.is_err());
}
#[tokio::test]
async fn it_tries_to_create_ws_connection_with_missing_err_handler_from_payload() {
let mut req = Request::get("/ws")
.version(Version::HTTP_11)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "123abc")
.body(())
.unwrap();
let u = hyper::upgrade::on(&mut req);
req.extensions_mut().insert(u);
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::from_payload(Payload::Parts(&parts)).await;
assert!(conn.is_err());
}
#[tokio::test]
async fn it_tries_to_create_ws_connection_without_upgrade_header_from_payload() {
let mut req = Request::get("/ws")
.version(Version::HTTP_11)
.header("connection", "Upgrade")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "123abc")
.body(())
.unwrap();
let error_handler = PipelineErrorHandler::from(ErrorFunc::new(|_| async move {}));
let u = hyper::upgrade::on(&mut req);
req.extensions_mut().insert(u);
req.extensions_mut().insert(HttpRequestScope {
error_handler: error_handler.clone(),
..HttpRequestScope::default()
});
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::from_payload(Payload::Parts(&parts)).await;
assert!(conn.is_err());
}
#[tokio::test]
async fn it_tries_to_create_ws_connection_without_connection_header_from_payload() {
let mut req = Request::get("/ws")
.version(Version::HTTP_11)
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "123abc")
.body(())
.unwrap();
let error_handler = PipelineErrorHandler::from(ErrorFunc::new(|_| async move {}));
let u = hyper::upgrade::on(&mut req);
req.extensions_mut().insert(u);
req.extensions_mut().insert(HttpRequestScope {
error_handler: error_handler.clone(),
..HttpRequestScope::default()
});
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::from_payload(Payload::Parts(&parts)).await;
assert!(conn.is_err());
}
#[tokio::test]
async fn it_tries_to_create_ws_connection_without_sec_websocket_version_header_from_payload() {
let mut req = Request::get("/ws")
.version(Version::HTTP_11)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-key", "123abc")
.body(())
.unwrap();
let error_handler = PipelineErrorHandler::from(ErrorFunc::new(|_| async move {}));
let u = hyper::upgrade::on(&mut req);
req.extensions_mut().insert(u);
req.extensions_mut().insert(HttpRequestScope {
error_handler: error_handler.clone(),
..HttpRequestScope::default()
});
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::from_payload(Payload::Parts(&parts)).await;
assert!(conn.is_err());
}
#[tokio::test]
async fn it_tries_to_create_ws_connection_without_sec_websocket_key_header_from_payload() {
let mut req = Request::get("/ws")
.version(Version::HTTP_11)
.header("upgrade", "websocket")
.header("connection", "Upgrade")
.header("sec-websocket-version", "13")
.body(())
.unwrap();
let error_handler = PipelineErrorHandler::from(ErrorFunc::new(|_| async move {}));
let u = hyper::upgrade::on(&mut req);
req.extensions_mut().insert(u);
req.extensions_mut().insert(HttpRequestScope {
error_handler: error_handler.clone(),
..HttpRequestScope::default()
});
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::from_payload(Payload::Parts(&parts)).await;
assert!(conn.is_err());
}
#[tokio::test]
async fn it_creates_wt_connection_from_payload() {
let mut req = Request::connect("/ws")
.version(Version::HTTP_2)
.header(SEC_WEBSOCKET_PROTOCOL, "foo-ws")
.body(())
.unwrap();
let error_handler = PipelineErrorHandler::from(ErrorFunc::new(|_| async move {}));
let u = hyper::upgrade::on(&mut req);
req.extensions_mut().insert(u);
req.extensions_mut()
.insert(hyper::ext::Protocol::from_static("websocket"));
req.extensions_mut().insert(HttpRequestScope {
error_handler: error_handler.clone(),
..HttpRequestScope::default()
});
let (parts, _) = req.into_parts();
let conn = WebSocketConnection::try_from(&parts).unwrap();
let conn = conn
.with_max_frame_size(1024)
.with_accept_unmasked_frames(true)
.with_protocols(["foo-ws"])
.with_max_message_size(1024)
.with_read_buffer_size(1024)
.with_max_write_buffer_size(1024)
.with_write_buffer_size(1024)
.with_max_frame_size(1024);
assert_eq!(conn.protocol, Some(HeaderValue::from_static("foo-ws")));
assert_eq!(conn.sec_websocket_key, None);
assert!(conn.config.accept_unmasked_frames);
assert_eq!(conn.config.max_message_size, Some(1024usize));
assert_eq!(conn.config.max_write_buffer_size, 1024);
assert_eq!(conn.config.read_buffer_size, 1024);
assert_eq!(conn.config.write_buffer_size, 1024);
assert_eq!(conn.config.max_frame_size, Some(1024usize));
}
#[test]
fn it_generates_websocket_accept_key() {
let key = WebSocketConnection::generate_websocket_accept_key(b"123");
assert_eq!(key, "V5hz1RKy1V4JclILDswC1e3Fek0=");
}
}