Skip to main content

core_utils/circuit/v2/
gate.rs

1use std::hash::Hash;
2
3use primitives::algebra::{
4    elliptic_curve::{Curve, Point, Scalar},
5    BoxedUint,
6};
7use serde::{Deserialize, Serialize};
8use wincode::{SchemaRead, SchemaWrite};
9
10use crate::circuit::{
11    errors::CircuitError,
12    AlgebraicType,
13    BatchSize,
14    BitPlaintextBinaryOp,
15    BitPlaintextUnaryOp,
16    BitShareBinaryOp,
17    BitShareUnaryOp,
18    Constant,
19    FieldPlaintextBinaryOp,
20    FieldPlaintextUnaryOp,
21    FieldShareBinaryOp,
22    FieldShareUnaryOp,
23    FieldType,
24    GateIndex,
25    Input,
26    PointPlaintextBinaryOp,
27    PointPlaintextUnaryOp,
28    PointShareBinaryOp,
29    PointShareUnaryOp,
30    Slice,
31};
32
33/// Gate operations, where the operation arguments correspond to _wire_ label.
34#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
35#[serde(bound(
36    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
37    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
38))]
39#[repr(C)]
40pub enum Gate<C: Curve> {
41    /// Input a wire
42    Input(Input),
43    /// Input a constant value
44    Constant(Constant<C>),
45    /// Generate random shares
46    Random {
47        algebraic_type: AlgebraicType,
48        batch_size: BatchSize,
49    },
50    /// Field share unary operations
51    FieldShareUnaryOp {
52        x: GateIndex,
53        op: FieldShareUnaryOp,
54    },
55    /// Field share binary operations, where the second wire may be a plaintext.
56    FieldShareBinaryOp {
57        x: GateIndex,
58        y: GateIndex,
59        op: FieldShareBinaryOp,
60    },
61    BatchSummation {
62        x: GateIndex,
63    },
64    BitShareUnaryOp {
65        x: GateIndex,
66        op: BitShareUnaryOp,
67    },
68    BitShareBinaryOp {
69        x: GateIndex,
70        y: GateIndex,
71        op: BitShareBinaryOp,
72    },
73    /// Operations with elliptic curve points
74    PointShareUnaryOp {
75        p: GateIndex,
76        op: PointShareUnaryOp,
77    },
78    PointShareBinaryOp {
79        p: GateIndex,
80        y: GateIndex,
81        op: PointShareBinaryOp,
82    },
83    /// Field plaintext unary operations
84    FieldPlaintextUnaryOp {
85        x: GateIndex,
86        op: FieldPlaintextUnaryOp,
87    },
88    /// Field plaintext binary operations
89    FieldPlaintextBinaryOp {
90        x: GateIndex,
91        y: GateIndex,
92        op: FieldPlaintextBinaryOp,
93    },
94    BitPlaintextUnaryOp {
95        x: GateIndex,
96        op: BitPlaintextUnaryOp,
97    },
98    BitPlaintextBinaryOp {
99        x: GateIndex,
100        y: GateIndex,
101        op: BitPlaintextBinaryOp,
102    },
103    PointPlaintextUnaryOp {
104        p: GateIndex,
105        op: PointPlaintextUnaryOp,
106    },
107    PointPlaintextBinaryOp {
108        p: GateIndex,
109        y: GateIndex,
110        op: PointPlaintextBinaryOp,
111    },
112    /// Request a daBit
113    DaBit {
114        field_type: FieldType,
115        batch_size: BatchSize,
116    },
117    GetDaBitFieldShare {
118        x: GateIndex,
119    },
120    GetDaBitSharedBit {
121        x: GateIndex,
122    },
123    /// Base field exponentiation operation
124    BaseFieldPow {
125        x: GateIndex,
126        exp: BoxedUint,
127    },
128    /// Bit plaintext conversion operations
129    BitPlaintextToField {
130        x: GateIndex,
131        field_type: FieldType,
132    },
133    FieldPlaintextToBit {
134        x: GateIndex,
135    },
136    /// Get a slice of elements from a batched wire
137    ExtractFromBatch {
138        x: GateIndex,
139        slice: Slice,
140    },
141    CollectToBatch {
142        wires: Vec<GateIndex>,
143    },
144    PointFromPlaintextExtendedEdwards {
145        wires: Vec<GateIndex>,
146    },
147    PlaintextPointToExtendedEdwards {
148        point: GateIndex,
149    },
150    PlaintextKeccakF1600 {
151        x: GateIndex,
152    },
153    CompressPlaintextPoint {
154        point: GateIndex,
155    },
156    KeyRecoveryPlaintextComputeErrors {
157        d_minus_one: GateIndex,
158        syndromes: GateIndex,
159    },
160}
161
162impl<C: Curve> Gate<C> {
163    /// Check whether the gate is an input gate.
164    pub fn is_input(&self) -> bool {
165        matches!(self, Gate::Input { .. })
166    }
167
168    /// Returns the indices of gate inputs.
169    pub fn get_inputs(&self) -> Vec<GateIndex> {
170        match &self {
171            Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => {
172                Vec::new()
173            }
174
175            Gate::FieldShareUnaryOp { x, .. }
176            | Gate::BatchSummation { x, .. }
177            | Gate::BitShareUnaryOp { x, .. }
178            | Gate::PointShareUnaryOp { p: x, .. }
179            | Gate::FieldPlaintextUnaryOp { x, .. }
180            | Gate::BitPlaintextUnaryOp { x, .. }
181            | Gate::PointPlaintextUnaryOp { p: x, .. }
182            | Gate::GetDaBitFieldShare { x, .. }
183            | Gate::GetDaBitSharedBit { x, .. }
184            | Gate::BaseFieldPow { x, .. }
185            | Gate::BitPlaintextToField { x, .. }
186            | Gate::FieldPlaintextToBit { x, .. }
187            | Gate::ExtractFromBatch { x, .. }
188            | Gate::PlaintextPointToExtendedEdwards { point: x, .. }
189            | Gate::CompressPlaintextPoint { point: x, .. }
190            | Gate::PlaintextKeccakF1600 { x } => {
191                vec![*x]
192            }
193
194            Gate::FieldShareBinaryOp { x, y, .. }
195            | Gate::BitShareBinaryOp { x, y, .. }
196            | Gate::PointShareBinaryOp { p: x, y, .. }
197            | Gate::FieldPlaintextBinaryOp { x, y, .. }
198            | Gate::BitPlaintextBinaryOp { x, y, .. }
199            | Gate::PointPlaintextBinaryOp { p: x, y, .. }
200            | Gate::KeyRecoveryPlaintextComputeErrors {
201                d_minus_one: x,
202                syndromes: y,
203                ..
204            } => {
205                vec![*x, *y]
206            }
207
208            Gate::CollectToBatch { wires, .. }
209            | Gate::PointFromPlaintextExtendedEdwards { wires, .. } => wires.clone(),
210        }
211    }
212
213    /// Maps inplace gate inputs using the given function.
214    pub fn map_inputs<F: FnMut(GateIndex) -> GateIndex>(mut self, mut f: F) -> Self {
215        match &mut self {
216            Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => (),
217
218            Gate::FieldShareUnaryOp { x, .. }
219            | Gate::BatchSummation { x, .. }
220            | Gate::BitShareUnaryOp { x, .. }
221            | Gate::PointShareUnaryOp { p: x, .. }
222            | Gate::FieldPlaintextUnaryOp { x, .. }
223            | Gate::BitPlaintextUnaryOp { x, .. }
224            | Gate::PointPlaintextUnaryOp { p: x, .. }
225            | Gate::GetDaBitFieldShare { x, .. }
226            | Gate::GetDaBitSharedBit { x, .. }
227            | Gate::BaseFieldPow { x, .. }
228            | Gate::BitPlaintextToField { x, .. }
229            | Gate::FieldPlaintextToBit { x, .. }
230            | Gate::ExtractFromBatch { x, .. }
231            | Gate::PlaintextPointToExtendedEdwards { point: x, .. }
232            | Gate::CompressPlaintextPoint { point: x, .. }
233            | Gate::PlaintextKeccakF1600 { x } => {
234                *x = f(*x);
235            }
236
237            Gate::FieldShareBinaryOp { x, y, .. }
238            | Gate::BitShareBinaryOp { x, y, .. }
239            | Gate::PointShareBinaryOp { p: x, y, .. }
240            | Gate::FieldPlaintextBinaryOp { x, y, .. }
241            | Gate::BitPlaintextBinaryOp { x, y, .. }
242            | Gate::PointPlaintextBinaryOp { p: x, y, .. }
243            | Gate::KeyRecoveryPlaintextComputeErrors {
244                d_minus_one: x,
245                syndromes: y,
246                ..
247            } => {
248                *x = f(*x);
249                *y = f(*y);
250            }
251
252            Gate::CollectToBatch { wires, .. }
253            | Gate::PointFromPlaintextExtendedEdwards { wires, .. } => {
254                wires.iter_mut().for_each(|x| *x = f(*x))
255            }
256        };
257
258        self
259    }
260
261    /// Tries to replace the gate inputs with the given ones.
262    ///
263    /// This function returns an error if the number of given inputs does not match the number of
264    /// gate inputs.
265    pub fn try_replace_inputs(mut self, inputs: Vec<GateIndex>) -> Result<Self, CircuitError<C>> {
266        if inputs.len() != self.get_inputs().len() {
267            return Err(CircuitError::InvalidGateInputCount {
268                expected: self.get_inputs().len(),
269                found: inputs.len(),
270            });
271        }
272
273        match &mut self {
274            Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => (),
275
276            Gate::FieldShareUnaryOp { x, .. }
277            | Gate::BatchSummation { x, .. }
278            | Gate::BitShareUnaryOp { x, .. }
279            | Gate::PointShareUnaryOp { p: x, .. }
280            | Gate::FieldPlaintextUnaryOp { x, .. }
281            | Gate::BitPlaintextUnaryOp { x, .. }
282            | Gate::PointPlaintextUnaryOp { p: x, .. }
283            | Gate::GetDaBitFieldShare { x, .. }
284            | Gate::GetDaBitSharedBit { x, .. }
285            | Gate::BaseFieldPow { x, .. }
286            | Gate::BitPlaintextToField { x, .. }
287            | Gate::FieldPlaintextToBit { x, .. }
288            | Gate::ExtractFromBatch { x, .. }
289            | Gate::PlaintextPointToExtendedEdwards { point: x, .. }
290            | Gate::CompressPlaintextPoint { point: x, .. }
291            | Gate::PlaintextKeccakF1600 { x } => {
292                *x = inputs[0];
293            }
294
295            Gate::FieldShareBinaryOp { x, y, .. }
296            | Gate::BitShareBinaryOp { x, y, .. }
297            | Gate::PointShareBinaryOp { p: x, y, .. }
298            | Gate::FieldPlaintextBinaryOp { x, y, .. }
299            | Gate::BitPlaintextBinaryOp { x, y, .. }
300            | Gate::PointPlaintextBinaryOp { p: x, y, .. }
301            | Gate::KeyRecoveryPlaintextComputeErrors {
302                d_minus_one: x,
303                syndromes: y,
304                ..
305            } => {
306                *x = inputs[0];
307                *y = inputs[1];
308            }
309
310            Gate::CollectToBatch { wires, .. }
311            | Gate::PointFromPlaintextExtendedEdwards { wires, .. } => *wires = inputs,
312        };
313
314        Ok(self)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use std::collections::HashSet;
321
322    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
323
324    use super::*;
325    use crate::circuit::{
326        preprocessing::{CircuitPreprocessing, FieldCircuitPreprocessing},
327        FieldShareBinaryOp,
328    };
329
330    #[test]
331    fn test_ser_gate() {
332        let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
333            x: 1,
334            y: 3,
335            op: FieldShareBinaryOp::Add,
336        };
337        let point_gate: Gate<C> = Gate::PointShareBinaryOp {
338            p: 1,
339            y: 3,
340            op: PointShareBinaryOp::Add,
341        };
342
343        let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
344        let point_gate_ser = bincode::serialize(&point_gate).unwrap();
345
346        let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
347        let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
348
349        assert_eq!(scalar_gate, scalar_gate_de);
350        assert_eq!(point_gate, point_gate_de);
351        let set = HashSet::from([scalar_gate, scalar_gate_de, point_gate, point_gate_de]);
352        assert_eq!(set.len(), 2)
353    }
354
355    #[test]
356    fn test_circuit_preprocessing_add() {
357        let a = CircuitPreprocessing {
358            bit_singlets: 0,
359            bit_triples: 1,
360            base_field: FieldCircuitPreprocessing {
361                singlets: 3,
362                triples: 4,
363                dabits: 2,
364            },
365            scalar: FieldCircuitPreprocessing {
366                singlets: 1,
367                triples: 2,
368                dabits: 1,
369            },
370            base_field_pow_pairs: vec![
371                (BoxedUint::from(vec![21]), 5),
372                (BoxedUint::from(vec![14]), 6),
373            ]
374            .into_iter()
375            .collect(),
376            mersenne107: FieldCircuitPreprocessing {
377                singlets: 0,
378                triples: 0,
379                dabits: 0,
380            },
381        };
382        let b = CircuitPreprocessing {
383            bit_singlets: 3,
384            bit_triples: 4,
385            base_field: FieldCircuitPreprocessing {
386                singlets: 0,
387                triples: 5,
388                dabits: 3,
389            },
390            scalar: FieldCircuitPreprocessing {
391                singlets: 2,
392                triples: 3,
393                dabits: 2,
394            },
395            base_field_pow_pairs: vec![
396                (BoxedUint::from(vec![21]), 6),
397                (BoxedUint::from(vec![13]), 7),
398            ]
399            .into_iter()
400            .collect(),
401            mersenne107: FieldCircuitPreprocessing {
402                singlets: 3,
403                triples: 2,
404                dabits: 0,
405            },
406        };
407
408        let c = a + b;
409
410        assert_eq!(c.scalar.singlets, 3);
411        assert_eq!(c.scalar.triples, 5);
412        assert_eq!(c.base_field.singlets, 3);
413        assert_eq!(c.base_field.triples, 9);
414        assert_eq!(c.bit_singlets, 3);
415        assert_eq!(c.bit_triples, 5);
416        assert_eq!(c.mersenne107.dabits, 0);
417        assert_eq!(c.mersenne107.singlets, 3);
418        assert_eq!(c.mersenne107.triples, 2);
419        assert_eq!(c.scalar.dabits, 3);
420        assert_eq!(c.base_field.dabits, 5);
421        assert_eq!(
422            c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
423            Some(&11)
424        );
425        assert_eq!(
426            c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
427            Some(&6)
428        );
429        assert_eq!(
430            c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
431            Some(&7)
432        );
433    }
434}