use std::{
pin::Pin,
task::{Context, Poll},
};
use async_tungstenite::{self, tokio::TokioAdapter, tungstenite::protocol, WebSocketStream};
use base64::{engine::general_purpose::STANDARD, Engine};
use futures::{Future, Sink, Stream};
use http::{
header::{self, HeaderName, SET_COOKIE},
HeaderValue, Method, Response, StatusCode,
};
use hyper::upgrade::{OnUpgrade, Upgraded};
use sha1::{Digest, Sha1};
#[allow(clippy::declare_interior_mutable_const)] const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)] const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
use crate::{Error, HttpResponse, Request};
use super::{Message, Websocket};
pub enum WebsocketUpgrade {}
impl WebsocketUpgrade {
pub fn from_req<THandler, TFut>(
req: Request,
handler: THandler,
) -> WebSocketUpgradeResponse<THandler, TFut>
where
THandler: FnOnce(Request, Box<dyn Websocket + Send>) -> TFut + Send + Sync + 'static,
TFut: Future<Output = ()> + Send + 'static,
{
WebSocketUpgradeResponse {
req,
handler,
#[cfg(feature = "cookies")]
cookies: None,
}
}
#[cfg(feature = "cookies")]
pub fn from_req_with_cookies<THandler, TFut>(
req: Request,
cookies: cookie::CookieJar,
handler: THandler,
) -> WebSocketUpgradeResponse<THandler, TFut>
where
THandler: FnOnce(Request, Box<dyn Websocket + Send>) -> TFut + Send + Sync + 'static,
TFut: Future<Output = ()> + Send + 'static,
{
WebSocketUpgradeResponse {
req,
handler,
cookies: Some(cookies),
}
}
}
pub struct WebSocketUpgradeResponse<THandler, TFut>
where
THandler: FnOnce(Request, Box<dyn Websocket + Send>) -> TFut + Send + Sync + 'static,
TFut: Future<Output = ()> + Send + 'static,
{
req: Request,
handler: THandler,
#[cfg(feature = "cookies")]
cookies: Option<cookie::CookieJar>,
}
impl<THandler, TFut> HttpResponse for WebSocketUpgradeResponse<THandler, TFut>
where
THandler: FnOnce(Request, Box<dyn Websocket + Send>) -> TFut + Send + Sync + 'static,
TFut: Future<Output = ()> + Send + 'static,
{
fn into_response(mut self) -> Result<Response<Vec<u8>>, Error> {
let mut resp = Response::builder();
#[cfg(feature = "cookies")]
if let Some(jar) = self.cookies {
if let Some(headers) = resp.headers_mut() {
for cookie in jar.delta() {
if let Ok(header_value) = cookie.encoded().to_string().parse() {
headers.append(SET_COOKIE, header_value);
}
}
}
}
if self.req.method() != Method::GET {
return Ok(resp.status(StatusCode::METHOD_NOT_ALLOWED).body(vec![])?);
}
if !header_contains(&self.req, header::CONNECTION, "upgrade") {
return Ok(resp.status(StatusCode::BAD_REQUEST).body(vec![])?);
}
if !header_eq(&self.req, header::UPGRADE, "websocket") {
return Ok(resp.status(StatusCode::BAD_REQUEST).body(vec![])?);
}
if !header_eq(&self.req, header::SEC_WEBSOCKET_VERSION, "13") {
return Ok(resp.status(StatusCode::BAD_REQUEST).body(vec![])?);
}
let sec_websocket_key = match self.req.headers_mut().remove(header::SEC_WEBSOCKET_KEY) {
Some(sec_websocket_key) => sec_websocket_key,
None => return Ok(resp.status(StatusCode::BAD_REQUEST).body(vec![])?),
};
let on_upgrade = match self.req.extensions_mut().remove::<OnUpgrade>() {
Some(on_upgrade) => on_upgrade,
None => return Ok(resp.status(StatusCode::BAD_REQUEST).body(vec![])?),
};
tokio::spawn(async move {
let upgraded = on_upgrade.await.expect("connection upgrade failed");
let upgraded = async_tungstenite::tokio::TokioAdapter::new(upgraded);
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, None) .await;
(self.handler)(self.req, Box::new(TokioSocket(socket))).await;
});
let builder = resp
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(sec_websocket_key.as_bytes()),
);
Ok(builder.body([].to_vec())?)
}
}
fn header_eq(req: &Request, key: HeaderName, value: &'static str) -> bool {
if let Some(header) = req.headers().get(&key) {
header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
} else {
false
}
}
fn header_contains(req: &Request, key: HeaderName, value: &'static str) -> bool {
let header = if let Some(header) = req.headers().get(&key) {
header
} else {
return false;
};
if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
header.to_ascii_lowercase().contains(value)
} else {
false
}
}
fn sign(key: &[u8]) -> HeaderValue {
let mut sha1 = Sha1::default();
sha1.update(key);
sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
HeaderValue::from_maybe_shared(STANDARD.encode(sha1.finalize()))
.expect("base64 is a valid value")
}
pub(crate) struct TokioSocket(WebSocketStream<TokioAdapter<Upgraded>>);
impl Sink<Message> for TokioSocket {
type Error = crate::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<WebSocketStream<TokioAdapter<Upgraded>> as Sink<async_tungstenite::tungstenite::Message>>::poll_ready(
Pin::new(&mut self.0),
cx,
).map_err(|e| e.into())
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
<WebSocketStream<TokioAdapter<Upgraded>> as Sink<async_tungstenite::tungstenite::Message>>::start_send(
Pin::new(&mut self.0),
item.into(),
).map_err(|e| e.into())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<WebSocketStream<TokioAdapter<Upgraded>> as Sink<async_tungstenite::tungstenite::Message>>::poll_flush(
Pin::new(&mut self.0),
cx,
).map_err(|e| e.into())
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
<WebSocketStream<TokioAdapter<Upgraded>> as Sink<async_tungstenite::tungstenite::Message>>::poll_close(
Pin::new(&mut self.0),
cx,
).map_err(|e| e.into())
}
}
impl Stream for TokioSocket {
type Item = Result<Message, crate::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match <WebSocketStream<TokioAdapter<Upgraded>> as Stream>::poll_next(
Pin::new(&mut self.0),
cx,
) {
Poll::Ready(msg) => Poll::Ready(msg.map(|v| v.map(Into::into).map_err(Into::into))),
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}