core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use primitives::algebra::{
4    elliptic_curve::{Curve, Point, Scalar},
5    BoxedUint,
6};
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    circuit::{
11        AlgebraicType,
12        Batched,
13        BitShareBinaryOp,
14        BitShareUnaryOp,
15        FieldPlaintextBinaryOp,
16        FieldPlaintextUnaryOp,
17        FieldShareBinaryOp,
18        FieldShareUnaryOp,
19        FieldType,
20        Input,
21        PointPlaintextBinaryOp,
22        PointPlaintextUnaryOp,
23        PointShareBinaryOp,
24        PointShareUnaryOp,
25        ShareOrPlaintext,
26    },
27    types::Label,
28};
29
30/// Gate operations, where the operation arguments correspond to _wire_ label.
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32#[serde(bound(
33    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
34    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
35))]
36pub enum Gate<C: Curve> {
37    /// Input a wire
38    Input {
39        input_type: Input<C>,
40    },
41    /// Field share unary operations
42    FieldShareUnaryOp {
43        x: Label,
44        op: FieldShareUnaryOp,
45        field_type: FieldType,
46    },
47    /// Field share binary operations, where the second wire may be a plaintext.
48    FieldShareBinaryOp {
49        x: Label,
50        y: Label,
51        y_form: ShareOrPlaintext,
52        op: FieldShareBinaryOp,
53        field_type: FieldType,
54    },
55    BatchSummation {
56        x: Label,
57        x_form: ShareOrPlaintext,
58        algebraic_type: AlgebraicType,
59    },
60    BitShareUnaryOp {
61        x: Label,
62        op: BitShareUnaryOp,
63    },
64    BitShareBinaryOp {
65        x: Label,
66        y: Label,
67        y_form: ShareOrPlaintext,
68        op: BitShareBinaryOp,
69    },
70    /// Operations with elliptic curve points
71    PointShareUnaryOp {
72        p: Label,
73        op: PointShareUnaryOp,
74    },
75    PointShareBinaryOp {
76        p: Label,
77        y: Label,
78        p_form: ShareOrPlaintext,
79        y_form: ShareOrPlaintext,
80        op: PointShareBinaryOp,
81    },
82    /// Field plaintext unary operations
83    FieldPlaintextUnaryOp {
84        x: Label,
85        op: FieldPlaintextUnaryOp,
86        field_type: FieldType,
87    },
88    /// Field plaintext binary operations
89    FieldPlaintextBinaryOp {
90        x: Label,
91        y: Label,
92        op: FieldPlaintextBinaryOp,
93        field_type: FieldType,
94    },
95    BitPlaintextUnaryOp {
96        x: Label,
97        op: FieldPlaintextUnaryOp,
98    },
99    BitPlaintextBinaryOp {
100        x: Label,
101        y: Label,
102        op: FieldPlaintextBinaryOp,
103    },
104    PointPlaintextUnaryOp {
105        p: Label,
106        op: PointPlaintextUnaryOp,
107    },
108    PointPlaintextBinaryOp {
109        p: Label,
110        y: Label,
111        op: PointPlaintextBinaryOp,
112    },
113    /// Request a daBit
114    DaBit {
115        field_type: FieldType,
116        batched: Batched,
117    },
118    GetDaBitFieldShare {
119        x: Label,
120        field_type: FieldType,
121    },
122    GetDaBitSharedBit {
123        x: Label,
124        field_type: FieldType,
125    },
126    /// ElGamal Encryption
127    EncryptPoint {
128        x: Label,
129        c: Point<C>,
130    },
131    DecryptPoint {
132        x: Label,
133        y: Label,
134    },
135    /// Base field exponentiation operation
136    BaseFieldPow {
137        x: Label,
138        exp: BoxedUint,
139    },
140    /// Bit plaintext conversion operations
141    BitPlaintextToField {
142        x: Label,
143        field_type: FieldType,
144    },
145    FieldPlaintextToBit {
146        x: Label,
147        field_type: FieldType,
148    },
149    /// Get the element at a certain index of a batched wire
150    BatchGetIndex {
151        x: Label,
152        x_type: AlgebraicType,
153        x_form: ShareOrPlaintext,
154        index: usize,
155    },
156    CollectToBatch {
157        wires: Vec<Label>,
158        x_type: AlgebraicType,
159        x_form: ShareOrPlaintext,
160    },
161    PointFromPlaintextExtendedEdwardsUnchecked {
162        wires: Vec<Label>,
163    },
164    PlaintextPointToExtendedEdwards {
165        point: Label,
166    },
167    PlaintextKeccakF1600 {
168        wires: Vec<Label>,
169    },
170}
171
172#[derive(Debug, Clone, Default, PartialEq, Eq)]
173pub struct CircuitPreprocessing {
174    pub scalar_singlets: usize,
175    pub scalar_triples: usize,
176    pub base_field_singlets: usize,
177    pub base_field_triples: usize,
178    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
179    pub bit_singlets: usize,
180    pub bit_triples: usize,
181    pub scalar_dabits: usize,
182    pub base_field_dabits: usize,
183}
184
185impl Add for CircuitPreprocessing {
186    type Output = Self;
187
188    fn add(self, other: Self) -> Self::Output {
189        Self {
190            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
191            scalar_triples: self.scalar_triples + other.scalar_triples,
192            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
193            base_field_triples: self.base_field_triples + other.base_field_triples,
194            bit_singlets: self.bit_singlets + other.bit_singlets,
195            bit_triples: self.bit_triples + other.bit_triples,
196            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
197            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
198            base_field_pow_pairs: {
199                let mut combined = self.base_field_pow_pairs;
200                for (k, v) in other.base_field_pow_pairs {
201                    *combined.entry(k).or_insert(0) += v;
202                }
203                combined
204            },
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
212
213    use super::*;
214    use crate::circuit::FieldShareBinaryOp;
215
216    #[test]
217    fn test_ser_gate() {
218        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
219            x: Label::from(1, 2),
220            y: Label::from(3, 4),
221            y_form: ShareOrPlaintext::Share,
222            op: FieldShareBinaryOp::Add,
223            field_type: FieldType::ScalarField,
224        };
225        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
226            x: Label::from(1, 2),
227            y: Label::from(3, 4),
228            y_form: ShareOrPlaintext::Plaintext,
229            op: FieldShareBinaryOp::Add,
230            field_type: FieldType::ScalarField,
231        };
232        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
233            p: Label::from(1, 2),
234            y: Label::from(3, 4),
235            p_form: ShareOrPlaintext::Share,
236            y_form: ShareOrPlaintext::Plaintext,
237            op: PointShareBinaryOp::Add,
238        };
239
240        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
241        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
242        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
243
244        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
245        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
246        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
247
248        assert_eq!(no_curve_gate, no_curve_gate_de);
249        assert_eq!(scalar_gate, scalar_gate_de);
250        assert_eq!(point_gate, point_gate_de);
251    }
252
253    #[test]
254    fn test_circuit_preprocessing_add() {
255        let a = CircuitPreprocessing {
256            scalar_singlets: 1,
257            scalar_triples: 2,
258            base_field_singlets: 3,
259            base_field_triples: 4,
260            bit_singlets: 0,
261            bit_triples: 1,
262            scalar_dabits: 1,
263            base_field_dabits: 2,
264            base_field_pow_pairs: vec![
265                (BoxedUint::from(vec![21]), 5),
266                (BoxedUint::from(vec![14]), 6),
267            ]
268            .into_iter()
269            .collect(),
270        };
271        let b = CircuitPreprocessing {
272            scalar_singlets: 2,
273            scalar_triples: 3,
274            base_field_singlets: 0,
275            base_field_triples: 5,
276            bit_singlets: 3,
277            bit_triples: 4,
278            scalar_dabits: 2,
279            base_field_dabits: 3,
280            base_field_pow_pairs: vec![
281                (BoxedUint::from(vec![21]), 6),
282                (BoxedUint::from(vec![13]), 7),
283            ]
284            .into_iter()
285            .collect(),
286        };
287
288        let c = a + b;
289
290        assert_eq!(c.scalar_singlets, 3);
291        assert_eq!(c.scalar_triples, 5);
292        assert_eq!(c.base_field_singlets, 3);
293        assert_eq!(c.base_field_triples, 9);
294        assert_eq!(c.bit_singlets, 3);
295        assert_eq!(c.bit_triples, 5);
296        assert_eq!(c.scalar_dabits, 3);
297        assert_eq!(c.base_field_dabits, 5);
298        assert_eq!(
299            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
300            Some(&11)
301        );
302        assert_eq!(
303            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
304            Some(&6)
305        );
306        assert_eq!(
307            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
308            Some(&7)
309        );
310    }
311}