Skip to main content

core_utils/circuit/
preprocessing.rs

1use std::{
2    collections::HashMap,
3    ops::{Add, AddAssign, Index, IndexMut},
4};
5
6use primitives::algebra::{elliptic_curve::Curve, BoxedUint};
7
8use crate::circuit::{
9    AlgebraicType,
10    BitShareBinaryOp,
11    Circuit,
12    FieldShareBinaryOp,
13    FieldShareUnaryOp,
14    FieldType,
15    Gate,
16    GateExt,
17    Input,
18    PointShareBinaryOp,
19    PointShareUnaryOp,
20    ShareOrPlaintext,
21};
22
23impl<C: Curve> Circuit<C> {
24    /// Counts all the preprocessing required by the circuit.
25    /// This includes scalar triples, base field triples, scalar singlets, base field singlets,
26    /// and base field pow pairs.
27    pub fn required_preprocessing(&self) -> CircuitPreprocessing {
28        let mut circuit_preprocessing = CircuitPreprocessing::default();
29        for gate in self.iter_gates_ext() {
30            self.add_to_required_preprocessing(gate, &mut circuit_preprocessing);
31        }
32        circuit_preprocessing
33    }
34
35    /// Updates the circuit preprocessing structure with the requirements of this gate.
36    pub fn add_to_required_preprocessing(
37        &self,
38        gate: &GateExt<C>,
39        circuit_preprocessing: &mut CircuitPreprocessing,
40    ) {
41        let batch_size = gate.output.get_batch_size() as usize;
42        match &gate.gate {
43            Gate::Input(Input::SecretPlaintext { algebraic_type, .. })
44            | Gate::Random { algebraic_type, .. } => match algebraic_type {
45                AlgebraicType::ScalarField | AlgebraicType::Point => {
46                    circuit_preprocessing.scalar.singlets += batch_size;
47                }
48                AlgebraicType::BaseField => {
49                    circuit_preprocessing.base_field.singlets += batch_size;
50                }
51                AlgebraicType::Bit => {
52                    circuit_preprocessing.bit_singlets += batch_size;
53                }
54                AlgebraicType::Mersenne107 => {
55                    circuit_preprocessing.mersenne107.singlets += batch_size;
56                }
57            },
58            Gate::FieldShareUnaryOp { op, .. } => {
59                let field_type = gate.output.get_field_type_unchecked();
60                match op {
61                    FieldShareUnaryOp::MulInverse | FieldShareUnaryOp::IsZero => {
62                        circuit_preprocessing[field_type].triples += batch_size;
63                        circuit_preprocessing[field_type].singlets += batch_size;
64                    }
65                    FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
66                }
67            }
68            Gate::FieldShareBinaryOp { op, y, .. } => match op {
69                FieldShareBinaryOp::Mul => {
70                    let field_type = gate.output.get_field_type_unchecked();
71                    if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
72                        circuit_preprocessing[field_type].triples += batch_size;
73                    }
74                }
75                FieldShareBinaryOp::Add => (),
76            },
77            Gate::PointShareUnaryOp { op, .. } => match op {
78                PointShareUnaryOp::IsZero => {
79                    circuit_preprocessing.scalar.triples += batch_size;
80                    circuit_preprocessing.scalar.singlets += batch_size;
81                }
82                PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
83            },
84            Gate::PointShareBinaryOp { op, p, y, .. } => match op {
85                PointShareBinaryOp::ScalarMul => {
86                    if self.gate_output_unchecked(*p).get_form() == ShareOrPlaintext::Share
87                        && self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share
88                    {
89                        circuit_preprocessing.scalar.triples += batch_size;
90                    }
91                }
92                PointShareBinaryOp::Add => (),
93            },
94            Gate::BitShareBinaryOp { op, y, .. } => match op {
95                BitShareBinaryOp::And | BitShareBinaryOp::Or => {
96                    if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
97                        circuit_preprocessing.bit_triples += batch_size;
98                    }
99                }
100                BitShareBinaryOp::Xor => (),
101            },
102            Gate::BaseFieldPow { exp, .. } => {
103                *circuit_preprocessing
104                    .base_field_pow_pairs
105                    .entry(exp.clone())
106                    .or_insert(0) += batch_size;
107                circuit_preprocessing.base_field.triples += batch_size;
108            }
109            Gate::DaBit { field_type, .. } => {
110                circuit_preprocessing[*field_type].dabits += batch_size
111            }
112
113            Gate::Input(_)
114            | Gate::Constant { .. }
115            | Gate::BatchSummation { .. }
116            | Gate::BitShareUnaryOp { .. }
117            | Gate::FieldPlaintextUnaryOp { .. }
118            | Gate::FieldPlaintextBinaryOp { .. }
119            | Gate::BitPlaintextUnaryOp { .. }
120            | Gate::BitPlaintextBinaryOp { .. }
121            | Gate::PointPlaintextUnaryOp { .. }
122            | Gate::PointPlaintextBinaryOp { .. }
123            | Gate::GetDaBitFieldShare { .. }
124            | Gate::GetDaBitSharedBit { .. }
125            | Gate::BitPlaintextToField { .. }
126            | Gate::FieldPlaintextToBit { .. }
127            | Gate::ExtractFromBatch { .. }
128            | Gate::CollectToBatch { .. }
129            | Gate::PointFromPlaintextExtendedEdwards { .. }
130            | Gate::PlaintextPointToExtendedEdwards { .. }
131            | Gate::PlaintextKeccakF1600 { .. }
132            | Gate::CompressPlaintextPoint { .. }
133            | Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
134        };
135    }
136}
137
138/// Field specific preprocessing requirements for a circuit.
139#[derive(Debug, Clone, Default, PartialEq, Eq)]
140pub struct FieldCircuitPreprocessing {
141    pub singlets: usize,
142    pub triples: usize,
143    pub dabits: usize,
144}
145
146impl AddAssign for FieldCircuitPreprocessing {
147    fn add_assign(&mut self, rhs: Self) {
148        self.singlets += rhs.singlets;
149        self.triples += rhs.triples;
150        self.dabits += rhs.dabits;
151    }
152}
153
154impl Add for FieldCircuitPreprocessing {
155    type Output = Self;
156    fn add(self, rhs: Self) -> Self::Output {
157        let mut res = self;
158        res += rhs;
159        res
160    }
161}
162
163/// Preprocessing requirements for a circuit.
164#[derive(Debug, Clone, Default, PartialEq, Eq)]
165pub struct CircuitPreprocessing {
166    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
167    pub bit_singlets: usize,
168    pub bit_triples: usize,
169    pub base_field: FieldCircuitPreprocessing,
170    pub scalar: FieldCircuitPreprocessing,
171    pub mersenne107: FieldCircuitPreprocessing,
172}
173
174impl AddAssign for CircuitPreprocessing {
175    fn add_assign(&mut self, rhs: Self) {
176        self.bit_singlets += rhs.bit_singlets;
177        self.bit_triples += rhs.bit_triples;
178        self.base_field += rhs.base_field;
179        self.scalar += rhs.scalar;
180        self.mersenne107 += rhs.mersenne107;
181        for (k, v) in rhs.base_field_pow_pairs {
182            *self.base_field_pow_pairs.entry(k).or_insert(0) += v;
183        }
184    }
185}
186
187impl Add for CircuitPreprocessing {
188    type Output = Self;
189
190    fn add(self, other: Self) -> Self::Output {
191        let mut res = self;
192        res += other;
193        res
194    }
195}
196
197impl Index<FieldType> for CircuitPreprocessing {
198    type Output = FieldCircuitPreprocessing;
199
200    fn index(&self, index: FieldType) -> &Self::Output {
201        match index {
202            FieldType::BaseField => &self.base_field,
203            FieldType::ScalarField => &self.scalar,
204            FieldType::Mersenne107 => &self.mersenne107,
205        }
206    }
207}
208
209impl IndexMut<FieldType> for CircuitPreprocessing {
210    fn index_mut(&mut self, index: FieldType) -> &mut Self::Output {
211        match index {
212            FieldType::BaseField => &mut self.base_field,
213            FieldType::ScalarField => &mut self.scalar,
214            FieldType::Mersenne107 => &mut self.mersenne107,
215        }
216    }
217}