use std::future::Future;
use std::marker::{PhantomData, Send};
use crate::async_tungstenite::WebSocketStream;
use crate::tungstenite::protocol::Role;
use crate::WebSocketConnection;
use async_dup::Arc;
use async_std::task;
use sha1::{Digest, Sha1};
use tide::http::format_err;
use tide::http::headers::{HeaderName, CONNECTION, UPGRADE};
use tide::{Middleware, Request, Response, Result, StatusCode};
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
#[derive(Debug)]
pub struct WebSocket<S, H> {
handler: Arc<H>,
ghostly_apparition: PhantomData<S>,
protocols: Vec<String>,
}
enum UpgradeStatus<S> {
Upgraded(Result<Response>),
NotUpgraded(Request<S>),
}
use UpgradeStatus::{NotUpgraded, Upgraded};
fn header_contains_ignore_case<T>(req: &Request<T>, header_name: HeaderName, value: &str) -> bool {
req.header(header_name)
.map(|h| {
h.as_str()
.split(',')
.any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
})
.unwrap_or(false)
}
impl<S, H, Fut> WebSocket<S, H>
where
S: Send + Sync + Clone + 'static,
H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
ghostly_apparition: PhantomData,
protocols: Default::default(),
}
}
pub fn with_protocols(self, protocols: &[&str]) -> Self {
Self {
protocols: protocols.iter().map(ToString::to_string).collect(),
..self
}
}
async fn handle_upgrade(&self, req: Request<S>) -> UpgradeStatus<S> {
let connection_upgrade = header_contains_ignore_case(&req, CONNECTION, "upgrade");
let upgrade_to_websocket = header_contains_ignore_case(&req, UPGRADE, "websocket");
let upgrade_requested = connection_upgrade && upgrade_to_websocket;
if !upgrade_requested {
return NotUpgraded(req);
}
let header = match req.header("Sec-Websocket-Key") {
Some(h) => h.as_str(),
None => return Upgraded(Err(format_err!("expected sec-websocket-key"))),
};
let protocol = req.header("Sec-Websocket-Protocol").and_then(|value| {
value
.as_str()
.split(',')
.map(str::trim)
.find(|req_p| self.protocols.iter().any(|p| p == req_p))
});
let mut response = Response::new(StatusCode::SwitchingProtocols);
response.insert_header(UPGRADE, "websocket");
response.insert_header(CONNECTION, "Upgrade");
let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
response.insert_header("Sec-Websocket-Version", "13");
if let Some(protocol) = protocol {
response.insert_header("Sec-Websocket-Protocol", protocol);
}
let http_res: &mut tide::http::Response = response.as_mut();
let upgrade_receiver = http_res.recv_upgrade().await;
let handler = self.handler.clone();
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
handler(req, stream.into()).await
} else {
Err(format_err!("never received an upgrade!"))
}
});
Upgraded(Ok(response))
}
}
#[tide::utils::async_trait]
impl<H, S, Fut> tide::Endpoint<S> for WebSocket<S, H>
where
H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
S: Send + Sync + Clone + 'static,
{
async fn call(&self, req: Request<S>) -> Result {
match self.handle_upgrade(req).await {
Upgraded(result) => result,
NotUpgraded(_) => Ok(Response::new(StatusCode::UpgradeRequired)),
}
}
}
#[tide::utils::async_trait]
impl<H, S, Fut> Middleware<S> for WebSocket<S, H>
where
H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
S: Send + Sync + Clone + 'static,
{
async fn handle(&self, req: Request<S>, next: tide::Next<'_, S>) -> Result {
match self.handle_upgrade(req).await {
Upgraded(result) => result,
NotUpgraded(req) => Ok(next.run(req).await),
}
}
}