1use std::{collections::HashMap, ops::Add};
2
3use primitives::algebra::{
4 elliptic_curve::{Curve, Point, Scalar},
5 BoxedUint,
6};
7use serde::{Deserialize, Serialize};
8
9use crate::{
10 circuit::{
11 AlgebraicType,
12 Batched,
13 BitShareBinaryOp,
14 BitShareUnaryOp,
15 FieldPlaintextBinaryOp,
16 FieldPlaintextUnaryOp,
17 FieldShareBinaryOp,
18 FieldShareUnaryOp,
19 FieldType,
20 Input,
21 PointShareBinaryOp,
22 PointShareUnaryOp,
23 ShareOrPlaintext,
24 },
25 types::Label,
26};
27
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30#[serde(bound(
31 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
32 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
33))]
34pub enum Gate<C: Curve> {
35 Input {
37 input_type: Input<C>,
38 },
39 FieldShareUnaryOp {
41 x: Label,
42 op: FieldShareUnaryOp,
43 field_type: FieldType,
44 },
45 FieldShareBinaryOp {
47 x: Label,
48 y: Label,
49 y_form: ShareOrPlaintext,
50 op: FieldShareBinaryOp,
51 field_type: FieldType,
52 },
53 BatchSummation {
54 x: Label,
55 x_form: ShareOrPlaintext,
56 algebraic_type: AlgebraicType,
57 },
58 BitShareUnaryOp {
59 x: Label,
60 op: BitShareUnaryOp,
61 },
62 BitShareBinaryOp {
63 x: Label,
64 y: Label,
65 y_form: ShareOrPlaintext,
66 op: BitShareBinaryOp,
67 },
68 PointShareUnaryOp {
70 p: Label,
71 op: PointShareUnaryOp,
72 },
73 PointShareBinaryOp {
74 p: Label,
75 y: Label,
76 x_form: ShareOrPlaintext,
77 y_form: ShareOrPlaintext,
78 op: PointShareBinaryOp,
79 },
80 FieldPlaintextUnaryOp {
82 x: Label,
83 op: FieldPlaintextUnaryOp,
84 field_type: FieldType,
85 },
86 FieldPlaintextBinaryOp {
88 x: Label,
89 y: Label,
90 op: FieldPlaintextBinaryOp,
91 field_type: FieldType,
92 },
93 BitPlaintextUnaryOp {
94 x: Label,
95 op: FieldPlaintextUnaryOp,
96 },
97 BitPlaintextBinaryOp {
98 x: Label,
99 y: Label,
100 op: FieldPlaintextBinaryOp,
101 },
102 DaBit {
104 field_type: FieldType,
105 batched: Batched,
106 },
107 GetDaBitFieldShare {
108 x: Label,
109 field_type: FieldType,
110 },
111 GetDaBitSharedBit {
112 x: Label,
113 field_type: FieldType,
114 },
115 EncryptPoint {
117 x: Label,
118 c: Point<C>,
119 },
120 DecryptPoint {
121 x: Label,
122 y: Label,
123 },
124 BaseFieldPow {
126 x: Label,
127 exp: BoxedUint,
128 },
129 BitPlaintextToField {
131 x: Label,
132 field_type: FieldType,
133 },
134 FieldPlaintextToBit {
135 x: Label,
136 field_type: FieldType,
137 },
138 BatchGetIndex {
140 x: Label,
141 x_type: AlgebraicType,
142 x_form: ShareOrPlaintext,
143 index: usize,
144 },
145 CollectToBatch {
146 wires: Vec<Label>,
147 x_type: AlgebraicType,
148 x_form: ShareOrPlaintext,
149 },
150}
151
152#[derive(Debug, Clone, Default, PartialEq, Eq)]
153pub struct CircuitPreprocessing {
154 pub scalar_singlets: usize,
155 pub scalar_triples: usize,
156 pub base_field_singlets: usize,
157 pub base_field_triples: usize,
158 pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
159 pub bit_singlets: usize,
160 pub bit_triples: usize,
161 pub scalar_dabits: usize,
162 pub base_field_dabits: usize,
163}
164
165impl Add for CircuitPreprocessing {
166 type Output = Self;
167
168 fn add(self, other: Self) -> Self::Output {
169 Self {
170 scalar_singlets: self.scalar_singlets + other.scalar_singlets,
171 scalar_triples: self.scalar_triples + other.scalar_triples,
172 base_field_singlets: self.base_field_singlets + other.base_field_singlets,
173 base_field_triples: self.base_field_triples + other.base_field_triples,
174 bit_singlets: self.bit_singlets + other.bit_singlets,
175 bit_triples: self.bit_triples + other.bit_triples,
176 scalar_dabits: self.scalar_dabits + other.scalar_dabits,
177 base_field_dabits: self.base_field_dabits + other.base_field_dabits,
178 base_field_pow_pairs: {
179 let mut combined = self.base_field_pow_pairs;
180 for (k, v) in other.base_field_pow_pairs {
181 *combined.entry(k).or_insert(0) += v;
182 }
183 combined
184 },
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
192
193 use super::*;
194 use crate::circuit::FieldShareBinaryOp;
195
196 #[test]
197 fn test_ser_gate() {
198 let no_curve_gate: Gate<C> = Gate::FieldShareBinaryOp {
199 x: Label::from(1, 2),
200 y: Label::from(3, 4),
201 y_form: ShareOrPlaintext::Share,
202 op: FieldShareBinaryOp::Add,
203 field_type: FieldType::ScalarField,
204 };
205 let scalar_gate: Gate<C> = Gate::FieldShareBinaryOp {
206 x: Label::from(1, 2),
207 y: Label::from(3, 4),
208 y_form: ShareOrPlaintext::Plaintext,
209 op: FieldShareBinaryOp::Add,
210 field_type: FieldType::ScalarField,
211 };
212 let point_gate: Gate<C> = Gate::PointShareBinaryOp {
213 p: Label::from(1, 2),
214 y: Label::from(3, 4),
215 x_form: ShareOrPlaintext::Share,
216 y_form: ShareOrPlaintext::Plaintext,
217 op: PointShareBinaryOp::Add,
218 };
219
220 let no_curve_gate_ser = bincode::serialize(&no_curve_gate).unwrap();
221 let scalar_gate_ser = bincode::serialize(&scalar_gate).unwrap();
222 let point_gate_ser = bincode::serialize(&point_gate).unwrap();
223
224 let no_curve_gate_de: Gate<C> = bincode::deserialize(&no_curve_gate_ser).unwrap();
225 let scalar_gate_de: Gate<C> = bincode::deserialize(&scalar_gate_ser).unwrap();
226 let point_gate_de: Gate<C> = bincode::deserialize(&point_gate_ser).unwrap();
227
228 assert_eq!(no_curve_gate, no_curve_gate_de);
229 assert_eq!(scalar_gate, scalar_gate_de);
230 assert_eq!(point_gate, point_gate_de);
231 }
232
233 #[test]
234 fn test_circuit_preprocessing_add() {
235 let a = CircuitPreprocessing {
236 scalar_singlets: 1,
237 scalar_triples: 2,
238 base_field_singlets: 3,
239 base_field_triples: 4,
240 bit_singlets: 0,
241 bit_triples: 1,
242 scalar_dabits: 1,
243 base_field_dabits: 2,
244 base_field_pow_pairs: vec![
245 (BoxedUint::from(vec![21]), 5),
246 (BoxedUint::from(vec![14]), 6),
247 ]
248 .into_iter()
249 .collect(),
250 };
251 let b = CircuitPreprocessing {
252 scalar_singlets: 2,
253 scalar_triples: 3,
254 base_field_singlets: 0,
255 base_field_triples: 5,
256 bit_singlets: 3,
257 bit_triples: 4,
258 scalar_dabits: 2,
259 base_field_dabits: 3,
260 base_field_pow_pairs: vec![
261 (BoxedUint::from(vec![21]), 6),
262 (BoxedUint::from(vec![13]), 7),
263 ]
264 .into_iter()
265 .collect(),
266 };
267
268 let c = a + b;
269
270 assert_eq!(c.scalar_singlets, 3);
271 assert_eq!(c.scalar_triples, 5);
272 assert_eq!(c.base_field_singlets, 3);
273 assert_eq!(c.base_field_triples, 9);
274 assert_eq!(c.bit_singlets, 3);
275 assert_eq!(c.bit_triples, 5);
276 assert_eq!(c.scalar_dabits, 3);
277 assert_eq!(c.base_field_dabits, 5);
278 assert_eq!(
279 c.base_field_pow_pairs.get(&BoxedUint::from(vec![21])),
280 Some(&11)
281 );
282 assert_eq!(
283 c.base_field_pow_pairs.get(&BoxedUint::from(vec![14])),
284 Some(&6)
285 );
286 assert_eq!(
287 c.base_field_pow_pairs.get(&BoxedUint::from(vec![13])),
288 Some(&7)
289 );
290 }
291}