use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use futures_util::{SinkExt, StreamExt, TryFutureExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use thiserror::Error;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_stream::wrappers::UnboundedReceiverStream;
use uuid::Uuid;
use warp::http::header::{HeaderMap, HeaderValue};
use warp::ws::{Message, WebSocket};
use warp::Filter;
use crate::services::*;
use json_rpc2::{Request, Response};
use tracing_subscriber::fmt::format::FmtSpan;
static CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
#[derive(Debug, Error)]
pub enum ServerError {
#[error("{0} is not a directory")]
NotDirectory(PathBuf),
#[error("party number may not be zero")]
ZeroPartyNumber,
#[error("party number is out of range")]
PartyNumberOutOfRange,
#[error("party number already exists for session {0}")]
PartyNumberAlreadyExists(Uuid),
#[error(transparent)]
NetAddrParse(#[from] std::net::AddrParseError),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
JsonRpcError(#[from] json_rpc2::Error),
}
pub type Result<T> = std::result::Result<T, ServerError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Parameters {
pub parties: u16,
pub threshold: u16,
}
impl Default for Parameters {
fn default() -> Self {
return Self {
parties: 3,
threshold: 1,
};
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum SessionKind {
#[serde(rename = "keygen")]
Keygen,
#[serde(rename = "sign")]
Sign,
}
impl Default for SessionKind {
fn default() -> Self {
SessionKind::Keygen
}
}
#[derive(Debug, Default, Clone, Serialize)]
pub struct Group {
pub uuid: Uuid,
pub params: Parameters,
pub label: String,
#[serde(skip)]
pub(crate) clients: Vec<usize>,
#[serde(skip)]
pub(crate) sessions: HashMap<Uuid, Session>,
}
impl Group {
pub fn new(conn: usize, params: Parameters, label: String) -> Self {
Self {
uuid: Uuid::new_v4(),
clients: vec![conn],
sessions: Default::default(),
params,
label,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct Session {
pub uuid: Uuid,
pub kind: SessionKind,
pub value: Option<Value>,
#[serde(skip)]
pub(crate) party_signups: Vec<(u16, usize)>,
#[serde(skip)]
pub(crate) finished: HashSet<u16>,
}
impl Default for Session {
fn default() -> Self {
Self {
uuid: Uuid::new_v4(),
kind: Default::default(),
party_signups: Default::default(),
finished: Default::default(),
value: None,
}
}
}
impl From<(SessionKind, Option<Value>)> for Session {
fn from(value: (SessionKind, Option<Value>)) -> Session {
Self {
uuid: Uuid::new_v4(),
kind: value.0,
party_signups: Default::default(),
finished: Default::default(),
value: value.1,
}
}
}
impl Session {
pub fn signup(&mut self, conn: usize) -> u16 {
let last = self.party_signups.last();
let num = if last.is_none() {
1
} else {
let (num, _) = last.unwrap();
num + 1
};
self.party_signups.push((num, conn));
num
}
pub fn load(
&mut self,
parameters: &Parameters,
conn: usize,
party_number: u16,
) -> Result<()> {
if party_number == 0 {
return Err(ServerError::ZeroPartyNumber);
}
if party_number > parameters.parties {
return Err(ServerError::PartyNumberOutOfRange);
}
if let Some(_) = self
.party_signups
.iter()
.find(|(num, _)| num == &party_number)
{
return Err(ServerError::PartyNumberAlreadyExists(
self.uuid.clone(),
));
}
self.party_signups.push((party_number, conn));
Ok(())
}
}
#[derive(Debug)]
pub struct State {
pub clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
pub groups: HashMap<Uuid, Group>,
}
#[derive(Debug)]
pub enum Notification {
Noop,
Group {
group_id: Uuid,
filter: Option<Vec<usize>>,
response: Response,
},
Session {
group_id: Uuid,
session_id: Uuid,
filter: Option<Vec<usize>>,
response: Response,
},
Relay {
messages: Vec<(usize, Response)>,
},
}
impl Default for Notification {
fn default() -> Self {
Self::Noop
}
}
pub struct Server;
impl Server {
pub async fn start(
path: &'static str,
addr: impl Into<SocketAddr>,
static_files: PathBuf,
) -> Result<()> {
let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| {
"tracing=info,warp=debug,mpc_websocket=info".to_owned()
});
if cfg!(debug_assertions) {
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_span_events(FmtSpan::CLOSE)
.init();
} else {
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_span_events(FmtSpan::CLOSE)
.json()
.init();
}
let state = Arc::new(RwLock::new(State {
clients: HashMap::new(),
groups: Default::default(),
}));
let state = warp::any().map(move || state.clone());
if !static_files.is_dir() {
return Err(ServerError::NotDirectory(static_files));
}
let static_files = static_files.canonicalize()?;
let static_path = static_files.to_string_lossy().into_owned();
tracing::info!(%static_path);
tracing::info!(path);
let client = warp::any().and(warp::fs::dir(static_files));
let mut headers = HeaderMap::new();
headers.insert(
"Cross-Origin-Embedder-Policy",
HeaderValue::from_static("require-corp"),
);
headers.insert(
"Cross-Origin-Opener-Policy",
HeaderValue::from_static("same-origin"),
);
let websocket = warp::path(path).and(warp::ws()).and(state).map(
|ws: warp::ws::Ws, state| {
ws.on_upgrade(move |socket| client_connected(socket, state))
},
);
let routes = websocket
.or(client)
.with(warp::reply::with::headers(headers))
.with(warp::trace::request());
warp::serve(routes).run(addr).await;
Ok(())
}
}
async fn client_connected(ws: WebSocket, state: Arc<RwLock<State>>) {
let conn_id = CONNECTION_ID.fetch_add(1, Ordering::Relaxed);
tracing::info!(conn_id, "connected");
let (mut user_ws_tx, mut user_ws_rx) = ws.split();
let (tx, rx) = mpsc::unbounded_channel::<Message>();
let mut rx = UnboundedReceiverStream::new(rx);
let mut close_flag = Arc::new(RwLock::new(false));
let should_close = Arc::clone(&close_flag);
tokio::task::spawn(async move {
while let Some(message) = rx.next().await {
user_ws_tx
.send(message)
.unwrap_or_else(|e| {
tracing::error!(?e, "websocket send error");
})
.await;
let reader = should_close.read().await;
if *reader {
match user_ws_tx.close().await {
Err(e) => tracing::warn!(?e, "failed to close websocket"),
_ => {}
}
break;
}
}
});
state.write().await.clients.insert(conn_id, tx);
while let Some(result) = user_ws_rx.next().await {
let msg = match result {
Ok(msg) => msg,
Err(e) => {
tracing::error!(conn_id, ?e, "websocket rx error");
break;
}
};
client_incoming_message(conn_id, &mut close_flag, msg, &state).await;
}
client_disconnected(conn_id, &state).await;
}
async fn client_incoming_message(
conn_id: usize,
close_flag: &mut Arc<RwLock<bool>>,
msg: Message,
state: &Arc<RwLock<State>>,
) {
let msg = if let Ok(s) = msg.to_str() {
s
} else {
return;
};
match json_rpc2::from_str(msg) {
Ok(req) => rpc_request(conn_id, close_flag, req, state).await,
Err(e) => tracing::warn!(conn_id, ?e, "websocket rx JSON error"),
}
}
async fn rpc_request(
conn_id: usize,
close_flag: &mut Arc<RwLock<bool>>,
request: Request,
state: &Arc<RwLock<State>>,
) {
use json_rpc2::futures::*;
let service: Box<
dyn Service<
Data = (
usize,
Arc<RwLock<State>>,
Arc<Mutex<Option<Notification>>>,
),
>,
> = Box::new(ServiceHandler {});
let server = Server::new(vec![&service]);
let notification: Arc<Mutex<Option<Notification>>> =
Arc::new(Mutex::new(None));
if let Some(response) = server
.serve(
&request,
&(conn_id, Arc::clone(state), Arc::clone(¬ification)),
)
.await
{
rpc_response(conn_id, &response, state).await;
if let Some(error) = response.error() {
if let Some(data) = &error.data {
if data == CLOSE_CONNECTION {
let mut writer = close_flag.write().await;
*writer = true;
}
}
}
}
let mut writer = notification.lock().await;
if let Some(notification) = writer.take() {
rpc_notify(state, notification).await;
}
}
fn filter_clients(
clients: Vec<usize>,
filter: Option<Vec<usize>>,
) -> Vec<usize> {
if let Some(filter) = filter {
clients
.into_iter()
.filter(|conn| filter.iter().find(|c| c == &conn).is_none())
.collect::<Vec<_>>()
} else {
clients
}
}
async fn rpc_notify(state: &Arc<RwLock<State>>, notification: Notification) {
let reader = state.read().await;
match notification {
Notification::Group {
group_id,
filter,
response,
} => {
let clients = if let Some(group) = reader.groups.get(&group_id) {
group.clients.clone()
} else {
vec![0usize]
};
let clients = filter_clients(clients, filter);
for conn_id in clients {
rpc_response(conn_id, &response, state).await;
}
}
Notification::Session {
group_id,
session_id,
filter,
response,
} => {
let clients = if let Some(group) = reader.groups.get(&group_id) {
if let Some(session) = group.sessions.get(&session_id) {
session.party_signups.iter().map(|i| i.1.clone()).collect()
} else {
tracing::warn!(
%session_id,
"notification session does not exist");
vec![0usize]
}
} else {
vec![0usize]
};
let clients = filter_clients(clients, filter);
for conn_id in clients {
rpc_response(conn_id, &response, state).await;
}
}
Notification::Relay { messages } => {
for (conn_id, response) in messages {
rpc_response(conn_id, &response, state).await;
}
}
Notification::Noop => {}
}
}
async fn rpc_response(
conn_id: usize,
response: &json_rpc2::Response,
state: &Arc<RwLock<State>>,
) {
tracing::debug!(conn_id, "send message");
if let Some(tx) = state.read().await.clients.get(&conn_id) {
tracing::debug!(?response, "send response");
let msg = serde_json::to_string(response).unwrap();
if let Err(_disconnected) = tx.send(Message::text(msg)) {
}
} else {
tracing::warn!(conn_id, "could not find tx for websocket");
}
}
async fn client_disconnected(conn_id: usize, state: &Arc<RwLock<State>>) {
tracing::info!(conn_id, "disconnected");
let mut empty_groups: Vec<Uuid> = Vec::new();
{
let mut writer = state.write().await;
writer.clients.remove(&conn_id);
for (key, group) in writer.groups.iter_mut() {
if let Some(index) =
group.clients.iter().position(|x| *x == conn_id)
{
group.clients.remove(index);
}
if group.clients.is_empty() {
empty_groups.push(key.clone());
}
}
}
let mut writer = state.write().await;
for key in empty_groups {
writer.groups.remove(&key);
tracing::info!(%key, "removed group");
}
}