use acton_service::prelude::*;
use acton_service::websocket::{Broadcaster, ConnectionId, Message, WebSocket, WebSocketUpgrade};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum IncomingMessage {
Join { room: String },
Leave { room: String },
Message { room: String, content: String },
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum OutgoingMessage {
Joined { room: String },
Left { room: String },
Message {
room: String,
content: String,
from: String,
},
Error { message: String },
System { message: String },
}
async fn ws_handler(
ws: WebSocketUpgrade,
Extension(broadcaster): Extension<Arc<Broadcaster>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_socket(socket, broadcaster))
}
async fn handle_socket(socket: WebSocket, broadcaster: Arc<Broadcaster>) {
let (mut sender, mut receiver) = socket.split();
let connection_id = ConnectionId::new();
let (tx, mut rx) = tokio::sync::mpsc::channel::<Message>(32);
broadcaster.register(connection_id, tx.clone()).await;
tracing::info!(connection_id = %connection_id, "New WebSocket connection");
let welcome = OutgoingMessage::System {
message: format!("Welcome! Your connection ID is {}", connection_id),
};
let _ = sender
.send(Message::Text(
serde_json::to_string(&welcome).unwrap().into(),
))
.await;
let send_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if sender.send(msg).await.is_err() {
break;
}
}
});
let broadcaster_clone = broadcaster.clone();
let conn_id_str = connection_id.to_string();
while let Some(result) = receiver.next().await {
match result {
Ok(Message::Text(text)) => match serde_json::from_str::<IncomingMessage>(&text) {
Ok(msg) => {
handle_incoming_message(
msg,
connection_id,
&conn_id_str,
&tx,
&broadcaster_clone,
)
.await;
}
Err(e) => {
let error = OutgoingMessage::Error {
message: format!("Invalid message format: {}", e),
};
let _ = tx
.send(Message::Text(serde_json::to_string(&error).unwrap().into()))
.await;
}
},
Ok(Message::Ping(data)) => {
let _ = tx.send(Message::Pong(data)).await;
}
Ok(Message::Close(_)) => {
break;
}
Err(e) => {
tracing::warn!(connection_id = %connection_id, error = %e, "WebSocket error");
break;
}
_ => {}
}
}
broadcaster.unregister(&connection_id).await;
send_task.abort();
tracing::info!(connection_id = %connection_id, "WebSocket connection closed");
}
async fn handle_incoming_message(
msg: IncomingMessage,
connection_id: ConnectionId,
conn_id_str: &str,
tx: &tokio::sync::mpsc::Sender<Message>,
broadcaster: &Broadcaster,
) {
match msg {
IncomingMessage::Join { room } => {
tracing::info!(connection_id = %connection_id, room = %room, "User joining room");
let response = OutgoingMessage::Joined { room: room.clone() };
let _ = tx
.send(Message::Text(
serde_json::to_string(&response).unwrap().into(),
))
.await;
let notification = OutgoingMessage::System {
message: format!("User {} joined the room", conn_id_str),
};
let _ = broadcaster
.broadcast_except(
&[connection_id],
Message::Text(serde_json::to_string(¬ification).unwrap().into()),
)
.await;
}
IncomingMessage::Leave { room } => {
tracing::info!(connection_id = %connection_id, room = %room, "User leaving room");
let response = OutgoingMessage::Left { room };
let _ = tx
.send(Message::Text(
serde_json::to_string(&response).unwrap().into(),
))
.await;
}
IncomingMessage::Message { room, content } => {
tracing::debug!(
connection_id = %connection_id,
room = %room,
"Broadcasting message"
);
let broadcast_msg = OutgoingMessage::Message {
room,
content,
from: conn_id_str.to_string(),
};
let _ = broadcaster
.broadcast_except(
&[connection_id],
Message::Text(serde_json::to_string(&broadcast_msg).unwrap().into()),
)
.await;
let _ = tx
.send(Message::Text(
serde_json::to_string(&broadcast_msg).unwrap().into(),
))
.await;
}
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.init();
let broadcaster = Arc::new(Broadcaster::new());
let routes = VersionedApiBuilder::new()
.with_base_path("/api")
.add_version(ApiVersion::V1, |router| {
router
.route("/ws", get(ws_handler))
.layer(Extension(broadcaster.clone()))
})
.build_routes();
let mut config = Config::<()>::default();
config.service.name = "chat-server".to_string();
config.service.port = 8080;
tracing::info!("Starting chat server on http://localhost:8080");
tracing::info!("Connect via WebSocket at ws://localhost:8080/api/v1/ws");
ServiceBuilder::new()
.with_config(config)
.with_routes(routes)
.build()
.serve()
.await?;
Ok(())
}