core_utils/circuit/v2/
preprocessing.rs1use std::{
2 collections::HashMap,
3 ops::{Add, AddAssign, Index, IndexMut},
4};
5
6use primitives::algebra::{elliptic_curve::Curve, BoxedUint};
7
8use crate::circuit::{
9 AlgebraicType,
10 BitShareBinaryOp,
11 Circuit,
12 FieldShareBinaryOp,
13 FieldShareUnaryOp,
14 FieldType,
15 Gate,
16 GateExt,
17 Input,
18 PointShareBinaryOp,
19 PointShareUnaryOp,
20 ShareOrPlaintext,
21};
22
23impl<C: Curve> Circuit<C> {
24 pub fn required_preprocessing(&self) -> CircuitPreprocessing {
28 let mut circuit_preprocessing = CircuitPreprocessing::default();
29 for gate in self.iter_gates_ext() {
30 self.add_to_required_preprocessing(gate, &mut circuit_preprocessing);
31 }
32 circuit_preprocessing
33 }
34
35 pub fn add_to_required_preprocessing(
37 &self,
38 gate: &GateExt<C>,
39 circuit_preprocessing: &mut CircuitPreprocessing,
40 ) {
41 let batch_size = gate.output.get_batch_size() as usize;
42 match &gate.gate {
43 Gate::Input(Input::SecretPlaintext { algebraic_type, .. })
44 | Gate::Random { algebraic_type, .. } => match algebraic_type {
45 AlgebraicType::ScalarField | AlgebraicType::Point => {
46 circuit_preprocessing.scalar.singlets += batch_size;
47 }
48 AlgebraicType::BaseField => {
49 circuit_preprocessing.base_field.singlets += batch_size;
50 }
51 AlgebraicType::Bit => {
52 circuit_preprocessing.bit_singlets += batch_size;
53 }
54 AlgebraicType::Mersenne107 => {
55 circuit_preprocessing.mersenne107.singlets += batch_size;
56 }
57 },
58 Gate::FieldShareUnaryOp { op, .. } => {
59 let field_type = gate.output.get_field_type_unchecked();
60 match op {
61 FieldShareUnaryOp::MulInverse | FieldShareUnaryOp::IsZero => {
62 circuit_preprocessing[field_type].triples += batch_size;
63 circuit_preprocessing[field_type].singlets += batch_size;
64 }
65 FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
66 }
67 }
68 Gate::FieldShareBinaryOp { op, y, .. } => match op {
69 FieldShareBinaryOp::Mul => {
70 let field_type = gate.output.get_field_type_unchecked();
71 if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
72 circuit_preprocessing[field_type].triples += batch_size;
73 }
74 }
75 FieldShareBinaryOp::Add => (),
76 },
77 Gate::PointShareUnaryOp { op, .. } => match op {
78 PointShareUnaryOp::IsZero => {
79 circuit_preprocessing.scalar.triples += batch_size;
80 circuit_preprocessing.scalar.singlets += batch_size;
81 }
82 PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
83 },
84 Gate::PointShareBinaryOp { op, p, y, .. } => match op {
85 PointShareBinaryOp::ScalarMul => {
86 if self.gate_output_unchecked(*p).get_form() == ShareOrPlaintext::Share
87 && self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share
88 {
89 circuit_preprocessing.scalar.triples += batch_size;
90 }
91 }
92 PointShareBinaryOp::Add => (),
93 },
94 Gate::BitShareBinaryOp { op, y, .. } => match op {
95 BitShareBinaryOp::And | BitShareBinaryOp::Or => {
96 if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
97 circuit_preprocessing.bit_triples += batch_size;
98 }
99 }
100 BitShareBinaryOp::Xor => (),
101 },
102 Gate::BaseFieldPow { exp, .. } => {
103 *circuit_preprocessing
104 .base_field_pow_pairs
105 .entry(exp.clone())
106 .or_insert(0) += batch_size;
107 circuit_preprocessing.base_field.triples += batch_size;
108 }
109 Gate::DaBit { field_type, .. } => {
110 circuit_preprocessing[*field_type].dabits += batch_size
111 }
112
113 Gate::Input(_)
114 | Gate::Constant { .. }
115 | Gate::BatchSummation { .. }
116 | Gate::BitShareUnaryOp { .. }
117 | Gate::FieldPlaintextUnaryOp { .. }
118 | Gate::FieldPlaintextBinaryOp { .. }
119 | Gate::BitPlaintextUnaryOp { .. }
120 | Gate::BitPlaintextBinaryOp { .. }
121 | Gate::PointPlaintextUnaryOp { .. }
122 | Gate::PointPlaintextBinaryOp { .. }
123 | Gate::GetDaBitFieldShare { .. }
124 | Gate::GetDaBitSharedBit { .. }
125 | Gate::BitPlaintextToField { .. }
126 | Gate::FieldPlaintextToBit { .. }
127 | Gate::ExtractFromBatch { .. }
128 | Gate::CollectToBatch { .. }
129 | Gate::PointFromPlaintextExtendedEdwards { .. }
130 | Gate::PlaintextPointToExtendedEdwards { .. }
131 | Gate::PlaintextKeccakF1600 { .. }
132 | Gate::CompressPlaintextPoint { .. }
133 | Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
134 };
135 }
136}
137
138#[derive(Debug, Clone, Default, PartialEq, Eq)]
140pub struct FieldCircuitPreprocessing {
141 pub singlets: usize,
142 pub triples: usize,
143 pub dabits: usize,
144}
145
146impl AddAssign for FieldCircuitPreprocessing {
147 fn add_assign(&mut self, rhs: Self) {
148 self.singlets += rhs.singlets;
149 self.triples += rhs.triples;
150 self.dabits += rhs.dabits;
151 }
152}
153
154impl Add for FieldCircuitPreprocessing {
155 type Output = Self;
156 fn add(self, rhs: Self) -> Self::Output {
157 let mut res = self;
158 res += rhs;
159 res
160 }
161}
162
163#[derive(Debug, Clone, Default, PartialEq, Eq)]
165pub struct CircuitPreprocessing {
166 pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
167 pub bit_singlets: usize,
168 pub bit_triples: usize,
169 pub base_field: FieldCircuitPreprocessing,
170 pub scalar: FieldCircuitPreprocessing,
171 pub mersenne107: FieldCircuitPreprocessing,
172}
173
174impl AddAssign for CircuitPreprocessing {
175 fn add_assign(&mut self, rhs: Self) {
176 self.bit_singlets += rhs.bit_singlets;
177 self.bit_triples += rhs.bit_triples;
178 self.base_field += rhs.base_field;
179 self.scalar += rhs.scalar;
180 self.mersenne107 += rhs.mersenne107;
181 for (k, v) in rhs.base_field_pow_pairs {
182 *self.base_field_pow_pairs.entry(k).or_insert(0) += v;
183 }
184 }
185}
186
187impl Add for CircuitPreprocessing {
188 type Output = Self;
189
190 fn add(self, other: Self) -> Self::Output {
191 let mut res = self;
192 res += other;
193 res
194 }
195}
196
197impl Index<FieldType> for CircuitPreprocessing {
198 type Output = FieldCircuitPreprocessing;
199
200 fn index(&self, index: FieldType) -> &Self::Output {
201 match index {
202 FieldType::BaseField => &self.base_field,
203 FieldType::ScalarField => &self.scalar,
204 FieldType::Mersenne107 => &self.mersenne107,
205 }
206 }
207}
208
209impl IndexMut<FieldType> for CircuitPreprocessing {
210 fn index_mut(&mut self, index: FieldType) -> &mut Self::Output {
211 match index {
212 FieldType::BaseField => &mut self.base_field,
213 FieldType::ScalarField => &mut self.scalar,
214 FieldType::Mersenne107 => &mut self.mersenne107,
215 }
216 }
217}