use std::net::SocketAddr;
use axum::body::Bytes;
use axum::extract::State;
use axum::extract::connect_info::ConnectInfo;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::response::IntoResponse;
use axum::routing::{MethodRouter, get};
use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded};
use futures::future::{Either, select};
use futures::{FutureExt, SinkExt, StreamExt};
use crate::client::Session;
use crate::server::{LocalSession, Server, SessionHandler};
type PerspectiveWSError = Box<dyn std::error::Error + Send + Sync>;
enum PerspectiveWSMessage {
Incoming(Bytes),
Outgoing(Bytes),
End,
}
#[derive(Clone)]
struct PerspectiveWSConnection(UnboundedSender<Bytes>);
impl SessionHandler for PerspectiveWSConnection {
async fn send_response<'a>(&'a mut self, resp: &'a [u8]) -> Result<(), PerspectiveWSError> {
Ok(self.0.send(Bytes::copy_from_slice(resp)).await?)
}
}
async fn process_message_loop(
socket: &mut WebSocket,
receiver: &mut UnboundedReceiver<Bytes>,
session: &mut LocalSession,
) -> Result<(), PerspectiveWSError> {
use Either::*;
use Message::*;
use PerspectiveWSMessage::*;
loop {
let msg = match select(socket.recv().boxed(), receiver.next()).await {
Right((Some(bytes), _)) => Ok(Outgoing(bytes)),
Left((Some(Ok(Binary(bytes))), _)) => Ok(Incoming(bytes)),
Right((None, _)) | Left((None | Some(Ok(Close(_))), _)) => Ok(End),
Left((Some(Ok(_)), _)) => Err("Unexpected message type".to_string()),
Left((Some(Err(err)), _)) => Err(format!("{err}")),
}?;
match msg {
End => break,
Outgoing(bytes) => socket.send(Binary(bytes)).await?,
Incoming(bytes) => {
session.handle_request(&bytes).await?;
},
}
}
Ok(())
}
pub fn websocket_handler() -> MethodRouter<Server> {
async fn websocket_handler_internal(
ws: WebSocketUpgrade,
State(server): State<Server>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
tracing::info!("{addr} Connected.");
ws.on_upgrade(move |mut socket| async move {
let (send, mut receiver) = unbounded::<Bytes>();
let mut session = server.new_session(PerspectiveWSConnection(send)).await;
if let Err(msg) = process_message_loop(&mut socket, &mut receiver, &mut session).await {
tracing::error!("Internal error {}", msg);
}
tracing::info!("{addr} Disconnected.");
session.close().await;
})
}
get(websocket_handler_internal)
}