1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5 elliptic_curve::{Curve, Point, Scalar},
6 BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11 circuit::{
12 AlgebraicType,
13 Batched,
14 BitShareBinaryOp,
15 BitShareUnaryOp,
16 FieldPlaintextBinaryOp,
17 FieldPlaintextUnaryOp,
18 FieldShareBinaryOp,
19 FieldShareUnaryOp,
20 FieldType,
21 Input,
22 PointPlaintextBinaryOp,
23 PointPlaintextUnaryOp,
24 PointShareBinaryOp,
25 PointShareUnaryOp,
26 ShareOrPlaintext,
27 },
28 types::Label,
29};
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38 Input {
40 input_type: Input<C>,
41 },
42 FieldShareUnaryOp {
44 x: Label,
45 op: FieldShareUnaryOp,
46 field_type: FieldType,
47 },
48 FieldShareBinaryOp {
50 x: Label,
51 y: Label,
52 y_form: ShareOrPlaintext,
53 op: FieldShareBinaryOp,
54 field_type: FieldType,
55 },
56 BatchSummation {
57 x: Label,
58 x_form: ShareOrPlaintext,
59 algebraic_type: AlgebraicType,
60 },
61 BitShareUnaryOp {
62 x: Label,
63 op: BitShareUnaryOp,
64 },
65 BitShareBinaryOp {
66 x: Label,
67 y: Label,
68 y_form: ShareOrPlaintext,
69 op: BitShareBinaryOp,
70 },
71 PointShareUnaryOp {
73 p: Label,
74 op: PointShareUnaryOp,
75 },
76 PointShareBinaryOp {
77 p: Label,
78 y: Label,
79 p_form: ShareOrPlaintext,
80 y_form: ShareOrPlaintext,
81 op: PointShareBinaryOp,
82 },
83 FieldPlaintextUnaryOp {
85 x: Label,
86 op: FieldPlaintextUnaryOp,
87 field_type: FieldType,
88 },
89 FieldPlaintextBinaryOp {
91 x: Label,
92 y: Label,
93 op: FieldPlaintextBinaryOp,
94 field_type: FieldType,
95 },
96 BitPlaintextUnaryOp {
97 x: Label,
98 op: FieldPlaintextUnaryOp,
99 },
100 BitPlaintextBinaryOp {
101 x: Label,
102 y: Label,
103 op: FieldPlaintextBinaryOp,
104 },
105 PointPlaintextUnaryOp {
106 p: Label,
107 op: PointPlaintextUnaryOp,
108 },
109 PointPlaintextBinaryOp {
110 p: Label,
111 y: Label,
112 op: PointPlaintextBinaryOp,
113 },
114 DaBit {
116 field_type: FieldType,
117 batched: Batched,
118 },
119 GetDaBitFieldShare {
120 x: Label,
121 field_type: FieldType,
122 },
123 GetDaBitSharedBit {
124 x: Label,
125 field_type: FieldType,
126 },
127 BaseFieldPow {
129 x: Label,
130 exp: BoxedUint,
131 },
132 BitPlaintextToField {
134 x: Label,
135 field_type: FieldType,
136 },
137 FieldPlaintextToBit {
138 x: Label,
139 field_type: FieldType,
140 },
141 BatchGetIndex {
143 x: Label,
144 x_type: AlgebraicType,
145 x_form: ShareOrPlaintext,
146 index: usize,
147 },
148 CollectToBatch {
149 wires: Vec<Label>,
150 x_type: AlgebraicType,
151 x_form: ShareOrPlaintext,
152 },
153 PointFromPlaintextExtendedEdwards {
154 wires: Vec<Label>,
155 },
156 PlaintextPointToExtendedEdwards {
157 point: Label,
158 },
159 PlaintextKeccakF1600 {
160 wires: Vec<Label>,
161 },
162 CompressPlaintextPoint {
163 point: Label,
164 },
165 KeyRecoveryPlaintextComputeErrors {
166 d_minus_one: Label,
167 syndromes: Label,
168 },
169}
170
171impl<C: Curve> Gate<C> {
172 pub fn get_labels(&self) -> Vec<Label> {
174 let mut labels = Vec::new();
175 self.for_each_label(|label| labels.push(label));
176 labels
177 }
178 pub fn add_to_required_preprocessing(
179 &self,
180 batched: Batched,
181 circuit_preprocessing: &mut CircuitPreprocessing,
182 ) {
183 match self {
184 Gate::Input { input_type } => match input_type {
185 Input::SecretPlaintext {
186 algebraic_type: AlgebraicType::ScalarField | AlgebraicType::Point,
187 batched,
188 ..
189 } => circuit_preprocessing.scalar_singlets += batched.count(),
190 Input::SecretPlaintext {
191 algebraic_type: AlgebraicType::BaseField,
192 batched,
193 ..
194 } => circuit_preprocessing.base_field_singlets += batched.count(),
195 Input::SecretPlaintext {
196 algebraic_type: AlgebraicType::Mersenne107,
197 ..
198 } => circuit_preprocessing.mersenne107_singlets += batched.count(),
199 Input::SecretPlaintext {
200 algebraic_type: AlgebraicType::Bit,
201 batched,
202 ..
203 } => circuit_preprocessing.bit_singlets += batched.count(),
204 Input::RandomShare {
205 algebraic_type,
206 batched,
207 } => match algebraic_type {
208 AlgebraicType::ScalarField | AlgebraicType::Point => {
209 circuit_preprocessing.scalar_singlets += batched.count();
210 }
211 AlgebraicType::BaseField => {
212 circuit_preprocessing.base_field_singlets += batched.count();
213 }
214 AlgebraicType::Bit => {
215 circuit_preprocessing.bit_singlets += batched.count();
216 }
217 AlgebraicType::Mersenne107 => {
218 circuit_preprocessing.mersenne107_singlets += batched.count();
219 }
220 },
221 Input::Share { .. }
222 | Input::Scalar { .. }
223 | Input::ScalarBatch { .. }
224 | Input::BaseField { .. }
225 | Input::Mersenne107 { .. }
226 | Input::Bit { .. }
227 | Input::Point { .. }
228 | Input::BaseFieldBatch { .. }
229 | Input::Mersenne107Batch { .. }
230 | Input::BitBatch { .. }
231 | Input::PointBatch { .. } => (),
232 },
233 Gate::FieldShareBinaryOp {
234 y_form: ShareOrPlaintext::Share,
235 op,
236 field_type,
237 ..
238 } => match op {
239 FieldShareBinaryOp::Mul => match field_type {
240 FieldType::ScalarField => {
241 circuit_preprocessing.scalar_triples += batched.count()
242 }
243 FieldType::BaseField => {
244 circuit_preprocessing.base_field_triples += batched.count()
245 }
246 FieldType::Mersenne107 => {
247 circuit_preprocessing.mersenne107_triples += batched.count()
248 }
249 },
250 FieldShareBinaryOp::Add => (),
251 },
252 Gate::FieldShareUnaryOp { op, field_type, .. } => match op {
253 FieldShareUnaryOp::MulInverse => match field_type {
254 FieldType::ScalarField => {
255 circuit_preprocessing.scalar_triples += batched.count();
256 circuit_preprocessing.scalar_singlets += batched.count();
257 }
258 FieldType::BaseField => {
259 circuit_preprocessing.base_field_triples += batched.count();
260 circuit_preprocessing.base_field_singlets += batched.count();
261 }
262 FieldType::Mersenne107 => {
263 circuit_preprocessing.mersenne107_triples += batched.count();
264 circuit_preprocessing.mersenne107_singlets += batched.count();
265 }
266 },
267 FieldShareUnaryOp::IsZero => match field_type {
268 FieldType::ScalarField => {
269 circuit_preprocessing.scalar_triples += batched.count();
270 circuit_preprocessing.scalar_singlets += batched.count();
271 }
272 FieldType::BaseField => {
273 circuit_preprocessing.base_field_triples += batched.count();
274 circuit_preprocessing.base_field_singlets += batched.count();
275 }
276 FieldType::Mersenne107 => {
277 circuit_preprocessing.mersenne107_triples += batched.count();
278 circuit_preprocessing.mersenne107_singlets += batched.count();
279 }
280 },
281 FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
282 },
283 Gate::PointShareUnaryOp { op, .. } => match op {
284 PointShareUnaryOp::IsZero => {
285 circuit_preprocessing.scalar_triples += batched.count();
286 circuit_preprocessing.scalar_singlets += batched.count();
287 }
288 PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
289 },
290 Gate::PointShareBinaryOp {
291 p_form: ShareOrPlaintext::Share,
292 y_form: ShareOrPlaintext::Share,
293 op,
294 ..
295 } => match op {
296 PointShareBinaryOp::ScalarMul => {
297 circuit_preprocessing.scalar_triples += batched.count()
298 }
299 PointShareBinaryOp::Add => (),
300 },
301 Gate::BitShareBinaryOp {
302 y_form: ShareOrPlaintext::Share,
303 op,
304 ..
305 } => match op {
306 BitShareBinaryOp::And | BitShareBinaryOp::Or => {
307 circuit_preprocessing.bit_triples += batched.count();
308 }
309 BitShareBinaryOp::Xor => (),
310 },
311 Gate::BaseFieldPow { exp, .. } => {
312 *circuit_preprocessing
313 .base_field_pow_pairs
314 .entry(exp.clone())
315 .or_insert(0) += batched.count();
316 circuit_preprocessing.base_field_triples += batched.count();
317 }
318 Gate::DaBit {
319 field_type,
320 batched,
321 } => match field_type {
322 FieldType::ScalarField => {
323 circuit_preprocessing.scalar_dabits += batched.count();
324 }
325 FieldType::BaseField => {
326 circuit_preprocessing.base_field_dabits += batched.count();
327 }
328 FieldType::Mersenne107 => {
329 circuit_preprocessing.mersenne107_dabits += batched.count();
330 }
331 },
332 Gate::BatchSummation { .. }
333 | Gate::BitShareUnaryOp { .. }
334 | Gate::BitShareBinaryOp {
335 y_form: ShareOrPlaintext::Plaintext,
336 ..
337 }
338 | Gate::PointShareBinaryOp {
339 p_form: ShareOrPlaintext::Plaintext,
340 ..
341 }
342 | Gate::PointShareBinaryOp {
343 p_form: ShareOrPlaintext::Share,
344 y_form: ShareOrPlaintext::Plaintext,
345 ..
346 }
347 | Gate::FieldPlaintextUnaryOp { .. }
348 | Gate::FieldPlaintextBinaryOp { .. }
349 | Gate::BitPlaintextUnaryOp { .. }
350 | Gate::BitPlaintextBinaryOp { .. }
351 | Gate::PointPlaintextUnaryOp { .. }
352 | Gate::PointPlaintextBinaryOp { .. }
353 | Gate::GetDaBitFieldShare { .. }
354 | Gate::GetDaBitSharedBit { .. }
355 | Gate::BitPlaintextToField { .. }
356 | Gate::FieldPlaintextToBit { .. }
357 | Gate::BatchGetIndex { .. }
358 | Gate::CollectToBatch { .. }
359 | Gate::PointFromPlaintextExtendedEdwards { .. }
360 | Gate::PlaintextPointToExtendedEdwards { .. }
361 | Gate::PlaintextKeccakF1600 { .. }
362 | Gate::CompressPlaintextPoint { .. }
363 | Gate::FieldShareBinaryOp {
364 y_form: ShareOrPlaintext::Plaintext,
365 ..
366 }
367 | Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
368 };
369 }
370}
371
372#[derive(Debug, Clone, Default, PartialEq, Eq)]
373pub struct CircuitPreprocessing {
374 pub scalar_singlets: usize,
375 pub scalar_triples: usize,
376 pub base_field_singlets: usize,
377 pub base_field_triples: usize,
378 pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
379 pub bit_singlets: usize,
380 pub bit_triples: usize,
381 pub mersenne107_dabits: usize,
382 pub mersenne107_singlets: usize,
383 pub mersenne107_triples: usize,
384 pub scalar_dabits: usize,
385 pub base_field_dabits: usize,
386}
387
388impl Add for CircuitPreprocessing {
389 type Output = Self;
390
391 fn add(self, other: Self) -> Self::Output {
392 Self {
393 scalar_singlets: self.scalar_singlets + other.scalar_singlets,
394 scalar_triples: self.scalar_triples + other.scalar_triples,
395 base_field_singlets: self.base_field_singlets + other.base_field_singlets,
396 base_field_triples: self.base_field_triples + other.base_field_triples,
397 bit_singlets: self.bit_singlets + other.bit_singlets,
398 bit_triples: self.bit_triples + other.bit_triples,
399 mersenne107_dabits: self.mersenne107_dabits + other.mersenne107_dabits,
400 mersenne107_singlets: self.mersenne107_singlets + other.mersenne107_singlets,
401 mersenne107_triples: self.mersenne107_triples + other.mersenne107_triples,
402 scalar_dabits: self.scalar_dabits + other.scalar_dabits,
403 base_field_dabits: self.base_field_dabits + other.base_field_dabits,
404 base_field_pow_pairs: {
405 let mut combined = self.base_field_pow_pairs;
406 for (k, v) in other.base_field_pow_pairs {
407 *combined.entry(k).or_insert(0) += v;
408 }
409 combined
410 },
411 }
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use std::collections::HashSet;
418
419 use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
420
421 use super::*;
422 use crate::circuit::FieldShareBinaryOp;
423
424 #[test]
425 fn test_ser_gate() {
426 let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
427 x: Label::from(1, 2),
428 y: Label::from(3, 4),
429 y_form: ShareOrPlaintext::Share,
430 op: FieldShareBinaryOp::Add,
431 field_type: FieldType::ScalarField,
432 };
433 let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
434 x: Label::from(1, 2),
435 y: Label::from(3, 4),
436 y_form: ShareOrPlaintext::Plaintext,
437 op: FieldShareBinaryOp::Add,
438 field_type: FieldType::ScalarField,
439 };
440 let point_gate: Gate<C> = Gate::PointShareBinaryOp {
441 p: Label::from(1, 2),
442 y: Label::from(3, 4),
443 p_form: ShareOrPlaintext::Share,
444 y_form: ShareOrPlaintext::Plaintext,
445 op: PointShareBinaryOp::Add,
446 };
447
448 let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
449 let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
450 let point_gate_ser = bincode::serialize(&point_gate).unwrap();
451
452 let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
453 let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
454 let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
455
456 assert_eq!(no_curve_gate, no_curve_gate_de);
457 assert_eq!(scalar_gate, scalar_gate_de);
458 assert_eq!(point_gate, point_gate_de);
459 let set = HashSet::from([
460 no_curve_gate,
461 no_curve_gate_de,
462 scalar_gate,
463 scalar_gate_de,
464 point_gate,
465 point_gate_de,
466 ]);
467 assert_eq!(set.len(), 3)
468 }
469
470 #[test]
471 fn test_circuit_preprocessing_add() {
472 let a = CircuitPreprocessing {
473 scalar_singlets: 1,
474 scalar_triples: 2,
475 base_field_singlets: 3,
476 base_field_triples: 4,
477 bit_singlets: 0,
478 bit_triples: 1,
479 mersenne107_dabits: 0,
480 mersenne107_singlets: 0,
481 mersenne107_triples: 0,
482 scalar_dabits: 1,
483 base_field_dabits: 2,
484 base_field_pow_pairs: vec![
485 (BoxedUint::from(vec![21]), 5),
486 (BoxedUint::from(vec![14]), 6),
487 ]
488 .into_iter()
489 .collect(),
490 };
491 let b = CircuitPreprocessing {
492 scalar_singlets: 2,
493 scalar_triples: 3,
494 base_field_singlets: 0,
495 base_field_triples: 5,
496 bit_singlets: 3,
497 bit_triples: 4,
498 mersenne107_dabits: 0,
499 mersenne107_singlets: 3,
500 mersenne107_triples: 2,
501 scalar_dabits: 2,
502 base_field_dabits: 3,
503 base_field_pow_pairs: vec![
504 (BoxedUint::from(vec![21]), 6),
505 (BoxedUint::from(vec![13]), 7),
506 ]
507 .into_iter()
508 .collect(),
509 };
510
511 let c = a + b;
512
513 assert_eq!(c.scalar_singlets, 3);
514 assert_eq!(c.scalar_triples, 5);
515 assert_eq!(c.base_field_singlets, 3);
516 assert_eq!(c.base_field_triples, 9);
517 assert_eq!(c.bit_singlets, 3);
518 assert_eq!(c.bit_triples, 5);
519 assert_eq!(c.mersenne107_dabits, 0);
520 assert_eq!(c.mersenne107_singlets, 3);
521 assert_eq!(c.mersenne107_triples, 2);
522 assert_eq!(c.scalar_dabits, 3);
523 assert_eq!(c.base_field_dabits, 5);
524 assert_eq!(
525 c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
526 Some(&11)
527 );
528 assert_eq!(
529 c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
530 Some(&6)
531 );
532 assert_eq!(
533 c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
534 Some(&7)
535 );
536 }
537}