polytune 0.1.0

Maliciously-Secure Multi-Party Computation (MPC) Engine using Authenticated Garbling
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
//! Provides communication channels for sending and receiving messages between parties.
//!
//! This module defines the fundamental abstraction for communication in the form of the [`Channel`]
//! trait, which can be implemented to support various communication methods and environments.
//!
//! The core design philosophy is to separate the protocol logic from the specifics of message
//! transport. The protocol implementation does not need to be concerned with how messages are
//! physically transmitted - it only interacts with the abstract `Channel` interface. This means you
//! can switch between different channel implementations (network sockets, in-memory channels, etc.)
//! without changing protocol code.
//!
//! ## Message Chunking
//!
//! The module provides automatic chunking of large messages to avoid issues with message size
//! limits. Messages are split into chunks, serialized, and reassembled on the receiving end
//! transparently.
//!
//! ## Serialization
//!
//! Messages are serialized using `bincode`, allowing for efficient binary encoding of structured
//! data. The channel primarily works with byte vectors, while higher-level send/receive functions
//! handle serialization and deserialization of application-level messages.

use std::fmt;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::atomic::{AtomicU64, Ordering};

use futures::future::{try_join, try_join_all};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::{
    Mutex,
    mpsc::{Receiver, Sender, channel},
};
#[cfg(not(target_arch = "wasm32"))]
use tracing::{Level, trace};

/// Errors related to sending / receiving / (de-)serializing messages.
#[derive(Debug)]
pub struct Error {
    /// The protocol phase during which the error occurred.
    pub phase: String,
    /// The specific error that was raised.
    pub reason: ErrorKind,
}

/// The specific error that occurred when trying to send / receive a message.
#[derive(Debug)]
pub enum ErrorKind {
    /// The (serialized) message could not be received over the channel.
    RecvError(String),
    /// The (serialized) message could not be sent over the channel.
    SendError(String),
    /// The message could not be serialized (before sending it out).
    SerdeError(String),
    /// The message is a Vec, but not of the expected length.
    InvalidLength,
}

/// A chunk of a message as bytes and the number of chunks remaining to be sent.
#[derive(Debug, Serialize)]
struct SendChunk<'a, T> {
    /// A chunk of a full message.
    chunk: &'a [T],
    /// The number of chunks that remain to be sent after the current one.
    remaining_chunks: usize,
}

/// A chunk of a message as bytes and the number of chunks remaining to be sent.
#[derive(Debug, Deserialize)]
struct RecvChunk<T> {
    /// A chunk of a full message.
    chunk: Vec<T>,
    /// The number of chunks that remain to be sent after the current one.
    remaining_chunks: usize,
}

/// Information about a sent message that can be useful for logging.
#[derive(Debug, Clone)]
pub struct SendInfo {
    phase: String,
    current_msg: usize,
    remaining_msgs: usize,
}

impl SendInfo {
    /// The name of the protocol phase that sent the message.
    pub fn phase(&self) -> &str {
        &self.phase
    }

    /// How many chunks have already been sent, 1 for the first message, 2 for the second, etc.
    pub fn sent(&self) -> usize {
        self.current_msg + 1
    }

    /// How many chunks have yet to be sent for the full message to be transmitted.
    pub fn remaining(&self) -> usize {
        self.remaining_msgs
    }

    /// The total number of chunks that make up the full message.
    pub fn total(&self) -> usize {
        self.sent() + self.remaining()
    }
}

/// Information about a received message that can be useful for logging.
#[derive(Debug, Clone)]
pub struct RecvInfo {
    phase: String,
    current_msg: usize,
    remaining_msgs: Option<usize>,
}

impl RecvInfo {
    /// The name of the protocol phase that sent the message.
    pub fn phase(&self) -> &str {
        &self.phase
    }

    /// How many chunks have already been sent, 1 for the first message, 2 for the second, etc.
    pub fn sent(&self) -> usize {
        self.current_msg + 1
    }

