1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//! Relay service websocket client using the [noise](https://noiseprotocol.org/)
//! protocol for end-to-end encryption intended for multi-party computation
//! and threshold signature applications.
//!
//! To support the web platform this client library uses
//! [web-sys](https://docs.rs/web-sys/latest/web_sys/) when
//! compiling for webassembly otherwise
//! [tokio-tunsgtenite](https://docs.rs/tokio-tungstenite/latest/tokio_tungstenite/).

#![deny(missing_docs)]

mod client;
mod error;
mod event_loop;
mod transport;

pub(crate) use client::{client_impl, client_transport_impl};
pub use event_loop::{Event, EventStream, JsonMessage};
pub use transport::{NetworkTransport, Transport};

#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod native;

#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub use native::{
    NativeClient as Client, NativeEventLoop as EventLoop,
};

#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
mod web;

#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
pub use web::{WebClient as Client, WebEventLoop as EventLoop};

use mpc_protocol::{
    hex, Encoding, Keypair, OpaqueMessage, ProtocolState,
    RequestMessage, SealedEnvelope, SessionId, TAGLEN,
};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;

pub(crate) type Peers = Arc<RwLock<HashMap<Vec<u8>, ProtocolState>>>;
pub(crate) type Server = Arc<RwLock<Option<ProtocolState>>>;

/// Options used to create a new websocket client.
pub struct ClientOptions {
    /// Client static keypair.
    pub keypair: Keypair,
    /// Public key for the server to connect to.
    pub server_public_key: Vec<u8>,
}

impl ClientOptions {
    /// Build a connection URL for the given server.
    ///
    /// This method appends the public key query string
    /// parameter necessary for connecting to the server.
    pub fn url(&self, server: &str) -> String {
        let server = server.trim_end_matches('/');
        format!(
            "{}/?public_key={}",
            server,
            hex::encode(self.keypair.public_key())
        )
    }
}

pub use error::Error;

/// Result type for the client library.
pub type Result<T> = std::result::Result<T, Error>;

/// Encrypt a message to send to a peer.
///
/// The protocol must be in transport mode.
async fn encrypt_peer_channel(
    public_key: impl AsRef<[u8]>,
    peer: &mut ProtocolState,
    payload: &[u8],
    encoding: Encoding,
    broadcast: bool,
    session_id: Option<SessionId>,
) -> Result<RequestMessage> {
    match peer {
        ProtocolState::Transport(transport) => {
            let mut contents = vec![0; payload.len() + TAGLEN];
            let length =
                transport.write_message(payload, &mut contents)?;
            let envelope = SealedEnvelope {
                length,
                encoding,
                payload: contents,
                broadcast,
            };

            let request =
                RequestMessage::Opaque(OpaqueMessage::PeerMessage {
                    public_key: public_key.as_ref().to_vec(),
                    session_id,
                    envelope,
                });

            Ok(request)
        }
        _ => Err(Error::NotTransportState),
    }
}

/// Decrypt a message received from a peer.
///
/// The protocol must be in transport mode.
async fn decrypt_peer_channel(
    peer: &mut ProtocolState,
    envelope: &SealedEnvelope,
) -> Result<Vec<u8>> {
    match peer {
        ProtocolState::Transport(transport) => {
            let mut contents = vec![0; envelope.length];
            transport.read_message(
                &envelope.payload[..envelope.length],
                &mut contents,
            )?;
            let new_length = contents.len() - TAGLEN;
            contents.truncate(new_length);
            Ok(contents)
        }
        _ => Err(Error::NotTransportState),
    }
}