use std::{future::Future, ops::ControlFlow};
use async_fn_traits::AsyncFn1;
use futures::StreamExt;
use http::StatusCode;
use http_body_util::BodyExt;
use super::{
connection_id::ConnectionId, connection_storage::ConnectionStorage, http_error::HttpError,
response::Response, response_body::ResponseBody,
};
use crate::{
prelude::*,
state::{
connection_state::ConnectionState,
context::{HttpRequestContext, WebsocketRequestContext},
global_state::GlobalState,
request_state::RequestState,
session_state::SessionState,
},
};
pub trait ContextWebsocketExt {
fn next_websocket<
Fn: Send + 'static + for<'a> AsyncFn1<&'a mut WebsocketRequestContext, Output = ()>,
>(
&mut self,
handler_fn: Fn,
) -> impl Future<Output = Result<ControlFlow<()>, HttpError>>
where
for<'a> <Fn as AsyncFn1<&'a mut WebsocketRequestContext>>::OutputFuture: Send;
}
impl ContextWebsocketExt for HttpRequestContext {
async fn next_websocket<
Fn: Send + 'static + for<'a> AsyncFn1<&'a mut WebsocketRequestContext, Output = ()>,
>(
&mut self,
handler_fn: Fn,
) -> Result<ControlFlow<()>, HttpError>
where
for<'a> <Fn as AsyncFn1<&'a mut WebsocketRequestContext>>::OutputFuture: Send,
{
if !hyper_tungstenite::is_upgrade_request(self.request_mut()) {
return Err(HttpError::from(
Response::builder()
.status(StatusCode::UPGRADE_REQUIRED)
.header("Upgrade", "websocket")
.body(ResponseBody::empty())?,
));
}
let (response, websocket) = hyper_tungstenite::upgrade(self.request_mut(), None)
.map_err(|err| HttpError::new(StatusCode::BAD_REQUEST, err.to_string()))?;
let session_state = SessionState::get_from_ctx(self).clone();
let global_state = GlobalState::get_from_ctx(self).clone();
tokio::spawn(async move {
let (connection_state, connection_id, mut rx) = {
let (tx, rx) = {
let websocket = match websocket.await {
Ok(ws) => ws,
Err(err) => {
tracing::debug!("failed to initialize websocket connection: {err}");
return;
}
};
websocket.split()
};
let connection_id = ConnectionId::generate();
let connection_state = {
let connection_state = ConnectionState::default();
connection_state.insert(tx).await;
connection_state.insert(connection_id).await;
connection_state
};
{
let mut connection_storage = session_state
.get_mut_or_insert_default::<ConnectionStorage>()
.await;
connection_storage
.get_mut()
.insert(connection_id, connection_state.clone());
}
(connection_state, connection_id, rx)
};
let mut ctx = WebsocketRequestContext::from_states(
global_state,
session_state.clone(),
connection_state.clone(),
RequestState::default(),
);
while let Some(message) = rx.next().await {
let message = match message {
Ok(message) => message,
Err(err) => {
tracing::debug!("websocket receive failed: {err}");
continue;
}
};
let request_state = {
let mut request_state = RequestState::default();
request_state.insert(message);
request_state
};
*RequestState::get_mut_from_ctx(&mut ctx) = request_state;
handler_fn(&mut ctx).await;
}
session_state
.get_mut_or_insert_default::<ConnectionStorage>()
.await
.get_mut()
.remove(&connection_id);
});
let converted_response = {
let (response_parts, response_body) = response.into_parts();
Response::from_parts(
response_parts,
response_body.map_err(|err| match err {}).boxed(),
)
};
self.next(converted_response)
}
}