core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5    elliptic_curve::{Curve, Point, Scalar},
6    BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    circuit::{
12        AlgebraicType,
13        Batched,
14        BitShareBinaryOp,
15        BitShareUnaryOp,
16        FieldPlaintextBinaryOp,
17        FieldPlaintextUnaryOp,
18        FieldShareBinaryOp,
19        FieldShareUnaryOp,
20        FieldType,
21        Input,
22        PointPlaintextBinaryOp,
23        PointPlaintextUnaryOp,
24        PointShareBinaryOp,
25        PointShareUnaryOp,
26        ShareOrPlaintext,
27    },
28    types::Label,
29};
30
31/// Gate operations, where the operation arguments correspond to _wire_ label.
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38    /// Input a wire
39    Input {
40        input_type: Input<C>,
41    },
42    /// Field share unary operations
43    FieldShareUnaryOp {
44        x: Label,
45        op: FieldShareUnaryOp,
46        field_type: FieldType,
47    },
48    /// Field share binary operations, where the second wire may be a plaintext.
49    FieldShareBinaryOp {
50        x: Label,
51        y: Label,
52        y_form: ShareOrPlaintext,
53        op: FieldShareBinaryOp,
54        field_type: FieldType,
55    },
56    BatchSummation {
57        x: Label,
58        x_form: ShareOrPlaintext,
59        algebraic_type: AlgebraicType,
60    },
61    BitShareUnaryOp {
62        x: Label,
63        op: BitShareUnaryOp,
64    },
65    BitShareBinaryOp {
66        x: Label,
67        y: Label,
68        y_form: ShareOrPlaintext,
69        op: BitShareBinaryOp,
70    },
71    /// Operations with elliptic curve points
72    PointShareUnaryOp {
73        p: Label,
74        op: PointShareUnaryOp,
75    },
76    PointShareBinaryOp {
77        p: Label,
78        y: Label,
79        p_form: ShareOrPlaintext,
80        y_form: ShareOrPlaintext,
81        op: PointShareBinaryOp,
82    },
83    /// Field plaintext unary operations
84    FieldPlaintextUnaryOp {
85        x: Label,
86        op: FieldPlaintextUnaryOp,
87        field_type: FieldType,
88    },
89    /// Field plaintext binary operations
90    FieldPlaintextBinaryOp {
91        x: Label,
92        y: Label,
93        op: FieldPlaintextBinaryOp,
94        field_type: FieldType,
95    },
96    BitPlaintextUnaryOp {
97        x: Label,
98        op: FieldPlaintextUnaryOp,
99    },
100    BitPlaintextBinaryOp {
101        x: Label,
102        y: Label,
103        op: FieldPlaintextBinaryOp,
104    },
105    PointPlaintextUnaryOp {
106        p: Label,
107        op: PointPlaintextUnaryOp,
108    },
109    PointPlaintextBinaryOp {
110        p: Label,
111        y: Label,
112        op: PointPlaintextBinaryOp,
113    },
114    /// Request a daBit
115    DaBit {
116        field_type: FieldType,
117        batched: Batched,
118    },
119    GetDaBitFieldShare {
120        x: Label,
121        field_type: FieldType,
122    },
123    GetDaBitSharedBit {
124        x: Label,
125        field_type: FieldType,
126    },
127    /// ElGamal Encryption
128    EncryptPoint {
129        x: Label,
130        c: Point<C>,
131    },
132    DecryptPoint {
133        x: Label,
134        y: Label,
135    },
136    /// Base field exponentiation operation
137    BaseFieldPow {
138        x: Label,
139        exp: BoxedUint,
140    },
141    /// Bit plaintext conversion operations
142    BitPlaintextToField {
143        x: Label,
144        field_type: FieldType,
145    },
146    FieldPlaintextToBit {
147        x: Label,
148        field_type: FieldType,
149    },
150    /// Get the element at a certain index of a batched wire
151    BatchGetIndex {
152        x: Label,
153        x_type: AlgebraicType,
154        x_form: ShareOrPlaintext,
155        index: usize,
156    },
157    CollectToBatch {
158        wires: Vec<Label>,
159        x_type: AlgebraicType,
160        x_form: ShareOrPlaintext,
161    },
162    PointFromPlaintextExtendedEdwards {
163        wires: Vec<Label>,
164    },
165    PlaintextPointToExtendedEdwards {
166        point: Label,
167    },
168    PlaintextKeccakF1600 {
169        wires: Vec<Label>,
170    },
171}
172
173impl<C: Curve> Gate<C> {
174    /// Gathers all the `Label` inside the gate.
175    pub fn get_labels(&self) -> Vec<Label> {
176        let mut labels = Vec::new();
177        self.for_each_label(|label| labels.push(label));
178        labels
179    }
180}
181
182#[derive(Debug, Clone, Default, PartialEq, Eq)]
183pub struct CircuitPreprocessing {
184    pub scalar_singlets: usize,
185    pub scalar_triples: usize,
186    pub base_field_singlets: usize,
187    pub base_field_triples: usize,
188    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
189    pub bit_singlets: usize,
190    pub bit_triples: usize,
191    pub mersenne107_dabits: usize,
192    pub mersenne107_singlets: usize,
193    pub mersenne107_triples: usize,
194    pub scalar_dabits: usize,
195    pub base_field_dabits: usize,
196}
197
198impl Add for CircuitPreprocessing {
199    type Output = Self;
200
201    fn add(self, other: Self) -> Self::Output {
202        Self {
203            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
204            scalar_triples: self.scalar_triples + other.scalar_triples,
205            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
206            base_field_triples: self.base_field_triples + other.base_field_triples,
207            bit_singlets: self.bit_singlets + other.bit_singlets,
208            bit_triples: self.bit_triples + other.bit_triples,
209            mersenne107_dabits: self.mersenne107_dabits + other.mersenne107_dabits,
210            mersenne107_singlets: self.mersenne107_singlets + other.mersenne107_singlets,
211            mersenne107_triples: self.mersenne107_triples + other.mersenne107_triples,
212            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
213            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
214            base_field_pow_pairs: {
215                let mut combined = self.base_field_pow_pairs;
216                for (k, v) in other.base_field_pow_pairs {
217                    *combined.entry(k).or_insert(0) += v;
218                }
219                combined
220            },
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use std::collections::HashSet;
228
229    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
230
231    use super::*;
232    use crate::circuit::FieldShareBinaryOp;
233
234    #[test]
235    fn test_ser_gate() {
236        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
237            x: Label::from(1, 2),
238            y: Label::from(3, 4),
239            y_form: ShareOrPlaintext::Share,
240            op: FieldShareBinaryOp::Add,
241            field_type: FieldType::ScalarField,
242        };
243        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
244            x: Label::from(1, 2),
245            y: Label::from(3, 4),
246            y_form: ShareOrPlaintext::Plaintext,
247            op: FieldShareBinaryOp::Add,
248            field_type: FieldType::ScalarField,
249        };
250        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
251            p: Label::from(1, 2),
252            y: Label::from(3, 4),
253            p_form: ShareOrPlaintext::Share,
254            y_form: ShareOrPlaintext::Plaintext,
255            op: PointShareBinaryOp::Add,
256        };
257
258        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
259        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
260        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
261
262        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
263        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
264        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
265
266        assert_eq!(no_curve_gate, no_curve_gate_de);
267        assert_eq!(scalar_gate, scalar_gate_de);
268        assert_eq!(point_gate, point_gate_de);
269        let set = HashSet::from([
270            no_curve_gate,
271            no_curve_gate_de,
272            scalar_gate,
273            scalar_gate_de,
274            point_gate,
275            point_gate_de,
276        ]);
277        assert_eq!(set.len(), 3)
278    }
279
280    #[test]
281    fn test_circuit_preprocessing_add() {
282        let a = CircuitPreprocessing {
283            scalar_singlets: 1,
284            scalar_triples: 2,
285            base_field_singlets: 3,
286            base_field_triples: 4,
287            bit_singlets: 0,
288            bit_triples: 1,
289            mersenne107_dabits: 0,
290            mersenne107_singlets: 0,
291            mersenne107_triples: 0,
292            scalar_dabits: 1,
293            base_field_dabits: 2,
294            base_field_pow_pairs: vec![
295                (BoxedUint::from(vec![21]), 5),
296                (BoxedUint::from(vec![14]), 6),
297            ]
298            .into_iter()
299            .collect(),
300        };
301        let b = CircuitPreprocessing {
302            scalar_singlets: 2,
303            scalar_triples: 3,
304            base_field_singlets: 0,
305            base_field_triples: 5,
306            bit_singlets: 3,
307            bit_triples: 4,
308            mersenne107_dabits: 0,
309            mersenne107_singlets: 3,
310            mersenne107_triples: 2,
311            scalar_dabits: 2,
312            base_field_dabits: 3,
313            base_field_pow_pairs: vec![
314                (BoxedUint::from(vec![21]), 6),
315                (BoxedUint::from(vec![13]), 7),
316            ]
317            .into_iter()
318            .collect(),
319        };
320
321        let c = a + b;
322
323        assert_eq!(c.scalar_singlets, 3);
324        assert_eq!(c.scalar_triples, 5);
325        assert_eq!(c.base_field_singlets, 3);
326        assert_eq!(c.base_field_triples, 9);
327        assert_eq!(c.bit_singlets, 3);
328        assert_eq!(c.bit_triples, 5);
329        assert_eq!(c.mersenne107_dabits, 0);
330        assert_eq!(c.mersenne107_singlets, 3);
331        assert_eq!(c.mersenne107_triples, 2);
332        assert_eq!(c.scalar_dabits, 3);
333        assert_eq!(c.base_field_dabits, 5);
334        assert_eq!(
335            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
336            Some(&11)
337        );
338        assert_eq!(
339            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
340            Some(&6)
341        );
342        assert_eq!(
343            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
344            Some(&7)
345        );
346    }
347}