1use std::{collections::HashMap, ops::Add};
2
3use primitives::algebra::{
4 elliptic_curve::{Curve, Point, Scalar},
5 BoxedUint,
6};
7use serde::{Deserialize, Serialize};
8
9use crate::{
10 circuit::{
11 AlgebraicType,
12 Batched,
13 BitShareBinaryOp,
14 BitShareUnaryOp,
15 FieldPlaintextBinaryOp,
16 FieldPlaintextUnaryOp,
17 FieldShareBinaryOp,
18 FieldShareUnaryOp,
19 FieldType,
20 Input,
21 PointPlaintextBinaryOp,
22 PointPlaintextUnaryOp,
23 PointShareBinaryOp,
24 PointShareUnaryOp,
25 ShareOrPlaintext,
26 },
27 types::Label,
28};
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32#[serde(bound(
33 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
34 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
35))]
36pub enum Gate<C: Curve> {
37 Input {
39 input_type: Input<C>,
40 },
41 FieldShareUnaryOp {
43 x: Label,
44 op: FieldShareUnaryOp,
45 field_type: FieldType,
46 },
47 FieldShareBinaryOp {
49 x: Label,
50 y: Label,
51 y_form: ShareOrPlaintext,
52 op: FieldShareBinaryOp,
53 field_type: FieldType,
54 },
55 BatchSummation {
56 x: Label,
57 x_form: ShareOrPlaintext,
58 algebraic_type: AlgebraicType,
59 },
60 BitShareUnaryOp {
61 x: Label,
62 op: BitShareUnaryOp,
63 },
64 BitShareBinaryOp {
65 x: Label,
66 y: Label,
67 y_form: ShareOrPlaintext,
68 op: BitShareBinaryOp,
69 },
70 PointShareUnaryOp {
72 p: Label,
73 op: PointShareUnaryOp,
74 },
75 PointShareBinaryOp {
76 p: Label,
77 y: Label,
78 p_form: ShareOrPlaintext,
79 y_form: ShareOrPlaintext,
80 op: PointShareBinaryOp,
81 },
82 FieldPlaintextUnaryOp {
84 x: Label,
85 op: FieldPlaintextUnaryOp,
86 field_type: FieldType,
87 },
88 FieldPlaintextBinaryOp {
90 x: Label,
91 y: Label,
92 op: FieldPlaintextBinaryOp,
93 field_type: FieldType,
94 },
95 BitPlaintextUnaryOp {
96 x: Label,
97 op: FieldPlaintextUnaryOp,
98 },
99 BitPlaintextBinaryOp {
100 x: Label,
101 y: Label,
102 op: FieldPlaintextBinaryOp,
103 },
104 PointPlaintextUnaryOp {
105 p: Label,
106 op: PointPlaintextUnaryOp,
107 },
108 PointPlaintextBinaryOp {
109 p: Label,
110 y: Label,
111 op: PointPlaintextBinaryOp,
112 },
113 DaBit {
115 field_type: FieldType,
116 batched: Batched,
117 },
118 GetDaBitFieldShare {
119 x: Label,
120 field_type: FieldType,
121 },
122 GetDaBitSharedBit {
123 x: Label,
124 field_type: FieldType,
125 },
126 EncryptPoint {
128 x: Label,
129 c: Point<C>,
130 },
131 DecryptPoint {
132 x: Label,
133 y: Label,
134 },
135 BaseFieldPow {
137 x: Label,
138 exp: BoxedUint,
139 },
140 BitPlaintextToField {
142 x: Label,
143 field_type: FieldType,
144 },
145 FieldPlaintextToBit {
146 x: Label,
147 field_type: FieldType,
148 },
149 BatchGetIndex {
151 x: Label,
152 x_type: AlgebraicType,
153 x_form: ShareOrPlaintext,
154 index: usize,
155 },
156 CollectToBatch {
157 wires: Vec<Label>,
158 x_type: AlgebraicType,
159 x_form: ShareOrPlaintext,
160 },
161 PointFromPlaintextExtendedEdwardsUnchecked {
162 wires: Vec<Label>,
163 },
164 PlaintextPointToExtendedEdwards {
165 point: Label,
166 },
167 PlaintextKeccakF1600 {
168 wires: Vec<Label>,
169 },
170}
171
172#[derive(Debug, Clone, Default, PartialEq, Eq)]
173pub struct CircuitPreprocessing {
174 pub scalar_singlets: usize,
175 pub scalar_triples: usize,
176 pub base_field_singlets: usize,
177 pub base_field_triples: usize,
178 pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
179 pub bit_singlets: usize,
180 pub bit_triples: usize,
181 pub scalar_dabits: usize,
182 pub base_field_dabits: usize,
183}
184
185impl Add for CircuitPreprocessing {
186 type Output = Self;
187
188 fn add(self, other: Self) -> Self::Output {
189 Self {
190 scalar_singlets: self.scalar_singlets + other.scalar_singlets,
191 scalar_triples: self.scalar_triples + other.scalar_triples,
192 base_field_singlets: self.base_field_singlets + other.base_field_singlets,
193 base_field_triples: self.base_field_triples + other.base_field_triples,
194 bit_singlets: self.bit_singlets + other.bit_singlets,
195 bit_triples: self.bit_triples + other.bit_triples,
196 scalar_dabits: self.scalar_dabits + other.scalar_dabits,
197 base_field_dabits: self.base_field_dabits + other.base_field_dabits,
198 base_field_pow_pairs: {
199 let mut combined = self.base_field_pow_pairs;
200 for (k, v) in other.base_field_pow_pairs {
201 *combined.entry(k).or_insert(0) += v;
202 }
203 combined
204 },
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
212
213 use super::*;
214 use crate::circuit::FieldShareBinaryOp;
215
216 #[test]
217 fn test_ser_gate() {
218 let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
219 x: Label::from(1, 2),
220 y: Label::from(3, 4),
221 y_form: ShareOrPlaintext::Share,
222 op: FieldShareBinaryOp::Add,
223 field_type: FieldType::ScalarField,
224 };
225 let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
226 x: Label::from(1, 2),
227 y: Label::from(3, 4),
228 y_form: ShareOrPlaintext::Plaintext,
229 op: FieldShareBinaryOp::Add,
230 field_type: FieldType::ScalarField,
231 };
232 let point_gate: Gate<C> = Gate::PointShareBinaryOp {
233 p: Label::from(1, 2),
234 y: Label::from(3, 4),
235 p_form: ShareOrPlaintext::Share,
236 y_form: ShareOrPlaintext::Plaintext,
237 op: PointShareBinaryOp::Add,
238 };
239
240 let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
241 let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
242 let point_gate_ser = bincode::serialize(&point_gate).unwrap();
243
244 let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
245 let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
246 let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
247
248 assert_eq!(no_curve_gate, no_curve_gate_de);
249 assert_eq!(scalar_gate, scalar_gate_de);
250 assert_eq!(point_gate, point_gate_de);
251 }
252
253 #[test]
254 fn test_circuit_preprocessing_add() {
255 let a = CircuitPreprocessing {
256 scalar_singlets: 1,
257 scalar_triples: 2,
258 base_field_singlets: 3,
259 base_field_triples: 4,
260 bit_singlets: 0,
261 bit_triples: 1,
262 scalar_dabits: 1,
263 base_field_dabits: 2,
264 base_field_pow_pairs: vec![
265 (BoxedUint::from(vec![21]), 5),
266 (BoxedUint::from(vec![14]), 6),
267 ]
268 .into_iter()
269 .collect(),
270 };
271 let b = CircuitPreprocessing {
272 scalar_singlets: 2,
273 scalar_triples: 3,
274 base_field_singlets: 0,
275 base_field_triples: 5,
276 bit_singlets: 3,
277 bit_triples: 4,
278 scalar_dabits: 2,
279 base_field_dabits: 3,
280 base_field_pow_pairs: vec![
281 (BoxedUint::from(vec![21]), 6),
282 (BoxedUint::from(vec![13]), 7),
283 ]
284 .into_iter()
285 .collect(),
286 };
287
288 let c = a + b;
289
290 assert_eq!(c.scalar_singlets, 3);
291 assert_eq!(c.scalar_triples, 5);
292 assert_eq!(c.base_field_singlets, 3);
293 assert_eq!(c.base_field_triples, 9);
294 assert_eq!(c.bit_singlets, 3);
295 assert_eq!(c.bit_triples, 5);
296 assert_eq!(c.scalar_dabits, 3);
297 assert_eq!(c.base_field_dabits, 5);
298 assert_eq!(
299 c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
300 Some(&11)
301 );
302 assert_eq!(
303 c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
304 Some(&6)
305 );
306 assert_eq!(
307 c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
308 Some(&7)
309 );
310 }
311}