core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5    elliptic_curve::{Curve, Point, Scalar},
6    BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    circuit::{
12        AlgebraicType,
13        Batched,
14        BitShareBinaryOp,
15        BitShareUnaryOp,
16        FieldPlaintextBinaryOp,
17        FieldPlaintextUnaryOp,
18        FieldShareBinaryOp,
19        FieldShareUnaryOp,
20        FieldType,
21        Input,
22        PointPlaintextBinaryOp,
23        PointPlaintextUnaryOp,
24        PointShareBinaryOp,
25        PointShareUnaryOp,
26        ShareOrPlaintext,
27    },
28    types::Label,
29};
30
31/// Gate operations, where the operation arguments correspond to _wire_ label.
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38    /// Input a wire
39    Input {
40        input_type: Input<C>,
41    },
42    /// Field share unary operations
43    FieldShareUnaryOp {
44        x: Label,
45        op: FieldShareUnaryOp,
46        field_type: FieldType,
47    },
48    /// Field share binary operations, where the second wire may be a plaintext.
49    FieldShareBinaryOp {
50        x: Label,
51        y: Label,
52        y_form: ShareOrPlaintext,
53        op: FieldShareBinaryOp,
54        field_type: FieldType,
55    },
56    BatchSummation {
57        x: Label,
58        x_form: ShareOrPlaintext,
59        algebraic_type: AlgebraicType,
60    },
61    BitShareUnaryOp {
62        x: Label,
63        op: BitShareUnaryOp,
64    },
65    BitShareBinaryOp {
66        x: Label,
67        y: Label,
68        y_form: ShareOrPlaintext,
69        op: BitShareBinaryOp,
70    },
71    /// Operations with elliptic curve points
72    PointShareUnaryOp {
73        p: Label,
74        op: PointShareUnaryOp,
75    },
76    PointShareBinaryOp {
77        p: Label,
78        y: Label,
79        p_form: ShareOrPlaintext,
80        y_form: ShareOrPlaintext,
81        op: PointShareBinaryOp,
82    },
83    /// Field plaintext unary operations
84    FieldPlaintextUnaryOp {
85        x: Label,
86        op: FieldPlaintextUnaryOp,
87        field_type: FieldType,
88    },
89    /// Field plaintext binary operations
90    FieldPlaintextBinaryOp {
91        x: Label,
92        y: Label,
93        op: FieldPlaintextBinaryOp,
94        field_type: FieldType,
95    },
96    BitPlaintextUnaryOp {
97        x: Label,
98        op: FieldPlaintextUnaryOp,
99    },
100    BitPlaintextBinaryOp {
101        x: Label,
102        y: Label,
103        op: FieldPlaintextBinaryOp,
104    },
105    PointPlaintextUnaryOp {
106        p: Label,
107        op: PointPlaintextUnaryOp,
108    },
109    PointPlaintextBinaryOp {
110        p: Label,
111        y: Label,
112        op: PointPlaintextBinaryOp,
113    },
114    /// Request a daBit
115    DaBit {
116        field_type: FieldType,
117        batched: Batched,
118    },
119    GetDaBitFieldShare {
120        x: Label,
121        field_type: FieldType,
122    },
123    GetDaBitSharedBit {
124        x: Label,
125        field_type: FieldType,
126    },
127    /// ElGamal Encryption
128    EncryptPoint {
129        x: Label,
130        c: Point<C>,
131    },
132    DecryptPoint {
133        x: Label,
134        y: Label,
135    },
136    /// Base field exponentiation operation
137    BaseFieldPow {
138        x: Label,
139        exp: BoxedUint,
140    },
141    /// Bit plaintext conversion operations
142    BitPlaintextToField {
143        x: Label,
144        field_type: FieldType,
145    },
146    FieldPlaintextToBit {
147        x: Label,
148        field_type: FieldType,
149    },
150    /// Get the element at a certain index of a batched wire
151    BatchGetIndex {
152        x: Label,
153        x_type: AlgebraicType,
154        x_form: ShareOrPlaintext,
155        index: usize,
156    },
157    CollectToBatch {
158        wires: Vec<Label>,
159        x_type: AlgebraicType,
160        x_form: ShareOrPlaintext,
161    },
162    PointFromPlaintextExtendedEdwardsUnchecked {
163        wires: Vec<Label>,
164    },
165    PlaintextPointToExtendedEdwards {
166        point: Label,
167    },
168    PlaintextKeccakF1600 {
169        wires: Vec<Label>,
170    },
171}
172
173impl<C: Curve> Gate<C> {
174    /// Gathers all the `Label` inside the gate.
175    pub fn get_labels(&self) -> Vec<Label> {
176        let mut labels = Vec::new();
177        self.for_each_label(|label| labels.push(label));
178        labels
179    }
180}
181
182#[derive(Debug, Clone, Default, PartialEq, Eq)]
183pub struct CircuitPreprocessing {
184    pub scalar_singlets: usize,
185    pub scalar_triples: usize,
186    pub base_field_singlets: usize,
187    pub base_field_triples: usize,
188    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
189    pub bit_singlets: usize,
190    pub bit_triples: usize,
191    pub scalar_dabits: usize,
192    pub base_field_dabits: usize,
193}
194
195impl Add for CircuitPreprocessing {
196    type Output = Self;
197
198    fn add(self, other: Self) -> Self::Output {
199        Self {
200            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
201            scalar_triples: self.scalar_triples + other.scalar_triples,
202            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
203            base_field_triples: self.base_field_triples + other.base_field_triples,
204            bit_singlets: self.bit_singlets + other.bit_singlets,
205            bit_triples: self.bit_triples + other.bit_triples,
206            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
207            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
208            base_field_pow_pairs: {
209                let mut combined = self.base_field_pow_pairs;
210                for (k, v) in other.base_field_pow_pairs {
211                    *combined.entry(k).or_insert(0) += v;
212                }
213                combined
214            },
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
222
223    use super::*;
224    use crate::circuit::FieldShareBinaryOp;
225
226    #[test]
227    fn test_ser_gate() {
228        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
229            x: Label::from(1, 2),
230            y: Label::from(3, 4),
231            y_form: ShareOrPlaintext::Share,
232            op: FieldShareBinaryOp::Add,
233            field_type: FieldType::ScalarField,
234        };
235        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
236            x: Label::from(1, 2),
237            y: Label::from(3, 4),
238            y_form: ShareOrPlaintext::Plaintext,
239            op: FieldShareBinaryOp::Add,
240            field_type: FieldType::ScalarField,
241        };
242        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
243            p: Label::from(1, 2),
244            y: Label::from(3, 4),
245            p_form: ShareOrPlaintext::Share,
246            y_form: ShareOrPlaintext::Plaintext,
247            op: PointShareBinaryOp::Add,
248        };
249
250        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
251        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
252        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
253
254        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
255        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
256        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
257
258        assert_eq!(no_curve_gate, no_curve_gate_de);
259        assert_eq!(scalar_gate, scalar_gate_de);
260        assert_eq!(point_gate, point_gate_de);
261    }
262
263    #[test]
264    fn test_circuit_preprocessing_add() {
265        let a = CircuitPreprocessing {
266            scalar_singlets: 1,
267            scalar_triples: 2,
268            base_field_singlets: 3,
269            base_field_triples: 4,
270            bit_singlets: 0,
271            bit_triples: 1,
272            scalar_dabits: 1,
273            base_field_dabits: 2,
274            base_field_pow_pairs: vec![
275                (BoxedUint::from(vec![21]), 5),
276                (BoxedUint::from(vec![14]), 6),
277            ]
278            .into_iter()
279            .collect(),
280        };
281        let b = CircuitPreprocessing {
282            scalar_singlets: 2,
283            scalar_triples: 3,
284            base_field_singlets: 0,
285            base_field_triples: 5,
286            bit_singlets: 3,
287            bit_triples: 4,
288            scalar_dabits: 2,
289            base_field_dabits: 3,
290            base_field_pow_pairs: vec![
291                (BoxedUint::from(vec![21]), 6),
292                (BoxedUint::from(vec![13]), 7),
293            ]
294            .into_iter()
295            .collect(),
296        };
297
298        let c = a + b;
299
300        assert_eq!(c.scalar_singlets, 3);
301        assert_eq!(c.scalar_triples, 5);
302        assert_eq!(c.base_field_singlets, 3);
303        assert_eq!(c.base_field_triples, 9);
304        assert_eq!(c.bit_singlets, 3);
305        assert_eq!(c.bit_triples, 5);
306        assert_eq!(c.scalar_dabits, 3);
307        assert_eq!(c.base_field_dabits, 5);
308        assert_eq!(
309            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
310            Some(&11)
311        );
312        assert_eq!(
313            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
314            Some(&6)
315        );
316        assert_eq!(
317            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
318            Some(&7)
319        );
320    }
321}