Skip to main content

core_utils/circuit/
gate.rs

1use std::{
2    collections::HashMap,
3    ops::{Add, AddAssign, Index, IndexMut},
4};
5
6use macros::GateMethods;
7use primitives::algebra::{
8    elliptic_curve::{Curve, Point, Scalar},
9    BoxedUint,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::circuit::{
14    AlgebraicType,
15    Batched,
16    BitShareBinaryOp,
17    BitShareUnaryOp,
18    FieldPlaintextBinaryOp,
19    FieldPlaintextUnaryOp,
20    FieldShareBinaryOp,
21    FieldShareUnaryOp,
22    FieldType,
23    GateIndex,
24    Input,
25    PointPlaintextBinaryOp,
26    PointPlaintextUnaryOp,
27    PointShareBinaryOp,
28    PointShareUnaryOp,
29    ShareOrPlaintext,
30};
31
32/// Gate operations, where the operation arguments correspond to _wire_ label.
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
34#[serde(bound(
35    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
36    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
37))]
38pub enum Gate<C: Curve> {
39    /// Input a wire
40    Input {
41        input_type: Input<C>,
42    },
43    /// Field share unary operations
44    FieldShareUnaryOp {
45        x: GateIndex,
46        op: FieldShareUnaryOp,
47        field_type: FieldType,
48    },
49    /// Field share binary operations, where the second wire may be a plaintext.
50    FieldShareBinaryOp {
51        x: GateIndex,
52        y: GateIndex,
53        y_form: ShareOrPlaintext,
54        op: FieldShareBinaryOp,
55        field_type: FieldType,
56    },
57    BatchSummation {
58        x: GateIndex,
59        x_form: ShareOrPlaintext,
60        algebraic_type: AlgebraicType,
61    },
62    BitShareUnaryOp {
63        x: GateIndex,
64        op: BitShareUnaryOp,
65    },
66    BitShareBinaryOp {
67        x: GateIndex,
68        y: GateIndex,
69        y_form: ShareOrPlaintext,
70        op: BitShareBinaryOp,
71    },
72    /// Operations with elliptic curve points
73    PointShareUnaryOp {
74        p: GateIndex,
75        op: PointShareUnaryOp,
76    },
77    PointShareBinaryOp {
78        p: GateIndex,
79        y: GateIndex,
80        p_form: ShareOrPlaintext,
81        y_form: ShareOrPlaintext,
82        op: PointShareBinaryOp,
83    },
84    /// Field plaintext unary operations
85    FieldPlaintextUnaryOp {
86        x: GateIndex,
87        op: FieldPlaintextUnaryOp,
88        field_type: FieldType,
89    },
90    /// Field plaintext binary operations
91    FieldPlaintextBinaryOp {
92        x: GateIndex,
93        y: GateIndex,
94        op: FieldPlaintextBinaryOp,
95        field_type: FieldType,
96    },
97    BitPlaintextUnaryOp {
98        x: GateIndex,
99        op: FieldPlaintextUnaryOp,
100    },
101    BitPlaintextBinaryOp {
102        x: GateIndex,
103        y: GateIndex,
104        op: FieldPlaintextBinaryOp,
105    },
106    PointPlaintextUnaryOp {
107        p: GateIndex,
108        op: PointPlaintextUnaryOp,
109    },
110    PointPlaintextBinaryOp {
111        p: GateIndex,
112        y: GateIndex,
113        op: PointPlaintextBinaryOp,
114    },
115    /// Request a daBit
116    DaBit {
117        field_type: FieldType,
118        batched: Batched,
119    },
120    GetDaBitFieldShare {
121        x: GateIndex,
122        field_type: FieldType,
123    },
124    GetDaBitSharedBit {
125        x: GateIndex,
126        field_type: FieldType,
127    },
128    /// Base field exponentiation operation
129    BaseFieldPow {
130        x: GateIndex,
131        exp: BoxedUint,
132    },
133    /// Bit plaintext conversion operations
134    BitPlaintextToField {
135        x: GateIndex,
136        field_type: FieldType,
137    },
138    FieldPlaintextToBit {
139        x: GateIndex,
140        field_type: FieldType,
141    },
142    /// Get the element at a certain index of a batched wire
143    BatchGetIndex {
144        x: GateIndex,
145        x_type: AlgebraicType,
146        x_form: ShareOrPlaintext,
147        index: usize,
148    },
149    CollectToBatch {
150        wires: Vec<GateIndex>,
151        x_type: AlgebraicType,
152        x_form: ShareOrPlaintext,
153    },
154    PointFromPlaintextExtendedEdwards {
155        wires: Vec<GateIndex>,
156    },
157    PlaintextPointToExtendedEdwards {
158        point: GateIndex,
159    },
160    PlaintextKeccakF1600 {
161        wires: Vec<GateIndex>,
162    },
163    CompressPlaintextPoint {
164        point: GateIndex,
165    },
166    KeyRecoveryPlaintextComputeErrors {
167        d_minus_one: GateIndex,
168        syndromes: GateIndex,
169    },
170}
171
172impl<C: Curve> Gate<C> {
173    /// Gathers all the `GateIndex` inside the gate.
174    pub fn get_gate_indices(&self) -> Vec<GateIndex> {
175        let mut gate_indices = Vec::new();
176        self.for_each_gate_index(|idx| gate_indices.push(idx));
177        gate_indices
178    }
179
180    pub fn add_to_required_preprocessing(
181        &self,
182        batched: Batched,
183        circuit_preprocessing: &mut CircuitPreprocessing,
184    ) {
185        match self {
186            Gate::Input { input_type } => match input_type {
187                Input::SecretPlaintext {
188                    algebraic_type: AlgebraicType::ScalarField | AlgebraicType::Point,
189                    batched,
190                    ..
191                } => circuit_preprocessing.scalar.singlets += batched.count(),
192                Input::SecretPlaintext {
193                    algebraic_type: AlgebraicType::BaseField,
194                    batched,
195                    ..
196                } => circuit_preprocessing.base_field.singlets += batched.count(),
197                Input::SecretPlaintext {
198                    algebraic_type: AlgebraicType::Mersenne107,
199                    ..
200                } => circuit_preprocessing.mersenne107.singlets += batched.count(),
201                Input::SecretPlaintext {
202                    algebraic_type: AlgebraicType::Bit,
203                    batched,
204                    ..
205                } => circuit_preprocessing.bit_singlets += batched.count(),
206                Input::RandomShare {
207                    algebraic_type,
208                    batched,
209                } => match algebraic_type {
210                    AlgebraicType::ScalarField | AlgebraicType::Point => {
211                        circuit_preprocessing.scalar.singlets += batched.count();
212                    }
213                    AlgebraicType::BaseField => {
214                        circuit_preprocessing.base_field.singlets += batched.count();
215                    }
216                    AlgebraicType::Bit => {
217                        circuit_preprocessing.bit_singlets += batched.count();
218                    }
219                    AlgebraicType::Mersenne107 => {
220                        circuit_preprocessing.mersenne107.singlets += batched.count();
221                    }
222                },
223                Input::Share { .. }
224                | Input::Scalar { .. }
225                | Input::ScalarBatch { .. }
226                | Input::BaseField { .. }
227                | Input::Mersenne107 { .. }
228                | Input::Bit { .. }
229                | Input::Point { .. }
230                | Input::BaseFieldBatch { .. }
231                | Input::Mersenne107Batch { .. }
232                | Input::BitBatch { .. }
233                | Input::PointBatch { .. } => (),
234            },
235            Gate::FieldShareBinaryOp {
236                y_form: ShareOrPlaintext::Share,
237                op,
238                field_type,
239                ..
240            } => match op {
241                FieldShareBinaryOp::Mul => {
242                    circuit_preprocessing[*field_type].triples += batched.count();
243                }
244                FieldShareBinaryOp::Add => (),
245            },
246            Gate::FieldShareUnaryOp { op, field_type, .. } => match op {
247                FieldShareUnaryOp::MulInverse | FieldShareUnaryOp::IsZero => {
248                    circuit_preprocessing[*field_type].triples += batched.count();
249                    circuit_preprocessing[*field_type].singlets += batched.count();
250                }
251                FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
252            },
253            Gate::PointShareUnaryOp { op, .. } => match op {
254                PointShareUnaryOp::IsZero => {
255                    circuit_preprocessing.scalar.triples += batched.count();
256                    circuit_preprocessing.scalar.singlets += batched.count();
257                }
258                PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
259            },
260            Gate::PointShareBinaryOp {
261                p_form: ShareOrPlaintext::Share,
262                y_form: ShareOrPlaintext::Share,
263                op,
264                ..
265            } => match op {
266                PointShareBinaryOp::ScalarMul => {
267                    circuit_preprocessing.scalar.triples += batched.count();
268                }
269                PointShareBinaryOp::Add => (),
270            },
271            Gate::BitShareBinaryOp {
272                y_form: ShareOrPlaintext::Share,
273                op,
274                ..
275            } => match op {
276                BitShareBinaryOp::And | BitShareBinaryOp::Or => {
277                    circuit_preprocessing.bit_triples += batched.count();
278                }
279                BitShareBinaryOp::Xor => (),
280            },
281            Gate::BaseFieldPow { exp, .. } => {
282                *circuit_preprocessing
283                    .base_field_pow_pairs
284                    .entry(exp.clone())
285                    .or_insert(0) += batched.count();
286                circuit_preprocessing.base_field.triples += batched.count();
287            }
288            Gate::DaBit {
289                field_type,
290                batched,
291            } => circuit_preprocessing[*field_type].dabits += batched.count(),
292            Gate::BatchSummation { .. }
293            | Gate::BitShareUnaryOp { .. }
294            | Gate::BitShareBinaryOp {
295                y_form: ShareOrPlaintext::Plaintext,
296                ..
297            }
298            | Gate::PointShareBinaryOp {
299                p_form: ShareOrPlaintext::Plaintext,
300                ..
301            }
302            | Gate::PointShareBinaryOp {
303                p_form: ShareOrPlaintext::Share,
304                y_form: ShareOrPlaintext::Plaintext,
305                ..
306            }
307            | Gate::FieldPlaintextUnaryOp { .. }
308            | Gate::FieldPlaintextBinaryOp { .. }
309            | Gate::BitPlaintextUnaryOp { .. }
310            | Gate::BitPlaintextBinaryOp { .. }
311            | Gate::PointPlaintextUnaryOp { .. }
312            | Gate::PointPlaintextBinaryOp { .. }
313            | Gate::GetDaBitFieldShare { .. }
314            | Gate::GetDaBitSharedBit { .. }
315            | Gate::BitPlaintextToField { .. }
316            | Gate::FieldPlaintextToBit { .. }
317            | Gate::BatchGetIndex { .. }
318            | Gate::CollectToBatch { .. }
319            | Gate::PointFromPlaintextExtendedEdwards { .. }
320            | Gate::PlaintextPointToExtendedEdwards { .. }
321            | Gate::PlaintextKeccakF1600 { .. }
322            | Gate::CompressPlaintextPoint { .. }
323            | Gate::FieldShareBinaryOp {
324                y_form: ShareOrPlaintext::Plaintext,
325                ..
326            }
327            | Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
328        };
329    }
330}
331#[derive(Debug, Clone, Default, PartialEq, Eq)]
332pub struct FieldCircuitPreprocessing {
333    pub singlets: usize,
334    pub triples: usize,
335    pub dabits: usize,
336}
337
338impl AddAssign for FieldCircuitPreprocessing {
339    fn add_assign(&mut self, rhs: Self) {
340        self.singlets += rhs.singlets;
341        self.triples += rhs.triples;
342        self.dabits += rhs.dabits;
343    }
344}
345
346impl Add for FieldCircuitPreprocessing {
347    type Output = Self;
348    fn add(self, rhs: Self) -> Self::Output {
349        let mut res = self;
350        res += rhs;
351        res
352    }
353}
354
355#[derive(Debug, Clone, Default, PartialEq, Eq)]
356pub struct CircuitPreprocessing {
357    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
358    pub bit_singlets: usize,
359    pub bit_triples: usize,
360    pub base_field: FieldCircuitPreprocessing,
361    pub scalar: FieldCircuitPreprocessing,
362    pub mersenne107: FieldCircuitPreprocessing,
363}
364
365impl AddAssign for CircuitPreprocessing {
366    fn add_assign(&mut self, rhs: Self) {
367        self.bit_singlets += rhs.bit_singlets;
368        self.bit_triples += rhs.bit_triples;
369        self.base_field += rhs.base_field;
370        self.scalar += rhs.scalar;
371        self.mersenne107 += rhs.mersenne107;
372        for (k, v) in rhs.base_field_pow_pairs {
373            *self.base_field_pow_pairs.entry(k).or_insert(0) += v;
374        }
375    }
376}
377
378impl Add for CircuitPreprocessing {
379    type Output = Self;
380
381    fn add(self, other: Self) -> Self::Output {
382        let mut res = self;
383        res += other;
384        res
385    }
386}
387
388impl Index<FieldType> for CircuitPreprocessing {
389    type Output = FieldCircuitPreprocessing;
390
391    fn index(&self, index: FieldType) -> &Self::Output {
392        match index {
393            FieldType::BaseField => &self.base_field,
394            FieldType::ScalarField => &self.scalar,
395            FieldType::Mersenne107 => &self.mersenne107,
396        }
397    }
398}
399
400impl IndexMut<FieldType> for CircuitPreprocessing {
401    fn index_mut(&mut self, index: FieldType) -> &mut Self::Output {
402        match index {
403            FieldType::BaseField => &mut self.base_field,
404            FieldType::ScalarField => &mut self.scalar,
405            FieldType::Mersenne107 => &mut self.mersenne107,
406        }
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use std::collections::HashSet;
413
414    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
415
416    use super::*;
417    use crate::circuit::FieldShareBinaryOp;
418
419    #[test]
420    fn test_ser_gate() {
421        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
422            x: 1,
423            y: 3,
424            y_form: ShareOrPlaintext::Share,
425            op: FieldShareBinaryOp::Add,
426            field_type: FieldType::ScalarField,
427        };
428        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
429            x: 1,
430            y: 3,
431            y_form: ShareOrPlaintext::Plaintext,
432            op: FieldShareBinaryOp::Add,
433            field_type: FieldType::ScalarField,
434        };
435        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
436            p: 1,
437            y: 3,
438            p_form: ShareOrPlaintext::Share,
439            y_form: ShareOrPlaintext::Plaintext,
440            op: PointShareBinaryOp::Add,
441        };
442
443        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
444        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
445        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
446
447        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
448        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
449        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
450
451        assert_eq!(no_curve_gate, no_curve_gate_de);
452        assert_eq!(scalar_gate, scalar_gate_de);
453        assert_eq!(point_gate, point_gate_de);
454        let set = HashSet::from([
455            no_curve_gate,
456            no_curve_gate_de,
457            scalar_gate,
458            scalar_gate_de,
459            point_gate,
460            point_gate_de,
461        ]);
462        assert_eq!(set.len(), 3)
463    }
464
465    #[test]
466    fn test_circuit_preprocessing_add() {
467        let a = CircuitPreprocessing {
468            bit_singlets: 0,
469            bit_triples: 1,
470            base_field: FieldCircuitPreprocessing {
471                singlets: 3,
472                triples: 4,
473                dabits: 2,
474            },
475            scalar: FieldCircuitPreprocessing {
476                singlets: 1,
477                triples: 2,
478                dabits: 1,
479            },
480            base_field_pow_pairs: vec![
481                (BoxedUint::from(vec![21]), 5),
482                (BoxedUint::from(vec![14]), 6),
483            ]
484            .into_iter()
485            .collect(),
486            mersenne107: FieldCircuitPreprocessing {
487                singlets: 0,
488                triples: 0,
489                dabits: 0,
490            },
491        };
492        let b = CircuitPreprocessing {
493            bit_singlets: 3,
494            bit_triples: 4,
495            base_field: FieldCircuitPreprocessing {
496                singlets: 0,
497                triples: 5,
498                dabits: 3,
499            },
500            scalar: FieldCircuitPreprocessing {
501                singlets: 2,
502                triples: 3,
503                dabits: 2,
504            },
505            base_field_pow_pairs: vec![
506                (BoxedUint::from(vec![21]), 6),
507                (BoxedUint::from(vec![13]), 7),
508            ]
509            .into_iter()
510            .collect(),
511            mersenne107: FieldCircuitPreprocessing {
512                singlets: 3,
513                triples: 2,
514                dabits: 0,
515            },
516        };
517
518        let c = a + b;
519
520        assert_eq!(c.scalar.singlets, 3);
521        assert_eq!(c.scalar.triples, 5);
522        assert_eq!(c.base_field.singlets, 3);
523        assert_eq!(c.base_field.triples, 9);
524        assert_eq!(c.bit_singlets, 3);
525        assert_eq!(c.bit_triples, 5);
526        assert_eq!(c.mersenne107.dabits, 0);
527        assert_eq!(c.mersenne107.singlets, 3);
528        assert_eq!(c.mersenne107.triples, 2);
529        assert_eq!(c.scalar.dabits, 3);
530        assert_eq!(c.base_field.dabits, 5);
531        assert_eq!(
532            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
533            Some(&11)
534        );
535        assert_eq!(
536            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
537            Some(&6)
538        );
539        assert_eq!(
540            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
541            Some(&7)
542        );
543    }
544}