1use std::{collections::HashMap, ops::Add};
2
3use macros::GateMethods;
4use primitives::algebra::{
5 elliptic_curve::{Curve, Point, Scalar},
6 BoxedUint,
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11 circuit::{
12 AlgebraicType,
13 Batched,
14 BitShareBinaryOp,
15 BitShareUnaryOp,
16 FieldPlaintextBinaryOp,
17 FieldPlaintextUnaryOp,
18 FieldShareBinaryOp,
19 FieldShareUnaryOp,
20 FieldType,
21 Input,
22 PointPlaintextBinaryOp,
23 PointPlaintextUnaryOp,
24 PointShareBinaryOp,
25 PointShareUnaryOp,
26 ShareOrPlaintext,
27 },
28 types::Label,
29};
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, GateMethods)]
33#[serde(bound(
34 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
35 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
36))]
37pub enum Gate<C: Curve> {
38 Input {
40 input_type: Input<C>,
41 },
42 FieldShareUnaryOp {
44 x: Label,
45 op: FieldShareUnaryOp,
46 field_type: FieldType,
47 },
48 FieldShareBinaryOp {
50 x: Label,
51 y: Label,
52 y_form: ShareOrPlaintext,
53 op: FieldShareBinaryOp,
54 field_type: FieldType,
55 },
56 BatchSummation {
57 x: Label,
58 x_form: ShareOrPlaintext,
59 algebraic_type: AlgebraicType,
60 },
61 BitShareUnaryOp {
62 x: Label,
63 op: BitShareUnaryOp,
64 },
65 BitShareBinaryOp {
66 x: Label,
67 y: Label,
68 y_form: ShareOrPlaintext,
69 op: BitShareBinaryOp,
70 },
71 PointShareUnaryOp {
73 p: Label,
74 op: PointShareUnaryOp,
75 },
76 PointShareBinaryOp {
77 p: Label,
78 y: Label,
79 p_form: ShareOrPlaintext,
80 y_form: ShareOrPlaintext,
81 op: PointShareBinaryOp,
82 },
83 FieldPlaintextUnaryOp {
85 x: Label,
86 op: FieldPlaintextUnaryOp,
87 field_type: FieldType,
88 },
89 FieldPlaintextBinaryOp {
91 x: Label,
92 y: Label,
93 op: FieldPlaintextBinaryOp,
94 field_type: FieldType,
95 },
96 BitPlaintextUnaryOp {
97 x: Label,
98 op: FieldPlaintextUnaryOp,
99 },
100 BitPlaintextBinaryOp {
101 x: Label,
102 y: Label,
103 op: FieldPlaintextBinaryOp,
104 },
105 PointPlaintextUnaryOp {
106 p: Label,
107 op: PointPlaintextUnaryOp,
108 },
109 PointPlaintextBinaryOp {
110 p: Label,
111 y: Label,
112 op: PointPlaintextBinaryOp,
113 },
114 DaBit {
116 field_type: FieldType,
117 batched: Batched,
118 },
119 GetDaBitFieldShare {
120 x: Label,
121 field_type: FieldType,
122 },
123 GetDaBitSharedBit {
124 x: Label,
125 field_type: FieldType,
126 },
127 EncryptPoint {
129 x: Label,
130 c: Point<C>,
131 },
132 DecryptPoint {
133 x: Label,
134 y: Label,
135 },
136 BaseFieldPow {
138 x: Label,
139 exp: BoxedUint,
140 },
141 BitPlaintextToField {
143 x: Label,
144 field_type: FieldType,
145 },
146 FieldPlaintextToBit {
147 x: Label,
148 field_type: FieldType,
149 },
150 BatchGetIndex {
152 x: Label,
153 x_type: AlgebraicType,
154 x_form: ShareOrPlaintext,
155 index: usize,
156 },
157 CollectToBatch {
158 wires: Vec<Label>,
159 x_type: AlgebraicType,
160 x_form: ShareOrPlaintext,
161 },
162 PointFromPlaintextExtendedEdwardsUnchecked {
163 wires: Vec<Label>,
164 },
165 PlaintextPointToExtendedEdwards {
166 point: Label,
167 },
168 PlaintextKeccakF1600 {
169 wires: Vec<Label>,
170 },
171}
172
173impl<C: Curve> Gate<C> {
174 pub fn get_labels(&self) -> Vec<Label> {
176 let mut labels = Vec::new();
177 self.for_each_label(|label| labels.push(label));
178 labels
179 }
180}
181
182#[derive(Debug, Clone, Default, PartialEq, Eq)]
183pub struct CircuitPreprocessing {
184 pub scalar_singlets: usize,
185 pub scalar_triples: usize,
186 pub base_field_singlets: usize,
187 pub base_field_triples: usize,
188 pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
189 pub bit_singlets: usize,
190 pub bit_triples: usize,
191 pub mersenne107_dabits: usize,
192 pub mersenne107_singlets: usize,
193 pub mersenne107_triples: usize,
194 pub scalar_dabits: usize,
195 pub base_field_dabits: usize,
196}
197
198impl Add for CircuitPreprocessing {
199 type Output = Self;
200
201 fn add(self, other: Self) -> Self::Output {
202 Self {
203 scalar_singlets: self.scalar_singlets + other.scalar_singlets,
204 scalar_triples: self.scalar_triples + other.scalar_triples,
205 base_field_singlets: self.base_field_singlets + other.base_field_singlets,
206 base_field_triples: self.base_field_triples + other.base_field_triples,
207 bit_singlets: self.bit_singlets + other.bit_singlets,
208 bit_triples: self.bit_triples + other.bit_triples,
209 mersenne107_dabits: self.mersenne107_dabits + other.mersenne107_dabits,
210 mersenne107_singlets: self.mersenne107_singlets + other.mersenne107_singlets,
211 mersenne107_triples: self.mersenne107_triples + other.mersenne107_triples,
212 scalar_dabits: self.scalar_dabits + other.scalar_dabits,
213 base_field_dabits: self.base_field_dabits + other.base_field_dabits,
214 base_field_pow_pairs: {
215 let mut combined = self.base_field_pow_pairs;
216 for (k, v) in other.base_field_pow_pairs {
217 *combined.entry(k).or_insert(0) += v;
218 }
219 combined
220 },
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use std::collections::HashSet;
228
229 use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
230
231 use super::*;
232 use crate::circuit::FieldShareBinaryOp;
233
234 #[test]
235 fn test_ser_gate() {
236 let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
237 x: Label::from(1, 2),
238 y: Label::from(3, 4),
239 y_form: ShareOrPlaintext::Share,
240 op: FieldShareBinaryOp::Add,
241 field_type: FieldType::ScalarField,
242 };
243 let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
244 x: Label::from(1, 2),
245 y: Label::from(3, 4),
246 y_form: ShareOrPlaintext::Plaintext,
247 op: FieldShareBinaryOp::Add,
248 field_type: FieldType::ScalarField,
249 };
250 let point_gate: Gate<C> = Gate::PointShareBinaryOp {
251 p: Label::from(1, 2),
252 y: Label::from(3, 4),
253 p_form: ShareOrPlaintext::Share,
254 y_form: ShareOrPlaintext::Plaintext,
255 op: PointShareBinaryOp::Add,
256 };
257
258 let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
259 let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
260 let point_gate_ser = bincode::serialize(&point_gate).unwrap();
261
262 let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
263 let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
264 let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
265
266 assert_eq!(no_curve_gate, no_curve_gate_de);
267 assert_eq!(scalar_gate, scalar_gate_de);
268 assert_eq!(point_gate, point_gate_de);
269 let set = HashSet::from([
270 no_curve_gate,
271 no_curve_gate_de,
272 scalar_gate,
273 scalar_gate_de,
274 point_gate,
275 point_gate_de,
276 ]);
277 assert_eq!(set.len(), 3)
278 }
279
280 #[test]
281 fn test_circuit_preprocessing_add() {
282 let a = CircuitPreprocessing {
283 scalar_singlets: 1,
284 scalar_triples: 2,
285 base_field_singlets: 3,
286 base_field_triples: 4,
287 bit_singlets: 0,
288 bit_triples: 1,
289 mersenne107_dabits: 0,
290 mersenne107_singlets: 0,
291 mersenne107_triples: 0,
292 scalar_dabits: 1,
293 base_field_dabits: 2,
294 base_field_pow_pairs: vec![
295 (BoxedUint::from(vec![21]), 5),
296 (BoxedUint::from(vec![14]), 6),
297 ]
298 .into_iter()
299 .collect(),
300 };
301 let b = CircuitPreprocessing {
302 scalar_singlets: 2,
303 scalar_triples: 3,
304 base_field_singlets: 0,
305 base_field_triples: 5,
306 bit_singlets: 3,
307 bit_triples: 4,
308 mersenne107_dabits: 0,
309 mersenne107_singlets: 3,
310 mersenne107_triples: 2,
311 scalar_dabits: 2,
312 base_field_dabits: 3,
313 base_field_pow_pairs: vec![
314 (BoxedUint::from(vec![21]), 6),
315 (BoxedUint::from(vec![13]), 7),
316 ]
317 .into_iter()
318 .collect(),
319 };
320
321 let c = a + b;
322
323 assert_eq!(c.scalar_singlets, 3);
324 assert_eq!(c.scalar_triples, 5);
325 assert_eq!(c.base_field_singlets, 3);
326 assert_eq!(c.base_field_triples, 9);
327 assert_eq!(c.bit_singlets, 3);
328 assert_eq!(c.bit_triples, 5);
329 assert_eq!(c.mersenne107_dabits, 0);
330 assert_eq!(c.mersenne107_singlets, 3);
331 assert_eq!(c.mersenne107_triples, 2);
332 assert_eq!(c.scalar_dabits, 3);
333 assert_eq!(c.base_field_dabits, 5);
334 assert_eq!(
335 c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
336 Some(&11)
337 );
338 assert_eq!(
339 c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
340 Some(&6)
341 );
342 assert_eq!(
343 c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
344 Some(&7)
345 );
346 }
347}