core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5    elliptic_curve::{Curve, Point, Scalar},
6    BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    circuit::{
12        AlgebraicType,
13        Batched,
14        BitShareBinaryOp,
15        BitShareUnaryOp,
16        FieldPlaintextBinaryOp,
17        FieldPlaintextUnaryOp,
18        FieldShareBinaryOp,
19        FieldShareUnaryOp,
20        FieldType,
21        Input,
22        PointPlaintextBinaryOp,
23        PointPlaintextUnaryOp,
24        PointShareBinaryOp,
25        PointShareUnaryOp,
26        ShareOrPlaintext,
27    },
28    types::Label,
29};
30
31/// Gate operations, where the operation arguments correspond to _wire_ label.
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38    /// Input a wire
39    Input {
40        input_type: Input<C>,
41    },
42    /// Field share unary operations
43    FieldShareUnaryOp {
44        x: Label,
45        op: FieldShareUnaryOp,
46        field_type: FieldType,
47    },
48    /// Field share binary operations, where the second wire may be a plaintext.
49    FieldShareBinaryOp {
50        x: Label,
51        y: Label,
52        y_form: ShareOrPlaintext,
53        op: FieldShareBinaryOp,
54        field_type: FieldType,
55    },
56    BatchSummation {
57        x: Label,
58        x_form: ShareOrPlaintext,
59        algebraic_type: AlgebraicType,
60    },
61    BitShareUnaryOp {
62        x: Label,
63        op: BitShareUnaryOp,
64    },
65    BitShareBinaryOp {
66        x: Label,
67        y: Label,
68        y_form: ShareOrPlaintext,
69        op: BitShareBinaryOp,
70    },
71    /// Operations with elliptic curve points
72    PointShareUnaryOp {
73        p: Label,
74        op: PointShareUnaryOp,
75    },
76    PointShareBinaryOp {
77        p: Label,
78        y: Label,
79        p_form: ShareOrPlaintext,
80        y_form: ShareOrPlaintext,
81        op: PointShareBinaryOp,
82    },
83    /// Field plaintext unary operations
84    FieldPlaintextUnaryOp {
85        x: Label,
86        op: FieldPlaintextUnaryOp,
87        field_type: FieldType,
88    },
89    /// Field plaintext binary operations
90    FieldPlaintextBinaryOp {
91        x: Label,
92        y: Label,
93        op: FieldPlaintextBinaryOp,
94        field_type: FieldType,
95    },
96    BitPlaintextUnaryOp {
97        x: Label,
98        op: FieldPlaintextUnaryOp,
99    },
100    BitPlaintextBinaryOp {
101        x: Label,
102        y: Label,
103        op: FieldPlaintextBinaryOp,
104    },
105    PointPlaintextUnaryOp {
106        p: Label,
107        op: PointPlaintextUnaryOp,
108    },
109    PointPlaintextBinaryOp {
110        p: Label,
111        y: Label,
112        op: PointPlaintextBinaryOp,
113    },
114    /// Request a daBit
115    DaBit {
116        field_type: FieldType,
117        batched: Batched,
118    },
119    GetDaBitFieldShare {
120        x: Label,
121        field_type: FieldType,
122    },
123    GetDaBitSharedBit {
124        x: Label,
125        field_type: FieldType,
126    },
127    /// Base field exponentiation operation
128    BaseFieldPow {
129        x: Label,
130        exp: BoxedUint,
131    },
132    /// Bit plaintext conversion operations
133    BitPlaintextToField {
134        x: Label,
135        field_type: FieldType,
136    },
137    FieldPlaintextToBit {
138        x: Label,
139        field_type: FieldType,
140    },
141    /// Get the element at a certain index of a batched wire
142    BatchGetIndex {
143        x: Label,
144        x_type: AlgebraicType,
145        x_form: ShareOrPlaintext,
146        index: usize,
147    },
148    CollectToBatch {
149        wires: Vec<Label>,
150        x_type: AlgebraicType,
151        x_form: ShareOrPlaintext,
152    },
153    PointFromPlaintextExtendedEdwards {
154        wires: Vec<Label>,
155    },
156    PlaintextPointToExtendedEdwards {
157        point: Label,
158    },
159    PlaintextKeccakF1600 {
160        wires: Vec<Label>,
161    },
162    CompressPlaintextPoint {
163        point: Label,
164    },
165    KeyRecoveryPlaintextComputeErrors {
166        d_minus_one: Label,
167        syndromes: Label,
168    },
169}
170
171impl<C: Curve> Gate<C> {
172    /// Gathers all the `Label` inside the gate.
173    pub fn get_labels(&self) -> Vec<Label> {
174        let mut labels = Vec::new();
175        self.for_each_label(|label| labels.push(label));
176        labels
177    }
178    pub fn add_to_required_preprocessing(
179        &self,
180        batched: Batched,
181        circuit_preprocessing: &mut CircuitPreprocessing,
182    ) {
183        match self {
184            Gate::Input { input_type } => match input_type {
185                Input::SecretPlaintext {
186                    algebraic_type: AlgebraicType::ScalarField | AlgebraicType::Point,
187                    batched,
188                    ..
189                } => circuit_preprocessing.scalar_singlets += batched.count(),
190                Input::SecretPlaintext {
191                    algebraic_type: AlgebraicType::BaseField,
192                    batched,
193                    ..
194                } => circuit_preprocessing.base_field_singlets += batched.count(),
195                Input::SecretPlaintext {
196                    algebraic_type: AlgebraicType::Mersenne107,
197                    ..
198                } => circuit_preprocessing.mersenne107_singlets += batched.count(),
199                Input::SecretPlaintext {
200                    algebraic_type: AlgebraicType::Bit,
201                    batched,
202                    ..
203                } => circuit_preprocessing.bit_singlets += batched.count(),
204                Input::RandomShare {
205                    algebraic_type,
206                    batched,
207                } => match algebraic_type {
208                    AlgebraicType::ScalarField | AlgebraicType::Point => {
209                        circuit_preprocessing.scalar_singlets += batched.count();
210                    }
211                    AlgebraicType::BaseField => {
212                        circuit_preprocessing.base_field_singlets += batched.count();
213                    }
214                    AlgebraicType::Bit => {
215                        circuit_preprocessing.bit_singlets += batched.count();
216                    }
217                    AlgebraicType::Mersenne107 => {
218                        circuit_preprocessing.mersenne107_singlets += batched.count();
219                    }
220                },
221                Input::Share { .. }
222                | Input::Scalar { .. }
223                | Input::ScalarBatch { .. }
224                | Input::BaseField { .. }
225                | Input::Mersenne107 { .. }
226                | Input::Bit { .. }
227                | Input::Point { .. }
228                | Input::BaseFieldBatch { .. }
229                | Input::Mersenne107Batch { .. }
230                | Input::BitBatch { .. }
231                | Input::PointBatch { .. } => (),
232            },
233            Gate::FieldShareBinaryOp {
234                y_form: ShareOrPlaintext::Share,
235                op,
236                field_type,
237                ..
238            } => match op {
239                FieldShareBinaryOp::Mul => match field_type {
240                    FieldType::ScalarField => {
241                        circuit_preprocessing.scalar_triples += batched.count()
242                    }
243                    FieldType::BaseField => {
244                        circuit_preprocessing.base_field_triples += batched.count()
245                    }
246                    FieldType::Mersenne107 => {
247                        circuit_preprocessing.mersenne107_triples += batched.count()
248                    }
249                },
250                FieldShareBinaryOp::Add => (),
251            },
252            Gate::FieldShareUnaryOp { op, field_type, .. } => match op {
253                FieldShareUnaryOp::MulInverse => match field_type {
254                    FieldType::ScalarField => {
255                        circuit_preprocessing.scalar_triples += batched.count();
256                        circuit_preprocessing.scalar_singlets += batched.count();
257                    }
258                    FieldType::BaseField => {
259                        circuit_preprocessing.base_field_triples += batched.count();
260                        circuit_preprocessing.base_field_singlets += batched.count();
261                    }
262                    FieldType::Mersenne107 => {
263                        circuit_preprocessing.mersenne107_triples += batched.count();
264                        circuit_preprocessing.mersenne107_singlets += batched.count();
265                    }
266                },
267                FieldShareUnaryOp::IsZero => match field_type {
268                    FieldType::ScalarField => {
269                        circuit_preprocessing.scalar_triples += batched.count();
270                        circuit_preprocessing.scalar_singlets += batched.count();
271                    }
272                    FieldType::BaseField => {
273                        circuit_preprocessing.base_field_triples += batched.count();
274                        circuit_preprocessing.base_field_singlets += batched.count();
275                    }
276                    FieldType::Mersenne107 => {
277                        circuit_preprocessing.mersenne107_triples += batched.count();
278                        circuit_preprocessing.mersenne107_singlets += batched.count();
279                    }
280                },
281                FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
282            },
283            Gate::PointShareUnaryOp { op, .. } => match op {
284                PointShareUnaryOp::IsZero => {
285                    circuit_preprocessing.scalar_triples += batched.count();
286                    circuit_preprocessing.scalar_singlets += batched.count();
287                }
288                PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
289            },
290            Gate::PointShareBinaryOp {
291                p_form: ShareOrPlaintext::Share,
292                y_form: ShareOrPlaintext::Share,
293                op,
294                ..
295            } => match op {
296                PointShareBinaryOp::ScalarMul => {
297                    circuit_preprocessing.scalar_triples += batched.count()
298                }
299                PointShareBinaryOp::Add => (),
300            },
301            Gate::BitShareBinaryOp {
302                y_form: ShareOrPlaintext::Share,
303                op,
304                ..
305            } => match op {
306                BitShareBinaryOp::And | BitShareBinaryOp::Or => {
307                    circuit_preprocessing.bit_triples += batched.count();
308                }
309                BitShareBinaryOp::Xor => (),
310            },
311            Gate::BaseFieldPow { exp, .. } => {
312                *circuit_preprocessing
313                    .base_field_pow_pairs
314                    .entry(exp.clone())
315                    .or_insert(0) += batched.count();
316                circuit_preprocessing.base_field_triples += batched.count();
317            }
318            Gate::DaBit {
319                field_type,
320                batched,
321            } => match field_type {
322                FieldType::ScalarField => {
323                    circuit_preprocessing.scalar_dabits += batched.count();
324                }
325                FieldType::BaseField => {
326                    circuit_preprocessing.base_field_dabits += batched.count();
327                }
328                FieldType::Mersenne107 => {
329                    circuit_preprocessing.mersenne107_dabits += batched.count();
330                }
331            },
332            Gate::BatchSummation { .. }
333            | Gate::BitShareUnaryOp { .. }
334            | Gate::BitShareBinaryOp {
335                y_form: ShareOrPlaintext::Plaintext,
336                ..
337            }
338            | Gate::PointShareBinaryOp {
339                p_form: ShareOrPlaintext::Plaintext,
340                ..
341            }
342            | Gate::PointShareBinaryOp {
343                p_form: ShareOrPlaintext::Share,
344                y_form: ShareOrPlaintext::Plaintext,
345                ..
346            }
347            | Gate::FieldPlaintextUnaryOp { .. }
348            | Gate::FieldPlaintextBinaryOp { .. }
349            | Gate::BitPlaintextUnaryOp { .. }
350            | Gate::BitPlaintextBinaryOp { .. }
351            | Gate::PointPlaintextUnaryOp { .. }
352            | Gate::PointPlaintextBinaryOp { .. }
353            | Gate::GetDaBitFieldShare { .. }
354            | Gate::GetDaBitSharedBit { .. }
355            | Gate::BitPlaintextToField { .. }
356            | Gate::FieldPlaintextToBit { .. }
357            | Gate::BatchGetIndex { .. }
358            | Gate::CollectToBatch { .. }
359            | Gate::PointFromPlaintextExtendedEdwards { .. }
360            | Gate::PlaintextPointToExtendedEdwards { .. }
361            | Gate::PlaintextKeccakF1600 { .. }
362            | Gate::CompressPlaintextPoint { .. }
363            | Gate::FieldShareBinaryOp {
364                y_form: ShareOrPlaintext::Plaintext,
365                ..
366            }
367            | Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
368        };
369    }
370}
371
372#[derive(Debug, Clone, Default, PartialEq, Eq)]
373pub struct CircuitPreprocessing {
374    pub scalar_singlets: usize,
375    pub scalar_triples: usize,
376    pub base_field_singlets: usize,
377    pub base_field_triples: usize,
378    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
379    pub bit_singlets: usize,
380    pub bit_triples: usize,
381    pub mersenne107_dabits: usize,
382    pub mersenne107_singlets: usize,
383    pub mersenne107_triples: usize,
384    pub scalar_dabits: usize,
385    pub base_field_dabits: usize,
386}
387
388impl Add for CircuitPreprocessing {
389    type Output = Self;
390
391    fn add(self, other: Self) -> Self::Output {
392        Self {
393            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
394            scalar_triples: self.scalar_triples + other.scalar_triples,
395            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
396            base_field_triples: self.base_field_triples + other.base_field_triples,
397            bit_singlets: self.bit_singlets + other.bit_singlets,
398            bit_triples: self.bit_triples + other.bit_triples,
399            mersenne107_dabits: self.mersenne107_dabits + other.mersenne107_dabits,
400            mersenne107_singlets: self.mersenne107_singlets + other.mersenne107_singlets,
401            mersenne107_triples: self.mersenne107_triples + other.mersenne107_triples,
402            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
403            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
404            base_field_pow_pairs: {
405                let mut combined = self.base_field_pow_pairs;
406                for (k, v) in other.base_field_pow_pairs {
407                    *combined.entry(k).or_insert(0) += v;
408                }
409                combined
410            },
411        }
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use std::collections::HashSet;
418
419    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
420
421    use super::*;
422    use crate::circuit::FieldShareBinaryOp;
423
424    #[test]
425    fn test_ser_gate() {
426        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
427            x: Label::from(1, 2),
428            y: Label::from(3, 4),
429            y_form: ShareOrPlaintext::Share,
430            op: FieldShareBinaryOp::Add,
431            field_type: FieldType::ScalarField,
432        };
433        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
434            x: Label::from(1, 2),
435            y: Label::from(3, 4),
436            y_form: ShareOrPlaintext::Plaintext,
437            op: FieldShareBinaryOp::Add,
438            field_type: FieldType::ScalarField,
439        };
440        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
441            p: Label::from(1, 2),
442            y: Label::from(3, 4),
443            p_form: ShareOrPlaintext::Share,
444            y_form: ShareOrPlaintext::Plaintext,
445            op: PointShareBinaryOp::Add,
446        };
447
448        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
449        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
450        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
451
452        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
453        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
454        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
455
456        assert_eq!(no_curve_gate, no_curve_gate_de);
457        assert_eq!(scalar_gate, scalar_gate_de);
458        assert_eq!(point_gate, point_gate_de);
459        let set = HashSet::from([
460            no_curve_gate,
461            no_curve_gate_de,
462            scalar_gate,
463            scalar_gate_de,
464            point_gate,
465            point_gate_de,
466        ]);
467        assert_eq!(set.len(), 3)
468    }
469
470    #[test]
471    fn test_circuit_preprocessing_add() {
472        let a = CircuitPreprocessing {
473            scalar_singlets: 1,
474            scalar_triples: 2,
475            base_field_singlets: 3,
476            base_field_triples: 4,
477            bit_singlets: 0,
478            bit_triples: 1,
479            mersenne107_dabits: 0,
480            mersenne107_singlets: 0,
481            mersenne107_triples: 0,
482            scalar_dabits: 1,
483            base_field_dabits: 2,
484            base_field_pow_pairs: vec![
485                (BoxedUint::from(vec![21]), 5),
486                (BoxedUint::from(vec![14]), 6),
487            ]
488            .into_iter()
489            .collect(),
490        };
491        let b = CircuitPreprocessing {
492            scalar_singlets: 2,
493            scalar_triples: 3,
494            base_field_singlets: 0,
495            base_field_triples: 5,
496            bit_singlets: 3,
497            bit_triples: 4,
498            mersenne107_dabits: 0,
499            mersenne107_singlets: 3,
500            mersenne107_triples: 2,
501            scalar_dabits: 2,
502            base_field_dabits: 3,
503            base_field_pow_pairs: vec![
504                (BoxedUint::from(vec![21]), 6),
505                (BoxedUint::from(vec![13]), 7),
506            ]
507            .into_iter()
508            .collect(),
509        };
510
511        let c = a + b;
512
513        assert_eq!(c.scalar_singlets, 3);
514        assert_eq!(c.scalar_triples, 5);
515        assert_eq!(c.base_field_singlets, 3);
516        assert_eq!(c.base_field_triples, 9);
517        assert_eq!(c.bit_singlets, 3);
518        assert_eq!(c.bit_triples, 5);
519        assert_eq!(c.mersenne107_dabits, 0);
520        assert_eq!(c.mersenne107_singlets, 3);
521        assert_eq!(c.mersenne107_triples, 2);
522        assert_eq!(c.scalar_dabits, 3);
523        assert_eq!(c.base_field_dabits, 5);
524        assert_eq!(
525            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
526            Some(&11)
527        );
528        assert_eq!(
529            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
530            Some(&6)
531        );
532        assert_eq!(
533            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
534            Some(&7)
535        );
536    }
537}