use std::{
collections::{BTreeMap, BTreeSet},
io,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use moosicbox_async_service::async_trait;
use moosicbox_ws::{
PlayerAction, WebsocketContext, WebsocketDisconnectError, WebsocketMessageError,
WebsocketSendError, WebsocketSender,
};
use serde_json::Value;
use strum_macros::AsRefStr;
use switchy_async::sync::{RwLock, mpsc, oneshot};
use switchy_database::{config::ConfigDatabase, profiles::PROFILES};
use tokio_util::sync::CancellationToken;
use crate::ws::{ConnId, Msg, RoomId};
#[async_trait]
impl WebsocketSender for WsServer {
async fn send(&self, connection_id: &str, data: &str) -> Result<(), WebsocketSendError> {
let id = connection_id.parse::<ConnId>().unwrap();
log::debug!("Sending to {id}");
self.send_message_to(id, data.to_string());
for sender in &self.senders {
sender.send(connection_id, data).await?;
}
Ok(())
}
async fn send_all(&self, data: &str) -> Result<(), WebsocketSendError> {
self.send_system_message("main", 0, data.to_string());
for sender in &self.senders {
sender.send_all(data).await?;
}
Ok(())
}
async fn send_all_except(
&self,
connection_id: &str,
data: &str,
) -> Result<(), WebsocketSendError> {
self.send_system_message(
"main",
connection_id.parse::<ConnId>().unwrap(),
data.to_string(),
);
for sender in &self.senders {
sender.send_all_except(connection_id, data).await?;
}
Ok(())
}
async fn ping(&self) -> Result<(), WebsocketSendError> {
self.ping_system();
for sender in &self.senders {
sender.ping().await?;
}
Ok(())
}
}
#[derive(Debug, AsRefStr)]
pub enum Command {
#[cfg(feature = "player")]
AddPlayerAction {
id: u64,
action: PlayerAction,
},
Connect {
profile: String,
conn_tx: mpsc::Sender<Msg>,
res_tx: oneshot::Sender<ConnId>,
},
Disconnect {
conn: ConnId,
},
Send {
msg: Msg,
conn: ConnId,
res_tx: oneshot::Sender<()>,
},
Broadcast {
msg: Msg,
res_tx: oneshot::Sender<()>,
},
BroadcastExcept {
msg: Msg,
conn: ConnId,
res_tx: oneshot::Sender<()>,
},
Message {
msg: Msg,
conn: ConnId,
res_tx: oneshot::Sender<()>,
},
}
impl std::fmt::Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_ref())
}
}
#[derive(Debug, Clone)]
struct Connection {
profile: String,
sender: mpsc::Sender<Msg>,
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct WsServer {
connections: BTreeMap<ConnId, Connection>,
config_db: ConfigDatabase,
rooms: BTreeMap<RoomId, BTreeSet<ConnId>>,
#[allow(unused)]
profiles: BTreeMap<String, BTreeSet<ConnId>>,
visitor_count: Arc<AtomicUsize>,
cmd_rx: flume::Receiver<Command>,
senders: Vec<Box<dyn WebsocketSender>>,
player_actions: Vec<(u64, PlayerAction)>,
token: CancellationToken,
}
impl WsServer {
#[must_use]
pub fn new(config_db: ConfigDatabase) -> (Self, WsServerHandle) {
let mut rooms = BTreeMap::new();
rooms.insert("main".to_owned(), BTreeSet::new());
let mut profiles = BTreeMap::new();
for profile in PROFILES.names() {
profiles.insert(profile, BTreeSet::new());
}
let (cmd_tx, cmd_rx) = flume::unbounded();
let token = CancellationToken::new();
let handle = WsServerHandle {
cmd_tx,
token: token.clone(),
};
(
Self {
connections: BTreeMap::new(),
config_db,
rooms,
profiles,
visitor_count: Arc::new(AtomicUsize::new(0)),
cmd_rx,
senders: vec![],
player_actions: vec![],
token,
},
handle,
)
}
#[cfg(feature = "player")]
pub fn add_player_action(&mut self, id: u64, action: PlayerAction) {
self.player_actions.push((id, action));
}
#[cfg(feature = "tunnel")]
pub fn add_sender(&mut self, sender: Box<dyn WebsocketSender>) {
self.senders.push(sender);
}
#[allow(clippy::unused_self)]
fn ping_system(&self) {
log::trace!("ping_system: pong");
}
fn send_system_message(&self, room: &str, skip: ConnId, msg: impl Into<String>) {
if let Some(sessions) = self.rooms.get(room) {
let msg = msg.into();
for conn_id in sessions {
if *conn_id != skip
&& let Some(Connection { sender, .. }) = self.connections.get(conn_id)
{
let _ = sender.send(msg.clone());
}
}
}
}
fn send_message_to(&self, id: ConnId, msg: impl Into<String>) {
if let Some(Connection { sender, .. }) = self.connections.get(&id) {
let _ = sender.send(msg.into());
}
}
async fn on_message(
&self,
id: ConnId,
msg: impl Into<String> + Send,
) -> Result<(), WebsocketMessageError> {
let connection_id = id.to_string();
let profile = self.connections.get(&id).unwrap().profile.clone();
log::trace!(
"on_message connection_id={connection_id} player_actions.len={}",
self.player_actions.len()
);
let context = WebsocketContext {
connection_id,
profile: Some(profile),
player_actions: self.player_actions.clone(),
};
let payload = msg.into();
let body = serde_json::from_str::<Value>(&payload)
.map_err(|e| WebsocketMessageError::InvalidPayload(payload, e.to_string()))?;
moosicbox_ws::process_message(&self.config_db, body, context, self).await?;
Ok(())
}
fn connect(&mut self, profile: String, tx: mpsc::Sender<Msg>) -> ConnId {
log::debug!("Someone joined");
let id = switchy_random::rng().next_u64();
self.connections.insert(
id,
Connection {
profile: profile.clone(),
sender: tx,
},
);
self.rooms.entry("main".to_owned()).or_default().insert(id);
let count = self.visitor_count.fetch_add(1, Ordering::SeqCst);
log::debug!("Visitor count: {}", count + 1);
let connection_id = id.to_string();
let context = WebsocketContext {
connection_id,
profile: Some(profile),
player_actions: self.player_actions.clone(),
};
let _ = moosicbox_ws::connect(self, &context);
id
}
async fn disconnect(&mut self, conn_id: ConnId) -> Result<(), WebsocketDisconnectError> {
log::debug!("Someone disconnected {conn_id}");
let count = self.visitor_count.fetch_sub(1, Ordering::SeqCst);
log::debug!("Visitor count: {}", count - 1);
if self.connections.remove(&conn_id).is_some() {
for sessions in self.rooms.values_mut() {
sessions.remove(&conn_id);
}
}
let connection_id = conn_id.to_string();
let context = WebsocketContext {
connection_id,
profile: None,
player_actions: self.player_actions.clone(),
};
moosicbox_ws::disconnect(&self.config_db, self, &context).await?;
Ok(())
}
#[allow(clippy::cognitive_complexity)]
async fn process_command(ctx: Arc<RwLock<Self>>, cmd: Command) -> io::Result<()> {
let cmd_str = cmd.to_string();
if log::log_enabled!(log::Level::Trace) {
log::trace!("process_command: cmd={cmd:?}");
} else {
log::debug!("process_command: cmd={cmd_str}");
}
match cmd {
#[cfg(feature = "player")]
Command::AddPlayerAction { id, action } => {
ctx.write().await.add_player_action(id, action);
log::debug!("Added a player action with id={id}");
}
Command::Connect {
profile,
conn_tx,
res_tx,
} => {
let conn_id = ctx.write().await.connect(profile, conn_tx);
res_tx
.send(conn_id)
.map_err(|e| std::io::Error::other(format!("Failed to send: {e:?}")))?;
}
Command::Disconnect { conn } => {
let response = ctx.write().await.disconnect(conn).await;
if let Err(error) = response {
moosicbox_assert::die_or_error!(
"Failed to disconnect connection {conn}: {:?}",
error
);
}
}
Command::Send { msg, conn, res_tx } => {
let response = ctx.read().await.send(&conn.to_string(), &msg).await;
if let Err(error) = response {
moosicbox_assert::die_or_error!(
"Failed to send message to {conn} {msg:?}: {error:?}",
);
}
let _ = res_tx.send(());
}
Command::Broadcast { msg, res_tx } => {
let response = ctx.read().await.send_all(&msg).await;
if let Err(error) = response {
moosicbox_assert::die_or_error!(
"Failed to broadcast message {msg:?}: {error:?}",
);
}
let _ = res_tx.send(());
}
Command::BroadcastExcept { msg, conn, res_tx } => {
let response = ctx
.read()
.await
.send_all_except(&conn.to_string(), &msg)
.await;
if let Err(error) = response {
moosicbox_assert::die_or_error!(
"Failed to broadcast message {msg:?}: {error:?}",
);
}
let _ = res_tx.send(());
}
Command::Message { conn, msg, res_tx } => {
let response = ctx.read().await.on_message(conn, msg.clone()).await;
if let Err(error) = response {
if log::log_enabled!(log::Level::Debug) {
moosicbox_assert::die_or_error!(
"Failed to process message from {}: {msg:?}: {error:?}",
conn
);
} else {
moosicbox_assert::die_or_error!(
"Failed to process message from {}: {msg:?}: {error:?} ({:?})",
conn,
msg
);
}
}
let _ = res_tx.send(());
}
}
log::debug!("process_command: Finished processing cmd {cmd_str}");
Ok(())
}
pub async fn run(self) -> io::Result<()> {
let token = self.token.clone();
let cmd_rx = self.cmd_rx.clone();
let ctx = Arc::new(RwLock::new(self));
while let Ok(Ok(cmd)) = switchy_async::select!(
() = token.cancelled() => {
log::debug!("WsServer was cancelled");
Err(std::io::Error::new(std::io::ErrorKind::Interrupted, "Cancelled"))
}
cmd = cmd_rx.recv_async() => { Ok(cmd) }
) {
log::trace!("Received WsServer command {cmd}");
switchy_async::runtime::Handle::current().spawn_with_name(
"server: WsServer process_command",
Self::process_command(ctx.clone(), cmd),
);
}
log::debug!("Stopped WsServer");
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct WsServerHandle {
cmd_tx: flume::Sender<Command>,
token: CancellationToken,
}
#[async_trait]
impl WebsocketSender for WsServerHandle {
async fn send(&self, connection_id: &str, data: &str) -> Result<(), WebsocketSendError> {
let id = connection_id.parse::<ConnId>().unwrap();
self.send(id, data.to_string()).await;
Ok(())
}
async fn send_all(&self, data: &str) -> Result<(), WebsocketSendError> {
if log::log_enabled!(log::Level::Trace) {
log::trace!("Broadcasting message to all: {data}");
} else {
log::debug!("Broadcasting message to all");
}
self.broadcast(data.to_string()).await;
Ok(())
}
async fn send_all_except(
&self,
connection_id: &str,
data: &str,
) -> Result<(), WebsocketSendError> {
if log::log_enabled!(log::Level::Trace) {
log::trace!("Broadcasting message to all except {connection_id}: {data}");
} else {
log::debug!("Broadcasting message to all except {connection_id}");
}
self.broadcast_except(connection_id.parse::<ConnId>().unwrap(), data.to_string())
.await;
Ok(())
}
async fn ping(&self) -> Result<(), WebsocketSendError> {
self.ping()
.await
.map_err(|e| WebsocketSendError::Unknown(e.to_string()))?;
Ok(())
}
}
impl WsServerHandle {
#[cfg(feature = "player")]
pub async fn add_player_action(&self, player_id: u64, action: PlayerAction) {
log::trace!("Sending AddPlayerAction command id={player_id}");
if let Err(e) = self
.cmd_tx
.send_async(Command::AddPlayerAction {
id: player_id,
action,
})
.await
{
moosicbox_assert::die_or_error!("Failed to send command: {e:?}");
}
}
pub async fn connect(&self, profile: String, conn_tx: mpsc::Sender<String>) -> ConnId {
log::trace!("Sending Connect command");
let (res_tx, res_rx) = oneshot::channel();
switchy_async::runtime::Handle::current().spawn_with_name("ws server connect", {
let cmd_tx = self.cmd_tx.clone();
async move {
if let Err(e) = cmd_tx
.send_async(Command::Connect {
profile,
conn_tx,
res_tx,
})
.await
{
moosicbox_assert::die_or_error!("Failed to send command: {e:?}");
}
}
});
res_rx.await.unwrap_or_else(|e| {
moosicbox_assert::die_or_panic!("Failed to recv response from ws server: {e:?}")
})
}
pub async fn send(&self, conn: ConnId, msg: impl Into<String> + Send) {
log::trace!("Sending Send command");
let (res_tx, res_rx) = oneshot::channel();
switchy_async::runtime::Handle::current().spawn_with_name("ws server send", {
let cmd_tx = self.cmd_tx.clone();
let msg = msg.into();
async move {
if let Err(e) = cmd_tx.send_async(Command::Send { msg, conn, res_tx }).await {
moosicbox_assert::die_or_error!("Failed to send command: {e:?}");
}
}
});
res_rx.await.unwrap_or_else(|e| {
moosicbox_assert::die_or_error!("Failed to recv response from ws server: {e:?}");
});
}
pub async fn broadcast(&self, msg: impl Into<String> + Send) {
log::trace!("Sending Broadcast command");
let (res_tx, res_rx) = oneshot::channel();
switchy_async::runtime::Handle::current().spawn_with_name("ws server broadcast", {
let cmd_tx = self.cmd_tx.clone();
let msg = msg.into();
async move {
if let Err(e) = cmd_tx.send_async(Command::Broadcast { msg, res_tx }).await {
moosicbox_assert::die_or_error!("Failed to send command: {e:?}");
}
}
});
res_rx.await.unwrap_or_else(|e| {
moosicbox_assert::die_or_error!("Failed to recv response from ws server: {e:?}");
});
}
pub async fn broadcast_except(&self, conn: ConnId, msg: impl Into<String> + Send) {
log::trace!("Sending BroadcastExcept command");
let (res_tx, res_rx) = oneshot::channel();
switchy_async::runtime::Handle::current().spawn_with_name("ws server broadcast_except", {
let cmd_tx = self.cmd_tx.clone();
let msg = msg.into();
async move {
if let Err(e) = cmd_tx
.send_async(Command::BroadcastExcept { msg, conn, res_tx })
.await
{
moosicbox_assert::die_or_error!("Failed to send command: {e:?}");
}
}
});
res_rx.await.unwrap_or_else(|e| {
moosicbox_assert::die_or_error!("Failed to recv response from ws server: {e:?}");
});
}
pub async fn send_message(&self, conn: ConnId, msg: impl Into<String> + Send) {
log::trace!("Sending Message command");
let (res_tx, res_rx) = oneshot::channel();
switchy_async::runtime::Handle::current().spawn_with_name("ws server send_message", {
let cmd_tx = self.cmd_tx.clone();
let msg = msg.into();
async move {
if let Err(e) = cmd_tx
.send_async(Command::Message { msg, conn, res_tx })
.await
{
moosicbox_assert::die_or_error!("Failed to send command: {e:?}");
}
}
});
res_rx.await.unwrap_or_else(|e| {
moosicbox_assert::die_or_error!("Failed to recv response from ws server: {e:?}");
});
}
pub async fn disconnect(&self, conn: ConnId) {
log::trace!("Sending Disconnect command");
if let Err(e) = self.cmd_tx.send_async(Command::Disconnect { conn }).await {
moosicbox_assert::die_or_error!("Failed to send command: {e:?}");
}
}
pub fn shutdown(&self) {
self.token.cancel();
}
}