1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6 algebra::{
7 elliptic_curve::{Curve, Point, Scalar},
8 field::{FieldExtension, SubfieldElement},
9 BoxedUint,
10 },
11 types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17 circuit::{
18 AlgebraicType,
19 BaseFieldPlaintext,
20 BaseFieldPlaintextBatch,
21 BitPlaintext,
22 BitPlaintextBatch,
23 Mersenne107Plaintext,
24 Mersenne107PlaintextBatch,
25 PointPlaintext,
26 PointPlaintextBatch,
27 ScalarPlaintext,
28 ScalarPlaintextBatch,
29 },
30 errors::{AbortError, FaultyPeer},
31 types::Label,
32};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub enum FieldPlaintextUnaryOp {
37 Neg,
38 MulInverse,
40 BitExtract {
42 little_endian_bit_idx: u16,
43 signed: bool,
44 },
45 Sqrt,
46 Pow {
47 exp: BoxedUint,
48 },
49}
50
51impl FieldPlaintextUnaryOp {
52 pub fn eval<F: FieldExtension>(
54 &self,
55 label: Label,
56 x: SubfieldElement<F>,
57 ) -> Result<SubfieldElement<F>, AbortError> {
58 match self {
59 FieldPlaintextUnaryOp::Neg => Ok(-x),
60 FieldPlaintextUnaryOp::MulInverse => {
61 if x == SubfieldElement::<F>::zero() {
62 Ok(SubfieldElement::<F>::zero())
63 } else {
64 Ok(x.invert().unwrap())
65 }
66 }
67 FieldPlaintextUnaryOp::BitExtract {
68 little_endian_bit_idx: idx,
69 signed,
70 } => {
71 let bit = if *signed && x > -x {
72 !(-SubfieldElement::<F>::one() - x)
73 .to_biguint()
74 .bit(*idx as u64)
75 } else {
76 x.to_biguint().bit(*idx as u64)
77 };
78 Ok(SubfieldElement::<F>::from(bit))
79 }
80 FieldPlaintextUnaryOp::Sqrt => {
81 let (choice, sqrt) =
82 SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
83 if !bool::from(choice) {
84 return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
85 }
86 Ok(sqrt)
87 }
88 FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
95pub enum FieldPlaintextBinaryOp {
96 Add,
97 Mul,
98 EuclDiv,
99 Mod,
100 Gt,
101 Ge,
102 Eq,
103 Xor,
104 Or,
105}
106
107impl FieldPlaintextBinaryOp {
108 pub fn eval<F: FieldExtension>(
109 &self,
110 x: SubfieldElement<F>,
111 y: SubfieldElement<F>,
112 label: Label,
113 ) -> Result<SubfieldElement<F>, AbortError> {
114 match self {
115 FieldPlaintextBinaryOp::Add => Ok(x + y),
116 FieldPlaintextBinaryOp::Mul => Ok(x * y),
117 FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
118 FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
119 FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
120 FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
121 FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
122 FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
123 FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
124 }
125 }
126}
127
128pub(crate) fn euclidean_division<F: FieldExtension>(
129 x: SubfieldElement<F>,
130 y: SubfieldElement<F>,
131 label: Label,
132) -> Result<SubfieldElement<F>, AbortError> {
133 if y == SubfieldElement::<F>::zero() {
134 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
135 }
136
137 let x = x.to_biguint();
139 let y = y.to_biguint();
140
141 let div = (x / y).to_bytes_be();
142 let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
144 .chain(div)
145 .collect::<Vec<_>>();
146
147 Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
148}
149
150fn modulo<F: FieldExtension>(
151 x: SubfieldElement<F>,
152 y: SubfieldElement<F>,
153 label: Label,
154) -> Result<SubfieldElement<F>, AbortError> {
155 if y == SubfieldElement::<F>::zero() {
156 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
157 }
158
159 let x = x.to_biguint();
161 let y = y.to_biguint();
162
163 let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
164 let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
166 .chain(modulo)
167 .collect::<Vec<_>>();
168
169 Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
174pub enum FieldShareUnaryOp {
175 Neg,
177 MulInverse,
179 Open,
181 IsZero,
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
188pub enum FieldShareBinaryOp {
189 Add,
191 Mul,
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
197pub enum BitShareUnaryOp {
198 Not,
200 Open,
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
206pub enum BitShareBinaryOp {
207 Xor,
209 Or,
211 And,
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
217pub enum PointPlaintextUnaryOp {
218 Neg,
220}
221
222impl PointPlaintextUnaryOp {
223 pub fn eval<C: Curve>(&self, x: Point<C>) -> Result<Point<C>, AbortError> {
224 match self {
225 PointPlaintextUnaryOp::Neg => Ok(-x),
226 }
227 }
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
232pub enum PointPlaintextBinaryOp {
233 Add,
235 ScalarMul,
237}
238
239impl PointPlaintextBinaryOp {
240 pub fn eval<C: Curve>(&self, x: Point<C>, y: Point<C>) -> Result<Point<C>, AbortError> {
241 match self {
242 PointPlaintextBinaryOp::Add => Ok(x + y),
243 PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
244 "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
245 )),
246 }
247 }
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
252pub enum PointShareUnaryOp {
253 Neg,
255 Open,
257 IsZero,
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
263pub enum PointShareBinaryOp {
264 Add,
266 ScalarMul,
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(bound(
272 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
273 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
274))]
275#[derive(Hash)]
276pub enum Input<C: Curve> {
277 SecretPlaintext {
278 inputer: PeerNumber,
279 algebraic_type: AlgebraicType,
280 batched: Batched,
281 },
282 Share {
283 algebraic_type: AlgebraicType,
284 batched: Batched,
285 },
286 RandomShare {
287 algebraic_type: AlgebraicType,
288 batched: Batched,
289 },
290 Scalar(ScalarPlaintext<C>),
291 ScalarBatch(ScalarPlaintextBatch<C>),
292 BaseField(BaseFieldPlaintext<C>),
293 BaseFieldBatch(BaseFieldPlaintextBatch<C>),
294 Mersenne107(Mersenne107Plaintext),
295 Mersenne107Batch(Mersenne107PlaintextBatch),
296 Bit(BitPlaintext),
297 BitBatch(BitPlaintextBatch),
298 Point(PointPlaintext<C>),
299 PointBatch(PointPlaintextBatch<C>),
300 ElGamalCiphertext {
301 c: PointPlaintext<C>,
302 r: PointPlaintext<C>,
303 },
304}
305
306impl<C: Curve> Input<C> {
307 pub fn batched(&self) -> Batched {
308 match self {
309 Input::SecretPlaintext { batched, .. } => *batched,
310 Input::Share { batched, .. } => *batched,
311 Input::RandomShare { batched, .. } => *batched,
312 Input::ScalarBatch(input) => Batched::Yes(input.len()),
313 Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
314 Input::Mersenne107Batch(input) => Batched::Yes(input.len()),
315 Input::BitBatch(input) => Batched::Yes(input.len()),
316 Input::PointBatch(input) => Batched::Yes(input.len()),
317 Input::ElGamalCiphertext { .. } => Batched::No,
318 Input::Scalar { .. } => Batched::No,
319 Input::BaseField { .. } => Batched::No,
320 Input::Mersenne107 { .. } => Batched::No,
321 Input::Bit { .. } => Batched::No,
322 Input::Point { .. } => Batched::No,
323 }
324 }
325}
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
329pub enum Batched {
330 Yes(usize),
331 No,
332}
333
334impl Batched {
335 pub fn count(&self) -> usize {
336 match self {
337 Batched::Yes(count) => *count,
338 Batched::No => 1,
339 }
340 }
341
342 pub fn is_batched(&self) -> bool {
343 match self {
344 Batched::Yes(_) => true,
345 Batched::No => false,
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use primitives::algebra::{
353 elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
354 field::SubfieldElement,
355 };
356
357 use super::*;
358
359 #[test]
360 fn test_scalar_unary_op() {
361 let mut rng = rand::thread_rng();
362 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
363 let label = Label::task(0);
364 let neg = FieldPlaintextUnaryOp::Neg;
365 let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
366
367 assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
368 assert_eq!(
369 mul_inverse.eval::<ScalarField<C>>(label, x),
370 Ok(x.invert().unwrap())
371 );
372 }
373
374 #[test]
375 fn test_scalar_binary_op() {
376 let mut rng = rand::thread_rng();
377 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
378 let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
379 let label = Label::task(0);
380
381 let add = FieldPlaintextBinaryOp::Add;
382 let mul = FieldPlaintextBinaryOp::Mul;
383 let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
384 let modulo_op = FieldPlaintextBinaryOp::Mod;
385 let gt = FieldPlaintextBinaryOp::Gt;
386 let ge = FieldPlaintextBinaryOp::Ge;
387 let eq = FieldPlaintextBinaryOp::Eq;
388
389 assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
390 assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
391 assert_eq!(
392 eucl_div.eval::<ScalarField<C>>(x, y, label),
393 euclidean_division::<ScalarField<C>>(x, y, label)
394 );
395 assert_eq!(
396 modulo_op.eval::<ScalarField<C>>(x, y, label),
397 modulo::<ScalarField<C>>(x, y, label)
398 );
399 assert_eq!(
400 gt.eval::<ScalarField<C>>(x, y, label),
401 Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
402 );
403 assert_eq!(
404 ge.eval::<ScalarField<C>>(x, y, label),
405 Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
406 );
407 assert_eq!(
408 eq.eval::<ScalarField<C>>(x, y, label),
409 Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
410 );
411 }
412
413 #[test]
414 fn test_boolean_binary_op() {
415 let and = FieldPlaintextBinaryOp::Mul;
416 let or = FieldPlaintextBinaryOp::Or;
417 let xor = FieldPlaintextBinaryOp::Xor;
418 let label = Label::task(0);
419 for bool_x in [false, true] {
420 for bool_y in [false, true] {
421 let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
422 let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
423 assert_eq!(
424 and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
425 Ok((bool_x && bool_y).into())
426 );
427 assert_eq!(
428 or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
429 Ok((bool_x || bool_y).into())
430 );
431 assert_eq!(
432 xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
433 Ok((bool_x ^ bool_y).into())
434 );
435 }
436 }
437 }
438
439 #[test]
440 fn test_euclidian_division() {
441 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
442 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
443 let label = Label::task(0);
444
445 let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
446 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
447 }
448
449 #[test]
450 fn test_modulo() {
451 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
452 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
453 let label = Label::task(0);
454
455 let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
456 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
457 }
458
459 #[test]
460 fn test_signed_bit_extract() {
461 let x = -Scalar::<C>::from(9u32);
462 let label = Label::task(0);
463 for i in 0..5 {
464 let op = FieldPlaintextUnaryOp::BitExtract {
465 little_endian_bit_idx: i,
466 signed: true,
467 };
468 let result = op.eval::<ScalarField<C>>(label, x);
469 assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
470 }
471 }
472
473 #[test]
474 fn test_sqrt() {
475 let mut rng = rand::thread_rng();
476 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
477 let label = Label::task(0);
478 let result = FieldPlaintextUnaryOp::Sqrt
479 .eval::<ScalarField<C>>(label, x * x)
480 .unwrap();
481
482 assert_eq!(result * result, x * x)
483 }
484
485 #[test]
486 fn test_pow() {
487 let mut rng = rand::thread_rng();
488 let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
489 let label = Label::task(0);
490 let five = BoxedUint::from(vec![5u64]);
491 let five_inv = BoxedUint::from(vec![
492 14757395258967641281,
493 14757395258967641292,
494 14757395258967641292,
495 5534023222112865484,
496 ]);
497 let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
498 .eval::<BaseField<C>>(label, x)
499 .unwrap();
500 let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
501 .eval::<BaseField<C>>(label, x_pow_5)
502 .unwrap();
503
504 assert_eq!(x_again, x)
505 }
506}