core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5    elliptic_curve::{Curve, Point, Scalar},
6    BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    circuit::{
12        AlgebraicType,
13        Batched,
14        BitShareBinaryOp,
15        BitShareUnaryOp,
16        FieldPlaintextBinaryOp,
17        FieldPlaintextUnaryOp,
18        FieldShareBinaryOp,
19        FieldShareUnaryOp,
20        FieldType,
21        Input,
22        PointPlaintextBinaryOp,
23        PointPlaintextUnaryOp,
24        PointShareBinaryOp,
25        PointShareUnaryOp,
26        ShareOrPlaintext,
27    },
28    types::Label,
29};
30
31/// Gate operations, where the operation arguments correspond to _wire_ label.
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38    /// Input a wire
39    Input {
40        input_type: Input<C>,
41    },
42    /// Field share unary operations
43    FieldShareUnaryOp {
44        x: Label,
45        op: FieldShareUnaryOp,
46        field_type: FieldType,
47    },
48    /// Field share binary operations, where the second wire may be a plaintext.
49    FieldShareBinaryOp {
50        x: Label,
51        y: Label,
52        y_form: ShareOrPlaintext,
53        op: FieldShareBinaryOp,
54        field_type: FieldType,
55    },
56    BatchSummation {
57        x: Label,
58        x_form: ShareOrPlaintext,
59        algebraic_type: AlgebraicType,
60    },
61    BitShareUnaryOp {
62        x: Label,
63        op: BitShareUnaryOp,
64    },
65    BitShareBinaryOp {
66        x: Label,
67        y: Label,
68        y_form: ShareOrPlaintext,
69        op: BitShareBinaryOp,
70    },
71    /// Operations with elliptic curve points
72    PointShareUnaryOp {
73        p: Label,
74        op: PointShareUnaryOp,
75    },
76    PointShareBinaryOp {
77        p: Label,
78        y: Label,
79        p_form: ShareOrPlaintext,
80        y_form: ShareOrPlaintext,
81        op: PointShareBinaryOp,
82    },
83    /// Field plaintext unary operations
84    FieldPlaintextUnaryOp {
85        x: Label,
86        op: FieldPlaintextUnaryOp,
87        field_type: FieldType,
88    },
89    /// Field plaintext binary operations
90    FieldPlaintextBinaryOp {
91        x: Label,
92        y: Label,
93        op: FieldPlaintextBinaryOp,
94        field_type: FieldType,
95    },
96    BitPlaintextUnaryOp {
97        x: Label,
98        op: FieldPlaintextUnaryOp,
99    },
100    BitPlaintextBinaryOp {
101        x: Label,
102        y: Label,
103        op: FieldPlaintextBinaryOp,
104    },
105    PointPlaintextUnaryOp {
106        p: Label,
107        op: PointPlaintextUnaryOp,
108    },
109    PointPlaintextBinaryOp {
110        p: Label,
111        y: Label,
112        op: PointPlaintextBinaryOp,
113    },
114    /// Request a daBit
115    DaBit {
116        field_type: FieldType,
117        batched: Batched,
118    },
119    GetDaBitFieldShare {
120        x: Label,
121        field_type: FieldType,
122    },
123    GetDaBitSharedBit {
124        x: Label,
125        field_type: FieldType,
126    },
127    /// Base field exponentiation operation
128    BaseFieldPow {
129        x: Label,
130        exp: BoxedUint,
131    },
132    /// Bit plaintext conversion operations
133    BitPlaintextToField {
134        x: Label,
135        field_type: FieldType,
136    },
137    FieldPlaintextToBit {
138        x: Label,
139        field_type: FieldType,
140    },
141    /// Get the element at a certain index of a batched wire
142    BatchGetIndex {
143        x: Label,
144        x_type: AlgebraicType,
145        x_form: ShareOrPlaintext,
146        index: usize,
147    },
148    CollectToBatch {
149        wires: Vec<Label>,
150        x_type: AlgebraicType,
151        x_form: ShareOrPlaintext,
152    },
153    PointFromPlaintextExtendedEdwards {
154        wires: Vec<Label>,
155    },
156    PlaintextPointToExtendedEdwards {
157        point: Label,
158    },
159    PlaintextKeccakF1600 {
160        wires: Vec<Label>,
161    },
162    CompressPlaintextPoint {
163        point: Label,
164    },
165}
166
167impl<C: Curve> Gate<C> {
168    /// Gathers all the `Label` inside the gate.
169    pub fn get_labels(&self) -> Vec<Label> {
170        let mut labels = Vec::new();
171        self.for_each_label(|label| labels.push(label));
172        labels
173    }
174    pub fn add_to_required_preprocessing(
175        &self,
176        batched: Batched,
177        circuit_preprocessing: &mut CircuitPreprocessing,
178    ) {
179        match self {
180            Gate::Input { input_type } => match input_type {
181                Input::SecretPlaintext {
182                    algebraic_type: AlgebraicType::ScalarField | AlgebraicType::Point,
183                    batched,
184                    ..
185                } => circuit_preprocessing.scalar_singlets += batched.count(),
186                Input::SecretPlaintext {
187                    algebraic_type: AlgebraicType::BaseField,
188                    batched,
189                    ..
190                } => circuit_preprocessing.base_field_singlets += batched.count(),
191                Input::SecretPlaintext {
192                    algebraic_type: AlgebraicType::Mersenne107,
193                    ..
194                } => circuit_preprocessing.mersenne107_singlets += batched.count(),
195                Input::SecretPlaintext {
196                    algebraic_type: AlgebraicType::Bit,
197                    batched,
198                    ..
199                } => circuit_preprocessing.bit_singlets += batched.count(),
200                Input::RandomShare {
201                    algebraic_type,
202                    batched,
203                } => match algebraic_type {
204                    AlgebraicType::ScalarField | AlgebraicType::Point => {
205                        circuit_preprocessing.scalar_singlets += batched.count();
206                    }
207                    AlgebraicType::BaseField => {
208                        circuit_preprocessing.base_field_singlets += batched.count();
209                    }
210                    AlgebraicType::Bit => {
211                        circuit_preprocessing.bit_singlets += batched.count();
212                    }
213                    AlgebraicType::Mersenne107 => {
214                        circuit_preprocessing.mersenne107_singlets += batched.count();
215                    }
216                },
217                Input::Share { .. }
218                | Input::Scalar { .. }
219                | Input::ScalarBatch { .. }
220                | Input::BaseField { .. }
221                | Input::Mersenne107 { .. }
222                | Input::Bit { .. }
223                | Input::Point { .. }
224                | Input::BaseFieldBatch { .. }
225                | Input::Mersenne107Batch { .. }
226                | Input::BitBatch { .. }
227                | Input::PointBatch { .. } => (),
228            },
229            Gate::FieldShareBinaryOp {
230                y_form: ShareOrPlaintext::Share,
231                op,
232                field_type,
233                ..
234            } => match op {
235                FieldShareBinaryOp::Mul => match field_type {
236                    FieldType::ScalarField => {
237                        circuit_preprocessing.scalar_triples += batched.count()
238                    }
239                    FieldType::BaseField => {
240                        circuit_preprocessing.base_field_triples += batched.count()
241                    }
242                    FieldType::Mersenne107 => {
243                        circuit_preprocessing.mersenne107_triples += batched.count()
244                    }
245                },
246                FieldShareBinaryOp::Add => (),
247            },
248            Gate::FieldShareUnaryOp { op, field_type, .. } => match op {
249                FieldShareUnaryOp::MulInverse => match field_type {
250                    FieldType::ScalarField => {
251                        circuit_preprocessing.scalar_triples += batched.count();
252                        circuit_preprocessing.scalar_singlets += batched.count();
253                    }
254                    FieldType::BaseField => {
255                        circuit_preprocessing.base_field_triples += batched.count();
256                        circuit_preprocessing.base_field_singlets += batched.count();
257                    }
258                    FieldType::Mersenne107 => {
259                        circuit_preprocessing.mersenne107_triples += batched.count();
260                        circuit_preprocessing.mersenne107_singlets += batched.count();
261                    }
262                },
263                FieldShareUnaryOp::IsZero => match field_type {
264                    FieldType::ScalarField => {
265                        circuit_preprocessing.scalar_triples += batched.count();
266                        circuit_preprocessing.scalar_singlets += batched.count();
267                    }
268                    FieldType::BaseField => {
269                        circuit_preprocessing.base_field_triples += batched.count();
270                        circuit_preprocessing.base_field_singlets += batched.count();
271                    }
272                    FieldType::Mersenne107 => {
273                        circuit_preprocessing.mersenne107_triples += batched.count();
274                        circuit_preprocessing.mersenne107_singlets += batched.count();
275                    }
276                },
277                FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
278            },
279            Gate::PointShareUnaryOp { op, .. } => match op {
280                PointShareUnaryOp::IsZero => {
281                    circuit_preprocessing.scalar_triples += batched.count();
282                    circuit_preprocessing.scalar_singlets += batched.count();
283                }
284                PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
285            },
286            Gate::PointShareBinaryOp {
287                p_form: ShareOrPlaintext::Share,
288                y_form: ShareOrPlaintext::Share,
289                op,
290                ..
291            } => match op {
292                PointShareBinaryOp::ScalarMul => {
293                    circuit_preprocessing.scalar_triples += batched.count()
294                }
295                PointShareBinaryOp::Add => (),
296            },
297            Gate::BitShareBinaryOp {
298                y_form: ShareOrPlaintext::Share,
299                op,
300                ..
301            } => match op {
302                BitShareBinaryOp::And | BitShareBinaryOp::Or => {
303                    circuit_preprocessing.bit_triples += batched.count();
304                }
305                BitShareBinaryOp::Xor => (),
306            },
307            Gate::BaseFieldPow { exp, .. } => {
308                *circuit_preprocessing
309                    .base_field_pow_pairs
310                    .entry(exp.clone())
311                    .or_insert(0) += batched.count();
312                circuit_preprocessing.base_field_triples += batched.count();
313            }
314            Gate::DaBit {
315                field_type,
316                batched,
317            } => match field_type {
318                FieldType::ScalarField => {
319                    circuit_preprocessing.scalar_dabits += batched.count();
320                }
321                FieldType::BaseField => {
322                    circuit_preprocessing.base_field_dabits += batched.count();
323                }
324                FieldType::Mersenne107 => {
325                    circuit_preprocessing.mersenne107_dabits += batched.count();
326                }
327            },
328            Gate::BatchSummation { .. }
329            | Gate::BitShareUnaryOp { .. }
330            | Gate::BitShareBinaryOp {
331                y_form: ShareOrPlaintext::Plaintext,
332                ..
333            }
334            | Gate::PointShareBinaryOp {
335                p_form: ShareOrPlaintext::Plaintext,
336                ..
337            }
338            | Gate::PointShareBinaryOp {
339                p_form: ShareOrPlaintext::Share,
340                y_form: ShareOrPlaintext::Plaintext,
341                ..
342            }
343            | Gate::FieldPlaintextUnaryOp { .. }
344            | Gate::FieldPlaintextBinaryOp { .. }
345            | Gate::BitPlaintextUnaryOp { .. }
346            | Gate::BitPlaintextBinaryOp { .. }
347            | Gate::PointPlaintextUnaryOp { .. }
348            | Gate::PointPlaintextBinaryOp { .. }
349            | Gate::GetDaBitFieldShare { .. }
350            | Gate::GetDaBitSharedBit { .. }
351            | Gate::BitPlaintextToField { .. }
352            | Gate::FieldPlaintextToBit { .. }
353            | Gate::BatchGetIndex { .. }
354            | Gate::CollectToBatch { .. }
355            | Gate::PointFromPlaintextExtendedEdwards { .. }
356            | Gate::PlaintextPointToExtendedEdwards { .. }
357            | Gate::PlaintextKeccakF1600 { .. }
358            | Gate::CompressPlaintextPoint { .. }
359            | Gate::FieldShareBinaryOp {
360                y_form: ShareOrPlaintext::Plaintext,
361                ..
362            } => (),
363        };
364    }
365}
366
367#[derive(Debug, Clone, Default, PartialEq, Eq)]
368pub struct CircuitPreprocessing {
369    pub scalar_singlets: usize,
370    pub scalar_triples: usize,
371    pub base_field_singlets: usize,
372    pub base_field_triples: usize,
373    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
374    pub bit_singlets: usize,
375    pub bit_triples: usize,
376    pub mersenne107_dabits: usize,
377    pub mersenne107_singlets: usize,
378    pub mersenne107_triples: usize,
379    pub scalar_dabits: usize,
380    pub base_field_dabits: usize,
381}
382
383impl Add for CircuitPreprocessing {
384    type Output = Self;
385
386    fn add(self, other: Self) -> Self::Output {
387        Self {
388            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
389            scalar_triples: self.scalar_triples + other.scalar_triples,
390            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
391            base_field_triples: self.base_field_triples + other.base_field_triples,
392            bit_singlets: self.bit_singlets + other.bit_singlets,
393            bit_triples: self.bit_triples + other.bit_triples,
394            mersenne107_dabits: self.mersenne107_dabits + other.mersenne107_dabits,
395            mersenne107_singlets: self.mersenne107_singlets + other.mersenne107_singlets,
396            mersenne107_triples: self.mersenne107_triples + other.mersenne107_triples,
397            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
398            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
399            base_field_pow_pairs: {
400                let mut combined = self.base_field_pow_pairs;
401                for (k, v) in other.base_field_pow_pairs {
402                    *combined.entry(k).or_insert(0) += v;
403                }
404                combined
405            },
406        }
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use std::collections::HashSet;
413
414    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
415
416    use super::*;
417    use crate::circuit::FieldShareBinaryOp;
418
419    #[test]
420    fn test_ser_gate() {
421        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
422            x: Label::from(1, 2),
423            y: Label::from(3, 4),
424            y_form: ShareOrPlaintext::Share,
425            op: FieldShareBinaryOp::Add,
426            field_type: FieldType::ScalarField,
427        };
428        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
429            x: Label::from(1, 2),
430            y: Label::from(3, 4),
431            y_form: ShareOrPlaintext::Plaintext,
432            op: FieldShareBinaryOp::Add,
433            field_type: FieldType::ScalarField,
434        };
435        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
436            p: Label::from(1, 2),
437            y: Label::from(3, 4),
438            p_form: ShareOrPlaintext::Share,
439            y_form: ShareOrPlaintext::Plaintext,
440            op: PointShareBinaryOp::Add,
441        };
442
443        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
444        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
445        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
446
447        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
448        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
449        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
450
451        assert_eq!(no_curve_gate, no_curve_gate_de);
452        assert_eq!(scalar_gate, scalar_gate_de);
453        assert_eq!(point_gate, point_gate_de);
454        let set = HashSet::from([
455            no_curve_gate,
456            no_curve_gate_de,
457            scalar_gate,
458            scalar_gate_de,
459            point_gate,
460            point_gate_de,
461        ]);
462        assert_eq!(set.len(), 3)
463    }
464
465    #[test]
466    fn test_circuit_preprocessing_add() {
467        let a = CircuitPreprocessing {
468            scalar_singlets: 1,
469            scalar_triples: 2,
470            base_field_singlets: 3,
471            base_field_triples: 4,
472            bit_singlets: 0,
473            bit_triples: 1,
474            mersenne107_dabits: 0,
475            mersenne107_singlets: 0,
476            mersenne107_triples: 0,
477            scalar_dabits: 1,
478            base_field_dabits: 2,
479            base_field_pow_pairs: vec![
480                (BoxedUint::from(vec![21]), 5),
481                (BoxedUint::from(vec![14]), 6),
482            ]
483            .into_iter()
484            .collect(),
485        };
486        let b = CircuitPreprocessing {
487            scalar_singlets: 2,
488            scalar_triples: 3,
489            base_field_singlets: 0,
490            base_field_triples: 5,
491            bit_singlets: 3,
492            bit_triples: 4,
493            mersenne107_dabits: 0,
494            mersenne107_singlets: 3,
495            mersenne107_triples: 2,
496            scalar_dabits: 2,
497            base_field_dabits: 3,
498            base_field_pow_pairs: vec![
499                (BoxedUint::from(vec![21]), 6),
500                (BoxedUint::from(vec![13]), 7),
501            ]
502            .into_iter()
503            .collect(),
504        };
505
506        let c = a + b;
507
508        assert_eq!(c.scalar_singlets, 3);
509        assert_eq!(c.scalar_triples, 5);
510        assert_eq!(c.base_field_singlets, 3);
511        assert_eq!(c.base_field_triples, 9);
512        assert_eq!(c.bit_singlets, 3);
513        assert_eq!(c.bit_triples, 5);
514        assert_eq!(c.mersenne107_dabits, 0);
515        assert_eq!(c.mersenne107_singlets, 3);
516        assert_eq!(c.mersenne107_triples, 2);
517        assert_eq!(c.scalar_dabits, 3);
518        assert_eq!(c.base_field_dabits, 5);
519        assert_eq!(
520            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
521            Some(&11)
522        );
523        assert_eq!(
524            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
525            Some(&6)
526        );
527        assert_eq!(
528            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
529            Some(&7)
530        );
531    }
532}