core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5    elliptic_curve::{Curve, Point, Scalar},
6    BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    circuit::{
12        AlgebraicType,
13        Batched,
14        BitShareBinaryOp,
15        BitShareUnaryOp,
16        FieldPlaintextBinaryOp,
17        FieldPlaintextUnaryOp,
18        FieldShareBinaryOp,
19        FieldShareUnaryOp,
20        FieldType,
21        Input,
22        PointPlaintextBinaryOp,
23        PointPlaintextUnaryOp,
24        PointShareBinaryOp,
25        PointShareUnaryOp,
26        ShareOrPlaintext,
27    },
28    types::Label,
29};
30
31/// Gate operations, where the operation arguments correspond to _wire_ label.
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38    /// Input a wire
39    Input {
40        input_type: Input<C>,
41    },
42    /// Field share unary operations
43    FieldShareUnaryOp {
44        x: Label,
45        op: FieldShareUnaryOp,
46        field_type: FieldType,
47    },
48    /// Field share binary operations, where the second wire may be a plaintext.
49    FieldShareBinaryOp {
50        x: Label,
51        y: Label,
52        y_form: ShareOrPlaintext,
53        op: FieldShareBinaryOp,
54        field_type: FieldType,
55    },
56    BatchSummation {
57        x: Label,
58        x_form: ShareOrPlaintext,
59        algebraic_type: AlgebraicType,
60    },
61    BitShareUnaryOp {
62        x: Label,
63        op: BitShareUnaryOp,
64    },
65    BitShareBinaryOp {
66        x: Label,
67        y: Label,
68        y_form: ShareOrPlaintext,
69        op: BitShareBinaryOp,
70    },
71    /// Operations with elliptic curve points
72    PointShareUnaryOp {
73        p: Label,
74        op: PointShareUnaryOp,
75    },
76    PointShareBinaryOp {
77        p: Label,
78        y: Label,
79        p_form: ShareOrPlaintext,
80        y_form: ShareOrPlaintext,
81        op: PointShareBinaryOp,
82    },
83    /// Field plaintext unary operations
84    FieldPlaintextUnaryOp {
85        x: Label,
86        op: FieldPlaintextUnaryOp,
87        field_type: FieldType,
88    },
89    /// Field plaintext binary operations
90    FieldPlaintextBinaryOp {
91        x: Label,
92        y: Label,
93        op: FieldPlaintextBinaryOp,
94        field_type: FieldType,
95    },
96    BitPlaintextUnaryOp {
97        x: Label,
98        op: FieldPlaintextUnaryOp,
99    },
100    BitPlaintextBinaryOp {
101        x: Label,
102        y: Label,
103        op: FieldPlaintextBinaryOp,
104    },
105    PointPlaintextUnaryOp {
106        p: Label,
107        op: PointPlaintextUnaryOp,
108    },
109    PointPlaintextBinaryOp {
110        p: Label,
111        y: Label,
112        op: PointPlaintextBinaryOp,
113    },
114    /// Request a daBit
115    DaBit {
116        field_type: FieldType,
117        batched: Batched,
118    },
119    GetDaBitFieldShare {
120        x: Label,
121        field_type: FieldType,
122    },
123    GetDaBitSharedBit {
124        x: Label,
125        field_type: FieldType,
126    },
127    /// Base field exponentiation operation
128    BaseFieldPow {
129        x: Label,
130        exp: BoxedUint,
131    },
132    /// Bit plaintext conversion operations
133    BitPlaintextToField {
134        x: Label,
135        field_type: FieldType,
136    },
137    FieldPlaintextToBit {
138        x: Label,
139        field_type: FieldType,
140    },
141    /// Get the element at a certain index of a batched wire
142    BatchGetIndex {
143        x: Label,
144        x_type: AlgebraicType,
145        x_form: ShareOrPlaintext,
146        index: usize,
147    },
148    CollectToBatch {
149        wires: Vec<Label>,
150        x_type: AlgebraicType,
151        x_form: ShareOrPlaintext,
152    },
153    PointFromPlaintextExtendedEdwards {
154        wires: Vec<Label>,
155    },
156    PlaintextPointToExtendedEdwards {
157        point: Label,
158    },
159    PlaintextKeccakF1600 {
160        wires: Vec<Label>,
161    },
162    CompressPlaintextPoint {
163        point: Label,
164    },
165}
166
167impl<C: Curve> Gate<C> {
168    /// Gathers all the `Label` inside the gate.
169    pub fn get_labels(&self) -> Vec<Label> {
170        let mut labels = Vec::new();
171        self.for_each_label(|label| labels.push(label));
172        labels
173    }
174}
175
176#[derive(Debug, Clone, Default, PartialEq, Eq)]
177pub struct CircuitPreprocessing {
178    pub scalar_singlets: usize,
179    pub scalar_triples: usize,
180    pub base_field_singlets: usize,
181    pub base_field_triples: usize,
182    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
183    pub bit_singlets: usize,
184    pub bit_triples: usize,
185    pub mersenne107_dabits: usize,
186    pub mersenne107_singlets: usize,
187    pub mersenne107_triples: usize,
188    pub scalar_dabits: usize,
189    pub base_field_dabits: usize,
190}
191
192impl Add for CircuitPreprocessing {
193    type Output = Self;
194
195    fn add(self, other: Self) -> Self::Output {
196        Self {
197            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
198            scalar_triples: self.scalar_triples + other.scalar_triples,
199            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
200            base_field_triples: self.base_field_triples + other.base_field_triples,
201            bit_singlets: self.bit_singlets + other.bit_singlets,
202            bit_triples: self.bit_triples + other.bit_triples,
203            mersenne107_dabits: self.mersenne107_dabits + other.mersenne107_dabits,
204            mersenne107_singlets: self.mersenne107_singlets + other.mersenne107_singlets,
205            mersenne107_triples: self.mersenne107_triples + other.mersenne107_triples,
206            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
207            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
208            base_field_pow_pairs: {
209                let mut combined = self.base_field_pow_pairs;
210                for (k, v) in other.base_field_pow_pairs {
211                    *combined.entry(k).or_insert(0) += v;
212                }
213                combined
214            },
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use std::collections::HashSet;
222
223    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
224
225    use super::*;
226    use crate::circuit::FieldShareBinaryOp;
227
228    #[test]
229    fn test_ser_gate() {
230        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
231            x: Label::from(1, 2),
232            y: Label::from(3, 4),
233            y_form: ShareOrPlaintext::Share,
234            op: FieldShareBinaryOp::Add,
235            field_type: FieldType::ScalarField,
236        };
237        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
238            x: Label::from(1, 2),
239            y: Label::from(3, 4),
240            y_form: ShareOrPlaintext::Plaintext,
241            op: FieldShareBinaryOp::Add,
242            field_type: FieldType::ScalarField,
243        };
244        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
245            p: Label::from(1, 2),
246            y: Label::from(3, 4),
247            p_form: ShareOrPlaintext::Share,
248            y_form: ShareOrPlaintext::Plaintext,
249            op: PointShareBinaryOp::Add,
250        };
251
252        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
253        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
254        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
255
256        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
257        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
258        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
259
260        assert_eq!(no_curve_gate, no_curve_gate_de);
261        assert_eq!(scalar_gate, scalar_gate_de);
262        assert_eq!(point_gate, point_gate_de);
263        let set = HashSet::from([
264            no_curve_gate,
265            no_curve_gate_de,
266            scalar_gate,
267            scalar_gate_de,
268            point_gate,
269            point_gate_de,
270        ]);
271        assert_eq!(set.len(), 3)
272    }
273
274    #[test]
275    fn test_circuit_preprocessing_add() {
276        let a = CircuitPreprocessing {
277            scalar_singlets: 1,
278            scalar_triples: 2,
279            base_field_singlets: 3,
280            base_field_triples: 4,
281            bit_singlets: 0,
282            bit_triples: 1,
283            mersenne107_dabits: 0,
284            mersenne107_singlets: 0,
285            mersenne107_triples: 0,
286            scalar_dabits: 1,
287            base_field_dabits: 2,
288            base_field_pow_pairs: vec![
289                (BoxedUint::from(vec![21]), 5),
290                (BoxedUint::from(vec![14]), 6),
291            ]
292            .into_iter()
293            .collect(),
294        };
295        let b = CircuitPreprocessing {
296            scalar_singlets: 2,
297            scalar_triples: 3,
298            base_field_singlets: 0,
299            base_field_triples: 5,
300            bit_singlets: 3,
301            bit_triples: 4,
302            mersenne107_dabits: 0,
303            mersenne107_singlets: 3,
304            mersenne107_triples: 2,
305            scalar_dabits: 2,
306            base_field_dabits: 3,
307            base_field_pow_pairs: vec![
308                (BoxedUint::from(vec![21]), 6),
309                (BoxedUint::from(vec![13]), 7),
310            ]
311            .into_iter()
312            .collect(),
313        };
314
315        let c = a + b;
316
317        assert_eq!(c.scalar_singlets, 3);
318        assert_eq!(c.scalar_triples, 5);
319        assert_eq!(c.base_field_singlets, 3);
320        assert_eq!(c.base_field_triples, 9);
321        assert_eq!(c.bit_singlets, 3);
322        assert_eq!(c.bit_triples, 5);
323        assert_eq!(c.mersenne107_dabits, 0);
324        assert_eq!(c.mersenne107_singlets, 3);
325        assert_eq!(c.mersenne107_triples, 2);
326        assert_eq!(c.scalar_dabits, 3);
327        assert_eq!(c.base_field_dabits, 5);
328        assert_eq!(
329            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
330            Some(&11)
331        );
332        assert_eq!(
333            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
334            Some(&6)
335        );
336        assert_eq!(
337            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
338            Some(&7)
339        );
340    }
341}