use std::net::SocketAddr;
use axum::extract::connect_info::ConnectInfo;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::routing::{MethodRouter, get};
use perspective_client::virtual_server::{VirtualServer, VirtualServerHandler};
type PerspectiveWSError = Box<dyn std::error::Error + Send + Sync>;
pub type PSPError = Box<dyn std::error::Error + Send + Sync>;
async fn process_message_loop(
socket: &mut WebSocket,
handler: impl VirtualServerHandler,
) -> Result<(), PerspectiveWSError> {
use Message::*;
let mut processor = VirtualServer::new(handler);
loop {
match socket.recv().await {
Some(Ok(Binary(msg))) => {
socket
.send(Binary(processor.handle_request(msg).await?))
.await?
},
Some(_) | None => {
tracing::debug!("Unexpected msg");
break;
},
};
}
Ok(())
}
pub fn custom_websocket_handler<S, T>(handler: T) -> MethodRouter<S>
where
T: VirtualServerHandler + Clone + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
let websocket_handler_internal = async |ws: WebSocketUpgrade,
ConnectInfo(addr): ConnectInfo<SocketAddr>|
-> axum::response::Response {
tracing::info!("{addr} Connected.");
ws.on_upgrade(move |mut socket| async move {
if let Err(msg) = process_message_loop(&mut socket, handler).await {
tracing::error!("Internal error {}", msg);
}
tracing::info!("{addr} Disconnected.");
})
};
get(websocket_handler_internal)
}