ark_mpc/
network.rs

1//! The `network` module defines abstractions of the transport used to
2//! communicate during the course of an MPC
3mod cert_verifier;
4mod config;
5mod mock;
6mod quic;
7mod stream_buffer;
8
9use ark_ec::CurveGroup;
10pub use quic::*;
11
12use futures::{Sink, Stream};
13#[cfg(any(feature = "test_helpers", feature = "benchmarks", test))]
14pub use mock::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream};
15
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20    algebra::{CurvePoint, Scalar},
21    error::MpcNetworkError,
22    fabric::ResultId,
23};
24
25/// A type alias of the id of a party in an MPC for readability
26pub type PartyId = u64;
27
28// ---------
29// | Trait |
30// ---------
31
32/// The type that the network sender receives
33#[derive(Clone, Debug, Serialize, Deserialize)]
34#[serde(bound = "C: CurveGroup")]
35pub struct NetworkOutbound<C: CurveGroup> {
36    /// The operation ID that generated this message
37    pub result_id: ResultId,
38    /// The body of the message
39    pub payload: NetworkPayload<C>,
40}
41
42/// The payload of an outbound message
43#[derive(Clone, Debug, Serialize, Deserialize)]
44#[serde(bound(serialize = "C: CurveGroup", deserialize = "C: CurveGroup"))]
45pub enum NetworkPayload<C: CurveGroup> {
46    /// A byte value
47    Bytes(Vec<u8>),
48    /// A scalar value
49    Scalar(Scalar<C>),
50    /// A batch of scalar values
51    ScalarBatch(Vec<Scalar<C>>),
52    /// A point on the curve
53    Point(CurvePoint<C>),
54    /// A batch of points on the curve
55    PointBatch(Vec<CurvePoint<C>>),
56}
57
58// ---------------
59// | Conversions |
60// ---------------
61
62impl<C: CurveGroup> From<Vec<u8>> for NetworkPayload<C> {
63    fn from(bytes: Vec<u8>) -> Self {
64        Self::Bytes(bytes)
65    }
66}
67
68impl<C: CurveGroup> From<Scalar<C>> for NetworkPayload<C> {
69    fn from(scalar: Scalar<C>) -> Self {
70        Self::Scalar(scalar)
71    }
72}
73
74impl<C: CurveGroup> From<Vec<Scalar<C>>> for NetworkPayload<C> {
75    fn from(scalars: Vec<Scalar<C>>) -> Self {
76        Self::ScalarBatch(scalars)
77    }
78}
79
80impl<C: CurveGroup> From<CurvePoint<C>> for NetworkPayload<C> {
81    fn from(point: CurvePoint<C>) -> Self {
82        Self::Point(point)
83    }
84}
85
86impl<C: CurveGroup> From<Vec<CurvePoint<C>>> for NetworkPayload<C> {
87    fn from(value: Vec<CurvePoint<C>>) -> Self {
88        Self::PointBatch(value)
89    }
90}
91
92/// The `MpcNetwork` trait defines shared functionality for a network
93/// implementing a connection between two parties in a 2PC
94///
95/// Values are sent as bytes, scalars, or curve points and always in batch form
96/// with the message length (measured in the number of elements sent) prepended
97/// to the message
98#[async_trait]
99pub trait MpcNetwork<C: CurveGroup>:
100    Send
101    + Stream<Item = Result<NetworkOutbound<C>, MpcNetworkError>>
102    + Sink<NetworkOutbound<C>, Error = MpcNetworkError>
103{
104    /// Get the party ID of the local party in the MPC
105    fn party_id(&self) -> PartyId;
106    /// Closes the connections opened in the handshake phase
107    async fn close(&mut self) -> Result<(), MpcNetworkError>;
108}