Skip to main content

mpc_core/mpc/
execution.rs

1use itertools::Itertools;
2use sha3::{Digest, Sha3_512, digest::generic_array::GenericArray};
3use snafu::{self, ResultExt, Whatever};
4use std::{collections::HashMap, mem::MaybeUninit};
5
6use crate::{
7    mpc::{
8        NetworkPhaseOutput, Operation,
9        circuit::{Round, WireId},
10        config::MpcConfig,
11        scheme::MpcScheme,
12    },
13    networking::{Network, ReceiveRequest, RecvLen, SendLen},
14};
15
16#[derive(Debug)]
17pub enum ExecutionState {
18    NewBorn,         // just created, no action has been performed
19    Handshaked,      // handshaked; checked that parties have the same mpc plan
20    FinishedOffline, // finished offline rounds
21    ReadyOnline,     // accepted inputs from user
22    Finished,        // finished execution
23}
24
25type WireMap<S> = HashMap<WireId, <S as MpcScheme>::Wire>;
26
27#[allow(dead_code)]
28pub struct ExecutionContext<S, N>
29where
30    S: MpcScheme,
31    N: Network,
32{
33    config: MpcConfig<S>,
34    execution_state: ExecutionState,
35    network: N,
36    wire_contents: WireMap<S>,
37    scheme_context: MaybeUninit<S::Context>,
38}
39
40impl<S, N> Drop for ExecutionContext<S, N>
41where
42    S: MpcScheme,
43    N: Network,
44{
45    fn drop(&mut self) {
46        use ExecutionState::*;
47        if matches!(
48            self.execution_state,
49            Handshaked | FinishedOffline | ReadyOnline | Finished
50        ) {
51            unsafe {
52                self.scheme_context.assume_init_drop();
53            }
54        }
55    }
56}
57
58macro_rules! assert_state {
59    ($fn_name:literal, $state_var:expr, $state_pat:pat) => {
60        assert!(
61            matches!($state_var, $state_pat),
62            concat!(
63                "`ExecutionContext::",
64                $fn_name,
65                "` must be called when `ExecutionContext::execution_state` is `",
66                stringify!($state_pat),
67                "`"
68            )
69        )
70    };
71}
72
73impl<S, N> ExecutionContext<S, N>
74where
75    S: MpcScheme,
76    N: Network,
77{
78    pub fn new(config: MpcConfig<S>, network: N) -> Result<Self, Whatever> {
79        if config.n_parties() != network.n_players() {
80            snafu::whatever!("#parties mismatch");
81        }
82        if config.my_id() != network.my_id() {
83            snafu::whatever!("Id mismatch");
84        }
85        Ok(Self {
86            config,
87            execution_state: ExecutionState::NewBorn,
88            network,
89            wire_contents: HashMap::new(),
90            scheme_context: MaybeUninit::uninit(),
91        })
92    }
93
94    pub fn execution_state(&self) -> &ExecutionState {
95        &self.execution_state
96    }
97
98    pub async fn handshake(&mut self) -> Result<(SendLen, RecvLen), Whatever> {
99        fn ctx<E: std::error::Error>(_: &mut E) -> String {
100            "Handshake error".to_string()
101        }
102
103        assert_state!("handshake", self.execution_state, ExecutionState::NewBorn);
104
105        let mut hasher = Sha3_512::new();
106        hasher.update(&postcard::to_stdvec(&self.config.circuit()).with_whatever_context(ctx)?);
107        let hash = hasher.finalize();
108
109        let send_len = self
110            .network
111            .broadcast_object(&hash)
112            .await
113            .with_whatever_context(ctx)?;
114
115        let my_id = self.network.my_id();
116        let request: Vec<_> = (0..self.network.n_players())
117            .filter(|from| *from != my_id)
118            .map(|from| ReceiveRequest::new(from, 1))
119            .collect();
120
121        let (hashes, recv_len) = self
122            .network
123            .recv_objects_many::<GenericArray<_, _>, _>(&request)
124            .await
125            .with_whatever_context(ctx)?;
126
127        if hashes
128            .into_iter()
129            .any(|x| x.into_iter().next().unwrap() != hash)
130        {
131            snafu::whatever!("Plan not agreed");
132        }
133
134        let context = self
135            .config
136            .scheme()
137            .establish_context(&mut self.network, self.config.circuit())
138            .with_whatever_context(ctx)?;
139        self.scheme_context = MaybeUninit::new(context);
140
141        self.execution_state = ExecutionState::Handshaked;
142        Ok((send_len, recv_len))
143    }
144
145    fn do_local_operations(
146        scheme: &S,
147        context: &mut S::Context,
148        wire_contents: &mut WireMap<S>,
149        round: &Round<S>,
150    ) -> Result<(), Whatever> {
151        fn ctx<E: std::error::Error>(_: &mut E) -> String {
152            "Local operation error".to_string()
153        }
154
155        let ops = round.local_operations();
156
157        for op in ops {
158            let inputs = op
159                .inputs()
160                .map(|id| wire_contents.get(&id).expect("Expected a wire content"))
161                .collect::<Vec<_>>();
162            let NetworkPhaseOutput {
163                pending,
164                send_request,
165                receive_request,
166            } = scheme
167                .do_network_phase(context, op, inputs)
168                .with_whatever_context(ctx)?;
169            assert!(send_request.is_empty());
170            assert!(receive_request.is_empty());
171
172            let output = scheme
173                .do_finalize_phase(context, pending, Vec::new())
174                .with_whatever_context(ctx)?;
175
176            for (wire_id, wire_content) in itertools::zip_eq(op.outputs(), output.0) {
177                let prev = wire_contents.insert(wire_id, wire_content);
178                assert!(prev.is_none());
179            }
180        }
181
182        Ok(())
183    }
184
185    async fn do_network_operations(
186        scheme: &S,
187        context: &mut S::Context,
188        wire_contents: &mut WireMap<S>,
189        network: &mut N,
190        round: &Round<S>,
191    ) -> Result<(SendLen, RecvLen), Whatever> {
192        fn ctx<E: std::error::Error>(_: &mut E) -> String {
193            "Network operation error".to_string()
194        }
195
196        let ops = round.network_operations();
197
198        let mut pending_ops = Vec::with_capacity(ops.len());
199        let mut all_receive_requests = Vec::new();
200        let mut all_send_requests = Vec::new();
201        let mut receive_request_counts = Vec::with_capacity(ops.len());
202
203        for op in ops {
204            let inputs = op
205                .inputs()
206                .map(|id| wire_contents.get(&id).expect("Expected a wire content"))
207                .collect::<Vec<_>>();
208
209            let NetworkPhaseOutput {
210                pending,
211                mut send_request,
212                mut receive_request,
213            } = scheme
214                .do_network_phase(context, op, inputs)
215                .with_whatever_context(ctx)?;
216
217            pending_ops.push(pending);
218            receive_request_counts.push(receive_request.len());
219            all_send_requests.append(&mut send_request);
220            all_receive_requests.append(&mut receive_request);
221
222            assert!(send_request.is_empty());
223            assert!(receive_request.is_empty());
224        }
225
226        let send_len = network
227            .send_objects_many(&all_send_requests)
228            .await
229            .with_whatever_context(ctx)?;
230
231        let (flat_network_data, recv_len) = network
232            .recv_objects_many::<S::NetworkElement, _>(&all_receive_requests)
233            .await
234            .with_whatever_context(ctx)?;
235
236        let mut data_iter = flat_network_data.into_iter();
237
238        let mut buffered_results = Vec::with_capacity(ops.len());
239        for (op, pending, req_count) in
240            itertools::izip!(ops, pending_ops.into_iter(), receive_request_counts)
241        {
242            let op_data: Vec<_> = data_iter.by_ref().take(req_count).flatten().collect();
243            let output = scheme
244                .do_finalize_phase(context, pending, op_data)
245                .with_whatever_context(ctx)?;
246            buffered_results.push((op, output));
247        }
248
249        for (op, output) in buffered_results {
250            for (wire_id, wire_content) in itertools::zip_eq(op.outputs(), output.0) {
251                let prev = wire_contents.insert(wire_id, wire_content);
252                assert!(prev.is_none());
253            }
254        }
255
256        Ok((send_len, recv_len))
257    }
258
259    pub async fn do_offline(&mut self) -> Result<(SendLen, RecvLen), Whatever> {
260        assert_state!(
261            "do_offline",
262            self.execution_state,
263            ExecutionState::Handshaked
264        );
265
266        let mut total_send_len = 0;
267        let mut total_recv_len = 0;
268        for round in self.config.circuit().offline_rounds().iter() {
269            Self::do_local_operations(
270                self.config.circuit().scheme(),
271                unsafe { self.scheme_context.assume_init_mut() },
272                &mut self.wire_contents,
273                round,
274            )?;
275            let (send_len, recv_len) = Self::do_network_operations(
276                self.config.circuit().scheme(),
277                unsafe { self.scheme_context.assume_init_mut() },
278                &mut self.wire_contents,
279                &mut self.network,
280                round,
281            )
282            .await?;
283            total_send_len += send_len;
284            total_recv_len += recv_len;
285        }
286
287        self.execution_state = ExecutionState::FinishedOffline;
288
289        Ok((total_send_len, total_recv_len))
290    }
291
292    pub fn prepare_input<I>(&mut self, user_inputs: I) -> Result<(), Whatever>
293    where
294        I: IntoIterator<Item = (WireId, S::Input)>,
295    {
296        use std::collections::hash_map::Entry;
297        assert_state!(
298            "prepare_input",
299            self.execution_state,
300            ExecutionState::FinishedOffline
301        );
302
303        let mut input_wires: HashMap<_, _> = self
304            .config
305            .circuit()
306            .get_input_operation_wire_ids_of_party(self.network.my_id())
307            .map(|x| (x, false))
308            .collect();
309
310        let mut inputs = Vec::with_capacity(input_wires.len());
311        for (wire_id, input) in user_inputs.into_iter() {
312            match input_wires.entry(wire_id) {
313                Entry::Occupied(mut occupied_entry) if !occupied_entry.get() => {
314                    *occupied_entry.get_mut() = true;
315                    inputs.push((wire_id, input));
316                }
317                Entry::Occupied(_) => snafu::whatever!("Duplicate input for wire {}", wire_id),
318                Entry::Vacant(_) => {
319                    snafu::whatever!("Wire {} is not an output of an input operation", wire_id)
320                }
321            };
322        }
323
324        if inputs.len() < input_wires.len() {
325            snafu::whatever!("Not enough inputs are provided");
326        }
327
328        self.config
329            .scheme()
330            .prepare_user_input(unsafe { self.scheme_context.assume_init_mut() }, inputs);
331
332        self.execution_state = ExecutionState::ReadyOnline;
333
334        Ok(())
335    }
336
337    pub async fn do_online(&mut self) -> Result<(SendLen, RecvLen), Whatever> {
338        assert_state!(
339            "do_online",
340            self.execution_state,
341            ExecutionState::ReadyOnline
342        );
343
344        let mut total_send_len = 0;
345        let mut total_recv_len = 0;
346        for round in self.config.circuit().online_rounds().iter() {
347            Self::do_local_operations(
348                self.config.circuit().scheme(),
349                unsafe { self.scheme_context.assume_init_mut() },
350                &mut self.wire_contents,
351                round,
352            )?;
353            let (send_len, recv_len) = Self::do_network_operations(
354                self.config.circuit().scheme(),
355                unsafe { self.scheme_context.assume_init_mut() },
356                &mut self.wire_contents,
357                &mut self.network,
358                round,
359            )
360            .await?;
361
362            total_send_len += send_len;
363            total_recv_len += recv_len;
364        }
365
366        self.execution_state = ExecutionState::Finished;
367
368        Ok((total_send_len, total_recv_len))
369    }
370
371    pub fn dump_wires_with_filter<P>(&self, mut predicate: P)
372    where
373        P: FnMut(&S::Wire) -> bool,
374    {
375        use std::{cmp, fmt};
376        #[allow(unused)]
377        struct Dummy1<'a, S: MpcScheme> {
378            id: WireId,
379            wire: &'a S::Wire,
380        }
381
382        impl<'a, S: MpcScheme> PartialEq for Dummy1<'a, S> {
383            fn eq(&self, other: &Self) -> bool {
384                self.id == other.id
385            }
386        }
387
388        impl<'a, S: MpcScheme> Eq for Dummy1<'a, S> {}
389
390        impl<'a, S: MpcScheme> cmp::PartialOrd for Dummy1<'a, S> {
391            fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
392                Some(self.cmp(other))
393            }
394        }
395
396        impl<'a, S: MpcScheme> cmp::Ord for Dummy1<'a, S> {
397            fn cmp(&self, other: &Self) -> std::cmp::Ordering {
398                self.id.0.cmp(&other.id.0)
399            }
400        }
401
402        struct Dummy2<'a, S: MpcScheme>(Vec<Dummy1<'a, S>>);
403
404        impl<'a, S: MpcScheme> fmt::Debug for Dummy2<'a, S> {
405            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406                f.debug_map()
407                    .entries(self.0.iter().map(|d| (d.id.0, d.wire)))
408                    .finish()
409            }
410        }
411
412        let to_print = self
413            .wire_contents
414            .iter()
415            .filter_map(move |(&id, wire)| predicate(wire).then_some(Dummy1::<S> { id, wire }))
416            .sorted_unstable()
417            .collect_vec();
418
419        println!("{:#?}", Dummy2(to_print));
420    }
421}