1use std::{collections::HashMap, ops::Add};
2
3use primitives::{
4 algebra::{
5 elliptic_curve::{Curve, Point, Scalar},
6 BoxedUint,
7 },
8 types::PeerNumber,
9};
10use serde::{Deserialize, Serialize};
11
12use super::BaseFieldPlaintext;
13use crate::{
14 circuit::{BitPlaintext, FieldBinaryOp, FieldUnaryOp, PointPlaintext, ScalarPlaintext},
15 types::Label,
16};
17
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20#[serde(bound(
21 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
22 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
23))]
24pub enum Gate<C: Curve> {
25 ScalarInput {
27 inputer: PeerNumber,
28 },
29 BaseFieldInput {
30 inputer: PeerNumber,
31 },
32 PointInput {
33 inputer: PeerNumber,
34 },
35 ScalarPlaintextInput {
37 c: ScalarPlaintext<C>,
38 },
39 BaseFieldPlaintextInput {
40 c: BaseFieldPlaintext<C>,
41 },
42 ScalarShareInput,
44 BaseFieldShareInput,
45 PointShareInput,
46 ElGamalCiphertextInput {
47 c: PointPlaintext<C>,
48 r: PointPlaintext<C>,
49 },
50 ScalarAdd {
52 x: Label,
53 y: Label,
54 },
55 BaseFieldAdd {
56 x: Label,
57 y: Label,
58 },
59 PointAdd {
60 x: Label,
61 y: Label,
62 },
63 ScalarMul {
65 x: Label,
66 y: Label,
67 },
68 BaseFieldMul {
69 x: Label,
70 y: Label,
71 },
72 PointMul {
73 x: Label,
74 y: Label,
75 },
76 ScalarAddPlaintext {
78 x: Label,
79 c: Label,
80 },
81 BaseFieldAddPlaintext {
82 x: Label,
83 c: Label,
84 },
85 PointAddPlaintext {
86 x: Label,
87 c: Label,
88 },
89 ScalarMulScalarPlaintext {
91 x: Label,
92 c: Label,
93 },
94 BaseFieldMulBaseFieldPlaintext {
95 x: Label,
96 c: Label,
97 },
98 PointMulScalarPlaintext {
99 x: Label,
100 c: ScalarPlaintext<C>,
101 },
102 ScalarMulPointPlaintext {
103 x: Label,
104 c: PointPlaintext<C>,
105 },
106 ScalarShareOpen {
108 x: Label,
109 },
110 BaseFieldShareOpen {
111 x: Label,
112 },
113 PointShareOpen {
114 x: Label,
115 },
116 ScalarPlaintextUnaryOp {
118 x: Label,
119 op: FieldUnaryOp,
120 },
121 BaseFieldPlaintextUnaryOp {
122 x: Label,
123 op: FieldUnaryOp,
124 },
125 ScalarPlaintextBinaryOp {
127 x: Label,
128 y: Label,
129 op: FieldBinaryOp,
130 },
131 BaseFieldPlaintextBinaryOp {
132 x: Label,
133 y: Label,
134 op: FieldBinaryOp,
135 },
136 ScalarMultiplicativeInverse {
138 x: Label,
139 },
140 BaseFieldMultiplicativeInverse {
141 x: Label,
142 },
143 ScalarAdditiveInverse {
145 x: Label,
146 },
147 BaseFieldAdditiveInverse {
148 x: Label,
149 },
150 PointAdditiveInverse {
151 x: Label,
152 },
153 ScalarGreaterThan {
155 x: Label,
156 y: Label,
157 },
158 BaseFieldGreaterThan {
159 x: Label,
160 y: Label,
161 },
162 ScalarGreaterThanOrEqual {
163 x: Label,
164 y: Label,
165 },
166 BaseFieldGreaterThanOrEqual {
167 x: Label,
168 y: Label,
169 },
170 ScalarZeroTest {
171 x: Label,
172 },
173 BaseFieldZeroTest {
174 x: Label,
175 },
176 PointZeroTest {
177 x: Label,
178 },
179 EncryptPoint {
181 x: Label,
182 c: Point<C>,
183 },
184 DecryptPoint {
185 x: Label,
186 y: Label,
187 },
188 RandomBaseFieldShare,
189 BaseFieldPow {
190 x: Label,
191 exp: BoxedUint,
192 },
193 PointPlaintextInput {
194 c: PointPlaintext<C>,
195 },
196 BitInput {
198 inputer: PeerNumber,
199 },
200 BitXor {
201 x: Label,
202 y: Label,
203 },
204 BitNot {
205 x: Label,
206 },
207 BitAnd {
208 x: Label,
209 y: Label,
210 },
211 BitPlaintextXor {
212 x: Label,
213 c: Label,
214 },
215 BitPlaintextAnd {
216 x: Label,
217 c: Label,
218 },
219 BitShareInput,
220 BitPlaintextInput {
221 c: BitPlaintext,
222 },
223 RandomBitShare,
224 BitPlaintextUnaryOp {
225 x: Label,
226 op: FieldUnaryOp,
227 },
228 BitPlaintextBinaryOp {
229 x: Label,
230 y: Label,
231 op: FieldBinaryOp,
232 },
233 BitShareOpen {
234 x: Label,
235 },
236 ScalarDaBit,
237 BaseFieldDaBit,
238 GetDaBitScalarShare {
239 x: Label,
240 },
241 GetDaBitBaseFieldShare {
242 x: Label,
243 },
244 GetScalarDaBitSharedBit {
245 x: Label,
246 },
247 GetBaseFieldDaBitSharedBit {
248 x: Label,
249 },
250 BitShareGetIndex {
251 x: Label,
252 index: usize,
253 },
254 BitPlaintextToScalar {
255 x: Label,
256 },
257 BitPlaintextToBaseField {
258 x: Label,
259 },
260 ScalarPlaintextToBit {
261 x: Label,
262 },
263 BaseFieldPlaintextToBit {
264 x: Label,
265 },
266}
267
268#[derive(Debug, Clone, Default, PartialEq, Eq)]
269pub struct CircuitPreprocessing {
270 pub scalar_singlets: usize,
271 pub scalar_triples: usize,
272 pub base_field_singlets: usize,
273 pub base_field_triples: usize,
274 pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
275 pub bit_singlets: usize,
276 pub bit_triples: usize,
277 pub scalar_dabits: usize,
278 pub base_field_dabits: usize,
279}
280
281impl Add for CircuitPreprocessing {
282 type Output = Self;
283
284 fn add(self, other: Self) -> Self::Output {
285 Self {
286 scalar_singlets: self.scalar_singlets + other.scalar_singlets,
287 scalar_triples: self.scalar_triples + other.scalar_triples,
288 base_field_singlets: self.base_field_singlets + other.base_field_singlets,
289 base_field_triples: self.base_field_triples + other.base_field_triples,
290 bit_singlets: self.bit_singlets + other.bit_singlets,
291 bit_triples: self.bit_triples + other.bit_triples,
292 scalar_dabits: self.scalar_dabits + other.scalar_dabits,
293 base_field_dabits: self.base_field_dabits + other.base_field_dabits,
294 base_field_pow_pairs: {
295 let mut combined = self.base_field_pow_pairs;
296 for (k, v) in other.base_field_pow_pairs {
297 *combined.entry(k).or_insert(0) += v;
298 }
299 combined
300 },
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
308
309 use super::*;
310
311 #[test]
312 fn test_ser_gate() {
313 let no_curve_gate: Gate<C> = Gate::ScalarAdd {
314 x: Label::from(1, 2),
315 y: Label::from(3, 4),
316 };
317 let scalar_gate: Gate<C> = Gate::ScalarAddPlaintext {
318 x: Label::from(1, 2),
319 c: Label::from(3, 4),
320 };
321 let point_gate: Gate<C> = Gate::PointAddPlaintext {
322 x: Label::from(1, 2),
323 c: Label::from(3, 4),
324 };
325
326 let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
327 let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
328 let point_gate_ser = bincode::serialize(&point_gate).unwrap();
329
330 let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
331 let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
332 let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
333
334 assert_eq!(no_curve_gate, no_curve_gate_de);
335 assert_eq!(scalar_gate, scalar_gate_de);
336 assert_eq!(point_gate, point_gate_de);
337 }
338
339 #[test]
340 fn test_circuit_preprocessing_add() {
341 let a = CircuitPreprocessing {
342 scalar_singlets: 1,
343 scalar_triples: 2,
344 base_field_singlets: 3,
345 base_field_triples: 4,
346 bit_singlets: 0,
347 bit_triples: 1,
348 scalar_dabits: 1,
349 base_field_dabits: 2,
350 base_field_pow_pairs: vec![
351 (BoxedUint::from(vec![21]), 5),
352 (BoxedUint::from(vec![14]), 6),
353 ]
354 .into_iter()
355 .collect(),
356 };
357 let b = CircuitPreprocessing {
358 scalar_singlets: 2,
359 scalar_triples: 3,
360 base_field_singlets: 0,
361 base_field_triples: 5,
362 bit_singlets: 3,
363 bit_triples: 4,
364 scalar_dabits: 2,
365 base_field_dabits: 3,
366 base_field_pow_pairs: vec![
367 (BoxedUint::from(vec![21]), 6),
368 (BoxedUint::from(vec![13]), 7),
369 ]
370 .into_iter()
371 .collect(),
372 };
373
374 let c = a + b;
375
376 assert_eq!(c.scalar_singlets, 3);
377 assert_eq!(c.scalar_triples, 5);
378 assert_eq!(c.base_field_singlets, 3);
379 assert_eq!(c.base_field_triples, 9);
380 assert_eq!(c.bit_singlets, 3);
381 assert_eq!(c.bit_triples, 5);
382 assert_eq!(c.scalar_dabits, 3);
383 assert_eq!(c.base_field_dabits, 5);
384 assert_eq!(
385 c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
386 Some(&11)
387 );
388 assert_eq!(
389 c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
390 Some(&6)
391 );
392 assert_eq!(
393 c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
394 Some(&7)
395 );
396 }
397}