    /// How many chunks have yet to be sent for the full message to be transmitted.
    ///
    /// Will be `None` for the first message, before it is clear how many chunks need to be sent.
    pub fn remaining(&self) -> Option<usize> {
        self.remaining_msgs
    }

    /// The total number of chunks that make up the full message.
    ///
    /// Will be `None` for the first message, before it is clear how many chunks need to be sent.
    pub fn total(&self) -> Option<usize> {
        self.remaining().map(|remaining| self.sent() + remaining)
    }
}

/// A communication channel used to send/receive messages to/from another party.
///
/// This trait defines the core interface for message transport in the protocol.
/// Implementations of this trait determine how messages are physically sent and received,
/// which can vary based on the environment (network, in-process, etc.).
pub trait Channel {
    /// The error that can occur sending messages over the channel.
    type SendError: fmt::Debug;
    /// The error that can occur receiving messages over the channel.
    type RecvError: fmt::Debug;

    /// Sends a message to the party with the given index (must be between `0..participants`).
    // We allow the async_fn_in_trait lint because we don't need to place additional bounds on
    // the returned future. We don't want to enforce returning Send futures as that is not
    // compatible with the `examples/wasm-http-channels` implementation.
    #[allow(async_fn_in_trait)]
    async fn send_bytes_to(
        &self,
        party: usize,
        chunk: Vec<u8>,
        info: SendInfo,
    ) -> Result<(), Self::SendError>;

    /// Awaits a response from the party with the given index (must be between `0..participants`).
    #[allow(async_fn_in_trait)]
    async fn recv_bytes_from(
        &self,
        party: usize,
        info: RecvInfo,
    ) -> Result<Vec<u8>, Self::RecvError>;
}

/// Serializes and sends an MPC message to the other party.
pub(crate) async fn send_to<S: Serialize + std::fmt::Debug>(
    channel: &impl Channel,
    party: usize,
    phase: &str,
    msg: &[S],
) -> Result<(), Error> {
    let chunk_size = 5_000_000;
    let mut chunks: Vec<_> = msg.chunks(chunk_size).collect();
    if chunks.is_empty() {
        chunks.push(&[]);
    }
    let length = chunks.len();
    for (i, chunk) in chunks.into_iter().enumerate() {
        let remaining_chunks = length - i - 1;
        let chunk = SendChunk {
            chunk,
            remaining_chunks,
        };
        let chunk = bincode::serialize(&chunk).map_err(|e| Error {
            phase: format!("sending {phase}"),
            reason: ErrorKind::SerdeError(format!("{e:?}")),
        })?;
        let info = SendInfo {
            phase: phase.to_string(),
            current_msg: i,
            remaining_msgs: remaining_chunks,
        };
        channel
            .send_bytes_to(party, chunk, info)
            .await
            .map_err(|e| Error {
                phase: phase.to_string(),
                reason: ErrorKind::SendError(format!("{e:?}")),
            })?;
    }
    Ok(())
}

/// Receives and deserializes an MPC message from the other party.
pub(crate) async fn recv_from<T: DeserializeOwned + std::fmt::Debug>(
    channel: &impl Channel,
    party: usize,
    phase: &str,
) -> Result<Vec<T>, Error> {
    let mut msg = vec![];
    let mut i = 0;
    let mut remaining = None;
    loop {
        let info = RecvInfo {
            phase: phase.to_string(),
            current_msg: i,
            remaining_msgs: remaining,
        };
        let chunk = channel
            .recv_bytes_from(party, info)
            .await
            .map_err(|e| Error {
                phase: phase.to_string(),
                reason: ErrorKind::RecvError(format!("{e:?}")),
            })?;
        let RecvChunk {
            chunk,
            remaining_chunks,
        }: RecvChunk<T> = bincode::deserialize(&chunk).map_err(|e| Error {
            phase: format!("receiving {phase}"),
            reason: ErrorKind::SerdeError(format!("{e:?}")),
        })?;
        msg.extend(chunk);
        if remaining_chunks == 0 {
            return Ok(msg);
        }
        remaining = Some(remaining_chunks);
        i += 1;
    }
}

