use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use headers::{
Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion, Upgrade,
};
use hyper::upgrade::{self, Upgraded};
pub use tokio_tungstenite::tungstenite;
pub use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
use tokio_tungstenite::WebSocketStream;
use crate::http::header::UPGRADE;
use crate::http::StatusCode;
use crate::{async_trait, throw, Context, Endpoint, State, Status};
pub type SocketStream = WebSocketStream<Upgraded>;
pub struct Websocket<F, S, Fut>
where
F: Fn(Context<S>, SocketStream) -> Fut,
{
task: Arc<F>,
config: Option<WebSocketConfig>,
_s: PhantomData<S>,
_fut: PhantomData<Fut>,
}
unsafe impl<F, S, Fut> Send for Websocket<F, S, Fut> where
F: Sync + Send + Fn(Context<S>, SocketStream) -> Fut
{
}
unsafe impl<F, S, Fut> Sync for Websocket<F, S, Fut> where
F: Sync + Send + Fn(Context<S>, SocketStream) -> Fut
{
}
impl<F, S, Fut> Websocket<F, S, Fut>
where
F: Fn(Context<S>, SocketStream) -> Fut,
{
fn config(config: Option<WebSocketConfig>, task: F) -> Self {
Self {
task: Arc::new(task),
config,
_s: PhantomData::default(),
_fut: PhantomData::default(),
}
}
pub fn new(task: F) -> Self {
Self::config(None, task)
}
pub fn with_config(config: WebSocketConfig, task: F) -> Self {
Self::config(Some(config), task)
}
}
#[async_trait(?Send)]
impl<'a, F, S, Fut> Endpoint<'a, S> for Websocket<F, S, Fut>
where
S: State,
F: 'static + Sync + Send + Fn(Context<S>, SocketStream) -> Fut,
Fut: 'static + Send + Future<Output = ()>,
{
#[inline]
async fn call(&'a self, ctx: &'a mut Context<S>) -> Result<(), Status> {
let header_map = &ctx.req.headers;
let key = header_map
.typed_get::<Upgrade>()
.filter(|upgrade| upgrade == &Upgrade::websocket())
.and(header_map.typed_get::<Connection>())
.filter(|connection| connection.contains(UPGRADE))
.and(header_map.typed_get::<SecWebsocketVersion>())
.filter(|version| version == &SecWebsocketVersion::V13)
.and(header_map.typed_get::<SecWebsocketKey>());
match key {
None => throw!(StatusCode::BAD_REQUEST, "invalid websocket upgrade request"),
Some(key) => {
let raw_req = ctx.req.take_raw();
let context = ctx.clone();
let task = self.task.clone();
let config = self.config;
ctx.exec.spawn(async move {
match upgrade::on(raw_req).await {
Err(err) => tracing::error!("websocket upgrade error: {}", err),
Ok(upgraded) => {
let websocket = WebSocketStream::from_raw_socket(
upgraded,
tungstenite::protocol::Role::Server,
config,
)
.await;
task(context, websocket).await
}
}
});
ctx.resp.status = StatusCode::SWITCHING_PROTOCOLS;
ctx.resp.headers.typed_insert(Connection::upgrade());
ctx.resp.headers.typed_insert(Upgrade::websocket());
ctx.resp.headers.typed_insert(SecWebsocketAccept::from(key));
Ok(())
}
}
}
}