Skip to main content

mpc_core/mpc/
scheme.rs

1use core::error::Error;
2use core::fmt::Debug;
3use core::marker::Send;
4
5use super::WireId;
6use crate::{
7    mpc::MpcCircuit,
8    networking::{Network, ReceiveRequest, SendRequest},
9};
10use serde::{Deserialize, Serialize};
11
12/// A type that represents an MPC operation.
13///
14/// An implementation must satisfy the following invariant:
15///
16/// ```rust,ignore
17/// // Any "input" operation must yield exactly one output.
18/// assert!(op.get_input_party_id().is_none() || op.outputs().count() == 1);
19/// ```
20pub trait Operation {
21    /// If the operation is an input, return the id of the party that gives the input. Otherwise,
22    /// return [`None`].
23    fn get_input_party_id(&self) -> Option<usize>;
24
25    /// Return the iterator of the input wires of this operation.
26    fn inputs<'a>(&'a self) -> Box<dyn Iterator<Item = WireId> + 'a>;
27
28    /// Return the iterator of the output wires of this operation.
29    fn outputs<'a>(&'a self) -> Box<dyn Iterator<Item = WireId> + 'a>;
30
31    fn is_input(&self) -> bool {
32        self.get_input_party_id().is_some()
33    }
34}
35
36pub struct NetworkPhaseOutput<'a, S: MpcScheme> {
37    pub pending: S::Pending<'a>,
38    pub send_request: Vec<SendRequest<'a, S::NetworkElement>>,
39    pub receive_request: Vec<ReceiveRequest<S::NetworkElement>>,
40}
41
42#[repr(transparent)]
43pub struct FinalizePhaseOutput<S: MpcScheme>(pub Vec<S::Wire>);
44
45/// A type that represents an MPC scheme.
46pub trait MpcScheme
47where
48    Self: Debug + Clone + Serialize + for<'de> Deserialize<'de>,
49{
50    /// The context type that contains the necessary information that should be remembered during
51    /// the entire execution. See [`MpcScheme::establish_context`],
52    /// [`MpcScheme::do_network_phase`], and [`MpcScheme::do_finalize_phase`].
53    type Context: Debug;
54
55    /// The type that represents the element that travels across the network.
56    type NetworkElement: Debug + Serialize + for<'de> Deserialize<'de>;
57
58    /// The type that represents the wire value.
59    type Wire: Debug;
60
61    /// The type that represents the input from the party.
62    type Input: Debug;
63
64    /// The type that represents the operation.
65    type Operation: Debug + Serialize + for<'de> Deserialize<'de> + Operation;
66
67    /// The type that represents the pending operation returned from [`MpcScheme::do_network_phase`]
68    /// and consumed by [`MpcScheme::do_finalize_phase`].
69    type Pending<'a>: Debug;
70
71    /// The error type returned by [`MpcScheme::establish_context`].
72    type EstablishContextError: Error + Send + Sync + 'static;
73    /// The error type returned by [`MpcScheme::do_network_phase`].
74    type NetworkPhaseError: Error + Send + Sync + 'static;
75    /// The error type returned by [`MpcScheme::do_finalize_phase`].
76    type FinalizePhaseError: Error + Send + Sync + 'static;
77
78    /// Returns true if the circuit "makes sense". [`MpcCircuit::new`] will fail if this returns
79    /// false.
80    ///
81    /// * `circuit`: an iterator of operations
82    fn is_circuit_sound<'a, I>(&self, circuit: I) -> bool
83    where
84        Self::Operation: 'a,
85        I: IntoIterator<Item = &'a Self::Operation>;
86
87    /// Returns `true` if `op` can be done without communication whenever the inputs are ready.
88    ///
89    /// ```rust,ignore
90    /// // The "input" operation of this scheme must not be local.
91    /// assert!(!op.is_input() || !scheme.is_operation_local(op));
92    /// ```
93    ///
94    /// * `op`: The operation.
95    fn is_operation_local(&self, op: &Self::Operation) -> bool;
96
97    /// Specify how the context should be established given a network and the circuit.
98    fn establish_context<N: Network>(
99        &self,
100        network: &mut N,
101        circuit: &MpcCircuit<Self>,
102    ) -> Result<Self::Context, Self::EstablishContextError>;
103
104    /// The caller should ensure that the inputs are valid.
105    fn prepare_user_input<I>(&self, context: &mut Self::Context, inputs: I)
106    where
107        I: IntoIterator<Item = (WireId, Self::Input)>;
108
109    /// Given a context, an operation, and the inputs, return the send/receive requests that should
110    /// be processed before calling the corresponding [`MpcScheme::do_network_phase`]. This
111    /// function must sanitize `inputs`.
112    fn do_network_phase<'a, I>(
113        &self,
114        context: &mut Self::Context,
115        op: &Self::Operation,
116        inputs: I,
117    ) -> Result<NetworkPhaseOutput<'a, Self>, Self::NetworkPhaseError>
118    where
119        Self::Wire: 'a,
120        I: IntoIterator<Item = &'a Self::Wire>;
121
122    /// Given a context, a pending operation, and the (possible) network data, finalize the
123    /// operation and return the corresponding outputs. The return value must be consistent with
124    /// [`Operation::outputs`] (i.e. the number of output wires must match).
125    fn do_finalize_phase<'a, I>(
126        &self,
127        context: &mut Self::Context,
128        pending: Self::Pending<'a>,
129        network_data: I,
130    ) -> Result<FinalizePhaseOutput<Self>, Self::FinalizePhaseError>
131    where
132        I: IntoIterator<Item = Self::NetworkElement>;
133}