1use std::collections::{HashMap, HashSet};
2use std::net::SocketAddr;
3use std::path::PathBuf;
4use std::sync::{
5 atomic::{AtomicUsize, Ordering},
6 Arc,
7};
8
9use futures_util::{SinkExt, StreamExt, TryFutureExt};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::{mpsc, Mutex, RwLock};
14use tokio_stream::wrappers::UnboundedReceiverStream;
15use uuid::Uuid;
16use warp::http::header::{HeaderMap, HeaderValue};
17use warp::ws::{Message, WebSocket};
18use warp::Filter;
19
20use crate::services::*;
21use json_rpc2::{Request, Response};
22
23use tracing_subscriber::fmt::format::FmtSpan;
24
25static CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
27
28type RpcService = Box<
29 dyn json_rpc2::futures::Service<
30 Data = (usize, Arc<RwLock<State>>, Arc<Mutex<Option<Notification>>>),
31 >,
32>;
33
34#[derive(Debug, Error)]
36pub enum ServerError {
37 #[error("{0} is not a directory")]
39 NotDirectory(PathBuf),
40
41 #[error("party number may not be zero")]
43 ZeroPartyNumber,
44
45 #[error("party number is out of range")]
47 PartyNumberOutOfRange,
48
49 #[error("party number already exists for session {0}")]
51 PartyNumberAlreadyExists(Uuid),
52
53 #[error(transparent)]
55 NetAddrParse(#[from] std::net::AddrParseError),
56
57 #[error(transparent)]
59 Io(#[from] std::io::Error),
60
61 #[error(transparent)]
63 JsonRpcError(#[from] json_rpc2::Error),
64}
65
66pub type Result<T> = std::result::Result<T, ServerError>;
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct Parameters {
72 pub parties: u16,
74 pub threshold: u16,
79}
80
81impl Default for Parameters {
82 fn default() -> Self {
83 Self {
84 parties: 3,
85 threshold: 1,
86 }
87 }
88}
89
90#[derive(Debug, Serialize, Deserialize, Clone)]
92pub enum SessionKind {
93 #[serde(rename = "keygen")]
95 Keygen,
96 #[serde(rename = "sign")]
98 Sign,
99}
100
101impl Default for SessionKind {
102 fn default() -> Self {
103 SessionKind::Keygen
104 }
105}
106
107#[derive(Debug, Default, Clone, Serialize)]
109pub struct Group {
110 pub uuid: Uuid,
112 pub params: Parameters,
114 pub label: String,
116 #[serde(skip)]
118 pub(crate) clients: Vec<usize>,
119 #[serde(skip)]
121 pub(crate) sessions: HashMap<Uuid, Session>,
122}
123
124impl Group {
125 pub fn new(conn: usize, params: Parameters, label: String) -> Self {
129 Self {
130 uuid: Uuid::new_v4(),
131 clients: vec![conn],
132 sessions: Default::default(),
133 params,
134 label,
135 }
136 }
137}
138
139#[derive(Debug, Clone, Serialize)]
141pub struct Session {
142 pub uuid: Uuid,
144 pub kind: SessionKind,
146 pub value: Option<Value>,
156
157 #[serde(skip)]
159 pub(crate) party_signups: Vec<(u16, usize)>,
160
161 #[serde(skip)]
164 pub(crate) finished: HashSet<u16>,
165
166 #[serde(skip)]
174 pub(crate) participants: HashMap<u16, u16>,
175}
176
177impl Default for Session {
178 fn default() -> Self {
179 Self {
180 uuid: Uuid::new_v4(),
181 kind: Default::default(),
182 party_signups: Default::default(),
183 finished: Default::default(),
184 value: None,
185 participants: Default::default(),
186 }
187 }
188}
189
190impl From<(SessionKind, Option<Value>)> for Session {
191 fn from(value: (SessionKind, Option<Value>)) -> Session {
192 Self {
193 uuid: Uuid::new_v4(),
194 kind: value.0,
195 party_signups: Default::default(),
196 finished: Default::default(),
197 value: value.1,
198 participants: Default::default(),
199 }
200 }
201}
202
203impl Session {
204 pub fn signup(&mut self, conn: usize) -> u16 {
209 let last = self.party_signups.last();
210 let num = if let Some((num, _)) = last {
211 num + 1
212 } else {
213 1
214 };
215 self.party_signups.push((num, conn));
224 num
225 }
226
227 pub fn load(
232 &mut self,
233 parameters: &Parameters,
234 conn: usize,
235 party_number: u16,
236 ) -> Result<()> {
237 if party_number == 0 {
238 return Err(ServerError::ZeroPartyNumber);
239 }
240 if party_number > parameters.parties {
241 return Err(ServerError::PartyNumberOutOfRange);
242 }
243 if self
244 .party_signups
245 .iter()
246 .any(|(num, _)| num == &party_number)
247 {
248 return Err(ServerError::PartyNumberAlreadyExists(self.uuid));
249 }
250 self.party_signups.push((party_number, conn));
251 Ok(())
252 }
253
254 pub fn resolve(&self, receiver: u16) -> Option<&(u16, usize)> {
257 if let SessionKind::Sign = self.kind {
258 if let Some(party_signup) = self.participants.get(&receiver) {
259 return self
260 .party_signups
261 .iter()
262 .find(|s| s.0 == *party_signup);
263 } else {
264 None
265 }
266 } else {
267 self.party_signups.iter().find(|s| s.0 == receiver)
268 }
269 }
270}
271
272#[derive(Debug)]
274pub struct State {
275 pub clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
277 pub groups: HashMap<Uuid, Group>,
279}
280
281#[derive(Debug)]
283pub enum Notification {
284 Noop,
292
293 Group {
295 group_id: Uuid,
297 filter: Option<Vec<usize>>,
299 response: Response,
301 },
302
303 Session {
305 group_id: Uuid,
307 session_id: Uuid,
309 filter: Option<Vec<usize>>,
311 response: Response,
313 },
314
315 Relay {
319 messages: Vec<(usize, Response)>,
321 },
322}
323
324impl Default for Notification {
325 fn default() -> Self {
326 Self::Noop
327 }
328}
329
330pub struct Server;
332
333impl Server {
334 pub async fn start(
343 path: &'static str,
344 addr: impl Into<SocketAddr>,
345 static_files: PathBuf,
346 ) -> Result<()> {
347 let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| {
349 "tracing=info,warp=debug,mpc_websocket=info".to_owned()
350 });
351
352 if cfg!(debug_assertions) {
353 tracing_subscriber::fmt()
354 .with_env_filter(filter)
355 .with_span_events(FmtSpan::CLOSE)
356 .init();
357 } else {
358 tracing_subscriber::fmt()
359 .with_env_filter(filter)
360 .with_span_events(FmtSpan::CLOSE)
361 .json()
362 .init();
363 }
364
365 let state = Arc::new(RwLock::new(State {
366 clients: HashMap::new(),
367 groups: Default::default(),
368 }));
369 let state = warp::any().map(move || state.clone());
370
371 if !static_files.is_dir() {
372 return Err(ServerError::NotDirectory(static_files));
373 }
374
375 let static_files = static_files.canonicalize()?;
376 let static_path = static_files.to_string_lossy().into_owned();
377 tracing::info!(%static_path);
378 tracing::info!(path);
379
380 let client = warp::any().and(warp::fs::dir(static_files));
381
382 let mut headers = HeaderMap::new();
383 headers.insert(
384 "Cross-Origin-Embedder-Policy",
385 HeaderValue::from_static("require-corp"),
386 );
387 headers.insert(
388 "Cross-Origin-Opener-Policy",
389 HeaderValue::from_static("same-origin"),
390 );
391
392 let websocket = warp::path(path).and(warp::ws()).and(state).map(
393 |ws: warp::ws::Ws, state| {
394 ws.on_upgrade(move |socket| client_connected(socket, state))
395 },
396 );
397
398 let routes = websocket
399 .or(client)
400 .with(warp::reply::with::headers(headers))
401 .with(warp::trace::request());
402
403 warp::serve(routes).run(addr).await;
404 Ok(())
405 }
406}
407
408async fn client_connected(ws: WebSocket, state: Arc<RwLock<State>>) {
409 let conn_id = CONNECTION_ID.fetch_add(1, Ordering::Relaxed);
410
411 tracing::info!(conn_id, "connected");
412
413 let (mut user_ws_tx, mut user_ws_rx) = ws.split();
415
416 let (tx, rx) = mpsc::unbounded_channel::<Message>();
419 let mut rx = UnboundedReceiverStream::new(rx);
420
421 let mut close_flag = Arc::new(RwLock::new(false));
422 let should_close = Arc::clone(&close_flag);
423
424 tokio::task::spawn(async move {
425 while let Some(message) = rx.next().await {
426 user_ws_tx
427 .send(message)
428 .unwrap_or_else(|e| {
429 tracing::error!(?e, "websocket send error");
430 })
431 .await;
432
433 let reader = should_close.read().await;
434 if *reader {
435 if let Err(e) = user_ws_tx.close().await {
436 tracing::warn!(?e, "failed to close websocket");
437 }
438 break;
439 }
440 }
441 });
442
443 state.write().await.clients.insert(conn_id, tx);
445
446 while let Some(result) = user_ws_rx.next().await {
448 let msg = match result {
449 Ok(msg) => msg,
450 Err(e) => {
451 tracing::error!(conn_id, ?e, "websocket rx error");
452 break;
453 }
454 };
455
456 client_incoming_message(conn_id, &mut close_flag, msg, &state).await;
457 }
458
459 client_disconnected(conn_id, &state).await;
462}
463
464async fn client_incoming_message(
465 conn_id: usize,
466 close_flag: &mut Arc<RwLock<bool>>,
467 msg: Message,
468 state: &Arc<RwLock<State>>,
469) {
470 let msg = if let Ok(s) = msg.to_str() {
471 s
472 } else {
473 return;
474 };
475
476 match json_rpc2::from_str(msg) {
477 Ok(req) => rpc_request(conn_id, close_flag, req, state).await,
478 Err(e) => tracing::warn!(conn_id, ?e, "websocket rx JSON error"),
479 }
480}
481
482async fn rpc_request(
484 conn_id: usize,
485 close_flag: &mut Arc<RwLock<bool>>,
486 request: Request,
487 state: &Arc<RwLock<State>>,
488) {
489 use json_rpc2::futures::*;
490
491 let service: RpcService = Box::new(ServiceHandler {});
492 let server = Server::new(vec![&service]);
493
494 let notification: Arc<Mutex<Option<Notification>>> =
495 Arc::new(Mutex::new(None));
496
497 if let Some(response) = server
498 .serve(
499 &request,
500 &(conn_id, Arc::clone(state), Arc::clone(¬ification)),
501 )
502 .await
503 {
504 rpc_response(conn_id, &response, state).await;
505
506 if let Some(error) = response.error() {
507 if let Some(data) = &error.data {
508 if data == CLOSE_CONNECTION {
509 let mut writer = close_flag.write().await;
510 *writer = true;
511 }
512 }
513 }
514 }
515
516 let mut writer = notification.lock().await;
517 if let Some(notification) = writer.take() {
518 rpc_notify(state, notification).await;
519 }
520}
521
522fn filter_clients(
524 clients: Vec<usize>,
525 filter: Option<Vec<usize>>,
526) -> Vec<usize> {
527 if let Some(filter) = filter {
528 clients
529 .into_iter()
530 .filter(|conn| !filter.iter().any(|c| c == conn))
531 .collect::<Vec<_>>()
532 } else {
533 clients
534 }
535}
536
537async fn rpc_notify(state: &Arc<RwLock<State>>, notification: Notification) {
539 let reader = state.read().await;
540 match notification {
541 Notification::Group {
542 group_id,
543 filter,
544 response,
545 } => {
546 let clients = if let Some(group) = reader.groups.get(&group_id) {
547 group.clients.clone()
548 } else {
549 vec![0usize]
550 };
551
552 let clients = filter_clients(clients, filter);
553 for conn_id in clients {
554 rpc_response(conn_id, &response, state).await;
555 }
556 }
557 Notification::Session {
558 group_id,
559 session_id,
560 filter,
561 response,
562 } => {
563 let clients = if let Some(group) = reader.groups.get(&group_id) {
564 if let Some(session) = group.sessions.get(&session_id) {
565 session.party_signups.iter().map(|i| i.1).collect()
566 } else {
567 tracing::warn!(
568 %session_id,
569 "notification session does not exist");
570 vec![0usize]
571 }
572 } else {
573 vec![0usize]
574 };
575
576 let clients = filter_clients(clients, filter);
577 for conn_id in clients {
578 rpc_response(conn_id, &response, state).await;
579 }
580 }
581 Notification::Relay { messages } => {
582 for (conn_id, response) in messages {
583 rpc_response(conn_id, &response, state).await;
584 }
585 }
586 Notification::Noop => {}
587 }
588}
589
590async fn rpc_response(
592 conn_id: usize,
593 response: &json_rpc2::Response,
594 state: &Arc<RwLock<State>>,
595) {
596 tracing::debug!(conn_id, "send message");
597 if let Some(tx) = state.read().await.clients.get(&conn_id) {
598 tracing::debug!(?response, "send response");
599 let msg = serde_json::to_string(response).unwrap();
600 if let Err(_disconnected) = tx.send(Message::text(msg)) {
601 }
605 } else {
606 tracing::warn!(conn_id, "could not find tx for websocket");
607 }
608}
609
610async fn client_disconnected(conn_id: usize, state: &Arc<RwLock<State>>) {
611 tracing::info!(conn_id, "disconnected");
612
613 let mut empty_groups: Vec<Uuid> = Vec::new();
616 {
617 let mut writer = state.write().await;
618 writer.clients.remove(&conn_id);
620 for (key, group) in writer.groups.iter_mut() {
622 if let Some(index) =
623 group.clients.iter().position(|x| *x == conn_id)
624 {
625 group.clients.remove(index);
626 }
627
628 if group.clients.is_empty() {
630 empty_groups.push(*key);
631 }
632 }
633 }
634
635 let mut writer = state.write().await;
637 for key in empty_groups {
638 writer.groups.remove(&key);
639 tracing::info!(%key, "removed group");
640 }
641}