use std::future::Future;
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD;
use futures_util::FutureExt;
use http::StatusCode;
use http::header;
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use sha1::Digest;
use sha1::Sha1;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::protocol::Role;
use crate::body::TakoBody;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
#[doc(alias = "websocket")]
#[doc(alias = "ws")]
pub struct TakoWs<H, Fut>
where
H: FnOnce(WebSocketStream<TokioIo<Upgraded>>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
request: Request,
handler: H,
}
impl<H, Fut> TakoWs<H, Fut>
where
H: FnOnce(WebSocketStream<TokioIo<Upgraded>>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
pub fn new(request: Request, handler: H) -> Self {
Self { request, handler }
}
}
impl<H, Fut> Responder for TakoWs<H, Fut>
where
H: FnOnce(WebSocketStream<TokioIo<Upgraded>>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
fn into_response(self) -> Response {
let (parts, body) = self.request.into_parts();
let req = http::Request::from_parts(parts, body);
let key = match req.headers().get("Sec-WebSocket-Key") {
Some(k) => k,
None => {
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(TakoBody::from("Missing Sec-WebSocket-Key".to_string()))
.expect("valid bad request response");
}
};
let accept = {
let mut sha1 = Sha1::new();
sha1.update(key.as_bytes());
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
STANDARD.encode(sha1.finalize())
};
let response = http::Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::UPGRADE, "websocket")
.header(header::CONNECTION, "Upgrade")
.header("Sec-WebSocket-Accept", accept)
.body(TakoBody::empty())
.expect("valid WebSocket upgrade response");
if let Some(on_upgrade) = req.extensions().get::<hyper::upgrade::OnUpgrade>().cloned() {
let handler = self.handler;
tokio::spawn(async move {
if let Ok(upgraded) = on_upgrade.await {
let upgraded = TokioIo::new(upgraded);
let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
let _ = std::panic::AssertUnwindSafe(handler(ws))
.catch_unwind()
.await;
}
});
}
response
}
}