use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Extension,
},
http::StatusCode,
response::Response,
};
use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use std::{
fmt,
sync::{atomic::Ordering, Arc},
};
use tokio::sync::Mutex;
use crate::{server::State, Result};
use polysig_protocol::{zlib, MeetingRequest, MeetingResponse};
pub type Connection = Arc<Mutex<WebSocketConnection>>;
pub struct WebSocketConnection {
pub(crate) id: u64,
pub writer: SplitSink<WebSocket, Message>,
}
impl fmt::Debug for WebSocketConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketConnection")
.field("id", &self.id)
.finish()
}
}
impl WebSocketConnection {
pub async fn send(&mut self, buffer: &[u8]) -> Result<()> {
let deflated = zlib::deflate(buffer)?;
self.writer.send(Message::Binary(deflated)).await?;
Ok(())
}
}
pub async fn upgrade(
Extension(state): Extension<State>,
ws: WebSocketUpgrade,
) -> std::result::Result<Response, StatusCode> {
tracing::debug!("websocket upgrade request");
let socket_state = Arc::clone(&state);
Ok(ws.on_upgrade(move |socket| {
let (ws_writer, ws_reader) = socket.split();
async move {
let socket_conn = {
let mut writer = state.write().await;
let id = writer.id.fetch_add(1, Ordering::SeqCst);
let conn =
Arc::new(Mutex::new(WebSocketConnection {
id,
writer: ws_writer,
}));
let socket_conn = Arc::clone(&conn);
writer.connections.insert(id, conn);
socket_conn
};
tokio::task::spawn(read(
ws_reader,
socket_state,
socket_conn,
));
}
}))
}
async fn disconnect(state: State, conn: Connection) {
let id = {
let reader = conn.lock().await;
reader.id
};
tracing::debug!("disconnect");
let mut writer = state.write().await;
writer.connections.remove(&id);
}
async fn read(
mut receiver: SplitStream<WebSocket>,
state: State,
conn: Connection,
) -> Result<()> {
while let Some(msg) = receiver.next().await {
match msg {
Ok(msg) => match msg {
Message::Text(_) => {}
Message::Binary(buffer) => {
match zlib::inflate(&buffer) {
Ok(inflated) => {
let message: MeetingRequest =
serde_json::from_slice(&inflated)?;
if let Err(e) = handle_message(
state.clone(),
conn.clone(),
message,
)
.await
{
tracing::error!(
error = %e,
"meeting_server::handle_message_error");
}
}
Err(e) => {
tracing::error!(
error = %e,
"meeting_server::zlib_inflate");
}
}
}
Message::Ping(_) => {}
Message::Pong(_) => {}
Message::Close(_frame) => {
disconnect(state, Arc::clone(&conn)).await;
return Ok(());
}
},
Err(e) => {
tracing::warn!(error = %e,"meeting_server::read_error");
disconnect(state, Arc::clone(&conn)).await;
return Err(e.into());
}
}
}
Ok(())
}
async fn handle_message(
state: State,
conn: Connection,
message: MeetingRequest,
) -> Result<()> {
match message {
MeetingRequest::NewRoom { owner_id, slots } => {
let mut state = state.write().await;
let meeting_id = state.meetings.new_room(owner_id, slots);
let mut socket = conn.lock().await;
let response = MeetingResponse::RoomCreated {
meeting_id,
owner_id,
};
let buffer = serde_json::to_vec(&response)?;
socket.send(&buffer).await?;
}
MeetingRequest::JoinRoom {
meeting_id,
user_id,
data,
} => {
let conn_id = {
let conn = conn.lock().await;
conn.id
};
let is_full = {
let mut state = state.write().await;
if let Some(meeting) =
state.meetings.room_mut(&meeting_id)
{
meeting.join(user_id, conn_id, data);
meeting.is_full()
} else {
tracing::warn!(id = %meeting_id, "no meeting");
false
}
};
let result = if is_full {
let mut state = state.write().await;
if let Some(meeting) =
state.meetings.remove_room(&meeting_id)
{
let mut participants =
Vec::with_capacity(meeting.slots.len());
let mut sockets =
Vec::with_capacity(meeting.slots.len());
for (user_id, value) in meeting.slots {
let (conn_id, data) = value.unwrap();
participants.push((user_id, data));
sockets.push(conn_id);
}
Some((sockets, participants))
} else {
None
}
} else {
None
};
if let Some((sockets, participants)) = result {
let message =
MeetingResponse::RoomReady { participants };
let buffer = serde_json::to_vec(&message)?;
let state = state.read().await;
for conn_id in sockets {
if let Some(conn) =
state.connections.get(&conn_id)
{
let mut conn = conn.lock().await;
conn.send(&buffer).await?;
}
}
}
}
}
Ok(())
}