use crate::container::App;
use bytes::Bytes;
use ferro_broadcast::{Broadcaster, ClientMessage, PresenceMember, ServerMessage};
use futures_util::{SinkExt, StreamExt};
use http_body_util::Full;
use hyper_tungstenite::tungstenite::Message as WsMsg;
use tokio::sync::mpsc;
use tokio::time::Instant;
use uuid::Uuid;
pub(crate) fn handle_ws_upgrade(
mut req: hyper::Request<hyper::body::Incoming>,
) -> hyper::Response<Full<Bytes>> {
let broadcaster = match App::get::<Broadcaster>() {
Some(b) => b,
None => {
return hyper::Response::builder()
.status(503)
.body(Full::new(Bytes::from("Broadcasting not configured")))
.unwrap();
}
};
let (response, ws_future) = match hyper_tungstenite::upgrade(&mut req, None) {
Ok(pair) => pair,
Err(e) => {
eprintln!("WebSocket upgrade failed: {e}");
return hyper::Response::builder()
.status(400)
.body(Full::new(Bytes::from("WebSocket upgrade failed")))
.unwrap();
}
};
tokio::spawn(async move {
match ws_future.await {
Ok(ws_stream) => {
handle_connection(ws_stream, broadcaster).await;
}
Err(e) => {
eprintln!("WebSocket connection failed: {e}");
}
}
});
response
}
async fn handle_connection(
ws_stream: hyper_tungstenite::HyperWebsocketStream,
broadcaster: Broadcaster,
) {
let socket_id = Uuid::new_v4().to_string();
let (mut ws_write, mut ws_read) = ws_stream.split();
let (tx, mut rx) = mpsc::channel::<ServerMessage>(32);
broadcaster.add_client(socket_id.clone(), tx);
let connected = ServerMessage::Connected {
socket_id: socket_id.clone(),
};
if let Ok(msg) = connected.to_ws_message() {
let _ = ws_write.send(msg).await;
}
let config = broadcaster.config().clone();
let mut heartbeat_interval = tokio::time::interval(config.heartbeat_interval);
let mut last_activity = Instant::now();
loop {
tokio::select! {
frame = ws_read.next() => {
match frame {
Some(Ok(msg)) => {
last_activity = Instant::now();
match msg {
WsMsg::Text(text) => {
handle_client_message(
&text,
&socket_id,
&broadcaster,
&mut ws_write,
).await;
}
WsMsg::Close(_) => break,
WsMsg::Pong(_) => {
}
_ => {} }
}
Some(Err(_)) => break,
None => break,
}
}
server_msg = rx.recv() => {
match server_msg {
Some(msg) => {
if let Ok(ws_msg) = msg.to_ws_message() {
if ws_write.send(ws_msg).await.is_err() {
break;
}
}
}
None => break,
}
}
_ = heartbeat_interval.tick() => {
if last_activity.elapsed() > config.client_timeout {
break;
}
if ws_write.send(WsMsg::Ping(vec![].into())).await.is_err() {
break;
}
}
}
}
let _ = ws_write.send(WsMsg::Close(None)).await;
broadcaster.remove_client(&socket_id);
}
async fn handle_client_message<S>(
text: &str,
socket_id: &str,
broadcaster: &Broadcaster,
ws_write: &mut S,
) where
S: SinkExt<WsMsg, Error = hyper_tungstenite::tungstenite::Error> + Unpin,
{
let client_msg = match ClientMessage::from_ws_text(text) {
Ok(msg) => msg,
Err(_) => {
let err = ServerMessage::Error {
message: "Invalid message format".into(),
};
if let Ok(ws_msg) = err.to_ws_message() {
let _ = ws_write.send(ws_msg).await;
}
return;
}
};
match client_msg {
ClientMessage::Subscribe {
channel,
auth,
channel_data,
} => {
let member_info = if channel.starts_with("presence-") {
channel_data.and_then(|data| {
let user_id = data.get("user_id")?.as_str()?.to_string();
let user_info = data
.get("user_info")
.cloned()
.unwrap_or(serde_json::Value::Null);
Some(PresenceMember::new(socket_id, user_id).with_info(user_info))
})
} else {
None
};
match broadcaster
.subscribe(socket_id, &channel, auth.as_deref(), member_info)
.await
{
Ok(()) => {
let msg = ServerMessage::Subscribed { channel };
if let Ok(ws_msg) = msg.to_ws_message() {
let _ = ws_write.send(ws_msg).await;
}
}
Err(e) => {
let msg = ServerMessage::SubscriptionError {
channel,
error: e.to_string(),
};
if let Ok(ws_msg) = msg.to_ws_message() {
let _ = ws_write.send(ws_msg).await;
}
}
}
}
ClientMessage::Unsubscribe { channel } => {
broadcaster.unsubscribe(socket_id, &channel).await;
let msg = ServerMessage::Unsubscribed { channel };
if let Ok(ws_msg) = msg.to_ws_message() {
let _ = ws_write.send(ws_msg).await;
}
}
ClientMessage::Whisper {
channel,
event,
data,
} => {
if let Err(e) = broadcaster.whisper(socket_id, &channel, &event, data).await {
let msg = ServerMessage::Error {
message: e.to_string(),
};
if let Ok(ws_msg) = msg.to_ws_message() {
let _ = ws_write.send(ws_msg).await;
}
}
}
ClientMessage::Ping => {
let msg = ServerMessage::Pong;
if let Ok(ws_msg) = msg.to_ws_message() {
let _ = ws_write.send(ws_msg).await;
}
}
}
}