core_utils/circuit/
gate.rs

1use std::{collections::HashMap, ops::Add};
2
3use primitives::{
4    algebra::{
5        elliptic_curve::{Curve, Point, Scalar},
6        BoxedUint,
7    },
8    types::PeerNumber,
9};
10use serde::{Deserialize, Serialize};
11
12use super::BaseFieldPlaintext;
13use crate::{
14    circuit::{BitPlaintext, FieldBinaryOp, FieldUnaryOp, PointPlaintext, ScalarPlaintext},
15    types::Label,
16};
17
18/// Gate operations, where the operation arguments correspond to _wire_ label.
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20#[serde(bound(
21    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
22    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
23))]
24pub enum Gate<C: Curve> {
25    /// Input a value to be secret shared
26    ScalarInput {
27        inputer: PeerNumber,
28    },
29    BaseFieldInput {
30        inputer: PeerNumber,
31    },
32    PointInput {
33        inputer: PeerNumber,
34    },
35    /// Input a plaintext scalar
36    ScalarPlaintextInput {
37        c: ScalarPlaintext<C>,
38    },
39    BaseFieldPlaintextInput {
40        c: BaseFieldPlaintext<C>,
41    },
42    /// Input an already secret-shared value
43    ScalarShareInput,
44    BaseFieldShareInput,
45    PointShareInput,
46    ElGamalCiphertextInput {
47        c: PointPlaintext<C>,
48        r: PointPlaintext<C>,
49    },
50    /// Addition gates between two secret shares
51    ScalarAdd {
52        x: Label,
53        y: Label,
54    },
55    BaseFieldAdd {
56        x: Label,
57        y: Label,
58    },
59    PointAdd {
60        x: Label,
61        y: Label,
62    },
63    /// Multiply two elements
64    ScalarMul {
65        x: Label,
66        y: Label,
67    },
68    BaseFieldMul {
69        x: Label,
70        y: Label,
71    },
72    PointMul {
73        x: Label,
74        y: Label,
75    },
76    /// Add a constant to a secret share
77    ScalarAddPlaintext {
78        x: Label,
79        c: Label,
80    },
81    BaseFieldAddPlaintext {
82        x: Label,
83        c: Label,
84    },
85    PointAddPlaintext {
86        x: Label,
87        c: Label,
88    },
89    /// Multiply a secret share by a constant
90    ScalarMulScalarPlaintext {
91        x: Label,
92        c: Label,
93    },
94    BaseFieldMulBaseFieldPlaintext {
95        x: Label,
96        c: Label,
97    },
98    PointMulScalarPlaintext {
99        x: Label,
100        c: ScalarPlaintext<C>,
101    },
102    ScalarMulPointPlaintext {
103        x: Label,
104        c: PointPlaintext<C>,
105    },
106    /// Open a value
107    ScalarShareOpen {
108        x: Label,
109    },
110    BaseFieldShareOpen {
111        x: Label,
112    },
113    PointShareOpen {
114        x: Label,
115    },
116    /// Field Unary Operations
117    ScalarPlaintextUnaryOp {
118        x: Label,
119        op: FieldUnaryOp,
120    },
121    BaseFieldPlaintextUnaryOp {
122        x: Label,
123        op: FieldUnaryOp,
124    },
125    /// Field Binary Operations
126    ScalarPlaintextBinaryOp {
127        x: Label,
128        y: Label,
129        op: FieldBinaryOp,
130    },
131    BaseFieldPlaintextBinaryOp {
132        x: Label,
133        y: Label,
134        op: FieldBinaryOp,
135    },
136    /// Multiplicative Inverse of a field
137    ScalarMultiplicativeInverse {
138        x: Label,
139    },
140    BaseFieldMultiplicativeInverse {
141        x: Label,
142    },
143    /// Additive Inverse
144    ScalarAdditiveInverse {
145        x: Label,
146    },
147    BaseFieldAdditiveInverse {
148        x: Label,
149    },
150    PointAdditiveInverse {
151        x: Label,
152    },
153    /// Comparisons (x > y, x >= y)
154    ScalarGreaterThan {
155        x: Label,
156        y: Label,
157    },
158    BaseFieldGreaterThan {
159        x: Label,
160        y: Label,
161    },
162    ScalarGreaterThanOrEqual {
163        x: Label,
164        y: Label,
165    },
166    BaseFieldGreaterThanOrEqual {
167        x: Label,
168        y: Label,
169    },
170    ScalarZeroTest {
171        x: Label,
172    },
173    BaseFieldZeroTest {
174        x: Label,
175    },
176    PointZeroTest {
177        x: Label,
178    },
179    /// Encryption
180    EncryptPoint {
181        x: Label,
182        c: Point<C>,
183    },
184    DecryptPoint {
185        x: Label,
186        y: Label,
187    },
188    RandomBaseFieldShare,
189    BaseFieldPow {
190        x: Label,
191        exp: BoxedUint,
192    },
193    PointPlaintextInput {
194        c: PointPlaintext<C>,
195    },
196    // Binary operations
197    BitInput {
198        inputer: PeerNumber,
199    },
200    BitXor {
201        x: Label,
202        y: Label,
203    },
204    BitNot {
205        x: Label,
206    },
207    BitAnd {
208        x: Label,
209        y: Label,
210    },
211    BitPlaintextXor {
212        x: Label,
213        c: Label,
214    },
215    BitPlaintextAnd {
216        x: Label,
217        c: Label,
218    },
219    BitShareInput,
220    BitPlaintextInput {
221        c: BitPlaintext,
222    },
223    RandomBitShare,
224    BitPlaintextUnaryOp {
225        x: Label,
226        op: FieldUnaryOp,
227    },
228    BitPlaintextBinaryOp {
229        x: Label,
230        y: Label,
231        op: FieldBinaryOp,
232    },
233    BitShareOpen {
234        x: Label,
235    },
236    ScalarDaBit,
237    BaseFieldDaBit,
238    GetDaBitScalarShare {
239        x: Label,
240    },
241    GetDaBitBaseFieldShare {
242        x: Label,
243    },
244    GetScalarDaBitSharedBit {
245        x: Label,
246    },
247    GetBaseFieldDaBitSharedBit {
248        x: Label,
249    },
250    BitShareGetIndex {
251        x: Label,
252        index: usize,
253    },
254    BitPlaintextToScalar {
255        x: Label,
256    },
257    BitPlaintextToBaseField {
258        x: Label,
259    },
260    ScalarPlaintextToBit {
261        x: Label,
262    },
263    BaseFieldPlaintextToBit {
264        x: Label,
265    },
266}
267
268#[derive(Debug, Clone, Default, PartialEq, Eq)]
269pub struct CircuitPreprocessing {
270    pub scalar_singlets: usize,
271    pub scalar_triples: usize,
272    pub base_field_singlets: usize,
273    pub base_field_triples: usize,
274    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
275    pub bit_singlets: usize,
276    pub bit_triples: usize,
277    pub scalar_dabits: usize,
278    pub base_field_dabits: usize,
279}
280
281impl Add for CircuitPreprocessing {
282    type Output = Self;
283
284    fn add(self, other: Self) -> Self::Output {
285        Self {
286            scalar_singlets: self.scalar_singlets + other.scalar_singlets,
287            scalar_triples: self.scalar_triples + other.scalar_triples,
288            base_field_singlets: self.base_field_singlets + other.base_field_singlets,
289            base_field_triples: self.base_field_triples + other.base_field_triples,
290            bit_singlets: self.bit_singlets + other.bit_singlets,
291            bit_triples: self.bit_triples + other.bit_triples,
292            scalar_dabits: self.scalar_dabits + other.scalar_dabits,
293            base_field_dabits: self.base_field_dabits + other.base_field_dabits,
294            base_field_pow_pairs: {
295                let mut combined = self.base_field_pow_pairs;
296                for (k, v) in other.base_field_pow_pairs {
297                    *combined.entry(k).or_insert(0) += v;
298                }
299                combined
300            },
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
308
309    use super::*;
310
311    #[test]
312    fn test_ser_gate() {
313        let no_curve_gate: Gate<C> = Gate::ScalarAdd {
314            x: Label::from(1, 2),
315            y: Label::from(3, 4),
316        };
317        let scalar_gate: Gate<C> = Gate::ScalarAddPlaintext {
318            x: Label::from(1, 2),
319            c: Label::from(3, 4),
320        };
321        let point_gate: Gate<C> = Gate::PointAddPlaintext {
322            x: Label::from(1, 2),
323            c: Label::from(3, 4),
324        };
325
326        let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
327        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
328        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
329
330        let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
331        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
332        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
333
334        assert_eq!(no_curve_gate, no_curve_gate_de);
335        assert_eq!(scalar_gate, scalar_gate_de);
336        assert_eq!(point_gate, point_gate_de);
337    }
338
339    #[test]
340    fn test_circuit_preprocessing_add() {
341        let a = CircuitPreprocessing {
342            scalar_singlets: 1,
343            scalar_triples: 2,
344            base_field_singlets: 3,
345            base_field_triples: 4,
346            bit_singlets: 0,
347            bit_triples: 1,
348            scalar_dabits: 1,
349            base_field_dabits: 2,
350            base_field_pow_pairs: vec![
351                (BoxedUint::from(vec![21]), 5),
352                (BoxedUint::from(vec![14]), 6),
353            ]
354            .into_iter()
355            .collect(),
356        };
357        let b = CircuitPreprocessing {
358            scalar_singlets: 2,
359            scalar_triples: 3,
360            base_field_singlets: 0,
361            base_field_triples: 5,
362            bit_singlets: 3,
363            bit_triples: 4,
364            scalar_dabits: 2,
365            base_field_dabits: 3,
366            base_field_pow_pairs: vec![
367                (BoxedUint::from(vec![21]), 6),
368                (BoxedUint::from(vec![13]), 7),
369            ]
370            .into_iter()
371            .collect(),
372        };
373
374        let c = a + b;
375
376        assert_eq!(c.scalar_singlets, 3);
377        assert_eq!(c.scalar_triples, 5);
378        assert_eq!(c.base_field_singlets, 3);
379        assert_eq!(c.base_field_triples, 9);
380        assert_eq!(c.bit_singlets, 3);
381        assert_eq!(c.bit_triples, 5);
382        assert_eq!(c.scalar_dabits, 3);
383        assert_eq!(c.base_field_dabits, 5);
384        assert_eq!(
385            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
386            Some(&11)
387        );
388        assert_eq!(
389            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
390            Some(&6)
391        );
392        assert_eq!(
393            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
394            Some(&7)
395        );
396    }
397}