use super::utils::decode_jwt;
use super::{listen_to_redis, Channel, ChannelControl, ChannelError, ChannelMessage};
use futures::SinkExt;
use futures::StreamExt;
use itertools::Itertools;
use redis::aio::MultiplexedConnection;
use redis::AsyncCommands;
use redis::RedisResult;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_tuple::{Deserialize_tuple, Serialize_tuple};
use std::collections::HashMap;
use std::fmt;
use std::fmt::{Display, Error};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
#[derive(Clone, Debug, Serialize_tuple)]
pub struct ServerMessage {
pub join_ref: Option<String>, pub event_ref: String,
pub topic: String, pub event: String,
pub payload: ServerPayload,
}
#[derive(Clone, Debug, Serialize)]
#[serde(untagged)]
pub enum ServerPayload {
ServerResponse(ServerResponse),
ServerJsonValue(serde_json::Value),
}
#[derive(Clone, Debug, Serialize)]
pub struct ServerResponse {
pub status: String,
pub response: Response,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum Response {
#[serde(rename = "join")]
Join { id: String },
#[serde(rename = "heartbeat")]
Heartbeat {},
#[serde(rename = "datetime")]
Datetime { datetime: String, counter: u32 },
#[serde(rename = "message")]
Message { message: String },
#[serde(rename = "null")]
Empty {},
}
impl fmt::Display for ServerMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let join_ref = self.join_ref.clone().unwrap_or("None".to_string());
let response_str = "...";
let payload_display = match self.payload {
ServerPayload::ServerResponse(ref resp) => format!(
"<Payload status={}, response={}>",
resp.status, response_str
),
ServerPayload::ServerJsonValue(ref value) => format!("<ServerJsonResponse {}>", value),
};
write!(
f,
"Message join_ref={}, ref={}, topic={}, event={}, {}",
join_ref, self.event_ref, self.topic, self.event, payload_display
)
}
}
#[derive(Debug, Deserialize_tuple)]
struct RequestMessage {
join_ref: Option<String>, event_ref: String,
topic: String, event: String,
payload: RequestPayload,
}
impl Display for RequestMessage {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), Error> {
write!(
formatter,
"<RequestMessage: join_ref={:?}, ref={}, topic={}, event={}, payload=...>",
self.join_ref, self.event_ref, self.topic, self.event
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
enum RequestPayload {
Join { token: String },
Message { message: String },
JsonValue(serde_json::Value), }
pub struct State {
pub ctl: Mutex<ChannelControl>,
pub redis_client: redis::Client,
pub id_length: u8,
pub jwt_secret: String,
pub jwt_expiration_secs: i64,
}
impl State {}
pub async fn axum_on_connected(
ws: axum::extract::ws::WebSocket,
state: Arc<State>,
user_token: Option<String>,
) {
info!("params: {:?}", user_token);
let conn_id = nanoid::nanoid!(8).to_string();
state.ctl.lock().await.conn_add_tx(conn_id.clone()).await;
info!("AXUM / WS_TX / new connection connected: {}", conn_id);
let (mut ws_tx, mut ws_rx) = ws.split();
let ws_tx_state = state.clone();
let ws_tx_conn_id = conn_id.clone();
let mut ws_tx_task = tokio::spawn(async move {
info!("AXUM / WS_TX / launch websocket tx task (conn rx => ws tx) ...");
let mut conn_rx = ws_tx_state
.ctl
.lock()
.await
.conn_rx(ws_tx_conn_id.clone())
.await
.unwrap();
loop {
match conn_rx.recv().await {
Ok(channel_message) => {
let ChannelMessage::Reply(reply_message) = channel_message;
let text_result = serde_json::to_string(&reply_message);
if text_result.is_err() {
error!(
"AXUM / WS_TX / fail to serialize reply message: {}",
text_result.err().unwrap()
);
break;
}
let text = text_result.unwrap();
let sending_result = ws_tx
.send(axum::extract::ws::Message::Text(text.into()))
.await;
if sending_result.is_err() {
error!(
"AXUM / WS_TX / websocket tx sending failed: {}",
sending_result.err().unwrap()
);
break; }
}
Err(e) => {
error!("AXUM / WS_TX / rx error: {:?}", e);
break;
}
}
}
});
let ws_rx_state = state.clone();
let ws_rx_conn_id = conn_id.clone();
let ws_rx_user_token = user_token.clone();
let mut ws_rx_task = tokio::spawn(async move {
info!("AXUM / WS_RX / websocket rx handling (ws rx =>) ...");
let mut redis_conn = ws_rx_state
.redis_client
.get_multiplexed_async_connection()
.await
.unwrap();
while let Some(msg_result) = ws_rx.next().await {
match msg_result {
Ok(msg) => match msg {
axum::extract::ws::Message::Text(text) => {
if let Err(e) = handle_message(
ws_rx_state.clone(),
ws_rx_user_token.clone(),
&ws_rx_conn_id,
&text,
&mut redis_conn,
)
.await
{
error!("AXUM / WS_RX / handle_message error: {:?}", e);
}
}
axum::extract::ws::Message::Close(_) => {
debug!("AXUM / WS_RX / close frame received");
break;
}
_ => {}
},
Err(e) => {
error!("AXUM / WS_RX / rx error: {:?}", e);
break;
}
}
}
});
tokio::select! {
_ = (&mut ws_tx_task) => {
info!("AXUM / ws_tx_task exits.");
ws_rx_task.abort();
info!("AXUM / ws_rx_task aborts.");
},
_ = (&mut ws_rx_task) => {
info!("AXUM / ws_rx_task exits.");
ws_tx_task.abort();
info!("AXUM / ws_tx_task aborts.");
},
}
state.ctl.lock().await.conn_cleanup(conn_id.clone()).await;
info!("AXUM / CONNECTION CLOSED");
}
async fn handle_message(
state: Arc<State>,
user_token: Option<String>,
conn_id: &str,
text: &str,
redis_conn: &mut redis::aio::MultiplexedConnection,
) -> RedisResult<()> {
let rm_result = serde_json::from_str::<RequestMessage>(text);
if rm_result.is_err() {
error!(
"WS_RX / conn: {}, error: {:?} text: {}",
&conn_id,
rm_result.err(),
text
);
return Ok(());
}
let rm: RequestMessage = rm_result.unwrap();
let channel_name = &rm.topic;
let join_ref = &rm.join_ref;
let event_ref = &rm.event_ref;
let event = &rm.event;
let payload = &rm.payload;
if channel_name == "phoenix" && event == "heartbeat" {
ok_reply(conn_id, None, event_ref, "phoenix", state.clone()).await;
let message = format!(r#"{{"conn_id": "{}"}}"#, conn_id); publish_event(redis_conn, "from:phoenix:heartbeat".to_string(), message).await;
}
if event == "phx_join" {
let _relay_task = handle_join(user_token, &rm, state.clone(), conn_id).await;
debug!("WS_RX / join processed");
}
if event == "phx_leave" {
handle_leave(
state.clone(),
conn_id,
join_ref.clone(),
event_ref,
channel_name.clone(),
)
.await;
debug!("WS_RX / leave processed");
}
let redis_topic = format!("from:{}:{}", channel_name, event.clone());
let message = serde_json::to_string(&payload).unwrap();
publish_event(redis_conn, redis_topic, message).await;
Ok(())
}
async fn publish_event(
redis_conn: &mut redis::aio::MultiplexedConnection,
redis_topic: String,
message: String,
) {
let publish_result: RedisResult<String> = redis_conn
.publish(redis_topic.clone(), message.clone())
.await;
if let Err(e) = publish_result {
error!("fail to publish to redis: {}", e)
}
}
pub fn is_special_channel(ch: &str) -> bool {
let excludes: Vec<&str> = vec!["phoenix", "admin", "system"];
excludes.contains(&ch)
}
pub async fn add_channel(ctl: &Mutex<ChannelControl>, channel_name: String) {
let ctl = ctl.lock().await;
let mut channels = ctl.channels.lock().await;
let channel_exists = channels.contains_key(&channel_name);
if channel_exists {
warn!("ADD_CH / channel {} already exists", channel_name);
}
channels
.entry(channel_name.clone())
.or_insert_with(|| Channel::new(channel_name.clone(), None));
warn!("ADD_CH / {} added", channel_name);
let channel_names = channels.keys().cloned().collect::<Vec<String>>();
info!(
"ADD_CH / {} created, channels: {} {:?}",
channel_name,
channel_names.len(),
channel_names
);
let meta = json!({
"channel": channel_name,
"channels": channels.keys().cloned().collect::<Vec<String>>(),
});
ctl.pub_meta_event("channe".into(), "add".into(), meta)
.await;
}
pub async fn launch_channel_redis_listen_task(
state: Arc<State>,
ctl: &Mutex<ChannelControl>,
channel_name: String,
redis_client: redis::Client,
) {
let ctl = ctl.lock().await;
let mut channels = ctl.channels.lock().await;
let channel: &mut Channel = channels.get_mut(&channel_name).unwrap();
if channel.redis_listen_task.is_some() {
warn!(
"LAUNCH_REDIS_TASK / channel {} redis_listen_task already exists",
channel_name
);
return;
}
channel.redis_listen_task = Some(tokio::spawn(listen_to_redis(
state,
channel.tx.clone(),
redis_client,
channel_name.clone(),
)));
info!(
"LAUNCH_REDIS_TASK / channel {} redis_listen_task launched",
channel_name
);
}
async fn handle_join(
user_token: Option<String>,
rm: &RequestMessage,
state: Arc<State>,
conn_id: &str,
) -> Result<JoinHandle<()>, ChannelError> {
let token = match &rm.payload {
RequestPayload::Join { token } => Ok(token.clone()),
_ => user_token.ok_or_else(|| {
error!("JOIN / invalid payload: {:?}", rm.payload);
ChannelError::BadToken
}),
}?;
let claims = match decode_jwt(&token, state.jwt_secret.clone()).await {
Ok(claims) => claims,
Err(e) => {
error!("JOIN / fail to decode JWT, {}, {}", e, token);
return Err(ChannelError::BadToken);
}
};
debug!("JOIN / claims: {:?}", claims);
let channel_name = rm.topic.clone();
if is_special_channel(&channel_name) {
info!("ADD_CH / channel {} is special, ignored", channel_name);
} else {
add_channel(&state.ctl, channel_name.clone()).await;
launch_channel_redis_listen_task(
state.clone(),
&state.ctl,
channel_name.clone(),
state.redis_client.clone(),
)
.await;
}
let agent_id = format!(
"{}:{}:{}",
conn_id,
channel_name.clone(),
rm.join_ref.clone().unwrap()
);
let join_ref = rm.join_ref.clone();
let event_ref = rm.event_ref.clone();
info!(
"JOIN / agent joining ({} => {}) ...",
agent_id, channel_name
);
state
.ctl
.lock()
.await
.agent_add(agent_id.to_string(), None)
.await; match state
.ctl
.lock()
.await
.channel_join(
&channel_name.clone(),
agent_id.to_string(),
claims.id.clone(),
)
.await
{
Ok(_) => {}
Err(e) => {
error!("JOIN / fail to join: {}", e);
return Err(e);
}
}
let relay_state = state.clone();
let local_join_ref = rm.join_ref.clone();
let local_conn_id = conn_id.to_string();
let local_agent_id = agent_id.clone();
let relay_task = tokio::spawn(async move {
let mut agent_rx = relay_state
.ctl
.lock()
.await
.agent_rx(local_agent_id.clone())
.await
.unwrap();
let conn_tx = relay_state
.ctl
.lock()
.await
.conn_tx(local_conn_id.to_string())
.await
.unwrap();
debug!(
"R / agent {} => conn {}",
local_agent_id.clone(),
local_conn_id.clone()
);
loop {
let message_opt = agent_rx.recv().await;
if message_opt.is_err() {
error!(
"R / fail to get message from agent rx: {}",
message_opt.err().unwrap()
);
break;
}
let mut channel_message = message_opt.unwrap();
let ChannelMessage::Reply(ref mut reply) = channel_message;
reply.join_ref = local_join_ref.clone();
let send_result = conn_tx.send(channel_message.clone()); if send_result.is_err() {
error!(
"R / agent {}, conn: {}, sending failure: {:?}, exit ...",
&local_agent_id,
&local_conn_id,
send_result.err().unwrap()
);
break; }
}
});
ok_reply(
conn_id,
join_ref.clone(),
&event_ref,
&channel_name,
state.clone(),
)
.await;
info!("JOIN / acked");
if channel_name == "admin" {
info!("JOIN / handling admin initialization ...");
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
for (name, channel) in channels.iter() {
let meta = json!({"channel": name, "agents": *channel.agents.lock().await});
ctl.pub_meta_event("channel".into(), "list".into(), meta)
.await;
}
}
presence_state(
conn_id,
join_ref.clone(),
&event_ref,
&channel_name,
state.clone(),
)
.await;
let mut redis_conn = state
.redis_client
.get_multiplexed_async_connection()
.await
.unwrap();
presence_diff(
&mut redis_conn,
channel_name.clone(),
agent_id.clone(),
claims.id.clone(),
PresenceAction::Join,
)
.await;
Ok(relay_task)
}
async fn handle_leave(
state: Arc<State>,
conn_id: &str,
join_ref: Option<String>,
event_ref: &str,
channel_name: String,
) {
let agent_id = format!("{}:{}:{}", conn_id, channel_name, join_ref.clone().unwrap());
let external_id_opt = state.ctl.lock().await.agent_rm(agent_id.clone()).await;
let agent_count = state
.ctl
.lock()
.await
.channel_leave(channel_name.clone(), agent_id.clone())
.await
.unwrap();
if agent_count == 0 && !is_special_channel(&channel_name) {
state
.ctl
.lock()
.await
.channel_remove_if_empty(channel_name.clone())
.await;
}
ok_reply(conn_id, join_ref, event_ref, &channel_name, state.clone()).await;
if external_id_opt.is_none() {
error!("LEAVE / agent {} not found", agent_id);
return;
}
info!("LEAVE / send presense_diff");
let mut redis_conn = state
.redis_client
.get_multiplexed_async_connection()
.await
.unwrap();
presence_diff(
&mut redis_conn,
channel_name.clone(),
agent_id.clone(),
external_id_opt.unwrap(),
PresenceAction::Leave,
)
.await;
}
async fn ok_reply(
conn_id: &str,
join_ref: Option<String>,
event_ref: &str,
channel_name: &str,
state: Arc<State>,
) {
let response = match join_ref {
None => Response::Empty {}, Some(ref join_ref) => Response::Join {
id: format!("{}:{}:{}", conn_id, channel_name, join_ref),
}, };
let join_reply_message = ServerMessage {
join_ref: join_ref.clone(),
event_ref: event_ref.to_string(),
topic: channel_name.to_string(),
event: "phx_reply".to_string(),
payload: ServerPayload::ServerResponse(ServerResponse {
status: "ok".to_string(),
response,
}),
};
state
.ctl
.lock()
.await
.conn_send(
conn_id.to_string(),
ChannelMessage::Reply(join_reply_message),
)
.await
.unwrap();
}
async fn presence_state(
conn_id: &str,
join_ref: Option<String>,
event_ref: &str,
channel_name: &str,
state: Arc<State>,
) {
let hashed_agents = state
.ctl
.lock()
.await
.agents
.lock()
.await
.iter()
.filter(|(_, agent)| agent.channel == channel_name)
.into_group_map_by(|(_, agent)| &agent.external_id)
.into_iter()
.map(|(external_id, group)| {
(external_id.clone(), {
json!({
"metas": group.into_iter()
.map(|(id, _)| json!({ "phx_ref": id }))
.collect::<Vec<_>>()
})
})
})
.collect::<HashMap<_, _>>();
let reply = ServerMessage {
join_ref: join_ref.clone(),
event_ref: event_ref.to_string(),
topic: channel_name.to_string(),
event: "presence_state".to_string(),
payload: ServerPayload::ServerJsonValue(json!(hashed_agents)),
};
state
.ctl
.lock()
.await
.conn_send(conn_id.to_string(), ChannelMessage::Reply(reply))
.await
.unwrap();
info!("P_STATE / sent");
}
#[derive(Debug)]
pub enum PresenceAction {
Join,
Leave,
}
pub async fn presence_diff_many(
redis_conn: &mut MultiplexedConnection,
channel_name: String,
action: PresenceAction,
items: serde_json::Value,
) {
let diff = match action {
PresenceAction::Join => json!({"joins": items, "leaves": {}}),
PresenceAction::Leave => json!({"joins": {}, "leaves": items}),
};
let redis_topic = format!("to:{}:presence_diff", channel_name);
let message = serde_json::to_string(&diff).unwrap();
let publish_result: RedisResult<String> = redis_conn
.publish(redis_topic.clone(), message.clone())
.await;
if let Err(e) = publish_result {
error!("P_DIFF_MANY / fail to publish to redis: {}", e)
} else {
info!("P_DIFF_MANY / sent, {:?}", action);
}
}
pub async fn presence_diff(
redis_conn: &mut MultiplexedConnection,
channel_name: String,
agent_id: String,
external_id: String,
action: PresenceAction,
) {
let items = json!({
external_id.clone(): {
"metas": [
{"phx_ref": agent_id.clone()}
],
},
});
let diff = match action {
PresenceAction::Join => json!({"joins": items, "leaves": {}}),
PresenceAction::Leave => json!({"joins": {}, "leaves": items}),
};
let redis_topic = format!("to:{}:presence_diff", channel_name);
let message = serde_json::to_string(&diff).unwrap();
let publish_result: RedisResult<String> = redis_conn
.publish(redis_topic.clone(), message.clone())
.await;
if let Err(e) = publish_result {
error!("P_DIFF / fail to publish to redis: {}", e)
} else {
info!("P_DIFF / sent, {:?}", action);
}
}
pub async fn datetime_handler(state: Arc<State>, channel_name: String) {
tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
info!("launch system/datetime task...");
let mut counter = 0;
let event = "datetime";
loop {
let now = chrono::Local::now();
let message = ServerMessage {
join_ref: None,
event_ref: counter.to_string(),
topic: channel_name.to_string(),
event: event.to_string(),
payload: ServerPayload::ServerResponse(ServerResponse {
status: "ok".to_string(),
response: Response::Datetime {
datetime: now.to_rfc3339_opts(chrono::SecondsFormat::Millis, false),
counter,
},
}),
};
match state
.ctl
.lock()
.await
.channel_broadcast(channel_name.to_string(), ChannelMessage::Reply(message))
.await
{
Ok(0) => {} Ok(_) => {} Err(ChannelError::ChannelEmpty) => {}
Err(e) => {
error!(
"DT / fail to broadcast, channel: {}, event: {}, error: {}",
channel_name, event, e
);
}
}
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
counter += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channel::utils::generate_jwt;
use axum::{
extract::{Query, State as AxumState, WebSocketUpgrade},
response::IntoResponse,
routing::get,
Router,
};
use futures::{SinkExt, StreamExt};
use serde::Deserialize;
use serde_json::json;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
#[derive(Debug, Deserialize)]
struct WebSocketParams {
#[serde(rename = "userToken")]
user_token: Option<String>,
#[allow(dead_code)]
#[serde(rename = "vsn")]
version: Option<String>,
}
async fn axum_websocket_handler(
ws: WebSocketUpgrade,
Query(params): Query<WebSocketParams>,
AxumState(state): AxumState<Arc<State>>,
) -> impl IntoResponse {
let user_token = params.user_token.clone();
ws.on_upgrade(move |socket| axum_on_connected(socket, state, user_token))
}
async fn setup_test_server() -> (String, Arc<State>, String) {
let redis_url = "redis://127.0.0.1:6379".to_string();
let redis_client = redis::Client::open(redis_url.clone()).unwrap();
let channel_control = ChannelControl::new(Arc::new(redis_client.clone()));
let state = Arc::new(State {
ctl: Mutex::new(channel_control),
redis_client,
id_length: 8,
jwt_secret: "secret".into(),
jwt_expiration_secs: 3600,
});
let rand_suffix = nanoid::nanoid!(8);
let system_channel = format!("system_{}", rand_suffix);
state
.ctl
.lock()
.await
.channel_add("phoenix".into(), None)
.await;
state
.ctl
.lock()
.await
.channel_add(system_channel.clone(), None)
.await;
state
.ctl
.lock()
.await
.channel_add("streaming".into(), None)
.await;
tokio::spawn(datetime_handler(state.clone(), system_channel.clone()));
let app = Router::new()
.route("/websocket", get(axum_websocket_handler))
.with_state(state.clone());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let addr = format!("ws://127.0.0.1:{}/websocket", port);
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(addr, state, system_channel)
}
async fn connect_client(
addr: &str,
) -> (
futures::stream::SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
futures::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) {
let (ws_stream, _) = connect_async(addr).await.expect("Failed to connect");
ws_stream.split()
}
async fn gen_token(state: &Arc<State>, channel: &str, id: &str) -> String {
generate_jwt(
id.to_string(),
channel.to_string(),
state.jwt_secret.clone(),
state.jwt_expiration_secs,
)
.await
.unwrap()
}
#[tokio::test]
async fn test_ws_websocket_connection() {
let (addr, _, _) = setup_test_server().await;
let (mut tx, mut rx) = connect_client(&addr).await;
let heartbeat = r#"[null,"1","phoenix","heartbeat",{}]"#;
tx.send(Message::text(heartbeat)).await.unwrap();
if let Some(Ok(msg)) = rx.next().await {
let response: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
assert_eq!(response[2], "phoenix");
assert_eq!(response[4]["status"], "ok");
}
}
#[tokio::test]
async fn test_ws_channel_join_leave_flow() {
let (addr, state, sys_chan) = setup_test_server().await;
let (mut tx, mut rx) = connect_client(&addr).await;
let token = gen_token(&state, &sys_chan, "user1").await;
let join_msg = format!(
r#"["1","ref1","{}","phx_join",{{"token":"{}"}}]"#,
sys_chan, token
);
tx.send(Message::text(join_msg)).await.unwrap();
let mut join_confirmed = false;
let timeout_duration = std::time::Duration::from_secs(5);
'join_loop: for _ in 0..10 {
match tokio::time::timeout(timeout_duration, rx.next()).await {
Ok(Some(Ok(msg))) => {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
if resp[1] == "ref1" && resp[2] == sys_chan && resp[3] == "phx_reply" {
assert_eq!(resp[4]["status"], "ok");
join_confirmed = true;
break 'join_loop;
}
}
Ok(Some(Err(_))) => panic!("WebSocket error during join"),
Ok(None) => panic!("WebSocket closed unexpectedly during join"),
Err(_) => panic!("Timed out waiting for join response"),
}
}
assert!(
join_confirmed,
"Never received join confirmation after reading multiple messages"
);
{
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
let agents = channels.get(&sys_chan).unwrap().agents.lock().await;
assert_eq!(agents.len(), 1);
}
let leave_msg = format!(r#"["1","ref2","{}","phx_leave",{{}}]"#, sys_chan);
tx.send(Message::text(leave_msg)).await.unwrap();
let mut leave_confirmed = false;
'leave_loop: for _ in 0..10 {
match tokio::time::timeout(timeout_duration, rx.next()).await {
Ok(Some(Ok(msg))) => {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
if resp[1] == "ref2" && resp[2] == sys_chan && resp[3] == "phx_reply" {
assert_eq!(resp[4]["status"], "ok");
leave_confirmed = true;
break 'leave_loop;
}
}
Ok(Some(Err(_))) => panic!("WebSocket error during leave"),
Ok(None) => panic!("WebSocket closed unexpectedly during leave"),
Err(_) => panic!("Timed out waiting for leave response"),
}
}
assert!(
leave_confirmed,
"Never received leave confirmation after reading multiple messages"
);
{
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
if let Some(channel) = channels.get(&sys_chan) {
let agents = channel.agents.lock().await;
assert_eq!(agents.len(), 0);
}
}
drop(tx);
drop(rx);
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
#[tokio::test]
async fn test_ws_connection_close_websocket() {
let (addr, state, sys_chan) = setup_test_server().await;
let (mut tx, mut rx) = connect_client(&addr).await;
let token = gen_token(&state, &sys_chan, "user1").await;
let join_msg = format!(
r#"["1","ref1","{}","phx_join",{{"token":"{}"}}]"#,
sys_chan, token
);
tx.send(Message::text(join_msg)).await.unwrap();
if let Some(Ok(_)) = rx.next().await {}
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let agent_count = {
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
if let Some(channel) = channels.get(&sys_chan) {
channel.agents.lock().await.len()
} else {
0
}
};
assert_eq!(agent_count, 1, "Agent should be joined");
drop(tx);
drop(rx);
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
let agent_count = {
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
if let Some(channel) = channels.get(&sys_chan) {
channel.agents.lock().await.len()
} else {
0
}
};
assert_eq!(
agent_count, 0,
"Agent should be removed after connection close"
);
}
#[tokio::test]
async fn test_ws_multiple_clients_fixed() {
let (addr, state, sys_chan) = setup_test_server().await;
let mut clients = vec![];
for i in 0..3 {
let (mut tx, mut rx) = connect_client(&addr).await;
let token = gen_token(&state, &sys_chan, &format!("user{}", i)).await;
let join_msg = format!(
r#"["{}","ref{}","{}","phx_join",{{"token":"{}"}}]"#,
i, i, sys_chan, token
);
tx.send(Message::text(join_msg)).await.unwrap();
if let Some(Ok(_)) = tokio::time::timeout(std::time::Duration::from_secs(2), rx.next())
.await
.unwrap()
{}
clients.push((tx, rx));
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let agent_count = {
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
let system_channel = channels.get(&sys_chan).unwrap();
let agents = system_channel.agents.lock().await;
agents.len()
};
assert_eq!(agent_count, 3, "Should have 3 agents connected");
drop(clients);
}
#[tokio::test]
async fn test_ws_flow_server() {
let (_addr, state, sys_chan) = setup_test_server().await;
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
assert!(channels.contains_key("phoenix"));
assert!(channels.contains_key(&sys_chan));
assert!(channels.contains_key("streaming"));
let agents = channels.get(&sys_chan).unwrap().agents.lock().await;
assert_eq!(agents.len(), 0);
}
#[tokio::test]
async fn test_ws_flow_join_leave() {
let (addr, state, sys_chan) = setup_test_server().await;
let (mut tx, mut rx) = connect_client(&addr).await;
let token = gen_token(&state, &sys_chan, "user1").await;
let join_msg = format!(
r#"["1","ref1","{}","phx_join",{{"token":"{}"}}]"#,
sys_chan, token
);
tx.send(Message::text(join_msg)).await.unwrap();
loop {
if let Some(Ok(msg)) = rx.next().await {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
if resp[3] == "phx_reply" && resp[1] == "ref1" {
assert_eq!(resp[2], sys_chan);
assert_eq!(resp[4]["status"], "ok");
break;
}
} else {
panic!("Stream ended or error before join reply");
}
}
{
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
let agents = channels.get(&sys_chan).unwrap().agents.lock().await;
assert_eq!(agents.len(), 1);
}
let leave_msg = format!(r#"["1","ref2","{}","phx_leave",{{}}]"#, sys_chan);
tx.send(Message::text(leave_msg)).await.unwrap();
loop {
if let Some(Ok(msg)) = rx.next().await {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
if resp[3] == "phx_reply" && resp[1] == "ref2" {
assert_eq!(resp[2], sys_chan);
assert_eq!(resp[4]["status"], "ok");
break;
}
} else {
panic!("Stream ended or error before leave reply");
}
}
{
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
if let Some(channel) = channels.get(&sys_chan) {
let agents = channel.agents.lock().await;
assert_eq!(agents.len(), 0);
}
}
}
#[tokio::test]
async fn test_ws_multiple_clients() {
let (addr, state, sys_chan) = setup_test_server().await;
let mut clients = vec![];
assert_eq!(clients.len(), 0);
for i in 0..3 {
let (mut tx, mut rx) = connect_client(&addr).await;
let token = gen_token(&state, &sys_chan, &format!("user{}", i)).await;
let join_msg = format!(
r#"["{}","ref{}","{}","phx_join",{{"token":"{}"}}]"#,
i, i, sys_chan, token
);
tx.send(Message::text(join_msg)).await.unwrap();
if let Some(Ok(msg)) = rx.next().await {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
assert_eq!(resp[4]["status"], "ok");
}
clients.push((tx, rx));
}
assert_eq!(clients.len(), 3);
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
let agents = channels.get(&sys_chan).unwrap().agents.lock().await;
assert_eq!(agents.len(), 3);
}
#[tokio::test]
async fn test_ws_message_broadcast() {
let (addr, state, sys_chan) = setup_test_server().await;
let (mut tx1, mut rx1) = connect_client(&addr).await;
let (mut tx2, mut rx2) = connect_client(&addr).await;
for (tx, i) in [(&mut tx1, 1), (&mut tx2, 2)] {
let token = gen_token(&state, &sys_chan, &format!("user{}", i)).await;
let join_msg = format!(
r#"["{}","ref{}","{}","phx_join",{{"token":"{}"}}]"#,
i, i, sys_chan, token
);
tx.send(Message::text(join_msg)).await.unwrap();
if let Some(Ok(_)) = if i == 1 {
rx1.next().await
} else {
rx2.next().await
} {}
}
let message = ServerMessage {
join_ref: None,
event_ref: "broadcast".to_string(),
topic: sys_chan.clone(),
event: "test".to_string(),
payload: ServerPayload::ServerResponse(ServerResponse {
status: "ok".to_string(),
response: Response::Message {
message: "test broadcast".to_string(),
},
}),
};
state
.ctl
.lock()
.await
.channel_broadcast(sys_chan.clone(), ChannelMessage::Reply(message))
.await
.unwrap();
for rx in [&mut rx1, &mut rx2] {
if let Some(Ok(msg)) = rx.next().await {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
if resp[1] == "broadcast" {
assert_eq!(resp[4]["response"]["message"], "test broadcast");
}
}
}
}
#[tokio::test]
async fn test_ws_invalid_messages() {
let (addr, _, _) = setup_test_server().await;
let (mut tx, mut rx) = connect_client(&addr).await;
tx.send(Message::text("invalid json")).await.unwrap();
tx.send(Message::text(r#"["invalid","format"]"#))
.await
.unwrap();
let invalid_channel = r#"["1","ref1","nonexistent","phx_join",{"token":"test"}]"#;
tx.send(Message::text(invalid_channel)).await.unwrap();
let heartbeat = r#"[null,"1","phoenix","heartbeat",{}]"#;
tx.send(Message::text(heartbeat)).await.unwrap();
if let Some(Ok(msg)) = rx.next().await {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
assert_eq!(resp[2], "phoenix");
assert_eq!(resp[4]["status"], "ok");
}
}
#[tokio::test]
async fn test_ws_system_channel() {
let (addr, state, sys_chan) = setup_test_server().await;
let (mut tx, mut rx) = connect_client(&addr).await;
let token = gen_token(&state, &sys_chan, "user1").await;
let join_msg = format!(
r#"["1","ref1","{}","phx_join",{{"token":"{}"}}]"#,
sys_chan, token
);
tx.send(Message::text(join_msg)).await.unwrap();
if let Some(Ok(msg)) = rx.next().await {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
assert_eq!(resp[2], sys_chan);
assert_eq!(resp[4]["status"], "ok");
}
match tokio::time::timeout(std::time::Duration::from_secs(5), rx.next()).await {
Ok(Some(Ok(msg))) => {
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
if resp[2] == sys_chan && resp[3] == "datetime" {
assert!(resp[4]["response"]["datetime"].is_string());
}
}
_ => panic!("Timed out waiting for datetime update"),
}
}
#[test]
fn test_ws_request_json_heartbeat() {
let msg: RequestMessage =
serde_json::from_str(r#"["1", "ref1", "room123", "heartbeat", {}]"#).unwrap();
assert_eq!(msg.join_ref, Some("1".to_string()));
assert_eq!(msg.event_ref, "ref1");
assert_eq!(msg.topic, "room123");
assert_eq!(msg.event, "heartbeat");
assert_eq!(msg.payload, serde_json::from_value(json!({})).unwrap());
}
#[test]
fn test_ws_request_json_join() {
let msg: RequestMessage = serde_json::from_str(
r#"["1", "ref1", "room123", "phx_join", {"token": "secret_token"}]"#,
)
.unwrap();
assert_eq!(msg.join_ref, Some("1".to_string()));
assert_eq!(msg.event_ref, "ref1");
assert_eq!(msg.topic, "room123");
assert_eq!(msg.event, "phx_join");
assert_eq!(
msg.payload,
RequestPayload::Join {
token: "secret_token".to_string()
}
);
}
#[test]
fn test_ws_request_json_message() {
let json = r#"["1", "ref4", "room123", "message", {"message": "Hello, World!"}]"#;
let msg: RequestMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.event, "message");
assert_eq!(
msg.payload,
RequestPayload::Message {
message: "Hello, World!".to_string()
}
);
}
#[test]
fn test_ws_request_json_message_payload() {
let payload: RequestPayload = serde_json::from_value(json!({})).unwrap();
assert_eq!(payload, RequestPayload::JsonValue(json!({})));
let payload: RequestPayload =
serde_json::from_value(json!({"token": "another_token"})).unwrap();
assert_eq!(
payload,
RequestPayload::Join {
token: "another_token".to_string()
}
);
let payload: RequestPayload =
serde_json::from_value(json!({ "message": "test message" })).unwrap();
assert_eq!(
payload,
RequestPayload::Message {
message: "test message".to_string()
}
);
}
#[test]
fn test_ws_request_json_invalid() {
assert!(serde_json::from_str::<RequestMessage>(r#"["1", "ref1", "room123"]"#).is_err());
assert!(
serde_json::from_str::<RequestMessage>(r#"["1", "ref1", "room123", "phx_join"]"#)
.is_err()
);
assert!(serde_json::from_str::<RequestMessage>(
r#"["1", "ref1", "room123", "phx_join", null]"#
)
.is_ok());
assert!(serde_json::from_str::<RequestMessage>(
r#"["1", "ref1", "room123", "phx_join", 23]"#
)
.is_ok());
assert!(serde_json::from_str::<RequestMessage>(
r#"["1", "ref1", "room123", "phx_join", 12.4]"#
)
.is_ok());
assert!(serde_json::from_str::<RequestMessage>(
r#"["1", "ref1", "room123", "phx_join", "nulldirect_token"]"#
)
.is_ok());
assert!(serde_json::from_str::<RequestMessage>(
r#"["1", "ref1", "room123", "phx_join", [1, null, "foobar"]]"#
)
.is_ok());
assert!(serde_json::from_str::<RequestMessage>(
r#"[123, "ref1", "room123", "phx_join", {"token": "secret"}]"#
)
.is_err());
}
async fn expect_msg_content(
rx: &mut futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
>,
content: &str,
) {
let timeout = std::time::Duration::from_secs(2);
let start = std::time::Instant::now();
loop {
if start.elapsed() > timeout {
panic!("Timeout waiting for message containing '{}'", content);
}
match tokio::time::timeout(std::time::Duration::from_millis(100), rx.next()).await {
Ok(Some(Ok(msg))) => {
let s = msg.to_string();
if s.contains(content) {
return;
}
}
_ => continue,
}
}
}
#[tokio::test]
async fn test_redis_listener_survival_on_zero_subscribers() {
let (addr, state, channel_name) = setup_test_server().await;
let (mut tx_a, mut rx_a) = connect_client(&addr).await;
let token_a = gen_token(&state, &channel_name, "user_a").await;
tx_a.send(Message::text(format!(
r#"["1","ref1","{}","phx_join",{{"token":"{}"}}]"#,
channel_name, token_a
)))
.await
.unwrap();
expect_msg_content(&mut rx_a, "phx_reply").await;
let mut redis_conn = state
.redis_client
.get_multiplexed_async_connection()
.await
.unwrap();
let redis_topic = format!("to:{}:test_event", channel_name);
redis_conn
.publish::<_, _, ()>(&redis_topic, r#"{"type":"message","message":"one"}"#)
.await
.unwrap();
expect_msg_content(&mut rx_a, "one").await;
drop(tx_a);
drop(rx_a);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
redis_conn
.publish::<_, _, ()>(&redis_topic, r#"{"type":"message","message":"void"}"#)
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let (mut tx_b, mut rx_b) = connect_client(&addr).await;
let token_b = gen_token(&state, &channel_name, "user_b").await;
tx_b.send(Message::text(format!(
r#"["2","ref2","{}","phx_join",{{"token":"{}"}}]"#,
channel_name, token_b
)))
.await
.unwrap();
expect_msg_content(&mut rx_b, "phx_reply").await;
redis_conn
.publish::<_, _, ()>(&redis_topic, r#"{"type":"message","message":"two"}"#)
.await
.unwrap();
expect_msg_content(&mut rx_b, "two").await;
}
#[tokio::test]
async fn test_nested_topics_and_raw_json() {
let (addr, state, _) = setup_test_server().await;
let nested_channel = "weather:wind";
let (mut tx, mut rx) = connect_client(&addr).await;
let token = gen_token(&state, nested_channel, "user_nested").await;
tx.send(Message::text(format!(
r#"["1","ref1","{}","phx_join",{{"token":"{}"}}]"#,
nested_channel, token
)))
.await
.unwrap();
expect_msg_content(&mut rx, "phx_reply").await;
let mut redis_conn = state
.redis_client
.get_multiplexed_async_connection()
.await
.unwrap();
let raw_topic = "to:weather:wind:update";
let raw_payload = r#"{"temperature": 25.5, "unit": "C"}"#;
redis_conn
.publish::<_, _, ()>(raw_topic, raw_payload)
.await
.unwrap();
let timeout = std::time::Duration::from_secs(2);
let start = std::time::Instant::now();
let mut found = false;
while start.elapsed() < timeout {
if let Ok(Some(Ok(msg))) =
tokio::time::timeout(std::time::Duration::from_millis(100), rx.next()).await
{
let resp: serde_json::Value = serde_json::from_str(&msg.to_string()).unwrap();
if resp.get(3).and_then(|v| v.as_str()) == Some("update") {
assert_eq!(resp[2], "weather:wind", "Topic parsed incorrectly");
assert_eq!(resp[4]["temperature"], 25.5, "Raw JSON payload corrupted");
found = true;
break;
}
}
}
assert!(found, "Did not receive 'update' event with correct payload");
}
#[tokio::test]
async fn test_ghost_channel_cleanup() {
let (addr, state, _) = setup_test_server().await;
let ghost_channel = format!("temp_{}", nanoid::nanoid!(4));
let (mut tx, rx) = connect_client(&addr).await;
let token = gen_token(&state, &ghost_channel, "ghost_user").await;
tx.send(Message::text(format!(
r#"["1","ref1","{}","phx_join",{{"token":"{}"}}]"#,
ghost_channel, token
)))
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
{
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
assert!(
channels.contains_key(&ghost_channel),
"Channel should exist"
);
}
drop(tx);
drop(rx);
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
{
let ctl = state.ctl.lock().await;
let channels = ctl.channels.lock().await;
assert!(
!channels.contains_key(&ghost_channel),
"Ghost channel leaked after disconnect"
);
}
}
}