mpc_protocol/
protocol.rs

1use crate::{encoding::types, PartyNumber, Result, TAGLEN};
2use http::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use snow::{HandshakeState, TransportState};
6use std::{
7    collections::{HashMap, HashSet},
8    time::{Duration, SystemTime},
9};
10
11/// Identifier for meeting points.
12pub type MeetingId = uuid::Uuid;
13
14/// Identifier for sessions.
15pub type SessionId = uuid::Uuid;
16
17/// User identifier wraps an SHA-256 hash of a
18/// unique arbitrary value.
19#[derive(Debug, Clone, Hash, Eq, PartialEq)]
20pub struct UserId([u8; 32]);
21
22impl AsRef<[u8; 32]> for UserId {
23    fn as_ref(&self) -> &[u8; 32] {
24        &self.0
25    }
26}
27
28impl From<[u8; 32]> for UserId {
29    fn from(value: [u8; 32]) -> Self {
30        Self(value)
31    }
32}
33
34/// Parameters used during key generation.
35#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
36pub struct Parameters {
37    /// Number of parties `n`.
38    pub parties: u16,
39    /// Threshold for signing `t`.
40    ///
41    /// The threshold must be crossed (`t + 1`) for signing
42    /// to commence.
43    pub threshold: u16,
44}
45
46impl Default for Parameters {
47    fn default() -> Self {
48        Self {
49            parties: 3,
50            threshold: 1,
51        }
52    }
53}
54
55/// Enumeration of protocol states.
56pub enum ProtocolState {
57    /// Noise handshake state.
58    Handshake(Box<HandshakeState>),
59    /// Noise transport state.
60    Transport(TransportState),
61}
62
63/// Handshake messages.
64#[derive(Default, Debug)]
65pub enum HandshakeMessage {
66    #[default]
67    #[doc(hidden)]
68    Noop,
69    /// Handshake initiator.
70    Initiator(usize, Vec<u8>),
71    /// Handshake responder.
72    Responder(usize, Vec<u8>),
73}
74
75impl From<&HandshakeMessage> for u8 {
76    fn from(value: &HandshakeMessage) -> Self {
77        match value {
78            HandshakeMessage::Noop => types::NOOP,
79            HandshakeMessage::Initiator(_, _) => {
80                types::HANDSHAKE_INITIATOR
81            }
82            HandshakeMessage::Responder(_, _) => {
83                types::HANDSHAKE_RESPONDER
84            }
85        }
86    }
87}
88
89/// Transparent messages are not encrypted.
90#[derive(Default, Debug)]
91pub enum TransparentMessage {
92    #[default]
93    #[doc(hidden)]
94    Noop,
95    /// Return an error message to the client.
96    Error(StatusCode, String),
97    /// Handshake message.
98    ServerHandshake(HandshakeMessage),
99    /// Relayed peer handshake message.
100    PeerHandshake {
101        /// Public key of the receiver.
102        public_key: Vec<u8>,
103        /// Handshake message.
104        message: HandshakeMessage,
105    },
106}
107
108impl From<&TransparentMessage> for u8 {
109    fn from(value: &TransparentMessage) -> Self {
110        match value {
111            TransparentMessage::Noop => types::NOOP,
112            TransparentMessage::Error(_, _) => types::ERROR,
113            TransparentMessage::ServerHandshake(_) => {
114                types::HANDSHAKE_SERVER
115            }
116            TransparentMessage::PeerHandshake { .. } => {
117                types::HANDSHAKE_PEER
118            }
119        }
120    }
121}
122
123/// Message sent between the server and a client.
124#[derive(Default, Debug)]
125pub enum ServerMessage {
126    #[default]
127    #[doc(hidden)]
128    Noop,
129    /// Return an error message to the client.
130    Error(StatusCode, String),
131    /// Request a new meeting point.
132    NewMeeting {
133        /// The identifier for the owner of the meeting point.
134        ///
135        /// The owner id must exist in the set of slots.
136        owner_id: UserId,
137        /// Slots for participants in the meeting.
138        slots: HashSet<UserId>,
139        /// Data associated aith the meeting point.
140        data: Value,
141    },
142    /// Response to a new meeting point request.
143    MeetingCreated(MeetingState),
144    /// Participant joins a meeting.
145    JoinMeeting(MeetingId, UserId),
146    /// Notification dispatched to all participants
147    /// in a meeting when the limit for the meeting
148    /// has been reached.
149    MeetingReady(MeetingState),
150    /// Request a new session.
151    NewSession(SessionRequest),
152    /// Register a peer connection in a session.
153    SessionConnection {
154        /// Session identifier.
155        session_id: SessionId,
156        /// Public key of the peer.
157        peer_key: Vec<u8>,
158    },
159    /// Response to a new session request.
160    SessionCreated(SessionState),
161    /// Notification dispatched to all participants
162    /// in a session when they have all completed
163    /// the server handshake.
164    SessionReady(SessionState),
165    /// Notification dispatched to all participants
166    /// in a session when they have all established
167    /// peer connections to each other.
168    SessionActive(SessionState),
169    /// Notification dispatched to all participants
170    /// in a session when the participants did not
171    /// all connect within the expected timeframe.
172    SessionTimeout(SessionId),
173    /// Request to close a session.
174    CloseSession(SessionId),
175    /// Message sent when a session was closed.
176    SessionFinished(SessionId),
177}
178
179impl From<&ServerMessage> for u8 {
180    fn from(value: &ServerMessage) -> Self {
181        match value {
182            ServerMessage::Noop => types::NOOP,
183            ServerMessage::Error(_, _) => types::ERROR,
184            ServerMessage::NewMeeting { .. } => types::MEETING_NEW,
185            ServerMessage::MeetingCreated(_) => {
186                types::MEETING_CREATED
187            }
188            ServerMessage::JoinMeeting(_, _) => types::MEETING_JOIN,
189            ServerMessage::MeetingReady(_) => types::MEETING_READY,
190            ServerMessage::NewSession(_) => types::SESSION_NEW,
191            ServerMessage::SessionConnection { .. } => {
192                types::SESSION_CONNECTION
193            }
194            ServerMessage::SessionCreated(_) => {
195                types::SESSION_CREATED
196            }
197            ServerMessage::SessionReady(_) => types::SESSION_READY,
198            ServerMessage::SessionActive(_) => types::SESSION_ACTIVE,
199            ServerMessage::SessionTimeout(_) => {
200                types::SESSION_TIMEOUT
201            }
202            ServerMessage::CloseSession(_) => types::SESSION_CLOSE,
203            ServerMessage::SessionFinished(_) => {
204                types::SESSION_FINISHED
205            }
206        }
207    }
208}
209
210/// Opaque messaages are encrypted.
211#[derive(Default, Debug)]
212pub enum OpaqueMessage {
213    #[default]
214    #[doc(hidden)]
215    Noop,
216
217    /// Encrypted message sent between the server and a client.
218    ///
219    /// After decrypting it can be decoded to a server message.
220    ServerMessage(SealedEnvelope),
221
222    /// Relay an encrypted message to a peer.
223    PeerMessage {
224        /// Public key of the receiver.
225        public_key: Vec<u8>,
226        /// Session identifier.
227        session_id: Option<SessionId>,
228        /// Message envelope.
229        envelope: SealedEnvelope,
230    },
231}
232
233impl From<&OpaqueMessage> for u8 {
234    fn from(value: &OpaqueMessage) -> Self {
235        match value {
236            OpaqueMessage::Noop => types::NOOP,
237            OpaqueMessage::ServerMessage(_) => types::OPAQUE_SERVER,
238            OpaqueMessage::PeerMessage { .. } => types::OPAQUE_PEER,
239        }
240    }
241}
242
243/// Request message sent to the server or another peer.
244#[derive(Default, Debug)]
245pub enum RequestMessage {
246    #[default]
247    #[doc(hidden)]
248    Noop,
249
250    /// Transparent message used for the handshake(s).
251    Transparent(TransparentMessage),
252
253    /// Opaque encrypted messages.
254    Opaque(OpaqueMessage),
255}
256
257impl From<&RequestMessage> for u8 {
258    fn from(value: &RequestMessage) -> Self {
259        match value {
260            RequestMessage::Noop => types::NOOP,
261            RequestMessage::Transparent(_) => types::TRANSPARENT,
262            RequestMessage::Opaque(_) => types::OPAQUE,
263        }
264    }
265}
266
267/// Response message sent by the server or a peer.
268#[derive(Default, Debug)]
269pub enum ResponseMessage {
270    #[default]
271    #[doc(hidden)]
272    Noop,
273
274    /// Transparent message used for the handshake(s).
275    Transparent(TransparentMessage),
276
277    /// Opaque encrypted messages.
278    Opaque(OpaqueMessage),
279}
280
281impl From<&ResponseMessage> for u8 {
282    fn from(value: &ResponseMessage) -> Self {
283        match value {
284            ResponseMessage::Noop => types::NOOP,
285            ResponseMessage::Transparent(_) => types::TRANSPARENT,
286            ResponseMessage::Opaque(_) => types::OPAQUE,
287        }
288    }
289}
290
291/// Encoding for message payloads.
292#[derive(Default, Clone, Copy, Debug)]
293pub enum Encoding {
294    #[default]
295    #[doc(hidden)]
296    Noop,
297    /// Binary encoding.
298    Blob,
299    /// JSON encoding.
300    Json,
301}
302
303impl From<Encoding> for u8 {
304    fn from(value: Encoding) -> Self {
305        match value {
306            Encoding::Noop => types::NOOP,
307            Encoding::Blob => types::ENCODING_BLOB,
308            Encoding::Json => types::ENCODING_JSON,
309        }
310    }
311}
312
313/// Chunk is used to respect the 65535 limit for
314/// noise protocol messages.
315///
316/// Payloads may be larger than this limit so we chunk
317/// them into individually encrypted payloads which then
318/// need to be re-combined after each chunk has been decrypted.
319#[derive(Default, Debug)]
320pub struct Chunk {
321    /// Length of the payload data.
322    pub length: usize,
323    /// Encrypted payload.
324    pub contents: Vec<u8>,
325}
326
327impl Chunk {
328    const CHUNK_SIZE: usize = 65535 - TAGLEN;
329
330    /// Split a payload into encrypted chunks.
331    pub fn split(
332        payload: &[u8],
333        transport: &mut TransportState,
334    ) -> Result<Vec<Chunk>> {
335        let mut chunks = Vec::new();
336        for chunk in payload.chunks(Self::CHUNK_SIZE) {
337            let mut contents = vec![0; chunk.len() + TAGLEN];
338            let length =
339                transport.write_message(chunk, &mut contents)?;
340            chunks.push(Chunk { length, contents });
341        }
342        Ok(chunks)
343    }
344
345    /// Decrypt chunks and join into a single payload.
346    pub fn join(
347        chunks: Vec<Chunk>,
348        transport: &mut TransportState,
349    ) -> Result<Vec<u8>> {
350        let mut payload = Vec::new();
351        for chunk in chunks {
352            let mut contents = vec![0; chunk.length];
353            transport.read_message(
354                &chunk.contents[..chunk.length],
355                &mut contents,
356            )?;
357            let new_length = contents.len() - TAGLEN;
358            contents.truncate(new_length);
359            payload.extend_from_slice(contents.as_slice());
360        }
361        Ok(payload)
362    }
363}
364
365/// Sealed envelope is an encrypted message.
366///
367/// The payload has been encrypted using the noise protocol
368/// channel and the recipient must decrypt and decode the payload.
369#[derive(Default, Debug)]
370pub struct SealedEnvelope {
371    /// Encoding for the payload.
372    pub encoding: Encoding,
373    /// Encrypted chunks.
374    pub chunks: Vec<Chunk>,
375    /// Whether this is a broadcast message.
376    pub broadcast: bool,
377}
378
379/// Session is a namespace for a group of participants
380/// to communicate for a series of rounds.
381///
382/// Use this for the keygen, signing or key refresh
383/// of an MPC protocol.
384pub struct Session {
385    /// Public key of the owner.
386    ///
387    /// The owner is the initiator that created
388    /// this session.
389    owner_key: Vec<u8>,
390
391    /// Public keys of the other session participants.
392    participant_keys: HashSet<Vec<u8>>,
393
394    /// Connections between peers established in this
395    /// session context.
396    connections: HashSet<(Vec<u8>, Vec<u8>)>,
397
398    /// Last access time so the server can reap
399    /// stale sessions.
400    last_access: SystemTime,
401}
402
403impl Session {
404    /// Public key of the session owner.
405    pub fn owner_key(&self) -> &[u8] {
406        self.owner_key.as_slice()
407    }
408
409    /// Get all participant's public keys
410    pub fn public_keys(&self) -> Vec<&[u8]> {
411        let mut keys = vec![self.owner_key.as_slice()];
412        let mut participants: Vec<_> = self
413            .participant_keys
414            .iter()
415            .map(|k| k.as_slice())
416            .collect();
417        keys.append(&mut participants);
418        keys
419    }
420
421    /// Register a connection between peers.
422    pub fn register_connection(
423        &mut self,
424        peer: Vec<u8>,
425        other: Vec<u8>,
426    ) {
427        self.connections.insert((peer, other));
428    }
429
430    /// Determine if this session is active.
431    ///
432    /// A session is active when all participants have created
433    /// their peer connections.
434    pub fn is_active(&self) -> bool {
435        let all_participants = self.public_keys();
436
437        fn check_connection(
438            connections: &HashSet<(Vec<u8>, Vec<u8>)>,
439            peer: &[u8],
440            all: &[&[u8]],
441        ) -> bool {
442            for key in all {
443                if key == &peer {
444                    continue;
445                }
446                // We don't know the order the connections
447                // were established so check both.
448                let left =
449                    connections.get(&(peer.to_vec(), key.to_vec()));
450                let right =
451                    connections.get(&(key.to_vec(), peer.to_vec()));
452                let is_connected = left.is_some() || right.is_some();
453                if !is_connected {
454                    return false;
455                }
456            }
457            true
458        }
459
460        for key in &all_participants {
461            let is_connected_others = check_connection(
462                &self.connections,
463                key,
464                all_participants.as_slice(),
465            );
466            if !is_connected_others {
467                return false;
468            }
469        }
470
471        true
472    }
473}
474
475/// Meeting point information.
476#[derive(Debug)]
477pub struct Meeting {
478    /// Map of user identifiers to public keys.
479    slots: HashMap<UserId, Option<Vec<u8>>>,
480
481    /// Last access time so the server can reap
482    /// stale meetings.
483    last_access: SystemTime,
484
485    /// Associated data for the meeting.
486    data: Value,
487}
488
489impl Meeting {
490    /// Add a participant public key to this meeting.
491    pub fn join(&mut self, user_id: UserId, public_key: Vec<u8>) {
492        self.slots.insert(user_id, Some(public_key));
493        self.last_access = SystemTime::now();
494    }
495
496    /// Whether this meeting point is full.
497    pub fn is_full(&self) -> bool {
498        self.slots.values().all(|s| s.is_some())
499    }
500
501    /// Public keys of the meeting participants.
502    pub fn participants(&self) -> Vec<Vec<u8>> {
503        self.slots
504            .values()
505            .filter(|s| s.is_some())
506            .map(|s| s.as_ref().unwrap().to_owned())
507            .collect()
508    }
509
510    /// Associated data.
511    pub fn data(&self) -> &Value {
512        &self.data
513    }
514}
515
516/// Manages a collection of meeting points.
517#[derive(Default)]
518pub struct MeetingManager {
519    meetings: HashMap<MeetingId, Meeting>,
520}
521
522impl MeetingManager {
523    /// Create a new meeting point.
524    pub fn new_meeting(
525        &mut self,
526        owner_key: Vec<u8>,
527        owner_id: UserId,
528        slots: HashSet<UserId>,
529        data: Value,
530    ) -> MeetingId {
531        let meeting_id = MeetingId::new_v4();
532        let slots: HashMap<UserId, Option<Vec<u8>>> =
533            slots.into_iter().map(|id| (id, None)).collect();
534
535        let mut meeting = Meeting {
536            slots,
537            last_access: SystemTime::now(),
538            data,
539        };
540        meeting.join(owner_id, owner_key);
541
542        self.meetings.insert(meeting_id, meeting);
543        meeting_id
544    }
545
546    /// Remove a meeting.
547    pub fn remove_meeting(
548        &mut self,
549        id: &MeetingId,
550    ) -> Option<Meeting> {
551        self.meetings.remove(id)
552    }
553
554    /// Get a meeting.
555    pub fn get_meeting(&self, id: &MeetingId) -> Option<&Meeting> {
556        self.meetings.get(id)
557    }
558
559    /// Get a mutable meeting.
560    pub fn get_meeting_mut(
561        &mut self,
562        id: &MeetingId,
563    ) -> Option<&mut Meeting> {
564        self.meetings.get_mut(id)
565    }
566
567    /// Get the keys of meetings that have expired.
568    pub fn expired_keys(&self, timeout: u64) -> Vec<MeetingId> {
569        self.meetings
570            .iter()
571            .filter(|(_, v)| {
572                let now = SystemTime::now();
573                let ttl = Duration::from_millis(timeout * 1000);
574                if let Some(current) = v.last_access.checked_add(ttl)
575                {
576                    current < now
577                } else {
578                    false
579                }
580            })
581            .map(|(k, _)| *k)
582            .collect::<Vec<_>>()
583    }
584}
585
586/// Manages a collection of sessions.
587#[derive(Default)]
588pub struct SessionManager {
589    sessions: HashMap<SessionId, Session>,
590}
591
592impl SessionManager {
593    /// Create a new session.
594    pub fn new_session(
595        &mut self,
596        owner_key: Vec<u8>,
597        participant_keys: Vec<Vec<u8>>,
598    ) -> SessionId {
599        let session_id = SessionId::new_v4();
600        let session = Session {
601            owner_key,
602            participant_keys: participant_keys.into_iter().collect(),
603            connections: Default::default(),
604            last_access: SystemTime::now(),
605        };
606        self.sessions.insert(session_id, session);
607        session_id
608    }
609
610    /// Get a session.
611    pub fn get_session(&self, id: &SessionId) -> Option<&Session> {
612        self.sessions.get(id)
613    }
614
615    /// Get a mutable session.
616    pub fn get_session_mut(
617        &mut self,
618        id: &SessionId,
619    ) -> Option<&mut Session> {
620        self.sessions.get_mut(id)
621    }
622
623    /// Remove a session.
624    pub fn remove_session(
625        &mut self,
626        id: &SessionId,
627    ) -> Option<Session> {
628        self.sessions.remove(id)
629    }
630
631    /// Retrieve and update the last access time for a session.
632    pub fn touch_session(
633        &mut self,
634        id: &SessionId,
635    ) -> Option<&Session> {
636        if let Some(session) = self.sessions.get_mut(id) {
637            session.last_access = SystemTime::now();
638            Some(&*session)
639        } else {
640            None
641        }
642    }
643
644    /// Get the keys of sessions that have expired.
645    pub fn expired_keys(&self, timeout: u64) -> Vec<SessionId> {
646        self.sessions
647            .iter()
648            .filter(|(_, v)| {
649                let now = SystemTime::now();
650                let ttl = Duration::from_millis(timeout * 1000);
651                if let Some(current) = v.last_access.checked_add(ttl)
652                {
653                    current < now
654                } else {
655                    false
656                }
657            })
658            .map(|(k, _)| *k)
659            .collect::<Vec<_>>()
660    }
661}
662
663/// Response from creating a meeting point.
664#[derive(Default, Debug, Clone)]
665pub struct MeetingState {
666    /// Meeting identifier.
667    pub meeting_id: MeetingId,
668    /// Public keys of the registered participants.
669    pub registered_participants: Vec<Vec<u8>>,
670    /// Data for the meeting state.
671    pub data: Value,
672}
673
674/// Request to create a new session.
675///
676/// Do no include the public key of the initiator as it
677/// is automatically added as the session *owner*.
678#[derive(Default, Debug)]
679pub struct SessionRequest {
680    /// Public keys of the session participants.
681    pub participant_keys: Vec<Vec<u8>>,
682}
683
684/// Response from creating new session.
685#[derive(Default, Debug, Clone)]
686pub struct SessionState {
687    /// Session identifier.
688    pub session_id: SessionId,
689    /// Public keys of all participants.
690    pub all_participants: Vec<Vec<u8>>,
691}
692
693impl SessionState {
694    /// Total number of participants in this session.
695    pub fn len(&self) -> usize {
696        self.all_participants.len()
697    }
698
699    /// Get the party index from a public key.
700    pub fn party_number(
701        &self,
702        public_key: impl AsRef<[u8]>,
703    ) -> Option<PartyNumber> {
704        self.all_participants
705            .iter()
706            .position(|k| k == public_key.as_ref())
707            .map(|pos| PartyNumber::new((pos + 1) as u16).unwrap())
708    }
709
710    /// Get the public key for a party number.
711    pub fn peer_key(
712        &self,
713        party_number: PartyNumber,
714    ) -> Option<&[u8]> {
715        for (index, key) in self.all_participants.iter().enumerate() {
716            if index + 1 == party_number.get() as usize {
717                return Some(key.as_slice());
718            }
719        }
720        None
721    }
722
723    /// Get the connections a peer should make.
724    pub fn connections(&self, own_key: &[u8]) -> &[Vec<u8>] {
725        if self.all_participants.is_empty() {
726            return &[];
727        }
728
729        if let Some(position) =
730            self.all_participants.iter().position(|k| k == own_key)
731        {
732            if position < self.all_participants.len() - 1 {
733                &self.all_participants[position + 1..]
734            } else {
735                &[]
736            }
737        } else {
738            &[]
739        }
740    }
741
742    /// Get the recipients for a broadcast message.
743    pub fn recipients(&self, own_key: &[u8]) -> Vec<Vec<u8>> {
744        self.all_participants
745            .iter()
746            .filter(|&k| k != own_key)
747            .map(|k| k.to_vec())
748            .collect()
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::Chunk;
755    use crate::PATTERN;
756    use anyhow::Result;
757
758    #[test]
759    fn chunks_split_join() -> Result<()> {
760        let builder_1 = snow::Builder::new(PATTERN.parse()?);
761        let builder_2 = snow::Builder::new(PATTERN.parse()?);
762
763        let keypair1 = builder_1.generate_keypair()?;
764        let keypair2 = builder_2.generate_keypair()?;
765
766        let mut initiator = builder_1
767            .local_private_key(&keypair1.private)
768            .remote_public_key(&keypair2.public)
769            .build_initiator()?;
770
771        let mut responder = builder_2
772            .local_private_key(&keypair2.private)
773            .remote_public_key(&keypair1.public)
774            .build_responder()?;
775
776        let (mut read_buf, mut first_msg, mut second_msg) =
777            ([0u8; 1024], [0u8; 1024], [0u8; 1024]);
778
779        // -> e
780        let len = initiator.write_message(&[], &mut first_msg)?;
781
782        // responder processes the first message...
783        responder.read_message(&first_msg[..len], &mut read_buf)?;
784
785        // <- e, ee
786        let len = responder.write_message(&[], &mut second_msg)?;
787
788        // initiator processes the response...
789        initiator.read_message(&second_msg[..len], &mut read_buf)?;
790
791        // NN handshake complete, transition into transport mode.
792        let mut initiator = initiator.into_transport_mode()?;
793        let mut responder = responder.into_transport_mode()?;
794
795        let mock_payload = vec![0; 76893];
796
797        // Split into chunks
798        let chunks = Chunk::split(&mock_payload, &mut initiator)?;
799        assert_eq!(2, chunks.len());
800
801        // Decrypt and combine the chunks
802        let decrypted_payload = Chunk::join(chunks, &mut responder)?;
803        assert_eq!(mock_payload, decrypted_payload);
804
805        Ok(())
806    }
807}