1use std::{
2 collections::HashMap,
3 ops::{Add, AddAssign, Index, IndexMut},
4};
5
6use macros::GateMethods;
7use primitives::algebra::{
8 elliptic_curve::{Curve, Point, Scalar},
9 BoxedUint,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::circuit::{
14 AlgebraicType,
15 Batched,
16 BitShareBinaryOp,
17 BitShareUnaryOp,
18 FieldPlaintextBinaryOp,
19 FieldPlaintextUnaryOp,
20 FieldShareBinaryOp,
21 FieldShareUnaryOp,
22 FieldType,
23 GateIndex,
24 Input,
25 PointPlaintextBinaryOp,
26 PointPlaintextUnaryOp,
27 PointShareBinaryOp,
28 PointShareUnaryOp,
29 ShareOrPlaintext,
30};
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
34#[serde(bound(
35 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
36 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
37))]
38pub enum Gate<C: Curve> {
39 Input {
41 input_type: Input<C>,
42 },
43 FieldShareUnaryOp {
45 x: GateIndex,
46 op: FieldShareUnaryOp,
47 field_type: FieldType,
48 },
49 FieldShareBinaryOp {
51 x: GateIndex,
52 y: GateIndex,
53 y_form: ShareOrPlaintext,
54 op: FieldShareBinaryOp,
55 field_type: FieldType,
56 },
57 BatchSummation {
58 x: GateIndex,
59 x_form: ShareOrPlaintext,
60 algebraic_type: AlgebraicType,
61 },
62 BitShareUnaryOp {
63 x: GateIndex,
64 op: BitShareUnaryOp,
65 },
66 BitShareBinaryOp {
67 x: GateIndex,
68 y: GateIndex,
69 y_form: ShareOrPlaintext,
70 op: BitShareBinaryOp,
71 },
72 PointShareUnaryOp {
74 p: GateIndex,
75 op: PointShareUnaryOp,
76 },
77 PointShareBinaryOp {
78 p: GateIndex,
79 y: GateIndex,
80 p_form: ShareOrPlaintext,
81 y_form: ShareOrPlaintext,
82 op: PointShareBinaryOp,
83 },
84 FieldPlaintextUnaryOp {
86 x: GateIndex,
87 op: FieldPlaintextUnaryOp,
88 field_type: FieldType,
89 },
90 FieldPlaintextBinaryOp {
92 x: GateIndex,
93 y: GateIndex,
94 op: FieldPlaintextBinaryOp,
95 field_type: FieldType,
96 },
97 BitPlaintextUnaryOp {
98 x: GateIndex,
99 op: FieldPlaintextUnaryOp,
100 },
101 BitPlaintextBinaryOp {
102 x: GateIndex,
103 y: GateIndex,
104 op: FieldPlaintextBinaryOp,
105 },
106 PointPlaintextUnaryOp {
107 p: GateIndex,
108 op: PointPlaintextUnaryOp,
109 },
110 PointPlaintextBinaryOp {
111 p: GateIndex,
112 y: GateIndex,
113 op: PointPlaintextBinaryOp,
114 },
115 DaBit {
117 field_type: FieldType,
118 batched: Batched,
119 },
120 GetDaBitFieldShare {
121 x: GateIndex,
122 field_type: FieldType,
123 },
124 GetDaBitSharedBit {
125 x: GateIndex,
126 field_type: FieldType,
127 },
128 BaseFieldPow {
130 x: GateIndex,
131 exp: BoxedUint,
132 },
133 BitPlaintextToField {
135 x: GateIndex,
136 field_type: FieldType,
137 },
138 FieldPlaintextToBit {
139 x: GateIndex,
140 field_type: FieldType,
141 },
142 BatchGetIndex {
144 x: GateIndex,
145 x_type: AlgebraicType,
146 x_form: ShareOrPlaintext,
147 index: usize,
148 },
149 CollectToBatch {
150 wires: Vec<GateIndex>,
151 x_type: AlgebraicType,
152 x_form: ShareOrPlaintext,
153 },
154 PointFromPlaintextExtendedEdwards {
155 wires: Vec<GateIndex>,
156 },
157 PlaintextPointToExtendedEdwards {
158 point: GateIndex,
159 },
160 PlaintextKeccakF1600 {
161 wires: Vec<GateIndex>,
162 },
163 CompressPlaintextPoint {
164 point: GateIndex,
165 },
166 KeyRecoveryPlaintextComputeErrors {
167 d_minus_one: GateIndex,
168 syndromes: GateIndex,
169 },
170}
171
172impl<C: Curve> Gate<C> {
173 pub fn get_gate_indices(&self) -> Vec<GateIndex> {
175 let mut gate_indices = Vec::new();
176 self.for_each_gate_index(|idx| gate_indices.push(idx));
177 gate_indices
178 }
179
180 pub fn add_to_required_preprocessing(
181 &self,
182 batched: Batched,
183 circuit_preprocessing: &mut CircuitPreprocessing,
184 ) {
185 match self {
186 Gate::Input { input_type } => match input_type {
187 Input::SecretPlaintext {
188 algebraic_type: AlgebraicType::ScalarField | AlgebraicType::Point,
189 batched,
190 ..
191 } => circuit_preprocessing.scalar.singlets += batched.count(),
192 Input::SecretPlaintext {
193 algebraic_type: AlgebraicType::BaseField,
194 batched,
195 ..
196 } => circuit_preprocessing.base_field.singlets += batched.count(),
197 Input::SecretPlaintext {
198 algebraic_type: AlgebraicType::Mersenne107,
199 ..
200 } => circuit_preprocessing.mersenne107.singlets += batched.count(),
201 Input::SecretPlaintext {
202 algebraic_type: AlgebraicType::Bit,
203 batched,
204 ..
205 } => circuit_preprocessing.bit_singlets += batched.count(),
206 Input::RandomShare {
207 algebraic_type,
208 batched,
209 } => match algebraic_type {
210 AlgebraicType::ScalarField | AlgebraicType::Point => {
211 circuit_preprocessing.scalar.singlets += batched.count();
212 }
213 AlgebraicType::BaseField => {
214 circuit_preprocessing.base_field.singlets += batched.count();
215 }
216 AlgebraicType::Bit => {
217 circuit_preprocessing.bit_singlets += batched.count();
218 }
219 AlgebraicType::Mersenne107 => {
220 circuit_preprocessing.mersenne107.singlets += batched.count();
221 }
222 },
223 Input::Share { .. }
224 | Input::Scalar { .. }
225 | Input::ScalarBatch { .. }
226 | Input::BaseField { .. }
227 | Input::Mersenne107 { .. }
228 | Input::Bit { .. }
229 | Input::Point { .. }
230 | Input::BaseFieldBatch { .. }
231 | Input::Mersenne107Batch { .. }
232 | Input::BitBatch { .. }
233 | Input::PointBatch { .. } => (),
234 },
235 Gate::FieldShareBinaryOp {
236 y_form: ShareOrPlaintext::Share,
237 op,
238 field_type,
239 ..
240 } => match op {
241 FieldShareBinaryOp::Mul => {
242 circuit_preprocessing[*field_type].triples += batched.count();
243 }
244 FieldShareBinaryOp::Add => (),
245 },
246 Gate::FieldShareUnaryOp { op, field_type, .. } => match op {
247 FieldShareUnaryOp::MulInverse | FieldShareUnaryOp::IsZero => {
248 circuit_preprocessing[*field_type].triples += batched.count();
249 circuit_preprocessing[*field_type].singlets += batched.count();
250 }
251 FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
252 },
253 Gate::PointShareUnaryOp { op, .. } => match op {
254 PointShareUnaryOp::IsZero => {
255 circuit_preprocessing.scalar.triples += batched.count();
256 circuit_preprocessing.scalar.singlets += batched.count();
257 }
258 PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
259 },
260 Gate::PointShareBinaryOp {
261 p_form: ShareOrPlaintext::Share,
262 y_form: ShareOrPlaintext::Share,
263 op,
264 ..
265 } => match op {
266 PointShareBinaryOp::ScalarMul => {
267 circuit_preprocessing.scalar.triples += batched.count();
268 }
269 PointShareBinaryOp::Add => (),
270 },
271 Gate::BitShareBinaryOp {
272 y_form: ShareOrPlaintext::Share,
273 op,
274 ..
275 } => match op {
276 BitShareBinaryOp::And | BitShareBinaryOp::Or => {
277 circuit_preprocessing.bit_triples += batched.count();
278 }
279 BitShareBinaryOp::Xor => (),
280 },
281 Gate::BaseFieldPow { exp, .. } => {
282 *circuit_preprocessing
283 .base_field_pow_pairs
284 .entry(exp.clone())
285 .or_insert(0) += batched.count();
286 circuit_preprocessing.base_field.triples += batched.count();
287 }
288 Gate::DaBit {
289 field_type,
290 batched,
291 } => circuit_preprocessing[*field_type].dabits += batched.count(),
292 Gate::BatchSummation { .. }
293 | Gate::BitShareUnaryOp { .. }
294 | Gate::BitShareBinaryOp {
295 y_form: ShareOrPlaintext::Plaintext,
296 ..
297 }
298 | Gate::PointShareBinaryOp {
299 p_form: ShareOrPlaintext::Plaintext,
300 ..
301 }
302 | Gate::PointShareBinaryOp {
303 p_form: ShareOrPlaintext::Share,
304 y_form: ShareOrPlaintext::Plaintext,
305 ..
306 }
307 | Gate::FieldPlaintextUnaryOp { .. }
308 | Gate::FieldPlaintextBinaryOp { .. }
309 | Gate::BitPlaintextUnaryOp { .. }
310 | Gate::BitPlaintextBinaryOp { .. }
311 | Gate::PointPlaintextUnaryOp { .. }
312 | Gate::PointPlaintextBinaryOp { .. }
313 | Gate::GetDaBitFieldShare { .. }
314 | Gate::GetDaBitSharedBit { .. }
315 | Gate::BitPlaintextToField { .. }
316 | Gate::FieldPlaintextToBit { .. }
317 | Gate::BatchGetIndex { .. }
318 | Gate::CollectToBatch { .. }
319 | Gate::PointFromPlaintextExtendedEdwards { .. }
320 | Gate::PlaintextPointToExtendedEdwards { .. }
321 | Gate::PlaintextKeccakF1600 { .. }
322 | Gate::CompressPlaintextPoint { .. }
323 | Gate::FieldShareBinaryOp {
324 y_form: ShareOrPlaintext::Plaintext,
325 ..
326 }
327 | Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
328 };
329 }
330}
331#[derive(Debug, Clone, Default, PartialEq, Eq)]
332pub struct FieldCircuitPreprocessing {
333 pub singlets: usize,
334 pub triples: usize,
335 pub dabits: usize,
336}
337
338impl AddAssign for FieldCircuitPreprocessing {
339 fn add_assign(&mut self, rhs: Self) {
340 self.singlets += rhs.singlets;
341 self.triples += rhs.triples;
342 self.dabits += rhs.dabits;
343 }
344}
345
346impl Add for FieldCircuitPreprocessing {
347 type Output = Self;
348 fn add(self, rhs: Self) -> Self::Output {
349 let mut res = self;
350 res += rhs;
351 res
352 }
353}
354
355#[derive(Debug, Clone, Default, PartialEq, Eq)]
356pub struct CircuitPreprocessing {
357 pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
358 pub bit_singlets: usize,
359 pub bit_triples: usize,
360 pub base_field: FieldCircuitPreprocessing,
361 pub scalar: FieldCircuitPreprocessing,
362 pub mersenne107: FieldCircuitPreprocessing,
363}
364
365impl AddAssign for CircuitPreprocessing {
366 fn add_assign(&mut self, rhs: Self) {
367 self.bit_singlets += rhs.bit_singlets;
368 self.bit_triples += rhs.bit_triples;
369 self.base_field += rhs.base_field;
370 self.scalar += rhs.scalar;
371 self.mersenne107 += rhs.mersenne107;
372 for (k, v) in rhs.base_field_pow_pairs {
373 *self.base_field_pow_pairs.entry(k).or_insert(0) += v;
374 }
375 }
376}
377
378impl Add for CircuitPreprocessing {
379 type Output = Self;
380
381 fn add(self, other: Self) -> Self::Output {
382 let mut res = self;
383 res += other;
384 res
385 }
386}
387
388impl Index<FieldType> for CircuitPreprocessing {
389 type Output = FieldCircuitPreprocessing;
390
391 fn index(&self, index: FieldType) -> &Self::Output {
392 match index {
393 FieldType::BaseField => &self.base_field,
394 FieldType::ScalarField => &self.scalar,
395 FieldType::Mersenne107 => &self.mersenne107,
396 }
397 }
398}
399
400impl IndexMut<FieldType> for CircuitPreprocessing {
401 fn index_mut(&mut self, index: FieldType) -> &mut Self::Output {
402 match index {
403 FieldType::BaseField => &mut self.base_field,
404 FieldType::ScalarField => &mut self.scalar,
405 FieldType::Mersenne107 => &mut self.mersenne107,
406 }
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use std::collections::HashSet;
413
414 use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
415
416 use super::*;
417 use crate::circuit::FieldShareBinaryOp;
418
419 #[test]
420 fn test_ser_gate() {
421 let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
422 x: 1,
423 y: 3,
424 y_form: ShareOrPlaintext::Share,
425 op: FieldShareBinaryOp::Add,
426 field_type: FieldType::ScalarField,
427 };
428 let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
429 x: 1,
430 y: 3,
431 y_form: ShareOrPlaintext::Plaintext,
432 op: FieldShareBinaryOp::Add,
433 field_type: FieldType::ScalarField,
434 };
435 let point_gate: Gate<C> = Gate::PointShareBinaryOp {
436 p: 1,
437 y: 3,
438 p_form: ShareOrPlaintext::Share,
439 y_form: ShareOrPlaintext::Plaintext,
440 op: PointShareBinaryOp::Add,
441 };
442
443 let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
444 let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
445 let point_gate_ser = bincode::serialize(&point_gate).unwrap();
446
447 let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
448 let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
449 let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
450
451 assert_eq!(no_curve_gate, no_curve_gate_de);
452 assert_eq!(scalar_gate, scalar_gate_de);
453 assert_eq!(point_gate, point_gate_de);
454 let set = HashSet::from([
455 no_curve_gate,
456 no_curve_gate_de,
457 scalar_gate,
458 scalar_gate_de,
459 point_gate,
460 point_gate_de,
461 ]);
462 assert_eq!(set.len(), 3)
463 }
464
465 #[test]
466 fn test_circuit_preprocessing_add() {
467 let a = CircuitPreprocessing {
468 bit_singlets: 0,
469 bit_triples: 1,
470 base_field: FieldCircuitPreprocessing {
471 singlets: 3,
472 triples: 4,
473 dabits: 2,
474 },
475 scalar: FieldCircuitPreprocessing {
476 singlets: 1,
477 triples: 2,
478 dabits: 1,
479 },
480 base_field_pow_pairs: vec![
481 (BoxedUint::from(vec![21]), 5),
482 (BoxedUint::from(vec![14]), 6),
483 ]
484 .into_iter()
485 .collect(),
486 mersenne107: FieldCircuitPreprocessing {
487 singlets: 0,
488 triples: 0,
489 dabits: 0,
490 },
491 };
492 let b = CircuitPreprocessing {
493 bit_singlets: 3,
494 bit_triples: 4,
495 base_field: FieldCircuitPreprocessing {
496 singlets: 0,
497 triples: 5,
498 dabits: 3,
499 },
500 scalar: FieldCircuitPreprocessing {
501 singlets: 2,
502 triples: 3,
503 dabits: 2,
504 },
505 base_field_pow_pairs: vec![
506 (BoxedUint::from(vec![21]), 6),
507 (BoxedUint::from(vec![13]), 7),
508 ]
509 .into_iter()
510 .collect(),
511 mersenne107: FieldCircuitPreprocessing {
512 singlets: 3,
513 triples: 2,
514 dabits: 0,
515 },
516 };
517
518 let c = a + b;
519
520 assert_eq!(c.scalar.singlets, 3);
521 assert_eq!(c.scalar.triples, 5);
522 assert_eq!(c.base_field.singlets, 3);
523 assert_eq!(c.base_field.triples, 9);
524 assert_eq!(c.bit_singlets, 3);
525 assert_eq!(c.bit_triples, 5);
526 assert_eq!(c.mersenne107.dabits, 0);
527 assert_eq!(c.mersenne107.singlets, 3);
528 assert_eq!(c.mersenne107.triples, 2);
529 assert_eq!(c.scalar.dabits, 3);
530 assert_eq!(c.base_field.dabits, 5);
531 assert_eq!(
532 c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
533 Some(&11)
534 );
535 assert_eq!(
536 c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
537 Some(&6)
538 );
539 assert_eq!(
540 c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
541 Some(&7)
542 );
543 }
544}