mpc_websocket/
server.rs

1use std::collections::{HashMap, HashSet};
2use std::net::SocketAddr;
3use std::path::PathBuf;
4use std::sync::{
5    atomic::{AtomicUsize, Ordering},
6    Arc,
7};
8
9use futures_util::{SinkExt, StreamExt, TryFutureExt};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::{mpsc, Mutex, RwLock};
14use tokio_stream::wrappers::UnboundedReceiverStream;
15use uuid::Uuid;
16use warp::http::header::{HeaderMap, HeaderValue};
17use warp::ws::{Message, WebSocket};
18use warp::Filter;
19
20use crate::services::*;
21use json_rpc2::{Request, Response};
22
23use tracing_subscriber::fmt::format::FmtSpan;
24
25/// Global unique connection id counter.
26static CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
27
28type RpcService = Box<
29    dyn json_rpc2::futures::Service<
30        Data = (usize, Arc<RwLock<State>>, Arc<Mutex<Option<Notification>>>),
31    >,
32>;
33
34/// Error thrown by the server.
35#[derive(Debug, Error)]
36pub enum ServerError {
37    /// Error generated when a directory is expected.
38    #[error("{0} is not a directory")]
39    NotDirectory(PathBuf),
40
41    /// Error generated if party number is zero.
42    #[error("party number may not be zero")]
43    ZeroPartyNumber,
44
45    /// Error generated if a party number is out of range.
46    #[error("party number is out of range")]
47    PartyNumberOutOfRange,
48
49    /// Error generated if a party number already exists for a session.
50    #[error("party number already exists for session {0}")]
51    PartyNumberAlreadyExists(Uuid),
52
53    /// Error generated parsing a socket address.
54    #[error(transparent)]
55    NetAddrParse(#[from] std::net::AddrParseError),
56
57    /// Error generated by the `std::io` module.
58    #[error(transparent)]
59    Io(#[from] std::io::Error),
60
61    /// Error generated by the JSON-RPC services.
62    #[error(transparent)]
63    JsonRpcError(#[from] json_rpc2::Error),
64}
65
66/// Result type for server errors.
67pub type Result<T> = std::result::Result<T, ServerError>;
68
69/// Parameters used during key generation and signing.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct Parameters {
72    /// Number of parties `n`.
73    pub parties: u16,
74    /// Threshold for signing `t`.
75    ///
76    /// The threshold must be crossed (`t + 1`) for signing
77    /// to commence.
78    pub threshold: u16,
79}
80
81impl Default for Parameters {
82    fn default() -> Self {
83        Self {
84            parties: 3,
85            threshold: 1,
86        }
87    }
88}
89
90/// Represents the type of session.
91#[derive(Debug, Serialize, Deserialize, Clone)]
92pub enum SessionKind {
93    /// Key generation session.
94    #[serde(rename = "keygen")]
95    Keygen,
96    /// Signing session.
97    #[serde(rename = "sign")]
98    Sign,
99}
100
101impl Default for SessionKind {
102    fn default() -> Self {
103        SessionKind::Keygen
104    }
105}
106
107/// Group is a collection of connected websocket clients.
108#[derive(Debug, Default, Clone, Serialize)]
109pub struct Group {
110    /// Unique identifier for the group.
111    pub uuid: Uuid,
112    /// Parameters for key generation.
113    pub params: Parameters,
114    /// Human-readable label for the group.
115    pub label: String,
116    /// Collection of client identifiers.
117    #[serde(skip)]
118    pub(crate) clients: Vec<usize>,
119    /// Sessions belonging to this group.
120    #[serde(skip)]
121    pub(crate) sessions: HashMap<Uuid, Session>,
122}
123
124impl Group {
125    /// Create a new group.
126    ///
127    /// The connection identifier `conn` becomes the initial client for the group.
128    pub fn new(conn: usize, params: Parameters, label: String) -> Self {
129        Self {
130            uuid: Uuid::new_v4(),
131            clients: vec![conn],
132            sessions: Default::default(),
133            params,
134            label,
135        }
136    }
137}
138
139/// Session used for key generation or signing communication.
140#[derive(Debug, Clone, Serialize)]
141pub struct Session {
142    /// Unique identifier for the session.
143    pub uuid: Uuid,
144    /// Kind of the session.
145    pub kind: SessionKind,
146    /// Public value associated with the session.
147    ///
148    /// The owner of a session will assign this when
149    /// the session is created and other participants
150    /// in the session can read this value.
151    ///
152    /// This can be used to assign public data like the
153    /// message or transaction that will be signed during
154    /// a signing session.
155    pub value: Option<Value>,
156
157    /// Map party number to connection identifier
158    #[serde(skip)]
159    pub(crate) party_signups: Vec<(u16, usize)>,
160
161    /// Party numbers for those that have
162    /// marked the session as finished.
163    #[serde(skip)]
164    pub(crate) finished: HashSet<u16>,
165
166    /// Map receiver indices to server issued party numbers
167    /// which can then be used to resolve a connection identifier.
168    ///
169    /// During keygen we don't have a pre-defined index so we use
170    /// the server issued party number; whereas during signing it
171    /// is imperative that we use the index into the array of the
172    /// indices allocated during keygen.
173    #[serde(skip)]
174    pub(crate) participants: HashMap<u16, u16>,
175}
176
177impl Default for Session {
178    fn default() -> Self {
179        Self {
180            uuid: Uuid::new_v4(),
181            kind: Default::default(),
182            party_signups: Default::default(),
183            finished: Default::default(),
184            value: None,
185            participants: Default::default(),
186        }
187    }
188}
189
190impl From<(SessionKind, Option<Value>)> for Session {
191    fn from(value: (SessionKind, Option<Value>)) -> Session {
192        Self {
193            uuid: Uuid::new_v4(),
194            kind: value.0,
195            party_signups: Default::default(),
196            finished: Default::default(),
197            value: value.1,
198            participants: Default::default(),
199        }
200    }
201}
202
203impl Session {
204    /// Signup to a session.
205    ///
206    /// This marks a connected client as actively participating in
207    /// this session and issues them a unique party signup number.
208    pub fn signup(&mut self, conn: usize) -> u16 {
209        let last = self.party_signups.last();
210        let num = if let Some((num, _)) = last {
211            num + 1
212        } else {
213            1
214        };
215        /*
216        let num = if last.is_none() {
217            1
218        } else {
219            let (num, _) = last.unwrap();
220            num + 1
221        };
222        */
223        self.party_signups.push((num, conn));
224        num
225    }
226
227    /// Load an existing party signup number into this session.
228    ///
229    /// This is used when loading key shares that have been persisted
230    /// to perform signing using the saved key shares.
231    pub fn load(
232        &mut self,
233        parameters: &Parameters,
234        conn: usize,
235        party_number: u16,
236    ) -> Result<()> {
237        if party_number == 0 {
238            return Err(ServerError::ZeroPartyNumber);
239        }
240        if party_number > parameters.parties {
241            return Err(ServerError::PartyNumberOutOfRange);
242        }
243        if self
244            .party_signups
245            .iter()
246            .any(|(num, _)| num == &party_number)
247        {
248            return Err(ServerError::PartyNumberAlreadyExists(self.uuid));
249        }
250        self.party_signups.push((party_number, conn));
251        Ok(())
252    }
253
254    /// Resolve a receiver identifier for a peer to peer message
255    /// that is being related to a party signup and connection id.
256    pub fn resolve(&self, receiver: u16) -> Option<&(u16, usize)> {
257        if let SessionKind::Sign = self.kind {
258            if let Some(party_signup) = self.participants.get(&receiver) {
259                return self
260                    .party_signups
261                    .iter()
262                    .find(|s| s.0 == *party_signup);
263            } else {
264                None
265            }
266        } else {
267            self.party_signups.iter().find(|s| s.0 == receiver)
268        }
269    }
270}
271
272/// Collection of clients and groups managed by the server.
273#[derive(Debug)]
274pub struct State {
275    /// Connected clients.
276    pub clients: HashMap<usize, mpsc::UnboundedSender<Message>>,
277    /// Groups keyed by unique identifier (UUID)
278    pub groups: HashMap<Uuid, Group>,
279}
280
281/// Notification sent by the server to multiple connected clients.
282#[derive(Debug)]
283pub enum Notification {
284    /// Indicates that the response should be ignored
285    /// and no notification messages should be sent.
286    ///
287    /// This is used when testing a threshold for sending
288    /// notifications; before a threshold has been reached
289    /// we want to return a response but not actually send
290    /// any notifications.
291    Noop,
292
293    /// Sends the response to all clients in the group.
294    Group {
295        /// The group identifier.
296        group_id: Uuid,
297        /// Ignore these clients.
298        filter: Option<Vec<usize>>,
299        /// Message to send to the clients.
300        response: Response,
301    },
302
303    /// Sends the response to all clients in the session.
304    Session {
305        /// The group identifier.
306        group_id: Uuid,
307        /// The session identifier.
308        session_id: Uuid,
309        /// Ignore these clients.
310        filter: Option<Vec<usize>>,
311        /// Message to send to the clients.
312        response: Response,
313    },
314
315    /// Relay messages to specific clients.
316    ///
317    /// Used for relaying peer to peer messages.
318    Relay {
319        /// Mapping of client connection identifiers to messages.
320        messages: Vec<(usize, Response)>,
321    },
322}
323
324impl Default for Notification {
325    fn default() -> Self {
326        Self::Noop
327    }
328}
329
330/// MPC websocket server handling JSON-RPC requests.
331pub struct Server;
332
333impl Server {
334    /// Start the server.
335    ///
336    /// The websocket endpoint is mounted at `path`,
337    /// the server will bind to `addr` and static assets
338    /// are served from `static_files`.
339    ///
340    /// Logs are emitted using the [tracing](https://docs.rs/tracing)
341    /// library, in release mode the logs are formatted as JSON.
342    pub async fn start(
343        path: &'static str,
344        addr: impl Into<SocketAddr>,
345        static_files: PathBuf,
346    ) -> Result<()> {
347        // Filter traces based on the RUST_LOG env var.
348        let filter = std::env::var("RUST_LOG").unwrap_or_else(|_| {
349            "tracing=info,warp=debug,mpc_websocket=info".to_owned()
350        });
351
352        if cfg!(debug_assertions) {
353            tracing_subscriber::fmt()
354                .with_env_filter(filter)
355                .with_span_events(FmtSpan::CLOSE)
356                .init();
357        } else {
358            tracing_subscriber::fmt()
359                .with_env_filter(filter)
360                .with_span_events(FmtSpan::CLOSE)
361                .json()
362                .init();
363        }
364
365        let state = Arc::new(RwLock::new(State {
366            clients: HashMap::new(),
367            groups: Default::default(),
368        }));
369        let state = warp::any().map(move || state.clone());
370
371        if !static_files.is_dir() {
372            return Err(ServerError::NotDirectory(static_files));
373        }
374
375        let static_files = static_files.canonicalize()?;
376        let static_path = static_files.to_string_lossy().into_owned();
377        tracing::info!(%static_path);
378        tracing::info!(path);
379
380        let client = warp::any().and(warp::fs::dir(static_files));
381
382        let mut headers = HeaderMap::new();
383        headers.insert(
384            "Cross-Origin-Embedder-Policy",
385            HeaderValue::from_static("require-corp"),
386        );
387        headers.insert(
388            "Cross-Origin-Opener-Policy",
389            HeaderValue::from_static("same-origin"),
390        );
391
392        let websocket = warp::path(path).and(warp::ws()).and(state).map(
393            |ws: warp::ws::Ws, state| {
394                ws.on_upgrade(move |socket| client_connected(socket, state))
395            },
396        );
397
398        let routes = websocket
399            .or(client)
400            .with(warp::reply::with::headers(headers))
401            .with(warp::trace::request());
402
403        warp::serve(routes).run(addr).await;
404        Ok(())
405    }
406}
407
408async fn client_connected(ws: WebSocket, state: Arc<RwLock<State>>) {
409    let conn_id = CONNECTION_ID.fetch_add(1, Ordering::Relaxed);
410
411    tracing::info!(conn_id, "connected");
412
413    // Split the socket into a sender and receive of messages.
414    let (mut user_ws_tx, mut user_ws_rx) = ws.split();
415
416    // Use an unbounded channel to handle buffering and flushing of messages
417    // to the websocket.
418    let (tx, rx) = mpsc::unbounded_channel::<Message>();
419    let mut rx = UnboundedReceiverStream::new(rx);
420
421    let mut close_flag = Arc::new(RwLock::new(false));
422    let should_close = Arc::clone(&close_flag);
423
424    tokio::task::spawn(async move {
425        while let Some(message) = rx.next().await {
426            user_ws_tx
427                .send(message)
428                .unwrap_or_else(|e| {
429                    tracing::error!(?e, "websocket send error");
430                })
431                .await;
432
433            let reader = should_close.read().await;
434            if *reader {
435                if let Err(e) = user_ws_tx.close().await {
436                    tracing::warn!(?e, "failed to close websocket");
437                }
438                break;
439            }
440        }
441    });
442
443    // Save the sender in our list of connected clients.
444    state.write().await.clients.insert(conn_id, tx);
445
446    // Handle incoming requests from clients
447    while let Some(result) = user_ws_rx.next().await {
448        let msg = match result {
449            Ok(msg) => msg,
450            Err(e) => {
451                tracing::error!(conn_id, ?e, "websocket rx error");
452                break;
453            }
454        };
455
456        client_incoming_message(conn_id, &mut close_flag, msg, &state).await;
457    }
458
459    // user_ws_rx stream will keep processing as long as the user stays
460    // connected. Once they disconnect, then...
461    client_disconnected(conn_id, &state).await;
462}
463
464async fn client_incoming_message(
465    conn_id: usize,
466    close_flag: &mut Arc<RwLock<bool>>,
467    msg: Message,
468    state: &Arc<RwLock<State>>,
469) {
470    let msg = if let Ok(s) = msg.to_str() {
471        s
472    } else {
473        return;
474    };
475
476    match json_rpc2::from_str(msg) {
477        Ok(req) => rpc_request(conn_id, close_flag, req, state).await,
478        Err(e) => tracing::warn!(conn_id, ?e, "websocket rx JSON error"),
479    }
480}
481
482/// Process a request message from a client.
483async fn rpc_request(
484    conn_id: usize,
485    close_flag: &mut Arc<RwLock<bool>>,
486    request: Request,
487    state: &Arc<RwLock<State>>,
488) {
489    use json_rpc2::futures::*;
490
491    let service: RpcService = Box::new(ServiceHandler {});
492    let server = Server::new(vec![&service]);
493
494    let notification: Arc<Mutex<Option<Notification>>> =
495        Arc::new(Mutex::new(None));
496
497    if let Some(response) = server
498        .serve(
499            &request,
500            &(conn_id, Arc::clone(state), Arc::clone(&notification)),
501        )
502        .await
503    {
504        rpc_response(conn_id, &response, state).await;
505
506        if let Some(error) = response.error() {
507            if let Some(data) = &error.data {
508                if data == CLOSE_CONNECTION {
509                    let mut writer = close_flag.write().await;
510                    *writer = true;
511                }
512            }
513        }
514    }
515
516    let mut writer = notification.lock().await;
517    if let Some(notification) = writer.take() {
518        rpc_notify(state, notification).await;
519    }
520}
521
522/// Remove `filters` from a list of clients.
523fn filter_clients(
524    clients: Vec<usize>,
525    filter: Option<Vec<usize>>,
526) -> Vec<usize> {
527    if let Some(filter) = filter {
528        clients
529            .into_iter()
530            .filter(|conn| !filter.iter().any(|c| c == conn))
531            .collect::<Vec<_>>()
532    } else {
533        clients
534    }
535}
536
537/// Send notification to connected client(s).
538async fn rpc_notify(state: &Arc<RwLock<State>>, notification: Notification) {
539    let reader = state.read().await;
540    match notification {
541        Notification::Group {
542            group_id,
543            filter,
544            response,
545        } => {
546            let clients = if let Some(group) = reader.groups.get(&group_id) {
547                group.clients.clone()
548            } else {
549                vec![0usize]
550            };
551
552            let clients = filter_clients(clients, filter);
553            for conn_id in clients {
554                rpc_response(conn_id, &response, state).await;
555            }
556        }
557        Notification::Session {
558            group_id,
559            session_id,
560            filter,
561            response,
562        } => {
563            let clients = if let Some(group) = reader.groups.get(&group_id) {
564                if let Some(session) = group.sessions.get(&session_id) {
565                    session.party_signups.iter().map(|i| i.1).collect()
566                } else {
567                    tracing::warn!(
568                        %session_id,
569                        "notification session does not exist");
570                    vec![0usize]
571                }
572            } else {
573                vec![0usize]
574            };
575
576            let clients = filter_clients(clients, filter);
577            for conn_id in clients {
578                rpc_response(conn_id, &response, state).await;
579            }
580        }
581        Notification::Relay { messages } => {
582            for (conn_id, response) in messages {
583                rpc_response(conn_id, &response, state).await;
584            }
585        }
586        Notification::Noop => {}
587    }
588}
589
590/// Send a message to a single client.
591async fn rpc_response(
592    conn_id: usize,
593    response: &json_rpc2::Response,
594    state: &Arc<RwLock<State>>,
595) {
596    tracing::debug!(conn_id, "send message");
597    if let Some(tx) = state.read().await.clients.get(&conn_id) {
598        tracing::debug!(?response, "send response");
599        let msg = serde_json::to_string(response).unwrap();
600        if let Err(_disconnected) = tx.send(Message::text(msg)) {
601            // The tx is disconnected, our `client_disconnected` code
602            // should be happening in another task, nothing more to
603            // do here.
604        }
605    } else {
606        tracing::warn!(conn_id, "could not find tx for websocket");
607    }
608}
609
610async fn client_disconnected(conn_id: usize, state: &Arc<RwLock<State>>) {
611    tracing::info!(conn_id, "disconnected");
612
613    // FIXME: prune session party signups for disconnected clients?
614
615    let mut empty_groups: Vec<Uuid> = Vec::new();
616    {
617        let mut writer = state.write().await;
618        // Stream closed up, so remove from the client list
619        writer.clients.remove(&conn_id);
620        // Remove the connection from any client groups
621        for (key, group) in writer.groups.iter_mut() {
622            if let Some(index) =
623                group.clients.iter().position(|x| *x == conn_id)
624            {
625                group.clients.remove(index);
626            }
627
628            // Group has no more connected clients so flag it for removal
629            if group.clients.is_empty() {
630                empty_groups.push(*key);
631            }
632        }
633    }
634
635    // Prune empty groups
636    let mut writer = state.write().await;
637    for key in empty_groups {
638        writer.groups.remove(&key);
639        tracing::info!(%key, "removed group");
640    }
641}