1use itertools::Itertools;
2use sha3::{Digest, Sha3_512, digest::generic_array::GenericArray};
3use snafu::{self, ResultExt, Whatever};
4use std::{collections::HashMap, mem::MaybeUninit};
5
6use crate::{
7 mpc::{
8 NetworkPhaseOutput, Operation,
9 circuit::{Round, WireId},
10 config::MpcConfig,
11 scheme::MpcScheme,
12 },
13 networking::{Network, ReceiveRequest, RecvLen, SendLen},
14};
15
16#[derive(Debug)]
17pub enum ExecutionState {
18 NewBorn, Handshaked, FinishedOffline, ReadyOnline, Finished, }
24
25type WireMap<S> = HashMap<WireId, <S as MpcScheme>::Wire>;
26
27#[allow(dead_code)]
28pub struct ExecutionContext<S, N>
29where
30 S: MpcScheme,
31 N: Network,
32{
33 config: MpcConfig<S>,
34 execution_state: ExecutionState,
35 network: N,
36 wire_contents: WireMap<S>,
37 scheme_context: MaybeUninit<S::Context>,
38}
39
40impl<S, N> Drop for ExecutionContext<S, N>
41where
42 S: MpcScheme,
43 N: Network,
44{
45 fn drop(&mut self) {
46 use ExecutionState::*;
47 if matches!(
48 self.execution_state,
49 Handshaked | FinishedOffline | ReadyOnline | Finished
50 ) {
51 unsafe {
52 self.scheme_context.assume_init_drop();
53 }
54 }
55 }
56}
57
58macro_rules! assert_state {
59 ($fn_name:literal, $state_var:expr, $state_pat:pat) => {
60 assert!(
61 matches!($state_var, $state_pat),
62 concat!(
63 "`ExecutionContext::",
64 $fn_name,
65 "` must be called when `ExecutionContext::execution_state` is `",
66 stringify!($state_pat),
67 "`"
68 )
69 )
70 };
71}
72
73impl<S, N> ExecutionContext<S, N>
74where
75 S: MpcScheme,
76 N: Network,
77{
78 pub fn new(config: MpcConfig<S>, network: N) -> Result<Self, Whatever> {
79 if config.n_parties() != network.n_players() {
80 snafu::whatever!("#parties mismatch");
81 }
82 if config.my_id() != network.my_id() {
83 snafu::whatever!("Id mismatch");
84 }
85 Ok(Self {
86 config,
87 execution_state: ExecutionState::NewBorn,
88 network,
89 wire_contents: HashMap::new(),
90 scheme_context: MaybeUninit::uninit(),
91 })
92 }
93
94 pub fn execution_state(&self) -> &ExecutionState {
95 &self.execution_state
96 }
97
98 pub async fn handshake(&mut self) -> Result<(SendLen, RecvLen), Whatever> {
99 fn ctx<E: std::error::Error>(_: &mut E) -> String {
100 "Handshake error".to_string()
101 }
102
103 assert_state!("handshake", self.execution_state, ExecutionState::NewBorn);
104
105 let mut hasher = Sha3_512::new();
106 hasher.update(&postcard::to_stdvec(&self.config.circuit()).with_whatever_context(ctx)?);
107 let hash = hasher.finalize();
108
109 let send_len = self
110 .network
111 .broadcast_object(&hash)
112 .await
113 .with_whatever_context(ctx)?;
114
115 let my_id = self.network.my_id();
116 let request: Vec<_> = (0..self.network.n_players())
117 .filter(|from| *from != my_id)
118 .map(|from| ReceiveRequest::new(from, 1))
119 .collect();
120
121 let (hashes, recv_len) = self
122 .network
123 .recv_objects_many::<GenericArray<_, _>, _>(&request)
124 .await
125 .with_whatever_context(ctx)?;
126
127 if hashes
128 .into_iter()
129 .any(|x| x.into_iter().next().unwrap() != hash)
130 {
131 snafu::whatever!("Plan not agreed");
132 }
133
134 let context = self
135 .config
136 .scheme()
137 .establish_context(&mut self.network, self.config.circuit())
138 .with_whatever_context(ctx)?;
139 self.scheme_context = MaybeUninit::new(context);
140
141 self.execution_state = ExecutionState::Handshaked;
142 Ok((send_len, recv_len))
143 }
144
145 fn do_local_operations(
146 scheme: &S,
147 context: &mut S::Context,
148 wire_contents: &mut WireMap<S>,
149 round: &Round<S>,
150 ) -> Result<(), Whatever> {
151 fn ctx<E: std::error::Error>(_: &mut E) -> String {
152 "Local operation error".to_string()
153 }
154
155 let ops = round.local_operations();
156
157 for op in ops {
158 let inputs = op
159 .inputs()
160 .map(|id| wire_contents.get(&id).expect("Expected a wire content"))
161 .collect::<Vec<_>>();
162 let NetworkPhaseOutput {
163 pending,
164 send_request,
165 receive_request,
166 } = scheme
167 .do_network_phase(context, op, inputs)
168 .with_whatever_context(ctx)?;
169 assert!(send_request.is_empty());
170 assert!(receive_request.is_empty());
171
172 let output = scheme
173 .do_finalize_phase(context, pending, Vec::new())
174 .with_whatever_context(ctx)?;
175
176 for (wire_id, wire_content) in itertools::zip_eq(op.outputs(), output.0) {
177 let prev = wire_contents.insert(wire_id, wire_content);
178 assert!(prev.is_none());
179 }
180 }
181
182 Ok(())
183 }
184
185 async fn do_network_operations(
186 scheme: &S,
187 context: &mut S::Context,
188 wire_contents: &mut WireMap<S>,
189 network: &mut N,
190 round: &Round<S>,
191 ) -> Result<(SendLen, RecvLen), Whatever> {
192 fn ctx<E: std::error::Error>(_: &mut E) -> String {
193 "Network operation error".to_string()
194 }
195
196 let ops = round.network_operations();
197
198 let mut pending_ops = Vec::with_capacity(ops.len());
199 let mut all_receive_requests = Vec::new();
200 let mut all_send_requests = Vec::new();
201 let mut receive_request_counts = Vec::with_capacity(ops.len());
202
203 for op in ops {
204 let inputs = op
205 .inputs()
206 .map(|id| wire_contents.get(&id).expect("Expected a wire content"))
207 .collect::<Vec<_>>();
208
209 let NetworkPhaseOutput {
210 pending,
211 mut send_request,
212 mut receive_request,
213 } = scheme
214 .do_network_phase(context, op, inputs)
215 .with_whatever_context(ctx)?;
216
217 pending_ops.push(pending);
218 receive_request_counts.push(receive_request.len());
219 all_send_requests.append(&mut send_request);
220 all_receive_requests.append(&mut receive_request);
221
222 assert!(send_request.is_empty());
223 assert!(receive_request.is_empty());
224 }
225
226 let send_len = network
227 .send_objects_many(&all_send_requests)
228 .await
229 .with_whatever_context(ctx)?;
230
231 let (flat_network_data, recv_len) = network
232 .recv_objects_many::<S::NetworkElement, _>(&all_receive_requests)
233 .await
234 .with_whatever_context(ctx)?;
235
236 let mut data_iter = flat_network_data.into_iter();
237
238 let mut buffered_results = Vec::with_capacity(ops.len());
239 for (op, pending, req_count) in
240 itertools::izip!(ops, pending_ops.into_iter(), receive_request_counts)
241 {
242 let op_data: Vec<_> = data_iter.by_ref().take(req_count).flatten().collect();
243 let output = scheme
244 .do_finalize_phase(context, pending, op_data)
245 .with_whatever_context(ctx)?;
246 buffered_results.push((op, output));
247 }
248
249 for (op, output) in buffered_results {
250 for (wire_id, wire_content) in itertools::zip_eq(op.outputs(), output.0) {
251 let prev = wire_contents.insert(wire_id, wire_content);
252 assert!(prev.is_none());
253 }
254 }
255
256 Ok((send_len, recv_len))
257 }
258
259 pub async fn do_offline(&mut self) -> Result<(SendLen, RecvLen), Whatever> {
260 assert_state!(
261 "do_offline",
262 self.execution_state,
263 ExecutionState::Handshaked
264 );
265
266 let mut total_send_len = 0;
267 let mut total_recv_len = 0;
268 for round in self.config.circuit().offline_rounds().iter() {
269 Self::do_local_operations(
270 self.config.circuit().scheme(),
271 unsafe { self.scheme_context.assume_init_mut() },
272 &mut self.wire_contents,
273 round,
274 )?;
275 let (send_len, recv_len) = Self::do_network_operations(
276 self.config.circuit().scheme(),
277 unsafe { self.scheme_context.assume_init_mut() },
278 &mut self.wire_contents,
279 &mut self.network,
280 round,
281 )
282 .await?;
283 total_send_len += send_len;
284 total_recv_len += recv_len;
285 }
286
287 self.execution_state = ExecutionState::FinishedOffline;
288
289 Ok((total_send_len, total_recv_len))
290 }
291
292 pub fn prepare_input<I>(&mut self, user_inputs: I) -> Result<(), Whatever>
293 where
294 I: IntoIterator<Item = (WireId, S::Input)>,
295 {
296 use std::collections::hash_map::Entry;
297 assert_state!(
298 "prepare_input",
299 self.execution_state,
300 ExecutionState::FinishedOffline
301 );
302
303 let mut input_wires: HashMap<_, _> = self
304 .config
305 .circuit()
306 .get_input_operation_wire_ids_of_party(self.network.my_id())
307 .map(|x| (x, false))
308 .collect();
309
310 let mut inputs = Vec::with_capacity(input_wires.len());
311 for (wire_id, input) in user_inputs.into_iter() {
312 match input_wires.entry(wire_id) {
313 Entry::Occupied(mut occupied_entry) if !occupied_entry.get() => {
314 *occupied_entry.get_mut() = true;
315 inputs.push((wire_id, input));
316 }
317 Entry::Occupied(_) => snafu::whatever!("Duplicate input for wire {}", wire_id),
318 Entry::Vacant(_) => {
319 snafu::whatever!("Wire {} is not an output of an input operation", wire_id)
320 }
321 };
322 }
323
324 if inputs.len() < input_wires.len() {
325 snafu::whatever!("Not enough inputs are provided");
326 }
327
328 self.config
329 .scheme()
330 .prepare_user_input(unsafe { self.scheme_context.assume_init_mut() }, inputs);
331
332 self.execution_state = ExecutionState::ReadyOnline;
333
334 Ok(())
335 }
336
337 pub async fn do_online(&mut self) -> Result<(SendLen, RecvLen), Whatever> {
338 assert_state!(
339 "do_online",
340 self.execution_state,
341 ExecutionState::ReadyOnline
342 );
343
344 let mut total_send_len = 0;
345 let mut total_recv_len = 0;
346 for round in self.config.circuit().online_rounds().iter() {
347 Self::do_local_operations(
348 self.config.circuit().scheme(),
349 unsafe { self.scheme_context.assume_init_mut() },
350 &mut self.wire_contents,
351 round,
352 )?;
353 let (send_len, recv_len) = Self::do_network_operations(
354 self.config.circuit().scheme(),
355 unsafe { self.scheme_context.assume_init_mut() },
356 &mut self.wire_contents,
357 &mut self.network,
358 round,
359 )
360 .await?;
361
362 total_send_len += send_len;
363 total_recv_len += recv_len;
364 }
365
366 self.execution_state = ExecutionState::Finished;
367
368 Ok((total_send_len, total_recv_len))
369 }
370
371 pub fn dump_wires_with_filter<P>(&self, mut predicate: P)
372 where
373 P: FnMut(&S::Wire) -> bool,
374 {
375 use std::{cmp, fmt};
376 #[allow(unused)]
377 struct Dummy1<'a, S: MpcScheme> {
378 id: WireId,
379 wire: &'a S::Wire,
380 }
381
382 impl<'a, S: MpcScheme> PartialEq for Dummy1<'a, S> {
383 fn eq(&self, other: &Self) -> bool {
384 self.id == other.id
385 }
386 }
387
388 impl<'a, S: MpcScheme> Eq for Dummy1<'a, S> {}
389
390 impl<'a, S: MpcScheme> cmp::PartialOrd for Dummy1<'a, S> {
391 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
392 Some(self.cmp(other))
393 }
394 }
395
396 impl<'a, S: MpcScheme> cmp::Ord for Dummy1<'a, S> {
397 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
398 self.id.0.cmp(&other.id.0)
399 }
400 }
401
402 struct Dummy2<'a, S: MpcScheme>(Vec<Dummy1<'a, S>>);
403
404 impl<'a, S: MpcScheme> fmt::Debug for Dummy2<'a, S> {
405 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406 f.debug_map()
407 .entries(self.0.iter().map(|d| (d.id.0, d.wire)))
408 .finish()
409 }
410 }
411
412 let to_print = self
413 .wire_contents
414 .iter()
415 .filter_map(move |(&id, wire)| predicate(wire).then_some(Dummy1::<S> { id, wire }))
416 .sorted_unstable()
417 .collect_vec();
418
419 println!("{:#?}", Dummy2(to_print));
420 }
421}