1use std::{
2 collections::{HashMap, VecDeque},
3 fmt,
4};
5
6use crate::mpc::{Operation, scheme::MpcScheme};
7use serde::{Deserialize, Serialize};
8use snafu::{Snafu, ensure};
9
10#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[repr(transparent)]
12pub struct WireId(pub usize);
13
14impl fmt::Display for WireId {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 write!(f, "{}", self.0)
17 }
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Round<S: MpcScheme> {
22 local_operations: Vec<S::Operation>,
23 network_operations: Vec<S::Operation>,
24}
25
26impl<S: MpcScheme> Round<S> {
27 pub fn local_operations(&self) -> &[S::Operation] {
28 &self.local_operations
29 }
30
31 pub fn network_operations(&self) -> &[S::Operation] {
32 &self.network_operations
33 }
34}
35
36#[derive(Debug, Serialize, Deserialize)]
37#[serde(bound = "S: MpcScheme")]
38pub struct MpcCircuit<S: MpcScheme> {
39 scheme: S,
40 offline_rounds: Vec<Round<S>>,
41 online_rounds: Vec<Round<S>>,
42}
43
44#[derive(Debug, Snafu)]
45pub enum MpcCircuitError {
46 #[snafu(display("The provided circuit is not sound"))]
47 CircuitUnsound,
48 #[snafu(display("Wire {wire} produced multiple times"))]
49 WireProducedManyTimes { wire: WireId },
50 #[snafu(display("Wire {wire} is consumed but it is never produced"))]
51 WireConsumedButNotProduced { wire: WireId },
52 #[snafu(display("Circuit is cyclic"))]
53 CircuitCyclic,
54 #[snafu(display("Circuit is empty"))]
55 EmptyCircuit,
56}
57
58impl<S: MpcScheme> MpcCircuit<S> {
59 pub fn new<I>(operations: I, scheme: S) -> Result<Self, MpcCircuitError>
60 where
61 I: IntoIterator<Item = S::Operation>,
62 {
63 let raw_ops = operations.into_iter().collect::<Vec<_>>();
64 let num_ops = raw_ops.len();
65
66 ensure!(num_ops > 0, EmptyCircuitSnafu);
67 ensure!(scheme.is_circuit_sound(&raw_ops), CircuitUnsoundSnafu);
68
69 let (mut ops, adj, mut in_degree) = {
70 let mut adj: Vec<Vec<usize>> = vec![vec![]; num_ops];
71 let mut in_degree: Vec<usize> = vec![0; num_ops];
72 let mut wire_producer: HashMap<WireId, usize> = HashMap::new();
73
74 for (idx, op) in raw_ops.iter().enumerate() {
75 for out_wire in op.outputs() {
76 let prev = wire_producer.insert(out_wire, idx);
77 if prev.is_some() {
78 return Err(MpcCircuitError::WireProducedManyTimes { wire: out_wire });
79 }
80 }
81 }
82
83 for (consumer_idx, op) in raw_ops.iter().enumerate() {
84 for input_wire in op.inputs() {
85 if let Some(&producer_idx) = wire_producer.get(&input_wire) {
86 adj[producer_idx].push(consumer_idx);
87 in_degree[consumer_idx] += 1;
88 } else {
89 return Err(MpcCircuitError::WireConsumedButNotProduced {
90 wire: input_wire,
91 });
92 }
93 }
94 }
95 let ops: Vec<_> = raw_ops.into_iter().map(Some).collect();
96 (ops, adj, in_degree)
97 };
98
99 let mut offline_rounds: Vec<Round<S>> = Vec::new();
100 let mut online_rounds: Vec<Round<S>> = Vec::new();
101
102 let mut offline_wave: VecDeque<_> = in_degree
103 .iter()
104 .enumerate()
105 .filter_map(|(idx, deg)| (*deg == 0).then_some(idx))
106 .collect();
107 let mut online_wave = VecDeque::new();
108
109 while !offline_wave.is_empty() {
110 let mut next_offline_wave = VecDeque::new();
111 let mut local_operations = Vec::new();
112 let mut network_operations = Vec::new();
113
114 while let Some(op_idx) = offline_wave.pop_front() {
115 if ops[op_idx].as_ref().unwrap().is_input() {
116 online_wave.push_back(op_idx);
117 continue;
118 }
119
120 let op = ops[op_idx].take().unwrap();
121 let op_local = scheme.is_operation_local(&op);
122
123 if op_local {
124 local_operations.push(op);
125 } else {
126 network_operations.push(op);
127 }
128
129 for &consumer_idx in &adj[op_idx] {
130 in_degree[consumer_idx] -= 1;
131 if in_degree[consumer_idx] == 0 {
132 if op_local {
133 offline_wave.push_back(consumer_idx);
134 } else {
135 next_offline_wave.push_back(consumer_idx);
136 }
137 }
138 }
139 }
140
141 offline_rounds.push(Round {
142 local_operations,
143 network_operations,
144 });
145
146 std::mem::swap(&mut offline_wave, &mut next_offline_wave);
147 }
148
149 while !online_wave.is_empty() {
150 let mut next_online_wave = VecDeque::new();
151 let mut local_operations = Vec::new();
152 let mut network_operations = Vec::new();
153
154 while let Some(op_idx) = online_wave.pop_front() {
155 let op = ops[op_idx].take().unwrap();
156 let op_local = scheme.is_operation_local(&op);
157
158 if op_local {
159 local_operations.push(op);
160 } else {
161 network_operations.push(op);
162 }
163
164 for &consumer_idx in &adj[op_idx] {
165 in_degree[consumer_idx] -= 1;
166 if in_degree[consumer_idx] == 0 {
167 if op_local {
168 online_wave.push_back(consumer_idx);
169 } else {
170 next_online_wave.push_back(consumer_idx);
171 }
172 }
173 }
174 }
175
176 online_rounds.push(Round {
177 local_operations,
178 network_operations,
179 });
180
181 std::mem::swap(&mut online_wave, &mut next_online_wave);
182 }
183
184 if ops.iter().any(Option::is_some) {
185 return Err(MpcCircuitError::CircuitCyclic);
186 }
187
188 debug_assert!(offline_rounds.iter().all(|round| {
189 round.network_operations.iter().all(|op| !op.is_input())
190 && round.local_operations.iter().all(|op| !op.is_input())
191 }));
192
193 debug_assert!(
194 online_rounds
195 .iter()
196 .all(|round| { round.local_operations.iter().all(|op| !op.is_input()) })
197 );
198
199 Ok(MpcCircuit {
200 scheme,
201 offline_rounds,
202 online_rounds,
203 })
204 }
205
206 pub fn get_input_operation_wire_ids_of_party(
207 &self,
208 party_id: usize,
209 ) -> impl Iterator<Item = WireId> {
210 self.online_rounds
211 .iter()
212 .flat_map(|round| round.network_operations.iter())
213 .filter(move |op| op.get_input_party_id() == Some(party_id))
214 .map(|op| op.outputs().next().expect("Expected one output"))
215 }
216
217 pub fn scheme(&self) -> &S {
218 &self.scheme
219 }
220
221 pub fn online_rounds(&self) -> &[Round<S>] {
222 &self.online_rounds
223 }
224
225 pub fn offline_rounds(&self) -> &[Round<S>] {
226 &self.offline_rounds
227 }
228
229 pub fn num_total_rounds(&self) -> usize {
230 self.offline_rounds.len() + self.online_rounds.len()
231 }
232
233 pub fn num_offline_rounds(&self) -> usize {
234 self.offline_rounds.len()
235 }
236
237 pub fn num_online_rounds(&self) -> usize {
238 self.online_rounds.len()
239 }
240}