1use std::hash::Hash;
2
3use primitives::algebra::{
4 elliptic_curve::{Curve, Point, Scalar},
5 BoxedUint,
6};
7use serde::{Deserialize, Serialize};
8use wincode::{SchemaRead, SchemaWrite};
9
10use crate::circuit::{
11 errors::CircuitError,
12 AlgebraicType,
13 BatchSize,
14 BitPlaintextBinaryOp,
15 BitPlaintextUnaryOp,
16 BitShareBinaryOp,
17 BitShareUnaryOp,
18 Constant,
19 FieldPlaintextBinaryOp,
20 FieldPlaintextUnaryOp,
21 FieldShareBinaryOp,
22 FieldShareUnaryOp,
23 FieldType,
24 GateIndex,
25 Input,
26 PointPlaintextBinaryOp,
27 PointPlaintextUnaryOp,
28 PointShareBinaryOp,
29 PointShareUnaryOp,
30 Slice,
31};
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
35#[serde(bound(
36 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
37 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
38))]
39#[repr(C)]
40pub enum Gate<C: Curve> {
41 Input(Input),
43 Constant(Constant<C>),
45 Random {
47 algebraic_type: AlgebraicType,
48 batch_size: BatchSize,
49 },
50 FieldShareUnaryOp {
52 x: GateIndex,
53 op: FieldShareUnaryOp,
54 },
55 FieldShareBinaryOp {
57 x: GateIndex,
58 y: GateIndex,
59 op: FieldShareBinaryOp,
60 },
61 BatchSummation {
62 x: GateIndex,
63 },
64 BitShareUnaryOp {
65 x: GateIndex,
66 op: BitShareUnaryOp,
67 },
68 BitShareBinaryOp {
69 x: GateIndex,
70 y: GateIndex,
71 op: BitShareBinaryOp,
72 },
73 PointShareUnaryOp {
75 p: GateIndex,
76 op: PointShareUnaryOp,
77 },
78 PointShareBinaryOp {
79 p: GateIndex,
80 y: GateIndex,
81 op: PointShareBinaryOp,
82 },
83 FieldPlaintextUnaryOp {
85 x: GateIndex,
86 op: FieldPlaintextUnaryOp,
87 },
88 FieldPlaintextBinaryOp {
90 x: GateIndex,
91 y: GateIndex,
92 op: FieldPlaintextBinaryOp,
93 },
94 BitPlaintextUnaryOp {
95 x: GateIndex,
96 op: BitPlaintextUnaryOp,
97 },
98 BitPlaintextBinaryOp {
99 x: GateIndex,
100 y: GateIndex,
101 op: BitPlaintextBinaryOp,
102 },
103 PointPlaintextUnaryOp {
104 p: GateIndex,
105 op: PointPlaintextUnaryOp,
106 },
107 PointPlaintextBinaryOp {
108 p: GateIndex,
109 y: GateIndex,
110 op: PointPlaintextBinaryOp,
111 },
112 DaBit {
114 field_type: FieldType,
115 batch_size: BatchSize,
116 },
117 GetDaBitFieldShare {
118 x: GateIndex,
119 },
120 GetDaBitSharedBit {
121 x: GateIndex,
122 },
123 BaseFieldPow {
125 x: GateIndex,
126 exp: BoxedUint,
127 },
128 BitPlaintextToField {
130 x: GateIndex,
131 field_type: FieldType,
132 },
133 FieldPlaintextToBit {
134 x: GateIndex,
135 },
136 ExtractFromBatch {
138 x: GateIndex,
139 slice: Slice,
140 },
141 CollectToBatch {
142 wires: Vec<GateIndex>,
143 },
144 PointFromPlaintextExtendedEdwards {
145 wires: Vec<GateIndex>,
146 },
147 PlaintextPointToExtendedEdwards {
148 point: GateIndex,
149 },
150 PlaintextKeccakF1600 {
151 x: GateIndex,
152 },
153 CompressPlaintextPoint {
154 point: GateIndex,
155 },
156 KeyRecoveryPlaintextComputeErrors {
157 d_minus_one: GateIndex,
158 syndromes: GateIndex,
159 },
160}
161
162impl<C: Curve> Gate<C> {
163 pub fn is_input(&self) -> bool {
165 matches!(self, Gate::Input { .. })
166 }
167
168 pub fn get_inputs(&self) -> Vec<GateIndex> {
170 match &self {
171 Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => {
172 Vec::new()
173 }
174
175 Gate::FieldShareUnaryOp { x, .. }
176 | Gate::BatchSummation { x, .. }
177 | Gate::BitShareUnaryOp { x, .. }
178 | Gate::PointShareUnaryOp { p: x, .. }
179 | Gate::FieldPlaintextUnaryOp { x, .. }
180 | Gate::BitPlaintextUnaryOp { x, .. }
181 | Gate::PointPlaintextUnaryOp { p: x, .. }
182 | Gate::GetDaBitFieldShare { x, .. }
183 | Gate::GetDaBitSharedBit { x, .. }
184 | Gate::BaseFieldPow { x, .. }
185 | Gate::BitPlaintextToField { x, .. }
186 | Gate::FieldPlaintextToBit { x, .. }
187 | Gate::ExtractFromBatch { x, .. }
188 | Gate::PlaintextPointToExtendedEdwards { point: x, .. }
189 | Gate::CompressPlaintextPoint { point: x, .. }
190 | Gate::PlaintextKeccakF1600 { x } => {
191 vec![*x]
192 }
193
194 Gate::FieldShareBinaryOp { x, y, .. }
195 | Gate::BitShareBinaryOp { x, y, .. }
196 | Gate::PointShareBinaryOp { p: x, y, .. }
197 | Gate::FieldPlaintextBinaryOp { x, y, .. }
198 | Gate::BitPlaintextBinaryOp { x, y, .. }
199 | Gate::PointPlaintextBinaryOp { p: x, y, .. }
200 | Gate::KeyRecoveryPlaintextComputeErrors {
201 d_minus_one: x,
202 syndromes: y,
203 ..
204 } => {
205 vec![*x, *y]
206 }
207
208 Gate::CollectToBatch { wires, .. }
209 | Gate::PointFromPlaintextExtendedEdwards { wires, .. } => wires.clone(),
210 }
211 }
212
213 pub fn map_inputs<F: FnMut(GateIndex) -> GateIndex>(mut self, mut f: F) -> Self {
215 match &mut self {
216 Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => (),
217
218 Gate::FieldShareUnaryOp { x, .. }
219 | Gate::BatchSummation { x, .. }
220 | Gate::BitShareUnaryOp { x, .. }
221 | Gate::PointShareUnaryOp { p: x, .. }
222 | Gate::FieldPlaintextUnaryOp { x, .. }
223 | Gate::BitPlaintextUnaryOp { x, .. }
224 | Gate::PointPlaintextUnaryOp { p: x, .. }
225 | Gate::GetDaBitFieldShare { x, .. }
226 | Gate::GetDaBitSharedBit { x, .. }
227 | Gate::BaseFieldPow { x, .. }
228 | Gate::BitPlaintextToField { x, .. }
229 | Gate::FieldPlaintextToBit { x, .. }
230 | Gate::ExtractFromBatch { x, .. }
231 | Gate::PlaintextPointToExtendedEdwards { point: x, .. }
232 | Gate::CompressPlaintextPoint { point: x, .. }
233 | Gate::PlaintextKeccakF1600 { x } => {
234 *x = f(*x);
235 }
236
237 Gate::FieldShareBinaryOp { x, y, .. }
238 | Gate::BitShareBinaryOp { x, y, .. }
239 | Gate::PointShareBinaryOp { p: x, y, .. }
240 | Gate::FieldPlaintextBinaryOp { x, y, .. }
241 | Gate::BitPlaintextBinaryOp { x, y, .. }
242 | Gate::PointPlaintextBinaryOp { p: x, y, .. }
243 | Gate::KeyRecoveryPlaintextComputeErrors {
244 d_minus_one: x,
245 syndromes: y,
246 ..
247 } => {
248 *x = f(*x);
249 *y = f(*y);
250 }
251
252 Gate::CollectToBatch { wires, .. }
253 | Gate::PointFromPlaintextExtendedEdwards { wires, .. } => {
254 wires.iter_mut().for_each(|x| *x = f(*x))
255 }
256 };
257
258 self
259 }
260
261 pub fn try_replace_inputs(mut self, inputs: Vec<GateIndex>) -> Result<Self, CircuitError<C>> {
266 if inputs.len() != self.get_inputs().len() {
267 return Err(CircuitError::InvalidGateInputCount {
268 expected: self.get_inputs().len(),
269 found: inputs.len(),
270 });
271 }
272
273 match &mut self {
274 Gate::Input(_) | Gate::Random { .. } | Gate::Constant(_) | Gate::DaBit { .. } => (),
275
276 Gate::FieldShareUnaryOp { x, .. }
277 | Gate::BatchSummation { x, .. }
278 | Gate::BitShareUnaryOp { x, .. }
279 | Gate::PointShareUnaryOp { p: x, .. }
280 | Gate::FieldPlaintextUnaryOp { x, .. }
281 | Gate::BitPlaintextUnaryOp { x, .. }
282 | Gate::PointPlaintextUnaryOp { p: x, .. }
283 | Gate::GetDaBitFieldShare { x, .. }
284 | Gate::GetDaBitSharedBit { x, .. }
285 | Gate::BaseFieldPow { x, .. }
286 | Gate::BitPlaintextToField { x, .. }
287 | Gate::FieldPlaintextToBit { x, .. }
288 | Gate::ExtractFromBatch { x, .. }
289 | Gate::PlaintextPointToExtendedEdwards { point: x, .. }
290 | Gate::CompressPlaintextPoint { point: x, .. }
291 | Gate::PlaintextKeccakF1600 { x } => {
292 *x = inputs[0];
293 }
294
295 Gate::FieldShareBinaryOp { x, y, .. }
296 | Gate::BitShareBinaryOp { x, y, .. }
297 | Gate::PointShareBinaryOp { p: x, y, .. }
298 | Gate::FieldPlaintextBinaryOp { x, y, .. }
299 | Gate::BitPlaintextBinaryOp { x, y, .. }
300 | Gate::PointPlaintextBinaryOp { p: x, y, .. }
301 | Gate::KeyRecoveryPlaintextComputeErrors {
302 d_minus_one: x,
303 syndromes: y,
304 ..
305 } => {
306 *x = inputs[0];
307 *y = inputs[1];
308 }
309
310 Gate::CollectToBatch { wires, .. }
311 | Gate::PointFromPlaintextExtendedEdwards { wires, .. } => *wires = inputs,
312 };
313
314 Ok(self)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use std::collections::HashSet;
321
322 use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
323
324 use super::*;
325 use crate::circuit::{
326 preprocessing::{CircuitPreprocessing, FieldCircuitPreprocessing},
327 FieldShareBinaryOp,
328 };
329
330 #[test]
331 fn test_ser_gate() {
332 let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
333 x: 1,
334 y: 3,
335 op: FieldShareBinaryOp::Add,
336 };
337 let point_gate: Gate<C> = Gate::PointShareBinaryOp {
338 p: 1,
339 y: 3,
340 op: PointShareBinaryOp::Add,
341 };
342
343 let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
344 let point_gate_ser = bincode::serialize(&point_gate).unwrap();
345
346 let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
347 let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
348
349 assert_eq!(scalar_gate, scalar_gate_de);
350 assert_eq!(point_gate, point_gate_de);
351 let set = HashSet::from([scalar_gate, scalar_gate_de, point_gate, point_gate_de]);
352 assert_eq!(set.len(), 2)
353 }
354
355 #[test]
356 fn test_circuit_preprocessing_add() {
357 let a = CircuitPreprocessing {
358 bit_singlets: 0,
359 bit_triples: 1,
360 base_field: FieldCircuitPreprocessing {
361 singlets: 3,
362 triples: 4,
363 dabits: 2,
364 },
365 scalar: FieldCircuitPreprocessing {
366 singlets: 1,
367 triples: 2,
368 dabits: 1,
369 },
370 base_field_pow_pairs: vec![
371 (BoxedUint::from(vec![21]), 5),
372 (BoxedUint::from(vec![14]), 6),
373 ]
374 .into_iter()
375 .collect(),
376 mersenne107: FieldCircuitPreprocessing {
377 singlets: 0,
378 triples: 0,
379 dabits: 0,
380 },
381 };
382 let b = CircuitPreprocessing {
383 bit_singlets: 3,
384 bit_triples: 4,
385 base_field: FieldCircuitPreprocessing {
386 singlets: 0,
387 triples: 5,
388 dabits: 3,
389 },
390 scalar: FieldCircuitPreprocessing {
391 singlets: 2,
392 triples: 3,
393 dabits: 2,
394 },
395 base_field_pow_pairs: vec![
396 (BoxedUint::from(vec![21]), 6),
397 (BoxedUint::from(vec![13]), 7),
398 ]
399 .into_iter()
400 .collect(),
401 mersenne107: FieldCircuitPreprocessing {
402 singlets: 3,
403 triples: 2,
404 dabits: 0,
405 },
406 };
407
408 let c = a + b;
409
410 assert_eq!(c.scalar.singlets, 3);
411 assert_eq!(c.scalar.triples, 5);
412 assert_eq!(c.base_field.singlets, 3);
413 assert_eq!(c.base_field.triples, 9);
414 assert_eq!(c.bit_singlets, 3);
415 assert_eq!(c.bit_triples, 5);
416 assert_eq!(c.mersenne107.dabits, 0);
417 assert_eq!(c.mersenne107.singlets, 3);
418 assert_eq!(c.mersenne107.triples, 2);
419 assert_eq!(c.scalar.dabits, 3);
420 assert_eq!(c.base_field.dabits, 5);
421 assert_eq!(
422 c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
423 Some(&11)
424 );
425 assert_eq!(
426 c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
427 Some(&6)
428 );
429 assert_eq!(
430 c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
431 Some(&7)
432 );
433 }
434}