core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use primitives::algebra::{
4    elliptic_curve::{Curve, Point, Scalar},
5    BoxedUint,
6};
7use serde::{Deserialize, Serialize};
8
9use crate::{
10    circuit::{
11        AlgebraicType,
12        Batched,
13        BitShareBinaryOp,
14        BitShareUnaryOp,
15        FieldPlaintextBinaryOp,
16        FieldPlaintextUnaryOp,
17        FieldShareBinaryOp,
18        FieldShareUnaryOp,
19        FieldType,
20        Input,
21        PointShareBinaryOp,
22        PointShareUnaryOp,
23        ShareOrPlaintext,
24    },
25    types::Label,
26};
27
28/// Gate operations, where the operation arguments correspond to _wire_ label.
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30#[serde(bound(
31    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
32    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
33))]
34pub enum Gate<C: Curve> {
35    /// Input a wire
36    Input {
37        input_type: Input<C>,
38    },
39    /// Field share unary operations
40    FieldShareUnaryOp {
41        x: Label,
42        op: FieldShareUnaryOp,
43        field_type: FieldType,
44    },
45    /// Field share binary operations, where the second wire may be a plaintext.
46    FieldShareBinaryOp {
47        x: Label,
48        y: Label,
49        y_form: ShareOrPlaintext,
50        op: FieldShareBinaryOp,
51        field_type: FieldType,
52    },
53    BatchSummation {
54        x: Label,
55        x_form: ShareOrPlaintext,
56        algebraic_type: AlgebraicType,
57    },
58    BitShareUnaryOp {
59        x: Label,
60        op: BitShareUnaryOp,
61    },
62    BitShareBinaryOp {
63        x: Label,
64        y: Label,
65        y_form: ShareOrPlaintext,
66        op: BitShareBinaryOp,
67    },
68    /// Operations with elliptic curve points
69    PointShareUnaryOp {
70        p: Label,
71        op: PointShareUnaryOp,
72    },
73    PointShareBinaryOp {
74        p: Label,
75        y: Label,
76        x_form: ShareOrPlaintext,
77        y_form: ShareOrPlaintext,
78        op: PointShareBinaryOp,
79    },
80    /// Field plaintext unary operations
81    FieldPlaintextUnaryOp {
82        x: Label,
83        op: FieldPlaintextUnaryOp,
84        field_type: FieldType,
85    },
86    /// Field plaintext binary operations
87    FieldPlaintextBinaryOp {
88        x: Label,
89        y: Label,
90        op: FieldPlaintextBinaryOp,
91        field_type: FieldType,
92    },
93    BitPlaintextUnaryOp {
94        x: Label,
95        op: FieldPlaintextUnaryOp,
96    },
97    BitPlaintextBinaryOp {
98        x: Label,
99        y: Label,
100        op: FieldPlaintextBinaryOp,
101    },
102    /// Request a daBit
103    DaBit {
104        field_type: FieldType,
105        batched: Batched,
106    },
107    GetDaBitFieldShare {
108        x: Label,
109        field_type: FieldType,
110    },
111    GetDaBitSharedBit {
112        x: Label,
113        field_type: FieldType,
114    },
115    /// ElGamal Encryption
116    EncryptPoint {
117        x: Label,
118        c: Point<C>,
119    },
120    DecryptPoint {
121        x: Label,
122        y: Label,
123    },
124    /// Base field exponentiation operation
125    BaseFieldPow {
126        x: Label,
127        exp: BoxedUint,
128    },
129    /// Bit plaintext conversion operations
130    BitPlaintextToField {
131        x: Label,
132        field_type: FieldType,
133    },
134    FieldPlaintextToBit {
135        x: Label,
136        field_type: FieldType,
137    },
138    /// Get the element at a certain index of a batched wire
139    BatchGetIndex {
140        x: Label,
141        x_type: AlgebraicType,
142        x_form: ShareOrPlaintext,
143        index: usize,
144    },
145    CollectToBatch {
146        wires: Vec<Label>,
147        x_type: AlgebraicType,
148        x_form: ShareOrPlaintext,
149    },
150}
151
152#[derive(Debug, Clone, Default, PartialEq, Eq)]
153pub struct CircuitPreprocessing {
154    pub scalar_singlets: usize,
155    pub scalar_triples: usize,
156    pub base_field_singlets: usize,
157    pub base_field_triples: usize,
158    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
159    pub bit_singlets: usize,
160    pub bit_triples: usize,
161    pub scalar_dabits: usize,
162    pub base_field_dabits: usize,
163}
164
165impl Add for CircuitPreprocessing {
166    type Output = Self;
167
168    fn add(self, other: Self) -> Self::Output {
169        Self {
170            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
171            scalar_triples: self.scalar_triples + other.scalar_triples,
172            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
173            base_field_triples: self.base_field_triples + other.base_field_triples,
174            bit_singlets: self.bit_singlets + other.bit_singlets,
175            bit_triples: self.bit_triples + other.bit_triples,
176            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
177            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
178            base_field_pow_pairs: {
179                let mut combined = self.base_field_pow_pairs;
180                for (k, v) in other.base_field_pow_pairs {
181                    *combined.entry(k).or_insert(0) += v;
182                }
183                combined
184            },
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
192
193    use super::*;
194    use crate::circuit::FieldShareBinaryOp;
195
196    #[test]
197    fn test_ser_gate() {
198        let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
199            x: Label::from(1, 2),
200            y: Label::from(3, 4),
201            y_form: ShareOrPlaintext::Share,
202            op: FieldShareBinaryOp::Add,
203            field_type: FieldType::ScalarField,
204        };
205        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
206            x: Label::from(1, 2),
207            y: Label::from(3, 4),
208            y_form: ShareOrPlaintext::Plaintext,
209            op: FieldShareBinaryOp::Add,
210            field_type: FieldType::ScalarField,
211        };
212        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
213            p: Label::from(1, 2),
214            y: Label::from(3, 4),
215            x_form: ShareOrPlaintext::Share,
216            y_form: ShareOrPlaintext::Plaintext,
217            op: PointShareBinaryOp::Add,
218        };
219
220        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
221        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
222        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
223
224        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
225        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
226        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
227
228        assert_eq!(no_curve_gate, no_curve_gate_de);
229        assert_eq!(scalar_gate, scalar_gate_de);
230        assert_eq!(point_gate, point_gate_de);
231    }
232
233    #[test]
234    fn test_circuit_preprocessing_add() {
235        let a = CircuitPreprocessing {
236            scalar_singlets: 1,
237            scalar_triples: 2,
238            base_field_singlets: 3,
239            base_field_triples: 4,
240            bit_singlets: 0,
241            bit_triples: 1,
242            scalar_dabits: 1,
243            base_field_dabits: 2,
244            base_field_pow_pairs: vec![
245                (BoxedUint::from(vec![21]), 5),
246                (BoxedUint::from(vec![14]), 6),
247            ]
248            .into_iter()
249            .collect(),
250        };
251        let b = CircuitPreprocessing {
252            scalar_singlets: 2,
253            scalar_triples: 3,
254            base_field_singlets: 0,
255            base_field_triples: 5,
256            bit_singlets: 3,
257            bit_triples: 4,
258            scalar_dabits: 2,
259            base_field_dabits: 3,
260            base_field_pow_pairs: vec![
261                (BoxedUint::from(vec![21]), 6),
262                (BoxedUint::from(vec![13]), 7),
263            ]
264            .into_iter()
265            .collect(),
266        };
267
268        let c = a + b;
269
270        assert_eq!(c.scalar_singlets, 3);
271        assert_eq!(c.scalar_triples, 5);
272        assert_eq!(c.base_field_singlets, 3);
273        assert_eq!(c.base_field_triples, 9);
274        assert_eq!(c.bit_singlets, 3);
275        assert_eq!(c.bit_triples, 5);
276        assert_eq!(c.scalar_dabits, 3);
277        assert_eq!(c.base_field_dabits, 5);
278        assert_eq!(
279            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
280            Some(&11)
281        );
282        assert_eq!(
283            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
284            Some(&6)
285        );
286        assert_eq!(
287            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
288            Some(&7)
289        );
290    }
291}