use crate::{handler::Handler, request::Request, response::Builder, tokio_io::TokioIO, Result};
use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt, TryStreamExt,
};
use hyper::{upgrade::Upgraded, StatusCode};
use std::{future::Future, sync::Arc};
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
#[derive(Debug)]
pub struct WebSocket {
pub sender: SplitSink<WebSocketStream<TokioIO<Upgraded>>, Message>,
pub receiver: SplitStream<WebSocketStream<TokioIO<Upgraded>>>,
}
impl WebSocket {
pub async fn send(&mut self, msg: Message) -> Result<()> {
self.sender.send(msg).await?;
Ok(())
}
pub async fn receive(&mut self) -> Result<Option<Message>> {
let msg = self.receiver.try_next().await?;
Ok(msg)
}
}
#[derive(Debug)]
pub struct WsHandler<H, Fut>
where
Fut: Future<Output = Result<()>> + Send + 'static,
H: Send + Sync + 'static + Fn(Request, WebSocket) -> Fut,
{
handler: Arc<H>,
}
pub(crate) fn new_ws<H, Fut>(handler: H) -> WsHandler<H, Fut>
where
Fut: Future<Output = Result<()>> + Send + 'static,
H: Send + Sync + 'static + Fn(Request, WebSocket) -> Fut,
{
WsHandler {
handler: Arc::new(handler),
}
}
#[async_trait::async_trait]
impl<H, Fut> Handler for WsHandler<H, Fut>
where
Fut: Future<Output = Result<()>> + Send + 'static,
H: Send + Sync + 'static + Fn(Request, WebSocket) -> Fut,
{
async fn handle(&self, req: Request) -> Builder {
let handler = self.handler.clone();
upgrade_connection(req, handler).await
}
}
async fn upgrade_connection<H, Fut>(mut req: Request, handler: Arc<H>) -> Builder
where
Fut: Future<Output = Result<()>> + Send + 'static,
H: Send + Sync + 'static + Fn(Request, WebSocket) -> Fut,
{
let builder = Builder::new();
if let Some(conn) = req.header::<headers::Connection>() {
if !conn.contains(hyper::header::UPGRADE) {
return builder.status(StatusCode::BAD_REQUEST);
}
} else {
return builder.status(StatusCode::BAD_REQUEST);
}
if let Some(upgrade) = req.header::<headers::Upgrade>() {
if upgrade != headers::Upgrade::websocket() {
return builder.status(StatusCode::BAD_REQUEST);
}
} else {
return builder.status(StatusCode::BAD_REQUEST);
}
let key = match req.header::<headers::SecWebsocketKey>() {
Some(sec_key) => sec_key,
None => return builder.status(StatusCode::BAD_REQUEST),
};
let builder = builder
.status(StatusCode::SWITCHING_PROTOCOLS)
.header_set(headers::Upgrade::websocket())
.header_set(headers::Connection::upgrade())
.header_set(headers::SecWebsocketAccept::from(key));
println!("upgrading connection to websocket");
tokio::spawn(async move {
let upgraded = hyper::upgrade::on(req.request())
.await
.expect("websocket upgrade failed - TODO report this error");
let ws = WebSocketStream::from_raw_socket(
TokioIO::new(upgraded),
tokio_tungstenite::tungstenite::protocol::Role::Server,
None,
)
.await;
let (tx, rx) = ws.split();
let res = (handler)(
req,
WebSocket {
sender: tx,
receiver: rx,
},
)
.await;
match res {
Ok(()) => println!("websocket handler returned"),
Err(e) => println!("websocket handler returned an error: {}", e),
};
});
builder
}