use std::{borrow::Cow, future::Future};
use futures_util::{FutureExt, future::BoxFuture};
use headers::HeaderMapExt;
use tokio_tungstenite::tungstenite::protocol::{Role, WebSocketConfig};
use super::{WebSocketStream, utils::sign};
use crate::{
Body, FromRequest, IntoResponse, OnUpgrade, Request, RequestBody, Response, Result,
error::WebSocketError,
http::{
Method, StatusCode,
header::{self, HeaderValue},
},
};
pub struct WebSocket {
key: HeaderValue,
on_upgrade: OnUpgrade,
protocols: Option<Box<[Cow<'static, str>]>>,
sec_websocket_protocol: Option<HeaderValue>,
config: Option<WebSocketConfig>,
}
impl WebSocket {
async fn internal_from_request(req: &Request) -> Result<Self, WebSocketError> {
let is_valid_upgrade_header = req.headers().get(header::UPGRADE)
== Some(&HeaderValue::from_static("websocket"))
|| req.headers().get(header::UPGRADE) == Some(&HeaderValue::from_static("WebSocket"));
if req.method() != Method::GET
|| !is_valid_upgrade_header
|| req.headers().get(header::SEC_WEBSOCKET_VERSION)
!= Some(&HeaderValue::from_static("13"))
{
return Err(WebSocketError::InvalidProtocol);
}
if !matches!(
req.headers()
.typed_get::<headers::Connection>()
.map(|connection| connection.contains(header::UPGRADE)),
Some(true)
) {
return Err(WebSocketError::InvalidProtocol);
}
let key = req
.headers()
.get(header::SEC_WEBSOCKET_KEY)
.cloned()
.ok_or(WebSocketError::InvalidProtocol)?;
let sec_websocket_protocol = req.headers().get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
Ok(Self {
key,
on_upgrade: req.take_upgrade()?,
protocols: None,
sec_websocket_protocol,
config: None,
})
}
}
impl<'a> FromRequest<'a> for WebSocket {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Self::internal_from_request(req).await.map_err(Into::into)
}
}
impl WebSocket {
#[must_use]
pub fn protocols<I>(mut self, protocols: I) -> Self
where
I: IntoIterator,
I::Item: Into<Cow<'static, str>>,
{
self.protocols = Some(
protocols
.into_iter()
.map(Into::into)
.collect::<Vec<_>>()
.into(),
);
self
}
pub fn config(self, config: WebSocketConfig) -> Self {
Self {
config: Some(config),
..self
}
}
#[must_use]
pub fn on_upgrade<F, Fut>(self, callback: F) -> WebSocketUpgraded<F>
where
F: FnOnce(WebSocketStream) -> Fut + Send + Sync + 'static,
Fut: Future + Send + 'static,
{
WebSocketUpgraded {
websocket: self,
callback,
}
}
}
pub struct WebSocketUpgraded<F> {
websocket: WebSocket,
callback: F,
}
type BoxWebSocketHandler =
Box<dyn FnOnce(WebSocketStream) -> BoxFuture<'static, ()> + Send + Sync + 'static>;
pub type BoxWebSocketUpgraded = WebSocketUpgraded<BoxWebSocketHandler>;
impl<F, Fut> WebSocketUpgraded<F>
where
F: FnOnce(WebSocketStream) -> Fut + Send + Sync + 'static,
Fut: Future + Send + 'static,
{
pub fn boxed(self) -> BoxWebSocketUpgraded {
WebSocketUpgraded {
websocket: self.websocket,
callback: Box::new(|stream| (self.callback)(stream).map(|_| ()).boxed()),
}
}
}
impl<F, Fut> IntoResponse for WebSocketUpgraded<F>
where
F: FnOnce(WebSocketStream) -> Fut + Send + Sync + 'static,
Fut: Future + Send + 'static,
{
fn into_response(self) -> Response {
let protocol = self
.websocket
.sec_websocket_protocol
.as_ref()
.and_then(|req_protocols| {
let req_protocols = req_protocols.to_str().ok()?;
let protocols = self.websocket.protocols.as_ref()?;
req_protocols
.split(',')
.map(|req_p| req_p.trim())
.find(|req_p| protocols.iter().any(|p| p == req_p))
});
let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, "upgrade")
.header(header::UPGRADE, "websocket")
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.websocket.key.as_bytes()),
);
if let Some(protocol) = protocol {
builder = builder.header(
header::SEC_WEBSOCKET_PROTOCOL,
HeaderValue::from_str(protocol).unwrap(),
);
}
let resp = builder.body(Body::empty());
tokio::spawn(async move {
let upgraded = match self.websocket.on_upgrade.await {
Ok(upgraded) => upgraded,
Err(_) => return,
};
let stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
upgraded,
Role::Server,
self.websocket.config,
)
.await;
(self.callback)(WebSocketStream::new(stream)).await;
});
resp
}
}