use {
futures::{FutureExt, TryFutureExt},
id_pool::IdPool as PortPool,
std::collections::{HashMap, HashSet},
std::sync::Arc,
std::sync::Mutex,
std::sync::atomic::{AtomicU32, Ordering},
tokio::sync::Barrier,
tracing::{Instrument, instrument},
warp::ws::{Message, WebSocket},
};
use crate::{
channel::Channel,
cli::Config,
connection,
envvars::{Env, replace_template_env},
error::AppResult,
metrics::Metrics,
process,
types::{CacheBuffer, ConnID, Event, EventRx, EventTx, PortID, ProcessSenders, RoomID},
};
type ConnectionMap = HashMap<RoomID, HashSet<ConnID>>;
type ProcessMap = HashMap<RoomID, ProcessSenders>;
type ProcessCacheMap = HashMap<RoomID, Arc<Mutex<CacheBuffer>>>;
struct State {
pub conns_next_id: AtomicU32,
pub conns: ConnectionMap,
pub cache: ProcessCacheMap,
pub procs: ProcessMap,
pub ports: Option<PortPool>,
pub cfg: Config,
}
#[instrument(name = "event", skip_all)]
pub async fn handle(
tx: EventTx,
mut rx: EventRx,
config: Config,
metrics: Metrics,
) -> Result<(), ()> {
let is_oneshot = config.oneshot;
let max_procs = config.max_rooms.unwrap_or(usize::MAX);
let mut state = State::new(config);
while let Some(event) = rx.recv().await {
match event {
Event::Connect { room, ws, env } if state.procs.contains_key(&room) => {
if is_oneshot {
tracing::warn!("client rejected, no connections permitted in oneshot mode");
let _ = ws.close().await;
continue;
}
metrics.inc_ws_connections(&room);
attach(room, env, ws, &tx, &mut state, None);
}
Event::Connect { room, ws, env } => {
if state.procs.len() >= max_procs {
tracing::warn!("client rejected, maximum number of rooms reached");
let _ = ws.close().await;
continue;
}
metrics.inc_ws_connections(&room);
let spawn_barrier = Some(Arc::new(Barrier::new(2)));
let attach_barrier = spawn_barrier.clone();
spawn(&room, &env, &tx, &mut state, spawn_barrier).ok();
attach(room, env, ws, &tx, &mut state, attach_barrier);
}
Event::Disconnect { room, conn, env } => {
metrics.dec_ws_connections(&room);
disconnect(room, env, conn, &mut state);
if is_oneshot {
break;
}
}
Event::ProcessExit { room, code, port } => {
metrics.clear(&room);
exit(room, code, port, &mut state);
if is_oneshot {
break;
}
}
Event::ProcessMeta { room, value } => {
metrics.set_metadata(&room, value);
}
Event::Shutdown => {
break;
}
}
}
shutdown(state);
Err(())
}
impl State {
pub fn new(cfg: Config) -> Self {
Self {
conns_next_id: AtomicU32::new(1),
conns: HashMap::new(),
procs: HashMap::new(),
ports: cfg.tcp_ports.clone().map(PortPool::new_ranged),
cache: HashMap::new(),
cfg,
}
}
pub fn new_conn_id(&self) -> ConnID {
self.conns_next_id.fetch_add(1, Ordering::Relaxed)
}
}
#[instrument(name = "attach", skip(env, ws, tx, state, barrier))]
fn attach(
room: RoomID,
env: Env,
ws: Box<WebSocket>,
tx: &EventTx,
state: &mut State,
barrier: Option<Arc<Barrier>>,
) {
let conn = state.new_conn_id();
let framing = (&state.cfg).into();
let (proc_tx_broadcast, proc_tx, _) = state.procs.get(&room).expect("room not in process map");
let proc_rx = proc_tx_broadcast.subscribe();
let cache = match state.cache.get(&room) {
Some(shared) => shared.lock().expect("poisoned lock").to_vec(),
None => Vec::new(),
};
let mut on_init = || {
let is_inserted = state
.conns
.entry(room.to_string())
.or_default()
.insert(conn);
if is_inserted {
tracing::info!(id = conn, "client connected");
if let Some(ref join_msg_template) = state.cfg.join_msg {
let join_msg = replace_template_env(join_msg_template, conn, &env);
let _ = proc_tx.send(Message::text(join_msg));
}
}
};
let on_disconnect = || {
let tx = tx.clone();
let room = room.clone();
let env = env.clone();
move |_| {
tracing::debug!(id = conn, "client disconnecting");
let _ = tx.send(Event::Disconnect { room, conn, env });
futures::future::ready(())
}
};
tokio::spawn(
connection::handle(*ws, conn, framing, proc_rx, proc_tx.clone(), barrier, cache)
.then({
on_init();
on_disconnect()
})
.in_current_span(),
);
}
#[instrument(name = "spawn", skip(env, tx, state, barrier))]
fn spawn(
room: &str,
env: &Env,
tx: &EventTx,
state: &mut State,
barrier: Option<Arc<Barrier>>,
) -> AppResult<()> {
let port = state.ports.as_mut().and_then(|p| p.request_id());
if let Some(port) = port {
tracing::debug!("reserved port {}", port);
}
let cache = match state.cfg.cache {
Some(ref c) => {
state
.cache
.entry(room.to_string())
.or_insert_with(|| Arc::new(Mutex::new(CacheBuffer::new(c))));
state.cache.get(room).cloned()
}
None => None,
};
let mut proc = Channel::new(&state.cfg, port, room, env.cgi.clone(), cache);
let senders = proc.take_senders();
proc.give_sender(tx.clone());
let on_init = || {
state.procs.insert(room.to_string(), senders);
};
let on_kill = || {
let tx = tx.clone();
let room = room.to_string();
move |code: Option<i32>| {
let _ = tx.send(Event::ProcessExit { room, code, port });
Ok(())
}
};
tokio::spawn(
process::handle(proc, barrier)
.map_ok_or_else(
move |e| {
tracing::error!("{}", e);
Err(e)
},
{
on_init();
on_kill()
},
)
.in_current_span(),
);
Ok(())
}
#[instrument(name = "disconnect", skip(env, conn, state))]
fn disconnect(room: RoomID, env: Env, conn: ConnID, state: &mut State) {
let room_conns = state.conns.entry(room.clone()).or_default();
let (_, proc_tx, _) = state.procs.get(&room).expect("room not in process map");
let is_removed = room_conns.remove(&conn);
if is_removed {
tracing::info!(id = conn, "client disconnected");
if let Some(ref leave_msg_template) = state.cfg.leave_msg {
let leave_msg = replace_template_env(leave_msg_template, conn, &env);
let _ = proc_tx.send(Message::text(leave_msg));
}
}
if room_conns.is_empty()
&& let Some((_, _, kill_tx)) = state.procs.remove(&room)
&& kill_tx.send(()).is_ok()
{
tracing::info!("all clients disconnected, killing process");
}
}
#[instrument(name = "exit", skip(code, port, state))]
fn exit(room: RoomID, code: Option<i32>, port: Option<PortID>, state: &mut State) {
if let Some(port) = port {
let _ = state.ports.as_mut().map(|p| p.return_id(port));
tracing::debug!("released port {}", port);
}
if !state.cfg.cache_persist {
state.cache.remove(&room);
}
if state.procs.contains_key(&room) {
tracing::error!(room, code, "process exited");
}
}
#[instrument(name = "shutdown", skip_all)]
fn shutdown(state: State) {
tracing::debug!("killing processes");
let procs = state.procs.into_values();
for (_, _, kill_tx) in procs {
let _ = kill_tx.send(());
}
}
#[cfg(test)]
mod tests {
use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex, atomic::AtomicU32},
};
use clap::Parser;
use tokio::sync::{
self, broadcast,
mpsc::{self},
oneshot,
};
use warp::{Filter, filters::ws::Message};
use super::{Env, Event, State, attach, disconnect};
use crate::{
cli::Config,
types::{Cache, CacheBuffer, ProcessSenders, ToProcessRx},
};
fn create_config(args: &'static str) -> Config {
Config::parse_from(args.split_whitespace())
}
fn create_process_senders() -> ProcessSenders {
create_process().1
}
fn create_process() -> (ToProcessRx, ProcessSenders) {
let (proc_tx, proc_rx) = mpsc::unbounded_channel();
let broadcast_tx = broadcast::Sender::new(16);
let (kill_tx, _) = oneshot::channel();
(proc_rx, (broadcast_tx, proc_tx, kill_tx))
}
fn create_process_with_cache() -> (ToProcessRx, ProcessSenders, CacheBuffer) {
let (proc_tx, proc_rx) = mpsc::unbounded_channel();
let broadcast_tx = broadcast::Sender::new(16);
let (kill_tx, _) = oneshot::channel();
let cache = CacheBuffer::new(&Cache::All(8));
(proc_rx, (broadcast_tx, proc_tx, kill_tx), cache)
}
async fn create_ws() -> (warp::ws::WebSocket, warp::test::WsClient) {
let (tx, mut rx) = sync::mpsc::unbounded_channel();
let route = warp::ws().map(move |websocket: warp::ws::Ws| {
let tx = tx.clone();
websocket.on_upgrade(move |ws| {
tx.send(ws).ok();
futures::future::ready(())
})
});
let wsc = warp::test::ws().handshake(route).await.expect("handshake");
let ws = rx.recv().await.unwrap();
(ws, wsc)
}
#[tokio::test]
async fn test_attach() {
let (mut proc_rx, senders) = create_process();
let mut state = State {
conns_next_id: AtomicU32::new(1),
conns: HashMap::new(),
procs: HashMap::from([("room1".to_string(), senders)]),
cfg: create_config("scalesocket cat --joinmsg=foo"),
ports: None,
cache: HashMap::new(),
};
let (tx, _) = sync::mpsc::unbounded_channel::<Event>();
let (ws, _) = create_ws().await;
attach(
"room1".to_string(),
Env::default(),
Box::new(ws),
&tx,
&mut state,
None,
);
let _ = proc_rx.recv().await;
}
#[tokio::test]
async fn test_attach_sends_joinmsg() {
let (mut proc_rx, senders) = create_process();
let mut state = State {
conns_next_id: AtomicU32::new(1),
conns: HashMap::new(),
procs: HashMap::from([("room1".to_string(), senders)]),
cfg: create_config("scalesocket cat --joinmsg=foo"),
ports: None,
cache: HashMap::new(),
};
let (tx, _) = sync::mpsc::unbounded_channel::<Event>();
let (ws, _) = create_ws().await;
attach(
"room1".to_string(),
Env::default(),
Box::new(ws),
&tx,
&mut state,
None,
);
let received_msg = proc_rx.recv().await.unwrap();
let received_msg = std::str::from_utf8(&received_msg.as_bytes()).unwrap();
assert_eq!("foo", received_msg);
}
#[tokio::test]
async fn test_attach_sends_cache() {
let (_proc_rx, senders, mut cache) = create_process_with_cache();
cache.write(Message::text("foo"));
cache.write(Message::text("bar"));
let mut state = State {
conns_next_id: AtomicU32::new(1),
conns: HashMap::new(),
procs: HashMap::from([("room1".to_string(), senders)]),
cfg: create_config("scalesocket --cache=all:64 --joinmsg=baz cat"),
ports: None,
cache: HashMap::from([("room1".to_string(), Arc::new(Mutex::new(cache)))]),
};
let (tx, _) = sync::mpsc::unbounded_channel::<Event>();
let (ws, mut wsc) = create_ws().await;
attach(
"room1".to_string(),
Env::default(),
Box::new(ws),
&tx,
&mut state,
None,
);
assert_eq!(wsc.recv().await.unwrap(), Message::text("foo"));
assert_eq!(wsc.recv().await.unwrap(), Message::text("bar"));
}
#[tokio::test]
async fn test_disconnect() {
let mut state = State {
conns_next_id: AtomicU32::new(1),
conns: HashMap::from([
("room1".to_string(), HashSet::from([1])),
("room2".to_string(), HashSet::from([2])),
]),
procs: HashMap::from([("room1".to_string(), create_process_senders())]),
cfg: create_config("scalesocket cat"),
ports: None,
cache: HashMap::new(),
};
disconnect("room1".to_string(), Env::default(), 1, &mut state);
assert!(state.conns.get("room1").unwrap().is_empty());
assert!(!state.conns.get("room2").unwrap().is_empty());
}
}