polysig_protocol/
protocol.rs

1use crate::{encoding::types, PartyNumber, Result, TAGLEN};
2use http::StatusCode;
3use serde::{Deserialize, Serialize};
4use snow::{HandshakeState, TransportState};
5use std::{
6    collections::{HashMap, HashSet},
7    time::{Duration, SystemTime},
8};
9
10/// Identifier for sessions.
11pub type SessionId = uuid::Uuid;
12
13/// User identifier wraps an SHA-256 hash of a
14/// unique arbitrary value.
15#[derive(
16    Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize,
17)]
18pub struct UserId([u8; 32]);
19
20impl AsRef<[u8; 32]> for UserId {
21    fn as_ref(&self) -> &[u8; 32] {
22        &self.0
23    }
24}
25
26impl From<[u8; 32]> for UserId {
27    fn from(value: [u8; 32]) -> Self {
28        Self(value)
29    }
30}
31
32/// Parameters used during key generation.
33#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
34pub struct Parameters {
35    /// Number of parties `n`.
36    pub parties: u16,
37    /// Threshold for signing `t`.
38    ///
39    /// The threshold must be crossed (`t + 1`) for signing
40    /// to commence.
41    pub threshold: u16,
42}
43
44impl Default for Parameters {
45    fn default() -> Self {
46        Self {
47            parties: 3,
48            threshold: 1,
49        }
50    }
51}
52
53/// Enumeration of protocol states.
54pub enum ProtocolState {
55    /// Noise handshake state.
56    Handshake(Box<HandshakeState>),
57    /// Noise transport state.
58    Transport(TransportState),
59}
60
61/// Handshake messages.
62#[derive(Default, Debug)]
63pub enum HandshakeMessage {
64    #[default]
65    #[doc(hidden)]
66    Noop,
67    /// Handshake initiator.
68    Initiator(usize, Vec<u8>),
69    /// Handshake responder.
70    Responder(usize, Vec<u8>),
71}
72
73impl From<&HandshakeMessage> for u8 {
74    fn from(value: &HandshakeMessage) -> Self {
75        match value {
76            HandshakeMessage::Noop => types::NOOP,
77            HandshakeMessage::Initiator(_, _) => {
78                types::HANDSHAKE_INITIATOR
79            }
80            HandshakeMessage::Responder(_, _) => {
81                types::HANDSHAKE_RESPONDER
82            }
83        }
84    }
85}
86
87/// Transparent messages are not encrypted.
88#[derive(Default, Debug)]
89pub enum TransparentMessage {
90    #[default]
91    #[doc(hidden)]
92    Noop,
93    /// Return an error message to the client.
94    Error(StatusCode, String),
95    /// Handshake message.
96    ServerHandshake(HandshakeMessage),
97    /// Relayed peer handshake message.
98    PeerHandshake {
99        /// Public key of the receiver.
100        public_key: Vec<u8>,
101        /// Handshake message.
102        message: HandshakeMessage,
103    },
104}
105
106impl From<&TransparentMessage> for u8 {
107    fn from(value: &TransparentMessage) -> Self {
108        match value {
109            TransparentMessage::Noop => types::NOOP,
110            TransparentMessage::Error(_, _) => types::ERROR,
111            TransparentMessage::ServerHandshake(_) => {
112                types::HANDSHAKE_SERVER
113            }
114            TransparentMessage::PeerHandshake { .. } => {
115                types::HANDSHAKE_PEER
116            }
117        }
118    }
119}
120
121/// Message sent between the server and a client.
122#[derive(Default, Debug)]
123pub enum ServerMessage {
124    #[default]
125    #[doc(hidden)]
126    Noop,
127    /// Return an error message to the client.
128    Error(StatusCode, String),
129    /// Request a new session.
130    NewSession(SessionRequest),
131    /// Register a peer connection in a session.
132    SessionConnection {
133        /// Session identifier.
134        session_id: SessionId,
135        /// Public key of the peer.
136        peer_key: Vec<u8>,
137    },
138    /// Response to a new session request.
139    SessionCreated(SessionState),
140    /// Notification dispatched to all participants
141    /// in a session when they have all completed
142    /// the server handshake.
143    SessionReady(SessionState),
144    /// Notification dispatched to all participants
145    /// in a session when they have all established
146    /// peer connections to each other.
147    SessionActive(SessionState),
148    /// Notification dispatched to all participants
149    /// in a session when the participants did not
150    /// all connect within the expected timeframe.
151    SessionTimeout(SessionId),
152    /// Request to close a session.
153    CloseSession(SessionId),
154    /// Message sent when a session was closed.
155    SessionFinished(SessionId),
156}
157
158impl From<&ServerMessage> for u8 {
159    fn from(value: &ServerMessage) -> Self {
160        match value {
161            ServerMessage::Noop => types::NOOP,
162            ServerMessage::Error(_, _) => types::ERROR,
163            ServerMessage::NewSession(_) => types::SESSION_NEW,
164            ServerMessage::SessionConnection { .. } => {
165                types::SESSION_CONNECTION
166            }
167            ServerMessage::SessionCreated(_) => {
168                types::SESSION_CREATED
169            }
170            ServerMessage::SessionReady(_) => types::SESSION_READY,
171            ServerMessage::SessionActive(_) => types::SESSION_ACTIVE,
172            ServerMessage::SessionTimeout(_) => {
173                types::SESSION_TIMEOUT
174            }
175            ServerMessage::CloseSession(_) => types::SESSION_CLOSE,
176            ServerMessage::SessionFinished(_) => {
177                types::SESSION_FINISHED
178            }
179        }
180    }
181}
182
183/// Opaque messaages are encrypted.
184#[derive(Default, Debug)]
185pub enum OpaqueMessage {
186    #[default]
187    #[doc(hidden)]
188    Noop,
189
190    /// Encrypted message sent between the server and a client.
191    ///
192    /// After decrypting it can be decoded to a server message.
193    ServerMessage(SealedEnvelope),
194
195    /// Relay an encrypted message to a peer.
196    PeerMessage {
197        /// Public key of the receiver.
198        public_key: Vec<u8>,
199        /// Session identifier.
200        session_id: Option<SessionId>,
201        /// Message envelope.
202        envelope: SealedEnvelope,
203    },
204}
205
206impl From<&OpaqueMessage> for u8 {
207    fn from(value: &OpaqueMessage) -> Self {
208        match value {
209            OpaqueMessage::Noop => types::NOOP,
210            OpaqueMessage::ServerMessage(_) => types::OPAQUE_SERVER,
211            OpaqueMessage::PeerMessage { .. } => types::OPAQUE_PEER,
212        }
213    }
214}
215
216/// Request message sent to the server or another peer.
217#[derive(Default, Debug)]
218pub enum RequestMessage {
219    #[default]
220    #[doc(hidden)]
221    Noop,
222
223    /// Transparent message used for the handshake(s).
224    Transparent(TransparentMessage),
225
226    /// Opaque encrypted messages.
227    Opaque(OpaqueMessage),
228}
229
230impl From<&RequestMessage> for u8 {
231    fn from(value: &RequestMessage) -> Self {
232        match value {
233            RequestMessage::Noop => types::NOOP,
234            RequestMessage::Transparent(_) => types::TRANSPARENT,
235            RequestMessage::Opaque(_) => types::OPAQUE,
236        }
237    }
238}
239
240/// Response message sent by the server or a peer.
241#[derive(Default, Debug)]
242pub enum ResponseMessage {
243    #[default]
244    #[doc(hidden)]
245    Noop,
246
247    /// Transparent message used for the handshake(s).
248    Transparent(TransparentMessage),
249
250    /// Opaque encrypted messages.
251    Opaque(OpaqueMessage),
252}
253
254impl From<&ResponseMessage> for u8 {
255    fn from(value: &ResponseMessage) -> Self {
256        match value {
257            ResponseMessage::Noop => types::NOOP,
258            ResponseMessage::Transparent(_) => types::TRANSPARENT,
259            ResponseMessage::Opaque(_) => types::OPAQUE,
260        }
261    }
262}
263
264/// Encoding for message payloads.
265#[derive(Default, Clone, Copy, Debug)]
266pub enum Encoding {
267    #[default]
268    #[doc(hidden)]
269    Noop,
270    /// Binary encoding.
271    Blob,
272    /// JSON encoding.
273    Json,
274}
275
276impl From<Encoding> for u8 {
277    fn from(value: Encoding) -> Self {
278        match value {
279            Encoding::Noop => types::NOOP,
280            Encoding::Blob => types::ENCODING_BLOB,
281            Encoding::Json => types::ENCODING_JSON,
282        }
283    }
284}
285
286/// Chunk is used to respect the 65535 limit for
287/// noise protocol messages.
288///
289/// Payloads may be larger than this limit so we chunk
290/// them into individually encrypted payloads which then
291/// need to be re-combined after each chunk has been decrypted.
292#[derive(Default, Debug)]
293pub struct Chunk {
294    /// Length of the payload data.
295    pub length: usize,
296    /// Encrypted payload.
297    pub contents: Vec<u8>,
298}
299
300impl Chunk {
301    const CHUNK_SIZE: usize = 65535 - TAGLEN;
302
303    /// Split a payload into encrypted chunks.
304    pub fn split(
305        payload: &[u8],
306        transport: &mut TransportState,
307    ) -> Result<Vec<Chunk>> {
308        let mut chunks = Vec::new();
309        for chunk in payload.chunks(Self::CHUNK_SIZE) {
310            let mut contents = vec![0; chunk.len() + TAGLEN];
311            let length =
312                transport.write_message(chunk, &mut contents)?;
313            chunks.push(Chunk { length, contents });
314        }
315        Ok(chunks)
316    }
317
318    /// Decrypt chunks and join into a single payload.
319    pub fn join(
320        chunks: Vec<Chunk>,
321        transport: &mut TransportState,
322    ) -> Result<Vec<u8>> {
323        let mut payload = Vec::new();
324        for chunk in chunks {
325            let mut contents = vec![0; chunk.length];
326            transport.read_message(
327                &chunk.contents[..chunk.length],
328                &mut contents,
329            )?;
330            let new_length = contents.len() - TAGLEN;
331            contents.truncate(new_length);
332            payload.extend_from_slice(contents.as_slice());
333        }
334        Ok(payload)
335    }
336}
337
338/// Sealed envelope is an encrypted message.
339///
340/// The payload has been encrypted using the noise protocol
341/// channel and the recipient must decrypt and decode the payload.
342#[derive(Default, Debug)]
343pub struct SealedEnvelope {
344    /// Encoding for the payload.
345    pub encoding: Encoding,
346    /// Encrypted chunks.
347    pub chunks: Vec<Chunk>,
348    /// Whether this is a broadcast message.
349    pub broadcast: bool,
350}
351
352/// Session is a namespace for a group of participants
353/// to communicate for a series of rounds.
354///
355/// Use this for the keygen, signing or key refresh
356/// of an MPC protocol.
357pub struct Session {
358    /// Public key of the owner.
359    ///
360    /// The owner is the initiator that created
361    /// this session.
362    owner_key: Vec<u8>,
363
364    /// Public keys of the other session participants.
365    participant_keys: HashSet<Vec<u8>>,
366
367    /// Connections between peers established in this
368    /// session context.
369    connections: HashSet<(Vec<u8>, Vec<u8>)>,
370
371    /// Last access time so the server can reap
372    /// stale sessions.
373    last_access: SystemTime,
374}
375
376impl Session {
377    /// Public key of the session owner.
378    pub fn owner_key(&self) -> &[u8] {
379        self.owner_key.as_slice()
380    }
381
382    /// Get all participant's public keys
383    pub fn public_keys(&self) -> Vec<&[u8]> {
384        let mut keys = vec![self.owner_key.as_slice()];
385        let mut participants: Vec<_> = self
386            .participant_keys
387            .iter()
388            .map(|k| k.as_slice())
389            .collect();
390        keys.append(&mut participants);
391        keys
392    }
393
394    /// Register a connection between peers.
395    pub fn register_connection(
396        &mut self,
397        peer: Vec<u8>,
398        other: Vec<u8>,
399    ) {
400        self.connections.insert((peer, other));
401    }
402
403    /// Determine if this session is active.
404    ///
405    /// A session is active when all participants have created
406    /// their peer connections.
407    pub fn is_active(&self) -> bool {
408        let all_participants = self.public_keys();
409
410        fn check_connection(
411            connections: &HashSet<(Vec<u8>, Vec<u8>)>,
412            peer: &[u8],
413            all: &[&[u8]],
414        ) -> bool {
415            for key in all {
416                if key == &peer {
417                    continue;
418                }
419                // We don't know the order the connections
420                // were established so check both.
421                let left =
422                    connections.get(&(peer.to_vec(), key.to_vec()));
423                let right =
424                    connections.get(&(key.to_vec(), peer.to_vec()));
425                let is_connected = left.is_some() || right.is_some();
426                if !is_connected {
427                    return false;
428                }
429            }
430            true
431        }
432
433        for key in &all_participants {
434            let is_connected_others = check_connection(
435                &self.connections,
436                key,
437                all_participants.as_slice(),
438            );
439            if !is_connected_others {
440                return false;
441            }
442        }
443
444        true
445    }
446}
447
448/// Manages a collection of sessions.
449#[derive(Default)]
450pub struct SessionManager {
451    sessions: HashMap<SessionId, Session>,
452}
453
454impl SessionManager {
455    /// Create a new session.
456    pub fn new_session(
457        &mut self,
458        owner_key: Vec<u8>,
459        participant_keys: Vec<Vec<u8>>,
460    ) -> SessionId {
461        let session_id = SessionId::new_v4();
462        let session = Session {
463            owner_key,
464            participant_keys: participant_keys.into_iter().collect(),
465            connections: Default::default(),
466            last_access: SystemTime::now(),
467        };
468        self.sessions.insert(session_id, session);
469        session_id
470    }
471
472    /// Get a session.
473    pub fn get_session(&self, id: &SessionId) -> Option<&Session> {
474        self.sessions.get(id)
475    }
476
477    /// Get a mutable session.
478    pub fn get_session_mut(
479        &mut self,
480        id: &SessionId,
481    ) -> Option<&mut Session> {
482        self.sessions.get_mut(id)
483    }
484
485    /// Remove a session.
486    pub fn remove_session(
487        &mut self,
488        id: &SessionId,
489    ) -> Option<Session> {
490        self.sessions.remove(id)
491    }
492
493    /// Retrieve and update the last access time for a session.
494    pub fn touch_session(
495        &mut self,
496        id: &SessionId,
497    ) -> Option<&Session> {
498        if let Some(session) = self.sessions.get_mut(id) {
499            session.last_access = SystemTime::now();
500            Some(&*session)
501        } else {
502            None
503        }
504    }
505
506    /// Get the keys of sessions that have expired.
507    pub fn expired_keys(&self, timeout: u64) -> Vec<SessionId> {
508        self.sessions
509            .iter()
510            .filter(|(_, v)| {
511                let now = SystemTime::now();
512                let ttl = Duration::from_millis(timeout * 1000);
513                if let Some(current) = v.last_access.checked_add(ttl)
514                {
515                    current < now
516                } else {
517                    false
518                }
519            })
520            .map(|(k, _)| *k)
521            .collect::<Vec<_>>()
522    }
523}
524
525/// Request to create a new session.
526///
527/// Do no include the public key of the initiator as it
528/// is automatically added as the session *owner*.
529#[derive(Default, Debug)]
530pub struct SessionRequest {
531    /// Public keys of the session participants.
532    pub participant_keys: Vec<Vec<u8>>,
533}
534
535/// Response from creating new session.
536#[derive(Default, Debug, Clone)]
537pub struct SessionState {
538    /// Session identifier.
539    pub session_id: SessionId,
540    /// Public keys of all participants.
541    pub all_participants: Vec<Vec<u8>>,
542}
543
544impl SessionState {
545    /// Total number of participants in this session.
546    pub fn len(&self) -> usize {
547        self.all_participants.len()
548    }
549
550    /// Get the party index from a public key.
551    pub fn party_number(
552        &self,
553        public_key: impl AsRef<[u8]>,
554    ) -> Option<PartyNumber> {
555        self.all_participants
556            .iter()
557            .position(|k| k == public_key.as_ref())
558            .map(|pos| PartyNumber::new((pos + 1) as u16).unwrap())
559    }
560
561    /// Get the public key for a party number.
562    pub fn peer_key(
563        &self,
564        party_number: PartyNumber,
565    ) -> Option<&[u8]> {
566        for (index, key) in self.all_participants.iter().enumerate() {
567            if index + 1 == party_number.get() as usize {
568                return Some(key.as_slice());
569            }
570        }
571        None
572    }
573
574    /// Get the connections a peer should make.
575    pub fn connections(&self, own_key: &[u8]) -> &[Vec<u8>] {
576        if self.all_participants.is_empty() {
577            return &[];
578        }
579
580        if let Some(position) =
581            self.all_participants.iter().position(|k| k == own_key)
582        {
583            if position < self.all_participants.len() - 1 {
584                &self.all_participants[position + 1..]
585            } else {
586                &[]
587            }
588        } else {
589            &[]
590        }
591    }
592
593    /// Get the recipients for a broadcast message.
594    pub fn recipients(&self, own_key: &[u8]) -> Vec<Vec<u8>> {
595        self.all_participants
596            .iter()
597            .filter(|&k| k != own_key)
598            .map(|k| k.to_vec())
599            .collect()
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::Chunk;
606    use crate::PATTERN;
607    use anyhow::Result;
608
609    #[test]
610    fn chunks_split_join() -> Result<()> {
611        let builder_1 = snow::Builder::new(PATTERN.parse()?);
612        let builder_2 = snow::Builder::new(PATTERN.parse()?);
613
614        let keypair1 = builder_1.generate_keypair()?;
615        let keypair2 = builder_2.generate_keypair()?;
616
617        let mut initiator = builder_1
618            .local_private_key(&keypair1.private)
619            .remote_public_key(&keypair2.public)
620            .build_initiator()?;
621
622        let mut responder = builder_2
623            .local_private_key(&keypair2.private)
624            .remote_public_key(&keypair1.public)
625            .build_responder()?;
626
627        let (mut read_buf, mut first_msg, mut second_msg) =
628            ([0u8; 1024], [0u8; 1024], [0u8; 1024]);
629
630        // -> e
631        let len = initiator.write_message(&[], &mut first_msg)?;
632
633        // responder processes the first message...
634        responder.read_message(&first_msg[..len], &mut read_buf)?;
635
636        // <- e, ee
637        let len = responder.write_message(&[], &mut second_msg)?;
638
639        // initiator processes the response...
640        initiator.read_message(&second_msg[..len], &mut read_buf)?;
641
642        // NN handshake complete, transition into transport mode.
643        let mut initiator = initiator.into_transport_mode()?;
644        let mut responder = responder.into_transport_mode()?;
645
646        let mock_payload = vec![0; 76893];
647
648        // Split into chunks
649        let chunks = Chunk::split(&mock_payload, &mut initiator)?;
650        assert_eq!(2, chunks.len());
651
652        // Decrypt and combine the chunks
653        let decrypted_payload = Chunk::join(chunks, &mut responder)?;
654        assert_eq!(mock_payload, decrypted_payload);
655
656        Ok(())
657    }
658}