mpc_core/mpc/scheme.rs
1use core::error::Error;
2use core::fmt::Debug;
3use core::marker::Send;
4
5use super::WireId;
6use crate::{
7 mpc::MpcCircuit,
8 networking::{Network, ReceiveRequest, SendRequest},
9};
10use serde::{Deserialize, Serialize};
11
12/// A type that represents an MPC operation.
13///
14/// An implementation must satisfy the following invariant:
15///
16/// ```rust,ignore
17/// // Any "input" operation must yield exactly one output.
18/// assert!(op.get_input_party_id().is_none() || op.outputs().count() == 1);
19/// ```
20pub trait Operation {
21 /// If the operation is an input, return the id of the party that gives the input. Otherwise,
22 /// return [`None`].
23 fn get_input_party_id(&self) -> Option<usize>;
24
25 /// Return the iterator of the input wires of this operation.
26 fn inputs<'a>(&'a self) -> Box<dyn Iterator<Item = WireId> + 'a>;
27
28 /// Return the iterator of the output wires of this operation.
29 fn outputs<'a>(&'a self) -> Box<dyn Iterator<Item = WireId> + 'a>;
30
31 fn is_input(&self) -> bool {
32 self.get_input_party_id().is_some()
33 }
34}
35
36pub struct NetworkPhaseOutput<'a, S: MpcScheme> {
37 pub pending: S::Pending<'a>,
38 pub send_request: Vec<SendRequest<'a, S::NetworkElement>>,
39 pub receive_request: Vec<ReceiveRequest<S::NetworkElement>>,
40}
41
42#[repr(transparent)]
43pub struct FinalizePhaseOutput<S: MpcScheme>(pub Vec<S::Wire>);
44
45/// A type that represents an MPC scheme.
46pub trait MpcScheme
47where
48 Self: Debug + Clone + Serialize + for<'de> Deserialize<'de>,
49{
50 /// The context type that contains the necessary information that should be remembered during
51 /// the entire execution. See [`MpcScheme::establish_context`],
52 /// [`MpcScheme::do_network_phase`], and [`MpcScheme::do_finalize_phase`].
53 type Context: Debug;
54
55 /// The type that represents the element that travels across the network.
56 type NetworkElement: Debug + Serialize + for<'de> Deserialize<'de>;
57
58 /// The type that represents the wire value.
59 type Wire: Debug;
60
61 /// The type that represents the input from the party.
62 type Input: Debug;
63
64 /// The type that represents the operation.
65 type Operation: Debug + Serialize + for<'de> Deserialize<'de> + Operation;
66
67 /// The type that represents the pending operation returned from [`MpcScheme::do_network_phase`]
68 /// and consumed by [`MpcScheme::do_finalize_phase`].
69 type Pending<'a>: Debug;
70
71 /// The error type returned by [`MpcScheme::establish_context`].
72 type EstablishContextError: Error + Send + Sync + 'static;
73 /// The error type returned by [`MpcScheme::do_network_phase`].
74 type NetworkPhaseError: Error + Send + Sync + 'static;
75 /// The error type returned by [`MpcScheme::do_finalize_phase`].
76 type FinalizePhaseError: Error + Send + Sync + 'static;
77
78 /// Returns true if the circuit "makes sense". [`MpcCircuit::new`] will fail if this returns
79 /// false.
80 ///
81 /// * `circuit`: an iterator of operations
82 fn is_circuit_sound<'a, I>(&self, circuit: I) -> bool
83 where
84 Self::Operation: 'a,
85 I: IntoIterator<Item = &'a Self::Operation>;
86
87 /// Returns `true` if `op` can be done without communication whenever the inputs are ready.
88 ///
89 /// ```rust,ignore
90 /// // The "input" operation of this scheme must not be local.
91 /// assert!(!op.is_input() || !scheme.is_operation_local(op));
92 /// ```
93 ///
94 /// * `op`: The operation.
95 fn is_operation_local(&self, op: &Self::Operation) -> bool;
96
97 /// Specify how the context should be established given a network and the circuit.
98 fn establish_context<N: Network>(
99 &self,
100 network: &mut N,
101 circuit: &MpcCircuit<Self>,
102 ) -> Result<Self::Context, Self::EstablishContextError>;
103
104 /// The caller should ensure that the inputs are valid.
105 fn prepare_user_input<I>(&self, context: &mut Self::Context, inputs: I)
106 where
107 I: IntoIterator<Item = (WireId, Self::Input)>;
108
109 /// Given a context, an operation, and the inputs, return the send/receive requests that should
110 /// be processed before calling the corresponding [`MpcScheme::do_network_phase`]. This
111 /// function must sanitize `inputs`.
112 fn do_network_phase<'a, I>(
113 &self,
114 context: &mut Self::Context,
115 op: &Self::Operation,
116 inputs: I,
117 ) -> Result<NetworkPhaseOutput<'a, Self>, Self::NetworkPhaseError>
118 where
119 Self::Wire: 'a,
120 I: IntoIterator<Item = &'a Self::Wire>;
121
122 /// Given a context, a pending operation, and the (possible) network data, finalize the
123 /// operation and return the corresponding outputs. The return value must be consistent with
124 /// [`Operation::outputs`] (i.e. the number of output wires must match).
125 fn do_finalize_phase<'a, I>(
126 &self,
127 context: &mut Self::Context,
128 pending: Self::Pending<'a>,
129 network_data: I,
130 ) -> Result<FinalizePhaseOutput<Self>, Self::FinalizePhaseError>
131 where
132 I: IntoIterator<Item = Self::NetworkElement>;
133}