use super::{
authenticate_endpoint, parse_account_id, Caller, ConnectionQuery,
};
use crate::{Result, ServerBackend, ServerState};
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Extension, OriginalUri, Query,
},
http::{HeaderMap, StatusCode},
response::Response,
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
typed_header::TypedHeader,
};
use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use sos_core::AccountId;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{broadcast, watch};
#[derive(Clone)]
pub struct WebSocketConnection {
send_tx: broadcast::Sender<Vec<u8>>,
close_tx: watch::Sender<Message>,
}
#[derive(Default)]
pub struct WebSocketAccount {
connections: HashMap<String, WebSocketConnection>,
}
impl WebSocketAccount {
pub async fn broadcast(
&self,
caller: &Caller,
message: Vec<u8>,
) -> Result<()> {
for (conn_id, conn) in &self.connections {
if conn_id.is_empty()
|| caller.connection_id().is_none()
|| conn_id == caller.connection_id().as_ref().unwrap()
{
continue;
}
if conn.send_tx.receiver_count() > 0 {
let _ = conn.send_tx.send(message.clone());
}
}
Ok(())
}
pub fn len(&self) -> usize {
self.connections.len()
}
}
pub async fn upgrade(
Extension(state): Extension<ServerState>,
Extension(backend): Extension<ServerBackend>,
TypedHeader(bearer): TypedHeader<Authorization<Bearer>>,
Query(query): Query<ConnectionQuery>,
OriginalUri(uri): OriginalUri,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> std::result::Result<Response, StatusCode> {
tracing::debug!("ws_server::upgrade_request");
let uri = uri.path().to_string();
let account_id = parse_account_id(&headers);
let caller = authenticate_endpoint(
account_id,
bearer,
uri.as_bytes(),
Some(query.clone()),
Arc::clone(&state),
Arc::clone(&backend),
)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;
let account_id = *caller.account_id();
let connection_id = query.connection_id;
let (close_tx, _) = watch::channel(Message::Close(None));
let (send_tx, _) = broadcast::channel(64);
let conn = WebSocketConnection { send_tx, close_tx };
{
let mut writer = state.write().await;
let account = writer
.sockets
.entry(account_id)
.or_insert(Default::default());
if account.connections.get(&connection_id).is_some() {
return Err(StatusCode::CONFLICT);
}
account
.connections
.insert(connection_id.clone(), conn.clone());
}
Ok(ws.on_upgrade(move |socket| {
handle_socket(socket, state, account_id, connection_id, conn)
}))
}
async fn disconnect(
state: ServerState,
account_id: &AccountId,
connection_id: &str,
) {
let mut writer = state.write().await;
tracing::debug!(account_id = %account_id, "ws_server::disconnect");
if let Some(account) = writer.sockets.get_mut(account_id) {
tracing::debug!(
account_id = %account_id,
"ws_server::disconnect::remove_socket",
);
if let Some(conn) = account.connections.remove(connection_id) {
tracing::info!(
account_id = %account_id,
connection_id = %connection_id,
"ws_server::disconnect",
);
if let Err(error) = conn.close_tx.send(Message::Close(None)) {
tracing::warn!(error = ?error);
}
}
};
}
async fn handle_socket(
socket: WebSocket,
state: ServerState,
account_id: AccountId,
connection_id: String,
conn: WebSocketConnection,
) {
tracing::info!(
account_id = %account_id,
connection_id = %connection_id,
"ws_server::connect",
);
let (writer, reader) = socket.split();
tokio::spawn(write(
Arc::clone(&state),
account_id,
connection_id.clone(),
writer,
conn.clone(),
));
tokio::spawn(read(
Arc::clone(&state),
account_id,
connection_id,
reader,
conn,
));
}
async fn read(
state: ServerState,
account_id: AccountId,
connection_id: String,
mut stream: SplitStream<WebSocket>,
conn: WebSocketConnection,
) -> Result<()> {
while let Some(msg) = stream.next().await {
match msg {
Ok(msg) => match msg {
Message::Text(_) => {}
Message::Binary(_) => {}
Message::Ping(_) => {}
Message::Pong(_) => {}
Message::Close(frame) => {
if let Err(error) =
conn.close_tx.send(Message::Close(frame))
{
tracing::warn!(error = ?error);
}
tracing::trace!(
account_id = %account_id,
"ws_server::disconnect::close_message",
);
disconnect(state, &account_id, &connection_id).await;
return Ok(());
}
},
Err(e) => {
tracing::trace!(
account_id = %account_id,
"ws_server::disconnect::read_error",
);
disconnect(state, &account_id, &connection_id).await;
return Err(e.into());
}
}
}
Ok(())
}
async fn write(
state: ServerState,
account_id: AccountId,
connection_id: String,
mut sink: SplitSink<WebSocket, Message>,
conn: WebSocketConnection,
) -> Result<()> {
let mut close_rx = conn.close_tx.subscribe();
let mut outgoing = conn.send_tx.subscribe();
loop {
tokio::select! {
_ = close_rx.changed() => {
let msg = close_rx.borrow().clone();
let _ = sink.send(msg).await;
break;
}
event = outgoing.recv() => {
match event {
Ok(msg) => {
if sink.send(Message::Binary(msg.into())).await.is_err() {
tracing::trace!(
account_id = %account_id,
"ws_server::disconnect::send_error",
);
disconnect(
state,
&account_id,
&connection_id,
).await;
break;
}
}
_ => {}
}
},
}
}
Ok(())
}