/// Receives and deserializes a Vec from the other party (while checking the length).
pub(crate) async fn recv_vec_from<T: DeserializeOwned + std::fmt::Debug>(
    channel: &impl Channel,
    party: usize,
    phase: &str,
    len: usize,
) -> Result<Vec<T>, Error> {
    let v: Vec<T> = recv_from(channel, party, phase).await?;
    if v.len() == len {
        Ok(v)
    } else {
        Err(Error {
            phase: phase.to_string(),
            reason: ErrorKind::InvalidLength,
        })
    }
}

/// Broadcasts the same data to all parties except self and receives responses from all other parties.
///
/// All sending and receiving is done concurrently.
///
/// # Security
/// Note that this is an unverified broadcast. If you need a broadcast that verifies that
/// each party actually sends the same data to the others, use [`crate::faand::broadcast`].
///
/// # Arguments
/// * `channel` - The communication channel
/// * `own_party` - Index of the current party (won't send to itself)
/// * `num_parties` - Total number of parties
/// * `phase` - Protocol phase name for message identification
/// * `data` - Data to send to all other parties
/// * `expected_recv_len` - Expected length of received vectors
///
/// # Returns
/// A vector indexed by party ID, where `result[i]` contains the response from party `i`.
/// The entry for `own_party` will be empty.
pub(crate) async fn unverified_broadcast<T>(
    channel: &impl Channel,
    own_party: usize,
    num_parties: usize,
    phase: &str,
    data: &[T],
) -> Result<Vec<Vec<T>>, Error>
where
    T: Serialize + DeserializeOwned + std::fmt::Debug,
{
    let expected_recv_len = data.len();
    let send_fut = try_join_all(
        (0..num_parties)
            .filter(|p| *p != own_party)
            .map(|p| send_to(channel, p, phase, data)),
    );

    let recv_fut = try_join_all((0..num_parties).map(async |p| {
        if p != own_party {
            recv_vec_from(channel, p, phase, expected_recv_len).await
        } else {
            Ok(vec![])
        }
    }));

    let (_, responses) = try_join(send_fut, recv_fut).await?;
    Ok(responses)
}

/// Scatters different data to each party and receives responses from all other parties.
///
/// All sending and receiving is done concurrently.
///
/// # Arguments
/// * `channel` - The communication channel
/// * `own_party` - Index of the current party (won't send to itself)
/// * `phase` - Protocol phase name for message identification
/// * `data_per_party` - Vector where `data_per_party[i]` is sent to party `i` except
///   when `i == own_party`
/// * `expected_recv_len` - Expected length of received vectors
///
/// # Returns
/// A vector indexed by party ID, where `result[i]` contains the response from party `i`.
/// `result[own_party]` will be an empty `Vec`.
pub(crate) async fn scatter<T>(
    channel: &impl Channel,
    own_party: usize,
    phase: &str,
    data_per_party: &[Vec<T>],
) -> Result<Vec<Vec<T>>, Error>
where
    T: Serialize + DeserializeOwned + std::fmt::Debug,
{
    let num_parties = data_per_party.len();

    let mut expected_recv_len = None;

    for (p, data) in data_per_party.iter().enumerate() {
        if p == own_party {
            continue;
        }
        // The first time we see a non-zero length vector we initialize
        // expected_recv_len
        if expected_recv_len.is_none() && !data.is_empty() {
            expected_recv_len = Some(data.len());
            continue;
        }

        if let Some(len) = expected_recv_len
            && len != data.len()
        {
            return Err(Error {
                phase: phase.to_string(),
                reason: ErrorKind::InvalidLength,
            });
        }
    }
    let Some(expected_recv_len) = expected_recv_len else {
        // data_per_party is empty if expected_recv_len is None
        return Ok(vec![]);
    };

    let send_fut = try_join_all(
        (0..num_parties)
            .filter(|p| *p != own_party)
            .map(|p| send_to(channel, p, phase, &data_per_party[p])),
    );

    let recv_fut = try_join_all((0..num_parties).map(async |p| {
        if p != own_party {
            recv_vec_from(channel, p, phase, expected_recv_len).await
        } else {
            Ok(vec![])
        }
    }));

    let (_, responses) = try_join(send_fut, recv_fut).await?;
    Ok(responses)
}

