use crate::{
SUPPORTED_PROTOCOL_VERSIONS,
server::{CloneableWbApi, common::protocol::Proto},
stats::VERSION,
};
use axum::extract::ws::{Message, WebSocket};
use futures::{
sink::SinkExt,
stream::{SplitSink, StreamExt},
};
use miette::{IntoDiagnostic, Result, miette};
use std::{net::SocketAddr, time::Duration};
use tokio::{spawn, sync::mpsc, time::timeout};
use tracing::{debug, error, info, trace};
use uuid::Uuid;
use worterbuch_common::{Protocol, ServerInfo, ServerMessage, WbApi, Welcome};
pub(crate) async fn serve(
client_id: Uuid,
remote_addr: SocketAddr,
worterbuch: CloneableWbApi,
websocket: WebSocket,
) -> Result<()> {
info!("New client connected: {client_id} ({remote_addr})");
if let Err(e) = worterbuch
.connected(client_id, Some(remote_addr), Protocol::WS)
.await
{
error!("Error while adding new client: {e}");
} else {
debug!("Receiving messages from client {client_id} ({remote_addr}) …",);
if let Err(e) = serve_loop(client_id, remote_addr, worterbuch.clone(), websocket).await {
error!("Error in serve loop: {e}");
}
}
info!("Client disconnected: {client_id} ({remote_addr})");
worterbuch
.disconnected(client_id, Some(remote_addr))
.await?;
Ok(())
}
type WebSocketSender = SplitSink<WebSocket, Message>;
async fn serve_loop(
client_id: Uuid,
remote_addr: SocketAddr,
worterbuch: CloneableWbApi,
websocket: WebSocket,
) -> Result<()> {
let config = worterbuch.config().to_owned();
let authorization_required = config.auth_token_key.is_some();
let send_timeout = config.send_timeout;
let mut authorized = None;
let (ws_tx, mut ws_rx) = websocket.split();
let (ws_send_tx, ws_send_rx) = mpsc::channel(config.channel_buffer_size);
spawn(send_loop(client_id, send_timeout, ws_tx, ws_send_rx));
let supported_protocol_versions = SUPPORTED_PROTOCOL_VERSIONS.into();
ws_send_tx
.send(ServerMessage::Welcome(Welcome {
client_id: client_id.to_string(),
info: ServerInfo::new(
VERSION.to_owned(),
supported_protocol_versions,
authorization_required,
),
}))
.await
.into_diagnostic()?;
let mut proto = Proto::new(
client_id,
ws_send_tx,
authorization_required,
config,
worterbuch,
);
loop {
if let Some(msg) = ws_rx.next().await {
match msg {
Ok(incoming_msg) => {
trace!("Processing incoming message …");
if let Message::Text(text) = incoming_msg {
let msg_processed = proto
.process_incoming_message(&text, &mut authorized)
.await?;
if !msg_processed {
break;
}
}
}
Err(e) => {
error!("Error in WebSocket connection: {e}");
break;
}
}
} else {
info!("WS stream of client {client_id} ({remote_addr}) closed.");
break;
}
}
Ok(())
}
async fn send_loop(
client_id: Uuid,
send_timeout: Option<Duration>,
mut ws_tx: SplitSink<WebSocket, Message>,
mut ws_send_rx: mpsc::Receiver<ServerMessage>,
) {
while let Some(msg) = ws_send_rx.recv().await {
if let Err(e) = send_with_timeout(msg, &mut ws_tx, send_timeout, client_id).await {
error!("Error sending WS message: {e}");
break;
}
}
}
async fn send_with_timeout(
msg: ServerMessage,
websocket: &mut WebSocketSender,
send_timeout: Option<Duration>,
client_id: Uuid,
) -> Result<()> {
let json = serde_json::to_string(&msg).into_diagnostic()?;
let msg = Message::Text(json.into());
if let Some(send_timeout) = send_timeout {
trace!("Sending with timeout {}s …", send_timeout.as_secs());
match timeout(send_timeout, websocket.send(msg)).await {
Ok(r) => r.into_diagnostic()?,
Err(_) => {
error!("Send timeout for client {client_id}");
return Err(miette!("Send timeout for client {client_id}"));
}
}
} else {
trace!("Sending without timeout …");
websocket.send(msg).await.into_diagnostic()?;
}
trace!("Sending done.");
Ok(())
}