Skip to main content

core_utils/circuit/
circuit.rs

1use primitives::algebra::elliptic_curve::Curve;
2
3use crate::{
4    circuit::{
5        errors::{CircuitError, ConversionError},
6        gate::Gate,
7        AlgebraicType,
8        BitShareBinaryOp,
9        BitShareUnaryOp,
10        FieldShareBinaryOp,
11        FieldShareUnaryOp,
12        FieldType,
13        GateIndex,
14        Input,
15        PointPlaintextBinaryOp,
16        PointShareBinaryOp,
17        PointShareUnaryOp,
18        ShareOrPlaintext,
19    },
20    key_recovery::{MXE_KEY_RECOVERY_D, MXE_KEY_RECOVERY_N},
21};
22
23/// A circuit composed of a sequence of gates, input and output identifiers.
24///
25/// Each circuit gate contains additional information about its output characteristics, refer to
26///  the ` GateExt ` struct for more details.
27///
28/// The circuit is always valid because the gate addition is validated.
29#[derive(Default, PartialEq, Debug, Clone)]
30pub struct Circuit<C: Curve> {
31    /// The circuit gates.
32    pub(super) gates: Vec<GateExt<C>>,
33    /// The input gates in order of definition
34    pub(super) inputs: Vec<GateIndex>,
35    /// The output gates in order of definition
36    pub(super) outputs: Vec<GateIndex>,
37}
38
39/// A circuit gate together with additional information about its output.
40/// The additional information is automatically deduced when the gate is added to the circuit.
41#[derive(Clone, Debug, PartialEq)]
42pub struct GateExt<C: Curve> {
43    pub gate: Gate<C>,
44    pub output: GateOutput,
45    pub level: GateLevel,
46}
47
48/// Gate output characteristics like algebraic type, visibility, and batch size
49#[derive(PartialEq, Copy, Clone, Debug)]
50pub struct GateOutput {
51    pub(super) algebraic_type: AlgebraicType,
52    pub(super) form: ShareOrPlaintext,
53    pub(super) batch_size: u32,
54}
55
56/// The level of a gate in a circuit. All gates with the same level can be executed in parallel as
57/// they do not depend on each other.
58///
59/// The gate level is a pair of integers: the first one is the communication round (level), and the
60/// second one is a relative level within the same communication round. The gate communication round
61/// is the number of communications rounds which have passed after the gate execution.
62#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Default)]
63pub struct GateLevel {
64    comm_level: usize,
65    level: usize, // relative counter for ordering gates within a multiplicative level
66}
67
68impl<C: Curve> GateExt<C> {
69    pub fn new(gate: Gate<C>, output: GateOutput, level: GateLevel) -> Self {
70        Self {
71            gate,
72            output,
73            level,
74        }
75    }
76}
77
78impl GateOutput {
79    pub fn get_type(&self) -> AlgebraicType {
80        self.algebraic_type
81    }
82
83    pub fn get_field_type(&self) -> Result<FieldType, ConversionError> {
84        FieldType::try_from(self.algebraic_type)
85    }
86
87    pub fn get_field_type_unchecked(&self) -> FieldType {
88        self.get_field_type().unwrap()
89    }
90
91    pub fn get_form(&self) -> ShareOrPlaintext {
92        self.form
93    }
94
95    pub fn get_batch_size(&self) -> u32 {
96        self.batch_size
97    }
98
99    fn is_field(&self) -> bool {
100        FieldType::try_from(self.algebraic_type).is_ok()
101    }
102
103    fn is_bit(&self) -> bool {
104        self.algebraic_type == AlgebraicType::Bit
105    }
106
107    fn is_point(&self) -> bool {
108        self.algebraic_type == AlgebraicType::Point
109    }
110
111    fn is_base_field(&self) -> bool {
112        self.algebraic_type == AlgebraicType::BaseField
113    }
114
115    fn is_scalar_field(&self) -> bool {
116        self.algebraic_type == AlgebraicType::ScalarField
117    }
118
119    fn is_share(&self) -> bool {
120        self.form == ShareOrPlaintext::Share
121    }
122
123    pub fn is_plaintext(&self) -> bool {
124        self.form == ShareOrPlaintext::Plaintext
125    }
126
127    fn with_new_form(self, form: ShareOrPlaintext) -> Self {
128        let mut res = self;
129        res.form = form;
130        res
131    }
132
133    fn with_new_type(self, algebraic_type: AlgebraicType) -> Self {
134        let mut res = self;
135        res.algebraic_type = algebraic_type;
136        res
137    }
138
139    fn with_new_batch_size(self, batch_size: u32) -> Self {
140        let mut res = self;
141        res.batch_size = batch_size;
142        res
143    }
144}
145
146impl GateLevel {
147    fn next(&self, comm_rounds: usize) -> GateLevel {
148        if comm_rounds > 0 {
149            GateLevel {
150                comm_level: self.comm_level + comm_rounds,
151                level: 0,
152            }
153        } else {
154            GateLevel {
155                comm_level: self.comm_level,
156                level: self.level + 1,
157            }
158        }
159    }
160
161    pub fn comm_level(&self) -> usize {
162        self.comm_level
163    }
164}
165
166impl<C: Curve> Circuit<C> {
167    pub fn new() -> Self {
168        Self::default()
169    }
170
171    /// Tries to add a gate to the circuit.
172    ///
173    /// This function validates the gate before adding the gate or fails otherwise.
174    pub fn add_gate(&mut self, gate: Gate<C>) -> Result<GateIndex, CircuitError<C>> {
175        self.validate_gate(&gate)?;
176
177        let index = self.nb_gates();
178        if index == GateIndex::MAX {
179            return Err(CircuitError::CircuitTooBig);
180        }
181
182        let gate_output = self.comp_gate_output(&gate);
183        let level = self.comp_gate_level(&gate);
184
185        if gate.is_input() {
186            self.inputs.push(index);
187        }
188        self.gates.push(GateExt::new(gate, gate_output?, level));
189
190        Ok(index)
191    }
192
193    /// Tries to set a gate as circuit output.
194    ///
195    /// This function fails if there is no gate with the given index.
196    pub fn add_output(&mut self, index: GateIndex) -> Result<(), CircuitError<C>> {
197        if index < self.nb_gates() {
198            self.outputs.push(index);
199            Ok(())
200        } else {
201            Err(CircuitError::GateIndexOutOfBounds(index, self.nb_gates()))
202        }
203    }
204
205    pub fn nb_gates(&self) -> GateIndex {
206        self.gates.len() as GateIndex
207    }
208
209    pub fn nb_inputs(&self) -> GateIndex {
210        self.inputs.len() as GateIndex
211    }
212
213    pub fn nb_outputs(&self) -> GateIndex {
214        self.outputs.len() as GateIndex
215    }
216
217    /// Consumes the circuit and returns the list of gates.
218    pub fn into_gates(self) -> Vec<GateExt<C>> {
219        self.gates
220    }
221
222    pub fn iter_gates_ext(
223        &self,
224    ) -> impl ExactSizeIterator<Item = &GateExt<C>> + DoubleEndedIterator {
225        self.gates.iter()
226    }
227
228    pub fn iter_gates(&self) -> impl ExactSizeIterator<Item = &Gate<C>> + DoubleEndedIterator {
229        self.gates.iter().map(|g| &g.gate)
230    }
231
232    pub fn iter_output_indices(&self) -> impl ExactSizeIterator<Item = &GateIndex> {
233        self.outputs.iter()
234    }
235
236    pub fn iter_input_indices(&self) -> impl ExactSizeIterator<Item = &GateIndex> {
237        self.inputs.iter()
238    }
239
240    pub fn gate_ext(&self, index: GateIndex) -> Result<&GateExt<C>, CircuitError<C>> {
241        if index < self.nb_gates() {
242            Ok(&self.gates[index as usize])
243        } else {
244            Err(CircuitError::GateIndexOutOfBounds(index, self.nb_gates()))
245        }
246    }
247
248    pub fn gate_ext_unchecked(&self, index: GateIndex) -> &GateExt<C> {
249        &self.gates[index as usize]
250    }
251
252    pub fn gate(&self, index: GateIndex) -> Result<&Gate<C>, CircuitError<C>> {
253        self.gate_ext(index).map(|g| &g.gate)
254    }
255
256    pub fn gate_unchecked(&self, index: GateIndex) -> &Gate<C> {
257        &self.gate_ext_unchecked(index).gate
258    }
259
260    pub fn gate_output(&self, index: GateIndex) -> Result<GateOutput, CircuitError<C>> {
261        self.gate_ext(index).map(|g| g.output)
262    }
263
264    pub fn gate_output_unchecked(&self, index: GateIndex) -> GateOutput {
265        self.gate_ext_unchecked(index).output
266    }
267
268    pub fn gate_level(&self, index: GateIndex) -> Result<GateLevel, CircuitError<C>> {
269        self.gate_ext(index).map(|g| g.level)
270    }
271
272    pub fn gate_level_unchecked(&self, index: GateIndex) -> GateLevel {
273        self.gate_ext_unchecked(index).level
274    }
275}
276
277macro_rules! check_algebraic_type {
278    ($exp_type:expr, $found_type:expr) => {
279        if $exp_type != $found_type {
280            return Err(CircuitError::InvalidGateAlgebraicType {
281                expected: $exp_type,
282                found: $found_type,
283            });
284        }
285    };
286}
287
288impl<C: Curve> Circuit<C> {
289    /// Opens and outputs as scalar field a given gate
290    pub fn open_and_output_scalar(&mut self, x: GateIndex) -> Result<(), CircuitError<C>> {
291        check_algebraic_type!(AlgebraicType::ScalarField, self.gate_output(x)?.get_type());
292        let opening_index = self.add_gate(Gate::FieldShareUnaryOp {
293            x,
294            op: FieldShareUnaryOp::Open,
295        })?;
296        self.add_output(opening_index)
297    }
298
299    /// Opens and outputs as base field a given gate
300    pub fn open_and_output_base_field(&mut self, x: GateIndex) -> Result<(), CircuitError<C>> {
301        check_algebraic_type!(AlgebraicType::BaseField, self.gate_output(x)?.get_type());
302        let opening_index = self.add_gate(Gate::FieldShareUnaryOp {
303            x,
304            op: FieldShareUnaryOp::Open,
305        })?;
306        self.add_output(opening_index)
307    }
308
309    /// Opens and outputs as Mersenne107 a given gate
310    pub fn open_and_output_mersenne107(&mut self, x: GateIndex) -> Result<(), CircuitError<C>> {
311        check_algebraic_type!(AlgebraicType::Mersenne107, self.gate_output(x)?.get_type());
312        let opening_index = self.add_gate(Gate::FieldShareUnaryOp {
313            x,
314            op: FieldShareUnaryOp::Open,
315        })?;
316        self.add_output(opening_index)
317    }
318
319    /// Opens and outputs as point a given gate
320    pub fn open_and_output_point(&mut self, p: GateIndex) -> Result<(), CircuitError<C>> {
321        check_algebraic_type!(AlgebraicType::Point, self.gate_output(p)?.get_type());
322        let opening_index = self.add_gate(Gate::PointShareUnaryOp {
323            p,
324            op: PointShareUnaryOp::Open,
325        })?;
326        self.add_output(opening_index)
327    }
328
329    /// Opens and outputs as bit a given gate
330    pub fn open_and_output_bit(&mut self, x: GateIndex) -> Result<(), CircuitError<C>> {
331        check_algebraic_type!(AlgebraicType::Bit, self.gate_output(x)?.get_type());
332        let opening_index = self.add_gate(Gate::BitShareUnaryOp {
333            x,
334            op: BitShareUnaryOp::Open,
335        })?;
336        self.add_output(opening_index)
337    }
338}
339
340impl<C: Curve> Circuit<C> {
341    /// Validates a gate.
342    ///
343    /// Checks that:
344    ///     - gate inputs are present in the circuit
345    ///     - gate input types correspond to the specification
346    ///     - gate input batch sizes are compatible
347    ///     - gate parameters are valid
348    fn validate_gate(&self, gate: &Gate<C>) -> Result<(), CircuitError<C>> {
349        macro_rules! check_op {
350            ($msg:expr, $gate:expr => $($val1:expr, $op:tt, $val2:expr);+$(;)?) => {
351                $(if !($val1 $op $val2) {
352                    return Err(CircuitError::InvalidGate(
353                        $gate.clone(),
354                        format!("{}: {:?} {:?} {:?}", $msg, $val1, stringify!($op), $val2),
355                    ));
356                })+
357            };
358        }
359
360        macro_rules! check_gate_properties {
361            ($gate:expr, $($func:ident),* $(,)?) => {
362                $(if !($gate.output.$func()) {
363                    return Err(CircuitError::InvalidGate(
364                        $gate.gate.clone(),
365                        format!("{:?} fails - {}", $gate.output, stringify!($func)))); }
366                )*
367            };
368        }
369
370        match gate {
371            Gate::Input(input) => {
372                check_op!(
373                    "input batch size must be non-zero",
374                    gate =>
375                    0, <, input.batch_size();
376                );
377            }
378            Gate::Constant(constant) => {
379                check_op!(
380                    "constant batch size must be non-zero",
381                    gate =>
382                     0, <, constant.batch_size()?;
383                );
384            }
385            Gate::Random { batch_size, .. } => {
386                check_op!(
387                    "random batch size must be non-zero",
388                    gate =>
389                     0, <, *batch_size;
390                );
391            }
392            Gate::FieldShareUnaryOp { x, .. } => {
393                check_gate_properties!(self.gate_ext(*x)?, is_field, is_share);
394            }
395            Gate::FieldShareBinaryOp { x, y, .. } => {
396                let (gx, gy) = (self.gate_ext(*x)?, self.gate_ext(*y)?);
397                check_gate_properties!(gx, is_field, is_share);
398                check_gate_properties!(gy, is_field);
399                check_op!(
400                    "inputs must have same batch-size and field type",
401                    gate =>
402                    gx.output.batch_size, ==, gy.output.batch_size;
403                    gx.output.algebraic_type, ==, gy.output.algebraic_type
404                );
405            }
406            Gate::BatchSummation { x } => {
407                self.gate_ext(*x)?;
408            }
409            Gate::BitShareUnaryOp { x, .. } => {
410                check_gate_properties!(self.gate_ext(*x)?, is_bit, is_share);
411            }
412            Gate::BitShareBinaryOp { x, y, .. } => {
413                let (gx, gy) = (self.gate_ext(*x)?, self.gate_ext(*y)?);
414                check_gate_properties!(gx, is_bit, is_share);
415                check_gate_properties!(gy, is_bit);
416                check_op!(
417                    "inputs must have same batch-size",
418                    gate =>
419                    gx.output.batch_size, ==, gy.output.batch_size
420                );
421            }
422            Gate::PointShareUnaryOp { p: x, .. } => {
423                check_gate_properties!(self.gate_ext(*x)?, is_point, is_share);
424            }
425            Gate::PointShareBinaryOp { p: x, y, op } => {
426                let (gx, gy) = (self.gate_ext(*x)?, self.gate_ext(*y)?);
427                if gx.output.is_plaintext() && gy.output.is_plaintext() {
428                    return Err(CircuitError::InvalidGate(
429                        gate.clone(),
430                        "at least one input must be share".to_string(),
431                    ));
432                }
433                check_gate_properties!(gx, is_point);
434                match op {
435                    PointShareBinaryOp::Add => {
436                        check_gate_properties!(gy, is_point);
437                    }
438                    PointShareBinaryOp::ScalarMul => {
439                        check_gate_properties!(gy, is_scalar_field);
440                    }
441                };
442                check_op!(
443                    "inputs must have same batch-size",
444                    gate =>
445                    gx.output.batch_size, ==, gy.output.batch_size
446                );
447            }
448            Gate::FieldPlaintextUnaryOp { x, .. } => {
449                check_gate_properties!(self.gate_ext(*x)?, is_field, is_plaintext);
450            }
451            Gate::FieldPlaintextBinaryOp { x, y, .. } => {
452                let (gx, gy) = (self.gate_ext(*x)?, self.gate_ext(*y)?);
453                check_gate_properties!(gx, is_field, is_plaintext);
454                check_gate_properties!(gy, is_field, is_plaintext);
455                check_op!(
456                    "inputs must have same field type",
457                    gate =>
458                    gx.output.algebraic_type, ==, gy.output.algebraic_type;
459                    gx.output.batch_size, ==, gy.output.batch_size
460                );
461            }
462            Gate::BitPlaintextUnaryOp { x, .. } => {
463                check_gate_properties!(self.gate_ext(*x)?, is_bit, is_plaintext);
464            }
465            Gate::BitPlaintextBinaryOp { x, y, .. } => {
466                let (gx, gy) = (self.gate_ext(*x)?, self.gate_ext(*y)?);
467                check_gate_properties!(gx, is_bit, is_plaintext);
468                check_gate_properties!(gy, is_bit, is_plaintext);
469                check_op!(
470                    "inputs must have same batch-size",
471                    gate =>
472                    gx.output.batch_size, ==, gy.output.batch_size
473                );
474            }
475            Gate::PointPlaintextUnaryOp { p: x, .. } => {
476                check_gate_properties!(self.gate_ext(*x)?, is_point, is_plaintext);
477            }
478            Gate::PointPlaintextBinaryOp { p: x, y, op } => {
479                let (gx, gy) = (self.gate_ext(*x)?, self.gate_ext(*y)?);
480                check_gate_properties!(gx, is_point, is_plaintext);
481                match op {
482                    PointPlaintextBinaryOp::Add => {
483                        check_gate_properties!(gy, is_point, is_plaintext);
484                    }
485                    PointPlaintextBinaryOp::ScalarMul => {
486                        check_gate_properties!(gy, is_scalar_field, is_plaintext);
487                    }
488                }
489                check_op!(
490                    "inputs must have same batch-size",
491                    gate =>
492                    gx.output.batch_size, ==, gy.output.batch_size
493                );
494            }
495            Gate::DaBit { batch_size, .. } => {
496                check_op!(
497                    "input batch size must be non-zero",
498                    gate =>
499                    0, <, *batch_size
500                );
501            }
502            Gate::GetDaBitFieldShare { x, .. } => {
503                // By convention, we suppose that the `DaBit` output is a field element
504                check_gate_properties!(self.gate_ext(*x)?, is_field, is_share);
505            }
506            Gate::GetDaBitSharedBit { x, .. } => {
507                // By convention, we suppose that the `DaBit` output is a field element
508                check_gate_properties!(self.gate_ext(*x)?, is_field, is_share);
509            }
510            Gate::BaseFieldPow { x, .. } => {
511                check_gate_properties!(self.gate_ext(*x)?, is_base_field, is_share);
512            }
513            Gate::BitPlaintextToField { x, .. } => {
514                check_gate_properties!(self.gate_ext(*x)?, is_bit, is_plaintext);
515            }
516            Gate::FieldPlaintextToBit { x, .. } => {
517                check_gate_properties!(self.gate_ext(*x)?, is_field, is_plaintext);
518            }
519            Gate::ExtractFromBatch { x, slice, .. } => {
520                let gx = self.gate_ext(*x)?;
521                if slice.is_empty() {
522                    return Err(CircuitError::InvalidGate(
523                        gate.clone(),
524                        format!("slice must be non-empty: {slice:?}"),
525                    ));
526                }
527                if slice.get_indices()
528                    .into_iter()
529                    .max()
530                    .expect("non-empty slice expected") // never fails as we check that the slice is non-empty
531                    >= gx.output.batch_size
532                {
533                    return Err(CircuitError::InvalidGate(
534                        gate.clone(),
535                        format!("slice indices out-of-range: {slice:?}"),
536                    ));
537                }
538            }
539            Gate::CollectToBatch { wires } => {
540                check_op!("expected at least one input", gate => 0, <, wires.len());
541                let first = self.gate_ext(wires[0])?.output;
542                for x in wires.iter().skip(1) {
543                    let gx = self.gate_ext(*x)?.output;
544                    check_op!(
545                        "all inputs must have the same type",
546                        gate =>
547                        first.algebraic_type, ==, gx.algebraic_type;
548                        first.form, ==, gx.form
549                    );
550                }
551            }
552            Gate::PointFromPlaintextExtendedEdwards { wires } => {
553                check_op!("expected exactly 4 inputs", gate => wires.len(), ==, 4);
554                for x in wires {
555                    let gx = self.gate_ext(*x)?;
556                    check_gate_properties!(gx, is_base_field, is_plaintext);
557                    check_op!("expected batch-size 1", gate => gx.output.batch_size, ==, 1);
558                }
559            }
560            Gate::PlaintextPointToExtendedEdwards { point: x, .. } => {
561                let gx = self.gate_ext(*x)?;
562                check_gate_properties!(gx, is_point, is_plaintext);
563                check_op!("expected batch-size 1", gate => gx.output.batch_size, ==, 1);
564            }
565            Gate::PlaintextKeccakF1600 { x } => {
566                let gx = self.gate_ext(*x)?;
567                check_gate_properties!(gx, is_bit, is_plaintext);
568                check_op!("expected batch-size 1600", gate => gx.output.batch_size, ==, 1600);
569            }
570            Gate::CompressPlaintextPoint { point: x, .. } => {
571                let gx = self.gate_ext(*x)?;
572                check_gate_properties!(gx, is_point, is_plaintext);
573                check_op!("expected batch-size 1", gate => gx.output.batch_size, ==, 1);
574            }
575            Gate::KeyRecoveryPlaintextComputeErrors {
576                d_minus_one,
577                syndromes,
578            } => {
579                let g1 = self.gate_ext(*d_minus_one)?;
580                let g2 = self.gate_ext(*syndromes)?;
581                check_gate_properties!(g1, is_base_field, is_plaintext);
582                check_gate_properties!(g2, is_base_field, is_plaintext);
583
584                check_op!("expected batch-size 1", gate => g1.output.batch_size, ==, 1);
585                // TODO: Check that the batch size of `g2` is correct.
586                check_op!(format!("expected batch-size {}", MXE_KEY_RECOVERY_D - 1),
587                    gate => g2.output.batch_size, ==, MXE_KEY_RECOVERY_D as u32 - 1);
588            }
589        }
590
591        Ok(())
592    }
593
594    /// Computes the output type of gate.
595    ///
596    /// **Note: ** This function can panic if the gate is not valid.
597    fn comp_gate_output(&self, gate: &Gate<C>) -> Result<GateOutput, CircuitError<C>> {
598        let r = match gate {
599            Gate::Input(input_type) => GateOutput {
600                batch_size: input_type.batch_size(),
601                algebraic_type: input_type.algebraic_type(),
602                form: input_type.share_or_plaintext(),
603            },
604
605            Gate::Constant(const_type) => GateOutput {
606                batch_size: const_type.batch_size()?,
607                algebraic_type: const_type.algebraic_type(),
608                form: ShareOrPlaintext::Plaintext,
609            },
610
611            Gate::Random {
612                algebraic_type,
613                batch_size,
614            } => GateOutput {
615                batch_size: *batch_size,
616                algebraic_type: *algebraic_type,
617                form: ShareOrPlaintext::Share,
618            },
619
620            Gate::FieldShareUnaryOp { x, op } => match op {
621                FieldShareUnaryOp::Neg | FieldShareUnaryOp::MulInverse => {
622                    self.gate_output_unchecked(*x)
623                }
624                FieldShareUnaryOp::Open | FieldShareUnaryOp::IsZero => self
625                    .gate_output_unchecked(*x)
626                    .with_new_form(ShareOrPlaintext::Plaintext),
627            },
628
629            Gate::FieldShareBinaryOp { x, .. }
630            | Gate::BitShareBinaryOp { x, .. }
631            | Gate::FieldPlaintextUnaryOp { x, .. }
632            | Gate::FieldPlaintextBinaryOp { x, .. }
633            | Gate::BitPlaintextUnaryOp { x, .. }
634            | Gate::BitPlaintextBinaryOp { x, .. }
635            | Gate::PointPlaintextUnaryOp { p: x, .. }
636            | Gate::PointPlaintextBinaryOp { p: x, .. }
637            | Gate::GetDaBitFieldShare { x, .. }
638            | Gate::BaseFieldPow { x, .. } => self.gate_output_unchecked(*x),
639
640            Gate::BatchSummation { x, .. } => self.gate_output_unchecked(*x).with_new_batch_size(1),
641
642            Gate::PointShareBinaryOp { p: x, .. } => self
643                .gate_output_unchecked(*x)
644                .with_new_form(ShareOrPlaintext::Share),
645
646            Gate::BitShareUnaryOp { x, op } => match op {
647                BitShareUnaryOp::Not => self.gate_output_unchecked(*x),
648                BitShareUnaryOp::Open => self
649                    .gate_output_unchecked(*x)
650                    .with_new_form(ShareOrPlaintext::Plaintext),
651            },
652
653            Gate::PointShareUnaryOp { p: x, op } => match op {
654                PointShareUnaryOp::Neg => self.gate_output_unchecked(*x),
655                PointShareUnaryOp::Open => self
656                    .gate_output_unchecked(*x)
657                    .with_new_form(ShareOrPlaintext::Plaintext),
658                PointShareUnaryOp::IsZero => self
659                    .gate_output_unchecked(*x)
660                    .with_new_form(ShareOrPlaintext::Plaintext)
661                    .with_new_type(AlgebraicType::ScalarField),
662            },
663
664            Gate::DaBit {
665                field_type,
666                batch_size,
667            } => GateOutput {
668                batch_size: *batch_size,
669                algebraic_type: AlgebraicType::from(*field_type),
670                form: ShareOrPlaintext::Share,
671            },
672
673            Gate::GetDaBitSharedBit { x, .. } => self
674                .gate_output_unchecked(*x)
675                .with_new_type(AlgebraicType::Bit),
676
677            Gate::BitPlaintextToField { x, field_type } => self
678                .gate_output_unchecked(*x)
679                .with_new_type(AlgebraicType::from(*field_type)),
680
681            Gate::FieldPlaintextToBit { x } => self
682                .gate_output_unchecked(*x)
683                .with_new_type(AlgebraicType::Bit),
684
685            Gate::ExtractFromBatch { x, slice } => self
686                .gate_output_unchecked(*x)
687                .with_new_batch_size(slice.len()),
688
689            Gate::CollectToBatch { wires, .. } => {
690                let batch_size = wires
691                    .iter()
692                    .map(|x| self.gate_output_unchecked(*x).batch_size)
693                    .sum();
694                self.gate_output_unchecked(wires[0])
695                    .with_new_batch_size(batch_size)
696            }
697
698            Gate::PointFromPlaintextExtendedEdwards { .. } => GateOutput {
699                algebraic_type: AlgebraicType::Point,
700                form: ShareOrPlaintext::Plaintext,
701                batch_size: 1,
702            },
703            Gate::PlaintextPointToExtendedEdwards { .. } => GateOutput {
704                algebraic_type: AlgebraicType::BaseField,
705                form: ShareOrPlaintext::Plaintext,
706                batch_size: 4,
707            },
708            Gate::PlaintextKeccakF1600 { .. } => GateOutput {
709                algebraic_type: AlgebraicType::Bit,
710                form: ShareOrPlaintext::Plaintext,
711                batch_size: 1600,
712            },
713            Gate::CompressPlaintextPoint { .. } => GateOutput {
714                algebraic_type: AlgebraicType::Bit,
715                form: ShareOrPlaintext::Plaintext,
716                batch_size: 256,
717            },
718            Gate::KeyRecoveryPlaintextComputeErrors { .. } => GateOutput {
719                algebraic_type: AlgebraicType::BaseField,
720                form: ShareOrPlaintext::Plaintext,
721                batch_size: MXE_KEY_RECOVERY_N as u32,
722            },
723        };
724
725        Ok(r)
726    }
727
728    /// Computes the number of rounds required to evaluate the gate.
729    ///
730    /// **Note: ** This function can panic if the gate is not valid.
731    fn comp_gate_comm_rounds(&self, gate: &Gate<C>) -> usize {
732        match gate {
733            Gate::Input(input_type) => match input_type {
734                Input::SecretPlaintext { .. } => 1,
735                _ => 0,
736            },
737
738            Gate::Constant(_) | Gate::Random { .. } => 0,
739
740            Gate::FieldShareUnaryOp { op, .. } => match op {
741                FieldShareUnaryOp::Neg => 0,
742                FieldShareUnaryOp::MulInverse => 2,
743                FieldShareUnaryOp::Open => 1,
744                FieldShareUnaryOp::IsZero => 2,
745            },
746            Gate::FieldShareBinaryOp { op, y, .. } => {
747                match (op, self.gate_output_unchecked(*y).form) {
748                    (FieldShareBinaryOp::Mul, ShareOrPlaintext::Share) => 1,
749                    (FieldShareBinaryOp::Mul, ShareOrPlaintext::Plaintext)
750                    | (FieldShareBinaryOp::Add, _) => 0,
751                }
752            }
753            Gate::BatchSummation { .. } => 0,
754            Gate::BitShareUnaryOp { op, .. } => match op {
755                BitShareUnaryOp::Not => 0,
756                BitShareUnaryOp::Open => 1,
757            },
758            Gate::BitShareBinaryOp { op, y, .. } => {
759                match (op, self.gate_output_unchecked(*y).form) {
760                    (BitShareBinaryOp::Xor, _) => 0,
761                    (_, ShareOrPlaintext::Share) => 1,
762                    (_, ShareOrPlaintext::Plaintext) => 0,
763                }
764            }
765            Gate::PointShareUnaryOp { op, .. } => match op {
766                PointShareUnaryOp::Neg => 0,
767                PointShareUnaryOp::Open => 1,
768                PointShareUnaryOp::IsZero => 2,
769            },
770            Gate::PointShareBinaryOp { y, op, .. } => {
771                match (op, self.gate_output_unchecked(*y).form) {
772                    (PointShareBinaryOp::Add, _) => 0,
773                    (PointShareBinaryOp::ScalarMul, ShareOrPlaintext::Share) => 1,
774                    (PointShareBinaryOp::ScalarMul, ShareOrPlaintext::Plaintext) => 0,
775                }
776            }
777
778            Gate::BaseFieldPow { .. } => 2,
779
780            Gate::FieldPlaintextUnaryOp { .. }
781            | Gate::FieldPlaintextBinaryOp { .. }
782            | Gate::BitPlaintextUnaryOp { .. }
783            | Gate::BitPlaintextBinaryOp { .. }
784            | Gate::PointPlaintextUnaryOp { .. }
785            | Gate::PointPlaintextBinaryOp { .. }
786            | Gate::DaBit { .. }
787            | Gate::GetDaBitFieldShare { .. }
788            | Gate::GetDaBitSharedBit { .. }
789            | Gate::BitPlaintextToField { .. }
790            | Gate::FieldPlaintextToBit { .. }
791            | Gate::ExtractFromBatch { .. }
792            | Gate::CollectToBatch { .. }
793            | Gate::PointFromPlaintextExtendedEdwards { .. }
794            | Gate::PlaintextPointToExtendedEdwards { .. }
795            | Gate::PlaintextKeccakF1600 { .. }
796            | Gate::CompressPlaintextPoint { .. }
797            | Gate::KeyRecoveryPlaintextComputeErrors { .. } => 0,
798        }
799    }
800
801    /// Computes the communication level of the gate.
802    ///
803    /// **Note: ** This function can panic if the gate is not valid.
804    fn comp_gate_level(&self, gate: &Gate<C>) -> GateLevel {
805        let comm_rounds = self.comp_gate_comm_rounds(gate);
806        match gate
807            .get_inputs()
808            .iter()
809            .map(|pred| self.gate_level_unchecked(*pred))
810            .max()
811        {
812            None => GateLevel::default(),
813            Some(preds_level) => preds_level.next(comm_rounds),
814        }
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
821
822    use crate::circuit::{AlgebraicType, Circuit, FieldShareBinaryOp, Gate, Input};
823
824    #[test]
825    fn test_circuit_new() {
826        let mut circuit = Circuit::<C>::new();
827
828        let x = circuit
829            .add_gate(Gate::Input(Input::SecretPlaintext {
830                inputer: 0,
831                algebraic_type: AlgebraicType::Mersenne107,
832                batch_size: 3,
833            }))
834            .unwrap();
835
836        let y = circuit
837            .add_gate(Gate::Input(Input::SecretPlaintext {
838                inputer: 0,
839                algebraic_type: AlgebraicType::Mersenne107,
840                batch_size: 3,
841            }))
842            .unwrap();
843
844        let z = circuit
845            .add_gate(Gate::FieldShareBinaryOp {
846                x,
847                y,
848                op: FieldShareBinaryOp::Mul,
849            })
850            .unwrap();
851
852        circuit.add_output(z).unwrap();
853
854        assert_eq!(circuit.nb_inputs(), 2);
855        assert_eq!(circuit.nb_gates(), 2 + 1);
856        assert_eq!(circuit.nb_outputs(), 1);
857    }
858}