/// A simple asynchronous channel using [`Sender`] and [`Receiver`].
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
#[allow(dead_code)]
#[doc(hidden)]
pub struct SimpleChannel {
    s: Vec<Option<Sender<Vec<u8>>>>,
    r: Vec<Option<Mutex<Receiver<Vec<u8>>>>>,
    /// The total number of bytes sent over the channel.
    bytes_sent: AtomicU64,
}

#[cfg(not(target_arch = "wasm32"))]
impl SimpleChannel {
    /// Creates channels for N parties to communicate with each other.
    pub fn channels(parties: usize) -> Vec<Self> {
        let buffer_capacity = 1024;
        let mut channels = vec![];
        for _ in 0..parties {
            let mut s = vec![];
            let mut r = vec![];
            for _ in 0..parties {
                s.push(None);
                r.push(None);
            }
            let bytes_sent = AtomicU64::new(0);
            channels.push(SimpleChannel { s, r, bytes_sent });
        }
        for a in 0..parties {
            for b in 0..parties {
                if a == b {
                    continue;
                }
                let (send_a_to_b, recv_a_to_b) = channel(buffer_capacity);
                let (send_b_to_a, recv_b_to_a) = channel(buffer_capacity);
                channels[a].s[b] = Some(send_a_to_b);
                channels[b].s[a] = Some(send_b_to_a);
                channels[a].r[b] = Some(Mutex::new(recv_b_to_a));
                channels[b].r[a] = Some(Mutex::new(recv_a_to_b));
            }
        }
        channels
    }

    /// Returns the total number of bytes sent on this channel.
    pub fn bytes_sent(&self) -> u64 {
        self.bytes_sent.load(Ordering::Relaxed)
    }
}

/// The error raised by `recv` calls of a [`SimpleChannel`].
#[derive(Debug)]
#[cfg(not(target_arch = "wasm32"))]
#[doc(hidden)]
pub enum AsyncRecvError {
    /// The channel has been closed.
    Closed,
    /// No message was received before the timeout.
    TimeoutElapsed,
}

#[cfg(not(target_arch = "wasm32"))]
impl Channel for SimpleChannel {
    type SendError = tokio::sync::mpsc::error::SendError<Vec<u8>>;
    type RecvError = AsyncRecvError;

    #[tracing::instrument(level = Level::TRACE, skip(self, msg))]
    async fn send_bytes_to(
        &self,
        p: usize,
        msg: Vec<u8>,
        info: SendInfo,
    ) -> Result<(), tokio::sync::mpsc::error::SendError<Vec<u8>>> {
        self.bytes_sent
            .fetch_add(msg.len() as u64, Ordering::Relaxed);
        let mb = msg.len() as f64 / 1024.0 / 1024.0;
        let i = info.sent();
        if i == 1 {
            trace!(size = mb, "Sending msg");
        } else {
            trace!(size = mb, "  (continued sending msg)");
        }
        self.s[p]
            .as_ref()
            .unwrap_or_else(|| panic!("No sender for party {p}"))
            .send(msg)
            .await
    }

    #[tracing::instrument(level = Level::TRACE, skip(self), fields(info = ?_info))]
    async fn recv_bytes_from(&self, p: usize, _info: RecvInfo) -> Result<Vec<u8>, AsyncRecvError> {
        let mut r = self.r[p]
            .as_ref()
            .unwrap_or_else(|| panic!("No receiver for party {p}"))
            .lock()
            .await;
        let chunk = r.recv();
        match tokio::time::timeout(std::time::Duration::from_secs(10 * 60), chunk).await {
            Ok(Some(chunk)) => {
                let mb = chunk.len() as f64 / 1024.0 / 1024.0;
                trace!(size = mb, "Received chunk");
                Ok(chunk)
            }
            Ok(None) => Err(AsyncRecvError::Closed),
            Err(_) => Err(AsyncRecvError::TimeoutElapsed),
        }
    }
}