use std::{
collections::{HashMap, HashSet},
io,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use rand::{rng, Rng as _};
use tokio::sync::mpsc;
use super::{ChatServerHandle, Command, ConnId, Msg, RoomId};
#[derive(Debug)]
pub struct ChatServer {
sessions: HashMap<ConnId, mpsc::UnboundedSender<Msg>>,
rooms: HashMap<RoomId, HashSet<ConnId>>,
visitor_count: Arc<AtomicUsize>,
cmd_rx: mpsc::UnboundedReceiver<Command>,
}
impl ChatServer {
pub fn new(
context: actix_web::web::Data<crate::server::AppContext>,
) -> (Self, ChatServerHandle) {
let mut rooms = HashMap::with_capacity(4);
rooms.insert("main".to_owned(), HashSet::new());
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
(
Self {
sessions: HashMap::new(),
rooms,
visitor_count: Arc::new(AtomicUsize::new(0)),
cmd_rx,
},
ChatServerHandle { cmd_tx, context },
)
}
async fn send_system_message(&self, room: &str, skip: ConnId, msg: impl Into<String>) {
if let Some(sessions) = self.rooms.get(room) {
let msg = msg.into();
for conn_id in sessions {
if *conn_id != skip {
if let Some(tx) = self.sessions.get(conn_id) {
let _ = tx.send(msg.clone());
}
}
}
}
}
async fn send_message(&self, conn: ConnId, msg: impl Into<String>) {
if let Some(room) = self
.rooms
.iter()
.find_map(|(room, participants)| participants.contains(&conn).then_some(room))
{
self.send_system_message(room, conn, msg).await;
};
}
async fn connect(&mut self, tx: mpsc::UnboundedSender<Msg>) -> ConnId {
log::info!("Someone joined");
self.send_system_message("main", 0, "Someone joined").await;
let id = rng().random::<u64>() as usize;
self.sessions.insert(id, tx);
self.rooms.entry("main".to_owned()).or_default().insert(id);
let count = self.visitor_count.fetch_add(1, Ordering::SeqCst);
self.send_system_message("main", 0, format!("Total online visitors: {}", count + 1))
.await;
id
}
async fn disconnect(&mut self, conn_id: ConnId) {
log::info!("Someone disconnected");
self.visitor_count.fetch_sub(1, Ordering::SeqCst);
let mut rooms: Vec<String> = Vec::new();
if self.sessions.remove(&conn_id).is_some() {
for (name, sessions) in &mut self.rooms {
if sessions.remove(&conn_id) {
rooms.push(name.to_owned());
}
}
}
for room in rooms {
self.send_system_message(&room, 0, "Someone disconnected")
.await;
}
}
fn list_rooms(&mut self) -> Vec<String> {
self.rooms.keys().cloned().collect()
}
async fn join_room(&mut self, conn_id: ConnId, room: String) {
let mut rooms = Vec::new();
for (n, sessions) in &mut self.rooms {
if sessions.remove(&conn_id) {
rooms.push(n.to_owned());
}
}
for room in rooms {
self.send_system_message(&room, 0, "Someone disconnected")
.await;
}
self.rooms.entry(room.clone()).or_default().insert(conn_id);
self.send_system_message(&room, conn_id, "Someone connected")
.await;
}
pub async fn run(mut self) -> io::Result<()> {
while let Some(cmd) = self.cmd_rx.recv().await {
match cmd {
Command::Connect { conn_tx, res_tx } => {
let conn_id = self.connect(conn_tx).await;
let _ = res_tx.send(conn_id);
}
Command::Disconnect { conn } => {
self.disconnect(conn).await;
}
Command::List { res_tx } => {
let _ = res_tx.send(self.list_rooms());
}
Command::Join { conn, room, res_tx } => {
self.join_room(conn, room).await;
let _ = res_tx.send(());
}
Command::Message { conn, msg, res_tx } => {
self.send_message(conn, msg).await;
let _ = res_tx.send(());
}
}
}
Ok(())
}
}