use super::connection::WebSocketConnection;
use super::registry::ConnectionRegistry;
use super::types::{ConnectionId, WebSocketConfig, WebSocketResult};
use axum::extract::ws::WebSocketUpgrade as AxumWebSocketUpgrade;
use std::sync::Arc;
pub struct WebSocketUpgrade {
_config: WebSocketConfig,
_registry: Arc<ConnectionRegistry>,
}
impl WebSocketUpgrade {
pub fn new(registry: Arc<ConnectionRegistry>) -> Self {
Self {
_config: WebSocketConfig::default(),
_registry: registry,
}
}
pub fn with_config(registry: Arc<ConnectionRegistry>, config: WebSocketConfig) -> Self {
Self {
_config: config,
_registry: registry,
}
}
pub async fn upgrade<H, F>(
self,
ws: AxumWebSocketUpgrade,
_handler: H,
) -> axum::response::Response
where
H: FnOnce(ConnectionId, Arc<WebSocketConnection>) -> F + Send + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
ws.on_upgrade(|_socket| async move {
tracing::info!("WebSocket connection upgraded (foundation mode)");
})
}
}
pub trait WebSocketHandler: Send + Sync + 'static {
fn handle_connection(
&self,
id: ConnectionId,
connection: Arc<WebSocketConnection>,
) -> impl std::future::Future<Output = ()> + Send;
}
pub fn extract_websocket_upgrade(
ws: AxumWebSocketUpgrade,
) -> WebSocketResult<AxumWebSocketUpgrade> {
Ok(ws)
}
#[derive(Clone)]
pub struct SimpleWebSocketHandler<F> {
handler: F,
}
impl<F, Fut> SimpleWebSocketHandler<F>
where
F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send,
{
pub fn new(handler: F) -> Self {
Self { handler }
}
}
impl<F, Fut> WebSocketHandler for SimpleWebSocketHandler<F>
where
F: Fn(ConnectionId, Arc<WebSocketConnection>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = ()> + Send,
{
async fn handle_connection(&self, id: ConnectionId, connection: Arc<WebSocketConnection>) {
(self.handler)(id, connection).await;
}
}
#[macro_export]
macro_rules! websocket_handler {
(|$id:ident: ConnectionId, $conn:ident: Arc<WebSocketConnection>| $body:expr) => {
SimpleWebSocketHandler::new(
|$id: ConnectionId, $conn: Arc<WebSocketConnection>| async move { $body },
)
};
}