use std::future::Future;
use std::marker::PhantomData;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use axum::routing::{self, MethodRouter};
use futures_util::stream::StreamExt;
use futures_util::SinkExt;
use crate::codec::WsMessage;
use crate::connection::{BoxFuture, ErasedSink, ErasedStream, WsConnection};
use crate::WsEndpoint;
pub type Connection<E> = WsConnection<<E as WsEndpoint>::ServerMsg, <E as WsEndpoint>::ClientMsg>;
pub fn handler<E, F, Fut>(callback: F) -> MethodRouter
where
E: WsEndpoint,
F: FnOnce(Connection<E>) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
routing::get(move |ws: WebSocketUpgrade| async move {
let cb = callback;
upgrade::<E, F, Fut>(ws, cb)
})
}
pub fn handler_with_state<E, F, Fut, S>(callback: F) -> MethodRouter<S>
where
E: WsEndpoint,
S: Clone + Send + Sync + 'static,
F: FnOnce(Connection<E>, S) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
routing::get(
move |ws: WebSocketUpgrade, axum::extract::State(state): axum::extract::State<S>| async move {
let cb = callback;
ws.on_upgrade(move |socket| async move {
let conn = wrap_axum_socket::<E>(socket);
cb(conn, state).await;
})
},
)
}
pub fn upgrade<E, F, Fut>(ws: WebSocketUpgrade, callback: F) -> Response
where
E: WsEndpoint,
F: FnOnce(Connection<E>) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
ws.on_upgrade(move |socket| async move {
let conn = wrap_axum_socket::<E>(socket);
callback(conn).await;
})
}
pub fn into_connection<E: WsEndpoint>(socket: WebSocket) -> Connection<E> {
wrap_axum_socket::<E>(socket)
}
fn wrap_axum_socket<E: WsEndpoint>(socket: WebSocket) -> Connection<E> {
let (sink, stream) = socket.split();
WsConnection {
sink: Box::new(AxumSink(sink)),
stream: Box::new(AxumStream(stream)),
_types: PhantomData,
}
}
struct AxumSink(futures_util::stream::SplitSink<WebSocket, Message>);
impl ErasedSink for AxumSink {
fn send(&mut self, msg: WsMessage) -> BoxFuture<'_, Result<(), ()>> {
Box::pin(async move {
let axum_msg = match msg {
WsMessage::Text(t) => Message::Text(t.into()),
WsMessage::Binary(b) => Message::Binary(b.into()),
};
self.0.send(axum_msg).await.map_err(|_| ())
})
}
fn close(&mut self) -> BoxFuture<'_, Result<(), ()>> {
Box::pin(async move { self.0.send(Message::Close(None)).await.map_err(|_| ()) })
}
}
struct AxumStream(futures_util::stream::SplitStream<WebSocket>);
impl ErasedStream for AxumStream {
fn next(&mut self) -> BoxFuture<'_, Option<Result<WsMessage, ()>>> {
Box::pin(async move {
loop {
match self.0.next().await {
None => return None,
Some(Err(_)) => return Some(Err(())),
Some(Ok(msg)) => match msg {
Message::Text(t) => return Some(Ok(WsMessage::Text(t.to_string()))),
Message::Binary(b) => return Some(Ok(WsMessage::Binary(b.to_vec()))),
Message::Close(_) => return None,
Message::Ping(_) | Message::Pong(_) => continue,
},
}
}
})
}
}