1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//! The `network` module defines abstractions of the transport used to
//! communicate during the course of an MPC
mod cert_verifier;
mod config;
mod mock;
mod quic;
mod stream_buffer;

use ark_ec::CurveGroup;
pub use quic::*;

use futures::{Sink, Stream};
#[cfg(any(feature = "test_helpers", feature = "benchmarks", test))]
pub use mock::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream};

use async_trait::async_trait;
use serde::{Deserialize, Serialize};

use crate::{
    algebra::{curve::CurvePoint, scalar::Scalar},
    error::MpcNetworkError,
    fabric::ResultId,
};

/// A type alias of the id of a party in an MPC for readability
pub type PartyId = u64;

// ---------
// | Trait |
// ---------

/// The type that the network sender receives
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "C: CurveGroup")]
pub struct NetworkOutbound<C: CurveGroup> {
    /// The operation ID that generated this message
    pub result_id: ResultId,
    /// The body of the message
    pub payload: NetworkPayload<C>,
}

/// The payload of an outbound message
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound(serialize = "C: CurveGroup", deserialize = "C: CurveGroup"))]
pub enum NetworkPayload<C: CurveGroup> {
    /// A byte value
    Bytes(Vec<u8>),
    /// A scalar value
    Scalar(Scalar<C>),
    /// A batch of scalar values
    ScalarBatch(Vec<Scalar<C>>),
    /// A point on the curve
    Point(CurvePoint<C>),
    /// A batch of points on the curve
    PointBatch(Vec<CurvePoint<C>>),
}

// ---------------
// | Conversions |
// ---------------

impl<C: CurveGroup> From<Vec<u8>> for NetworkPayload<C> {
    fn from(bytes: Vec<u8>) -> Self {
        Self::Bytes(bytes)
    }
}

impl<C: CurveGroup> From<Scalar<C>> for NetworkPayload<C> {
    fn from(scalar: Scalar<C>) -> Self {
        Self::Scalar(scalar)
    }
}

impl<C: CurveGroup> From<Vec<Scalar<C>>> for NetworkPayload<C> {
    fn from(scalars: Vec<Scalar<C>>) -> Self {
        Self::ScalarBatch(scalars)
    }
}

impl<C: CurveGroup> From<CurvePoint<C>> for NetworkPayload<C> {
    fn from(point: CurvePoint<C>) -> Self {
        Self::Point(point)
    }
}

impl<C: CurveGroup> From<Vec<CurvePoint<C>>> for NetworkPayload<C> {
    fn from(value: Vec<CurvePoint<C>>) -> Self {
        Self::PointBatch(value)
    }
}

/// The `MpcNetwork` trait defines shared functionality for a network implementing a
/// connection between two parties in a 2PC
///
/// Values are sent as bytes, scalars, or curve points and always in batch form with the
/// message length (measured in the number of elements sent) prepended to the message
#[async_trait]
pub trait MpcNetwork<C: CurveGroup>:
    Send
    + Stream<Item = Result<NetworkOutbound<C>, MpcNetworkError>>
    + Sink<NetworkOutbound<C>, Error = MpcNetworkError>
{
    /// Get the party ID of the local party in the MPC
    fn party_id(&self) -> PartyId;
    /// Closes the connections opened in the handshake phase
    async fn close(&mut self) -> Result<(), MpcNetworkError>;
}