use self::rejection::*;
use super::{rejection::*, FromRequest, RequestParts};
use crate::{
body::{self, Bytes},
response::Response,
Error,
};
use async_trait::async_trait;
use futures_util::{
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
};
use http::{
header::{self, HeaderName, HeaderValue},
Method, StatusCode,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
use sha1::{Digest, Sha1};
use std::{
borrow::Cow,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio_tungstenite::{
tungstenite::{
self as ts,
protocol::{self, WebSocketConfig},
},
WebSocketStream,
};
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade {
config: WebSocketConfig,
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
on_upgrade: OnUpgrade,
sec_websocket_protocol: Option<HeaderValue>,
}
impl WebSocketUpgrade {
pub fn max_send_queue(mut self, max: usize) -> Self {
self.config.max_send_queue = Some(max);
self
}
pub fn max_message_size(mut self, max: usize) -> Self {
self.config.max_message_size = Some(max);
self
}
pub fn max_frame_size(mut self, max: usize) -> Self {
self.config.max_frame_size = Some(max);
self
}
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
if let Some(req_protocols) = self
.sec_websocket_protocol
.as_ref()
.and_then(|p| p.to_str().ok())
{
self.protocol = protocols
.into_iter()
.map(Into::into)
.find(|protocol| {
req_protocols
.split(',')
.any(|req_protocol| req_protocol.trim() == protocol)
})
.map(|protocol| match protocol {
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
Cow::Borrowed(s) => HeaderValue::from_static(s),
});
}
self
}
pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future + Send + 'static,
{
let on_upgrade = self.on_upgrade;
let config = self.config;
tokio::spawn(async move {
let upgraded = on_upgrade.await.expect("connection upgrade failed");
let socket =
WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
.await;
let socket = WebSocket { inner: socket };
callback(socket).await;
});
#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);
if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
builder.body(body::boxed(body::Empty::new())).unwrap()
}
}
#[async_trait]
impl<B> FromRequest<B> for WebSocketUpgrade
where
B: Send,
{
type Rejection = WebSocketUpgradeRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
if req.method() != Method::GET {
return Err(MethodNotGet.into());
}
if !header_contains(req, header::CONNECTION, "upgrade")? {
return Err(InvalidConnectionHeader.into());
}
if !header_eq(req, header::UPGRADE, "websocket")? {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(req, header::SEC_WEBSOCKET_VERSION, "13")? {
return Err(InvalidWebSocketVersionHeader.into());
}
let sec_websocket_key = if let Some(key) = req
.headers_mut()
.ok_or_else(HeadersAlreadyExtracted::default)?
.remove(header::SEC_WEBSOCKET_KEY)
{
key
} else {
return Err(WebSocketKeyHeaderMissing.into());
};
let on_upgrade = req
.extensions_mut()
.ok_or_else(ExtensionsAlreadyExtracted::default)?
.remove::<OnUpgrade>()
.unwrap();
let sec_websocket_protocol = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(header::SEC_WEBSOCKET_PROTOCOL)
.cloned();
Ok(Self {
config: Default::default(),
protocol: None,
sec_websocket_key,
on_upgrade,
sec_websocket_protocol,
})
}
}
fn header_eq<B>(
req: &RequestParts<B>,
key: HeaderName,
value: &'static str,
) -> Result<bool, HeadersAlreadyExtracted> {
if let Some(header) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(&key)
{
Ok(header.as_bytes().eq_ignore_ascii_case(value.as_bytes()))
} else {
Ok(false)
}
}
fn header_contains<B>(
req: &RequestParts<B>,
key: HeaderName,
value: &'static str,
) -> Result<bool, HeadersAlreadyExtracted> {
let header = if let Some(header) = req
.headers()
.ok_or_else(HeadersAlreadyExtracted::default)?
.get(&key)
{
header
} else {
return Ok(false);
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
Ok(header.to_ascii_lowercase().contains(value))
} else {
Ok(false)
}
}
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<Upgraded>,
}
impl WebSocket {
pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
self.next().await
}
pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
self.inner
.send(msg.into_tungstenite())
.await
.map_err(Error::new)
}
pub async fn close(mut self) -> Result<(), Error> {
self.inner.close(None).await.map_err(Error::new)
}
}
impl Stream for WebSocket {
type Item = Result<Message, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx).map(|option_msg| {
option_msg.map(|result_msg| {
result_msg
.map_err(Error::new)
.map(Message::from_tungstenite)
})
})
}
}
impl Sink<Message> for WebSocket {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
Pin::new(&mut self.inner)
.start_send(item.into_tungstenite())
.map_err(Error::new)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
}
}
pub type CloseCode = u16;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CloseFrame<'t> {
pub code: CloseCode,
pub reason: Cow<'t, str>,
}
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
Text(String),
Binary(Vec<u8>),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close(Option<CloseFrame<'static>>),
}
impl Message {
fn into_tungstenite(self) -> ts::Message {
match self {
Self::Text(text) => ts::Message::Text(text),
Self::Binary(binary) => ts::Message::Binary(binary),
Self::Ping(ping) => ts::Message::Ping(ping),
Self::Pong(pong) => ts::Message::Pong(pong),
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
code: ts::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
})),
Self::Close(None) => ts::Message::Close(None),
}
}
fn from_tungstenite(message: ts::Message) -> Self {
match message {
ts::Message::Text(text) => Self::Text(text),
ts::Message::Binary(binary) => Self::Binary(binary),
ts::Message::Ping(ping) => Self::Ping(ping),
ts::Message::Pong(pong) => Self::Pong(pong),
ts::Message::Close(Some(close)) => Self::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
})),
ts::Message::Close(None) => Self::Close(None),
}
}
pub fn into_data(self) -> Vec<u8> {
match self {
Self::Text(string) => string.into_bytes(),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
Self::Close(None) => Vec::new(),
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
}
}
pub fn into_text(self) -> Result<String, Error> {
match self {
Self::Text(string) => Ok(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
.map_err(|err| err.utf8_error())
.map_err(Error::new)?),
Self::Close(None) => Ok(String::new()),
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
}
}
pub fn to_text(&self) -> Result<&str, Error> {
match *self {
Self::Text(ref string) => Ok(string),
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
Ok(std::str::from_utf8(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(""),
Self::Close(Some(ref frame)) => Ok(&frame.reason),
}
}
}
impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_data()
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let b64 = Bytes::from(base64::encode(&sha1.finalize()));
HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
}
pub mod rejection {
use crate::extract::rejection::*;
define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `GET`"]
pub struct MethodNotGet;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
pub struct InvalidConnectionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Upgrade` header did not include 'websocket'"]
pub struct InvalidUpgradeHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
pub struct InvalidWebSocketVersionHeader;
}
define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Key` header missing"]
pub struct WebSocketKeyHeaderMissing;
}
composite_rejection! {
pub enum WebSocketUpgradeRejection {
MethodNotGet,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
HeadersAlreadyExtracted,
ExtensionsAlreadyExtracted,
}
}
}