core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5    elliptic_curve::{Curve, Point, Scalar},
6    BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    circuit::{
12        AlgebraicType,
13        Batched,
14        BitShareBinaryOp,
15        BitShareUnaryOp,
16        FieldPlaintextBinaryOp,
17        FieldPlaintextUnaryOp,
18        FieldShareBinaryOp,
19        FieldShareUnaryOp,
20        FieldType,
21        Input,
22        PointPlaintextBinaryOp,
23        PointPlaintextUnaryOp,
24        PointShareBinaryOp,
25        PointShareUnaryOp,
26        ShareOrPlaintext,
27    },
28    types::Label,
29};
30
31/// Gate operations, where the operation arguments correspond to _wire_ label.
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38    /// Input a wire
39    Input {
40        input_type: Input<C>,
41    },
42    /// Field share unary operations
43    FieldShareUnaryOp {
44        x: Label,
45        op: FieldShareUnaryOp,
46        field_type: FieldType,
47    },
48    /// Field share binary operations, where the second wire may be a plaintext.
49    FieldShareBinaryOp {
50        x: Label,
51        y: Label,
52        y_form: ShareOrPlaintext,
53        op: FieldShareBinaryOp,
54        field_type: FieldType,
55    },
56    BatchSummation {
57        x: Label,
58        x_form: ShareOrPlaintext,
59        algebraic_type: AlgebraicType,
60    },
61    BitShareUnaryOp {
62        x: Label,
63        op: BitShareUnaryOp,
64    },
65    BitShareBinaryOp {
66        x: Label,
67        y: Label,
68        y_form: ShareOrPlaintext,
69        op: BitShareBinaryOp,
70    },
71    /// Operations with elliptic curve points
72    PointShareUnaryOp {
73        p: Label,
74        op: PointShareUnaryOp,
75    },
76    PointShareBinaryOp {
77        p: Label,
78        y: Label,
79        p_form: ShareOrPlaintext,
80        y_form: ShareOrPlaintext,
81        op: PointShareBinaryOp,
82    },
83    /// Field plaintext unary operations
84    FieldPlaintextUnaryOp {
85        x: Label,
86        op: FieldPlaintextUnaryOp,
87        field_type: FieldType,
88    },
89    /// Field plaintext binary operations
90    FieldPlaintextBinaryOp {
91        x: Label,
92        y: Label,
93        op: FieldPlaintextBinaryOp,
94        field_type: FieldType,
95    },
96    BitPlaintextUnaryOp {
97        x: Label,
98        op: FieldPlaintextUnaryOp,
99    },
100    BitPlaintextBinaryOp {
101        x: Label,
102        y: Label,
103        op: FieldPlaintextBinaryOp,
104    },
105    PointPlaintextUnaryOp {
106        p: Label,
107        op: PointPlaintextUnaryOp,
108    },
109    PointPlaintextBinaryOp {
110        p: Label,
111        y: Label,
112        op: PointPlaintextBinaryOp,
113    },
114    /// Request a daBit
115    DaBit {
116        field_type: FieldType,
117        batched: Batched,
118    },
119    GetDaBitFieldShare {
120        x: Label,
121        field_type: FieldType,
122    },
123    GetDaBitSharedBit {
124        x: Label,
125        field_type: FieldType,
126    },
127    /// Base field exponentiation operation
128    BaseFieldPow {
129        x: Label,
130        exp: BoxedUint,
131    },
132    /// Bit plaintext conversion operations
133    BitPlaintextToField {
134        x: Label,
135        field_type: FieldType,
136    },
137    FieldPlaintextToBit {
138        x: Label,
139        field_type: FieldType,
140    },
141    /// Get the element at a certain index of a batched wire
142    BatchGetIndex {
143        x: Label,
144        x_type: AlgebraicType,
145        x_form: ShareOrPlaintext,
146        index: usize,
147    },
148    CollectToBatch {
149        wires: Vec<Label>,
150        x_type: AlgebraicType,
151        x_form: ShareOrPlaintext,
152    },
153    PointFromPlaintextExtendedEdwards {
154        wires: Vec<Label>,
155    },
156    PlaintextPointToExtendedEdwards {
157        point: Label,
158    },
159    PlaintextKeccakF1600 {
160        wires: Vec<Label>,
161    },
162}
163
164impl<C: Curve> Gate<C> {
165    /// Gathers all the `Label` inside the gate.
166    pub fn get_labels(&self) -> Vec<Label> {
167        let mut labels = Vec::new();
168        self.for_each_label(|label| labels.push(label));
169        labels
170    }
171}
172
173#[derive(Debug, Clone, Default, PartialEq, Eq)]
174pub struct CircuitPreprocessing {
175    pub scalar_singlets: usize,
176    pub scalar_triples: usize,
177    pub base_field_singlets: usize,
178    pub base_field_triples: usize,
179    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
180    pub bit_singlets: usize,
181    pub bit_triples: usize,
182    pub mersenne107_dabits: usize,
183    pub mersenne107_singlets: usize,
184    pub mersenne107_triples: usize,
185    pub scalar_dabits: usize,
186    pub base_field_dabits: usize,
187}
188
189impl Add for CircuitPreprocessing {
190    type Output = Self;
191
192    fn add(self, other: Self) -> Self::Output {
193        Self {
194            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
195            scalar_triples: self.scalar_triples + other.scalar_triples,
196            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
197            base_field_triples: self.base_field_triples + other.base_field_triples,
198            bit_singlets: self.bit_singlets + other.bit_singlets,
199            bit_triples: self.bit_triples + other.bit_triples,
200            mersenne107_dabits: self.mersenne107_dabits + other.mersenne107_dabits,
201            mersenne107_singlets: self.mersenne107_singlets + other.mersenne107_singlets,
202            mersenne107_triples: self.mersenne107_triples + other.mersenne107_triples,
203            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
204            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
205            base_field_pow_pairs: {
206                let mut combined = self.base_field_pow_pairs;
207                for (k, v) in other.base_field_pow_pairs {
208                    *combined.entry(k).or_insert(0) += v;
209                }
210                combined
211            },
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use std::collections::HashSet;
219
220    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
221
222    use super::*;
223    use crate::circuit::FieldShareBinaryOp;
224
225    #[test]
226    fn test_ser_gate() {
227        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
228            x: Label::from(1, 2),
229            y: Label::from(3, 4),
230            y_form: ShareOrPlaintext::Share,
231            op: FieldShareBinaryOp::Add,
232            field_type: FieldType::ScalarField,
233        };
234        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
235            x: Label::from(1, 2),
236            y: Label::from(3, 4),
237            y_form: ShareOrPlaintext::Plaintext,
238            op: FieldShareBinaryOp::Add,
239            field_type: FieldType::ScalarField,
240        };
241        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
242            p: Label::from(1, 2),
243            y: Label::from(3, 4),
244            p_form: ShareOrPlaintext::Share,
245            y_form: ShareOrPlaintext::Plaintext,
246            op: PointShareBinaryOp::Add,
247        };
248
249        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
250        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
251        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
252
253        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
254        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
255        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
256
257        assert_eq!(no_curve_gate, no_curve_gate_de);
258        assert_eq!(scalar_gate, scalar_gate_de);
259        assert_eq!(point_gate, point_gate_de);
260        let set = HashSet::from([
261            no_curve_gate,
262            no_curve_gate_de,
263            scalar_gate,
264            scalar_gate_de,
265            point_gate,
266            point_gate_de,
267        ]);
268        assert_eq!(set.len(), 3)
269    }
270
271    #[test]
272    fn test_circuit_preprocessing_add() {
273        let a = CircuitPreprocessing {
274            scalar_singlets: 1,
275            scalar_triples: 2,
276            base_field_singlets: 3,
277            base_field_triples: 4,
278            bit_singlets: 0,
279            bit_triples: 1,
280            mersenne107_dabits: 0,
281            mersenne107_singlets: 0,
282            mersenne107_triples: 0,
283            scalar_dabits: 1,
284            base_field_dabits: 2,
285            base_field_pow_pairs: vec![
286                (BoxedUint::from(vec![21]), 5),
287                (BoxedUint::from(vec![14]), 6),
288            ]
289            .into_iter()
290            .collect(),
291        };
292        let b = CircuitPreprocessing {
293            scalar_singlets: 2,
294            scalar_triples: 3,
295            base_field_singlets: 0,
296            base_field_triples: 5,
297            bit_singlets: 3,
298            bit_triples: 4,
299            mersenne107_dabits: 0,
300            mersenne107_singlets: 3,
301            mersenne107_triples: 2,
302            scalar_dabits: 2,
303            base_field_dabits: 3,
304            base_field_pow_pairs: vec![
305                (BoxedUint::from(vec![21]), 6),
306                (BoxedUint::from(vec![13]), 7),
307            ]
308            .into_iter()
309            .collect(),
310        };
311
312        let c = a + b;
313
314        assert_eq!(c.scalar_singlets, 3);
315        assert_eq!(c.scalar_triples, 5);
316        assert_eq!(c.base_field_singlets, 3);
317        assert_eq!(c.base_field_triples, 9);
318        assert_eq!(c.bit_singlets, 3);
319        assert_eq!(c.bit_triples, 5);
320        assert_eq!(c.mersenne107_dabits, 0);
321        assert_eq!(c.mersenne107_singlets, 3);
322        assert_eq!(c.mersenne107_triples, 2);
323        assert_eq!(c.scalar_dabits, 3);
324        assert_eq!(c.base_field_dabits, 5);
325        assert_eq!(
326            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
327            Some(&11)
328        );
329        assert_eq!(
330            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
331            Some(&6)
332        );
333        assert_eq!(
334            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
335            Some(&7)
336        );
337    }
338}