Skip to main content

mpc_core/mpc/
circuit.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    fmt,
4};
5
6use crate::mpc::{Operation, scheme::MpcScheme};
7use serde::{Deserialize, Serialize};
8use snafu::{Snafu, ensure};
9
10#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[repr(transparent)]
12pub struct WireId(pub usize);
13
14impl fmt::Display for WireId {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        write!(f, "{}", self.0)
17    }
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Round<S: MpcScheme> {
22    local_operations: Vec<S::Operation>,
23    network_operations: Vec<S::Operation>,
24}
25
26impl<S: MpcScheme> Round<S> {
27    pub fn local_operations(&self) -> &[S::Operation] {
28        &self.local_operations
29    }
30
31    pub fn network_operations(&self) -> &[S::Operation] {
32        &self.network_operations
33    }
34}
35
36#[derive(Debug, Serialize, Deserialize)]
37#[serde(bound = "S: MpcScheme")]
38pub struct MpcCircuit<S: MpcScheme> {
39    scheme: S,
40    offline_rounds: Vec<Round<S>>,
41    online_rounds: Vec<Round<S>>,
42}
43
44#[derive(Debug, Snafu)]
45pub enum MpcCircuitError {
46    #[snafu(display("The provided circuit is not sound"))]
47    CircuitUnsound,
48    #[snafu(display("Wire {wire} produced multiple times"))]
49    WireProducedManyTimes { wire: WireId },
50    #[snafu(display("Wire {wire} is consumed but it is never produced"))]
51    WireConsumedButNotProduced { wire: WireId },
52    #[snafu(display("Circuit is cyclic"))]
53    CircuitCyclic,
54    #[snafu(display("Circuit is empty"))]
55    EmptyCircuit,
56}
57
58impl<S: MpcScheme> MpcCircuit<S> {
59    pub fn new<I>(operations: I, scheme: S) -> Result<Self, MpcCircuitError>
60    where
61        I: IntoIterator<Item = S::Operation>,
62    {
63        let raw_ops = operations.into_iter().collect::<Vec<_>>();
64        let num_ops = raw_ops.len();
65
66        ensure!(num_ops > 0, EmptyCircuitSnafu);
67        ensure!(scheme.is_circuit_sound(&raw_ops), CircuitUnsoundSnafu);
68
69        let (mut ops, adj, mut in_degree) = {
70            let mut adj: Vec<Vec<usize>> = vec![vec![]; num_ops];
71            let mut in_degree: Vec<usize> = vec![0; num_ops];
72            let mut wire_producer: HashMap<WireId, usize> = HashMap::new();
73
74            for (idx, op) in raw_ops.iter().enumerate() {
75                for out_wire in op.outputs() {
76                    let prev = wire_producer.insert(out_wire, idx);
77                    if prev.is_some() {
78                        return Err(MpcCircuitError::WireProducedManyTimes { wire: out_wire });
79                    }
80                }
81            }
82
83            for (consumer_idx, op) in raw_ops.iter().enumerate() {
84                for input_wire in op.inputs() {
85                    if let Some(&producer_idx) = wire_producer.get(&input_wire) {
86                        adj[producer_idx].push(consumer_idx);
87                        in_degree[consumer_idx] += 1;
88                    } else {
89                        return Err(MpcCircuitError::WireConsumedButNotProduced {
90                            wire: input_wire,
91                        });
92                    }
93                }
94            }
95            let ops: Vec<_> = raw_ops.into_iter().map(Some).collect();
96            (ops, adj, in_degree)
97        };
98
99        let mut offline_rounds: Vec<Round<S>> = Vec::new();
100        let mut online_rounds: Vec<Round<S>> = Vec::new();
101
102        let mut offline_wave: VecDeque<_> = in_degree
103            .iter()
104            .enumerate()
105            .filter_map(|(idx, deg)| (*deg == 0).then_some(idx))
106            .collect();
107        let mut online_wave = VecDeque::new();
108
109        while !offline_wave.is_empty() {
110            let mut next_offline_wave = VecDeque::new();
111            let mut local_operations = Vec::new();
112            let mut network_operations = Vec::new();
113
114            while let Some(op_idx) = offline_wave.pop_front() {
115                if ops[op_idx].as_ref().unwrap().is_input() {
116                    online_wave.push_back(op_idx);
117                    continue;
118                }
119
120                let op = ops[op_idx].take().unwrap();
121                let op_local = scheme.is_operation_local(&op);
122
123                if op_local {
124                    local_operations.push(op);
125                } else {
126                    network_operations.push(op);
127                }
128
129                for &consumer_idx in &adj[op_idx] {
130                    in_degree[consumer_idx] -= 1;
131                    if in_degree[consumer_idx] == 0 {
132                        if op_local {
133                            offline_wave.push_back(consumer_idx);
134                        } else {
135                            next_offline_wave.push_back(consumer_idx);
136                        }
137                    }
138                }
139            }
140
141            offline_rounds.push(Round {
142                local_operations,
143                network_operations,
144            });
145
146            std::mem::swap(&mut offline_wave, &mut next_offline_wave);
147        }
148
149        while !online_wave.is_empty() {
150            let mut next_online_wave = VecDeque::new();
151            let mut local_operations = Vec::new();
152            let mut network_operations = Vec::new();
153
154            while let Some(op_idx) = online_wave.pop_front() {
155                let op = ops[op_idx].take().unwrap();
156                let op_local = scheme.is_operation_local(&op);
157
158                if op_local {
159                    local_operations.push(op);
160                } else {
161                    network_operations.push(op);
162                }
163
164                for &consumer_idx in &adj[op_idx] {
165                    in_degree[consumer_idx] -= 1;
166                    if in_degree[consumer_idx] == 0 {
167                        if op_local {
168                            online_wave.push_back(consumer_idx);
169                        } else {
170                            next_online_wave.push_back(consumer_idx);
171                        }
172                    }
173                }
174            }
175
176            online_rounds.push(Round {
177                local_operations,
178                network_operations,
179            });
180
181            std::mem::swap(&mut online_wave, &mut next_online_wave);
182        }
183
184        if ops.iter().any(Option::is_some) {
185            return Err(MpcCircuitError::CircuitCyclic);
186        }
187
188        debug_assert!(offline_rounds.iter().all(|round| {
189            round.network_operations.iter().all(|op| !op.is_input())
190                && round.local_operations.iter().all(|op| !op.is_input())
191        }));
192
193        debug_assert!(
194            online_rounds
195                .iter()
196                .all(|round| { round.local_operations.iter().all(|op| !op.is_input()) })
197        );
198
199        Ok(MpcCircuit {
200            scheme,
201            offline_rounds,
202            online_rounds,
203        })
204    }
205
206    pub fn get_input_operation_wire_ids_of_party(
207        &self,
208        party_id: usize,
209    ) -> impl Iterator<Item = WireId> {
210        self.online_rounds
211            .iter()
212            .flat_map(|round| round.network_operations.iter())
213            .filter(move |op| op.get_input_party_id() == Some(party_id))
214            .map(|op| op.outputs().next().expect("Expected one output"))
215    }
216
217    pub fn scheme(&self) -> &S {
218        &self.scheme
219    }
220
221    pub fn online_rounds(&self) -> &[Round<S>] {
222        &self.online_rounds
223    }
224
225    pub fn offline_rounds(&self) -> &[Round<S>] {
226        &self.offline_rounds
227    }
228
229    pub fn num_total_rounds(&self) -> usize {
230        self.offline_rounds.len() + self.online_rounds.len()
231    }
232
233    pub fn num_offline_rounds(&self) -> usize {
234        self.offline_rounds.len()
235    }
236
237    pub fn num_online_rounds(&self) -> usize {
238        self.online_rounds.len()
239    }
240}