1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6 algebra::{
7 elliptic_curve::{BaseFieldElement, Curve, Point, Scalar},
8 field::{subfield_element::Mersenne107Element, Bit, FieldExtension, SubfieldElement},
9 BoxedUint,
10 },
11 types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15use wincode::{SchemaRead, SchemaWrite};
16
17use crate::{
18 circuit::{errors::BatchSizeError, AlgebraicType, BatchSize, GateIndex, ShareOrPlaintext},
19 errors::{AbortError, FaultyPeer},
20};
21
22#[derive(
24 Debug,
25 Clone,
26 PartialEq,
27 Eq,
28 Hash,
29 Serialize,
30 Deserialize,
31 SchemaRead,
32 SchemaWrite,
33 PartialOrd,
34 Ord,
35)]
36#[repr(C)]
37pub enum FieldPlaintextUnaryOp {
38 Neg,
39 MulInverse,
41 BitExtract {
43 little_endian_bit_idx: u16,
44 signed: bool,
45 },
46 Sqrt,
47 Pow {
48 exp: BoxedUint,
49 },
50}
51
52impl FieldPlaintextUnaryOp {
53 pub fn eval<F: FieldExtension>(
55 &self,
56 label: GateIndex,
57 x: &SubfieldElement<F>,
58 ) -> Result<SubfieldElement<F>, AbortError> {
59 match self {
60 FieldPlaintextUnaryOp::Neg => Ok(-x),
61 FieldPlaintextUnaryOp::MulInverse => {
62 Ok(x.invert().unwrap_or(SubfieldElement::<F>::zero()))
63 }
64 FieldPlaintextUnaryOp::BitExtract {
65 little_endian_bit_idx: idx,
66 signed,
67 } => {
68 let bit = if *signed && *x > -x {
69 !(-SubfieldElement::<F>::one() - x)
70 .to_biguint()
71 .bit(*idx as u64)
72 } else {
73 x.to_biguint().bit(*idx as u64)
74 };
75 Ok(SubfieldElement::<F>::from(bit))
76 }
77 FieldPlaintextUnaryOp::Sqrt => {
78 let (choice, sqrt) =
79 SubfieldElement::<F>::sqrt_ratio(x, &SubfieldElement::<F>::one());
80 if !bool::from(choice) {
81 return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
82 }
83 Ok(sqrt)
84 }
85 FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
86 }
87 }
88}
89
90#[derive(
92 Debug,
93 Clone,
94 Copy,
95 PartialEq,
96 Eq,
97 Hash,
98 Serialize,
99 Deserialize,
100 SchemaRead,
101 SchemaWrite,
102 PartialOrd,
103 Ord,
104)]
105#[repr(C)]
106pub enum FieldPlaintextBinaryOp {
107 Add,
108 Mul,
109 EuclDiv,
110 Mod,
111 Gt,
112 Ge,
113 Eq,
114 Xor,
115 Or,
116}
117
118impl FieldPlaintextBinaryOp {
119 pub fn eval<F: FieldExtension>(
120 &self,
121 x: &SubfieldElement<F>,
122 y: &SubfieldElement<F>,
123 label: GateIndex,
124 ) -> Result<SubfieldElement<F>, AbortError> {
125 match self {
126 FieldPlaintextBinaryOp::Add => Ok(x + y),
127 FieldPlaintextBinaryOp::Mul => Ok(x * y),
128 FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
129 FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
130 FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
131 FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
132 FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
133 FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
134 FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
135 }
136 }
137}
138
139pub(crate) fn euclidean_division<F: FieldExtension>(
140 x: &SubfieldElement<F>,
141 y: &SubfieldElement<F>,
142 label: GateIndex,
143) -> Result<SubfieldElement<F>, AbortError> {
144 if *y == SubfieldElement::<F>::zero() {
145 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
146 }
147
148 let x = x.to_biguint();
150 let y = y.to_biguint();
151
152 let div = (x / y).to_bytes_be();
153 let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
155 .chain(div)
156 .collect::<Vec<_>>();
157
158 Ok(SubfieldElement::<F>::from_be_bytes(&div)?)
159}
160
161fn modulo<F: FieldExtension>(
162 x: &SubfieldElement<F>,
163 y: &SubfieldElement<F>,
164 label: GateIndex,
165) -> Result<SubfieldElement<F>, AbortError> {
166 if *y == SubfieldElement::<F>::zero() {
167 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
168 }
169
170 let x = x.to_biguint();
172 let y = y.to_biguint();
173
174 let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
175 let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
177 .chain(modulo)
178 .collect::<Vec<_>>();
179
180 Ok(SubfieldElement::<F>::from_be_bytes(&modulo)?)
181}
182
183#[derive(
185 Debug,
186 Clone,
187 Copy,
188 PartialEq,
189 Eq,
190 Hash,
191 Serialize,
192 Deserialize,
193 SchemaRead,
194 SchemaWrite,
195 PartialOrd,
196 Ord,
197)]
198#[repr(C)]
199pub enum FieldShareUnaryOp {
200 Neg,
202 MulInverse,
204 Open,
206 IsZero,
208}
209
210#[derive(
213 Debug,
214 Clone,
215 Copy,
216 PartialEq,
217 Eq,
218 Hash,
219 Serialize,
220 Deserialize,
221 SchemaRead,
222 SchemaWrite,
223 PartialOrd,
224 Ord,
225)]
226#[repr(C)]
227pub enum FieldShareBinaryOp {
228 Add,
230 Mul,
232}
233
234#[derive(
236 Debug,
237 Clone,
238 Copy,
239 PartialEq,
240 Eq,
241 Hash,
242 Serialize,
243 Deserialize,
244 SchemaRead,
245 SchemaWrite,
246 PartialOrd,
247 Ord,
248)]
249#[repr(C)]
250pub enum BitShareUnaryOp {
251 Not,
253 Open,
255}
256
257#[derive(
259 Debug,
260 Clone,
261 Copy,
262 PartialEq,
263 Eq,
264 Hash,
265 Serialize,
266 Deserialize,
267 SchemaRead,
268 SchemaWrite,
269 PartialOrd,
270 Ord,
271)]
272#[repr(C)]
273pub enum BitShareBinaryOp {
274 Xor,
276 Or,
278 And,
280}
281
282#[derive(
284 Debug,
285 Clone,
286 Copy,
287 PartialEq,
288 Eq,
289 Hash,
290 Serialize,
291 Deserialize,
292 SchemaRead,
293 SchemaWrite,
294 PartialOrd,
295 Ord,
296)]
297#[repr(C)]
298pub enum BitPlaintextUnaryOp {
299 Not,
301}
302
303impl BitPlaintextUnaryOp {
304 pub fn eval(&self, x: Bit) -> Bit {
305 match self {
306 BitPlaintextUnaryOp::Not => Bit::ONE - x,
307 }
308 }
309}
310
311#[derive(
313 Debug,
314 Clone,
315 Copy,
316 PartialEq,
317 Eq,
318 Hash,
319 Serialize,
320 Deserialize,
321 SchemaRead,
322 SchemaWrite,
323 PartialOrd,
324 Ord,
325)]
326#[repr(C)]
327pub enum BitPlaintextBinaryOp {
328 Xor,
330 Or,
332 And,
334}
335
336impl BitPlaintextBinaryOp {
337 pub fn eval(&self, x: Bit, y: Bit) -> Bit {
338 match self {
339 BitPlaintextBinaryOp::Xor => x + y,
340 BitPlaintextBinaryOp::Or => x + y - x * y,
341 BitPlaintextBinaryOp::And => x * y,
342 }
343 }
344}
345
346#[derive(
348 Debug,
349 Clone,
350 Copy,
351 PartialEq,
352 Eq,
353 Hash,
354 Serialize,
355 Deserialize,
356 SchemaRead,
357 SchemaWrite,
358 PartialOrd,
359 Ord,
360)]
361#[repr(C)]
362pub enum PointPlaintextUnaryOp {
363 Neg,
365}
366
367impl PointPlaintextUnaryOp {
368 pub fn eval<C: Curve>(&self, x: &Point<C>) -> Result<Point<C>, AbortError> {
369 match self {
370 PointPlaintextUnaryOp::Neg => Ok(-x),
371 }
372 }
373}
374
375#[derive(
377 Debug,
378 Clone,
379 Copy,
380 PartialEq,
381 Eq,
382 Hash,
383 Serialize,
384 Deserialize,
385 SchemaRead,
386 SchemaWrite,
387 PartialOrd,
388 Ord,
389)]
390#[repr(C)]
391pub enum PointPlaintextBinaryOp {
392 Add,
394 ScalarMul,
396}
397
398impl PointPlaintextBinaryOp {
399 pub fn eval<C: Curve>(&self, x: &Point<C>, y: &Point<C>) -> Result<Point<C>, AbortError> {
400 match self {
401 PointPlaintextBinaryOp::Add => Ok(x + y),
402 PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
403 "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
404 )),
405 }
406 }
407}
408
409#[derive(
411 Debug,
412 Clone,
413 Copy,
414 PartialEq,
415 Eq,
416 Hash,
417 Serialize,
418 Deserialize,
419 SchemaRead,
420 SchemaWrite,
421 PartialOrd,
422 Ord,
423)]
424#[repr(C)]
425pub enum PointShareUnaryOp {
426 Neg,
428 Open,
430 IsZero,
432}
433
434#[derive(
436 Debug,
437 Clone,
438 Copy,
439 PartialEq,
440 Eq,
441 Hash,
442 Serialize,
443 Deserialize,
444 SchemaRead,
445 SchemaWrite,
446 PartialOrd,
447 Ord,
448)]
449#[repr(C)]
450pub enum PointShareBinaryOp {
451 Add,
453 ScalarMul,
455}
456
457#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
460#[repr(C)]
461pub enum Input {
462 Plaintext {
463 algebraic_type: AlgebraicType,
464 batch_size: BatchSize,
465 },
466 SecretPlaintext {
467 inputer: PeerNumber,
468 algebraic_type: AlgebraicType,
469 batch_size: BatchSize,
470 },
471 Share {
472 algebraic_type: AlgebraicType,
473 batch_size: BatchSize,
474 },
475}
476
477impl Input {
478 pub fn batch_size(&self) -> u32 {
479 match self {
480 Input::Plaintext { batch_size, .. }
481 | Input::SecretPlaintext { batch_size, .. }
482 | Input::Share { batch_size, .. } => *batch_size,
483 }
484 }
485
486 pub fn algebraic_type(&self) -> AlgebraicType {
487 match self {
488 Input::Plaintext { algebraic_type, .. }
489 | Input::Share { algebraic_type, .. }
490 | Input::SecretPlaintext { algebraic_type, .. } => *algebraic_type,
491 }
492 }
493
494 pub fn share_or_plaintext(&self) -> ShareOrPlaintext {
495 match self {
496 Input::SecretPlaintext { .. } | Input::Share { .. } => ShareOrPlaintext::Share,
497 Input::Plaintext { .. } => ShareOrPlaintext::Plaintext,
498 }
499 }
500}
501
502#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
503#[serde(bound(
504 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
505 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
506))]
507#[repr(C)]
508pub enum Constant<C: Curve> {
509 Scalar(Scalar<C>),
510 ScalarBatch(Vec<Scalar<C>>),
511 BaseField(BaseFieldElement<C>),
512 BaseFieldBatch(Vec<BaseFieldElement<C>>),
513 Mersenne107(Mersenne107Element),
514 Mersenne107Batch(Vec<Mersenne107Element>),
515 Bit(Bit),
516 BitBatch(Vec<Bit>),
517 Point(Point<C>),
518 PointBatch(Vec<Point<C>>),
519}
520
521impl<C: Curve> Constant<C> {
522 pub fn batch_size(&self) -> Result<u32, BatchSizeError> {
523 let n = match self {
524 Constant::ScalarBatch(v) => v.len(),
525 Constant::BaseFieldBatch(v) => v.len(),
526 Constant::Mersenne107Batch(v) => v.len(),
527 Constant::BitBatch(v) => v.len(),
528 Constant::PointBatch(v) => v.len(),
529 Constant::Scalar(_)
530 | Constant::BaseField(_)
531 | Constant::Mersenne107(_)
532 | Constant::Bit(_)
533 | Constant::Point(_) => 1,
534 };
535 if let Ok(n) = u32::try_from(n) {
536 Ok(n)
537 } else {
538 Err(BatchSizeError(n))
539 }
540 }
541
542 pub fn algebraic_type(&self) -> AlgebraicType {
543 match self {
544 Constant::Scalar(_) | Constant::ScalarBatch(_) => AlgebraicType::ScalarField,
545 Constant::BaseField(_) | Constant::BaseFieldBatch(_) => AlgebraicType::BaseField,
546 Constant::Mersenne107(_) | Constant::Mersenne107Batch(_) => AlgebraicType::Mersenne107,
547 Constant::Bit(_) | Constant::BitBatch(_) => AlgebraicType::Bit,
548 Constant::Point(_) | Constant::PointBatch(_) => AlgebraicType::Point,
549 }
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use primitives::algebra::{
556 elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
557 field::SubfieldElement,
558 };
559
560 use super::*;
561
562 #[test]
563 fn test_scalar_unary_op() {
564 let mut rng = rand::thread_rng();
565 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
566 let label = 0;
567 let neg = FieldPlaintextUnaryOp::Neg;
568 let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
569
570 assert_eq!(neg.eval::<ScalarField<C>>(label, &x), Ok(-x));
571 assert_eq!(
572 mul_inverse.eval::<ScalarField<C>>(label, &x),
573 Ok(x.invert().unwrap())
574 );
575 }
576
577 #[test]
578 fn test_scalar_binary_op() {
579 let mut rng = rand::thread_rng();
580 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
581 let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
582 let label = 0;
583
584 let add = FieldPlaintextBinaryOp::Add;
585 let mul = FieldPlaintextBinaryOp::Mul;
586 let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
587 let modulo_op = FieldPlaintextBinaryOp::Mod;
588 let gt = FieldPlaintextBinaryOp::Gt;
589 let ge = FieldPlaintextBinaryOp::Ge;
590 let eq = FieldPlaintextBinaryOp::Eq;
591
592 assert_eq!(add.eval::<ScalarField<C>>(&x, &y, label), Ok(x + y));
593 assert_eq!(mul.eval::<ScalarField<C>>(&x, &y, label), Ok(x * y));
594 assert_eq!(
595 eucl_div.eval::<ScalarField<C>>(&x, &y, label),
596 euclidean_division::<ScalarField<C>>(&x, &y, label)
597 );
598 assert_eq!(
599 modulo_op.eval::<ScalarField<C>>(&x, &y, label),
600 modulo::<ScalarField<C>>(&x, &y, label)
601 );
602 assert_eq!(
603 gt.eval::<ScalarField<C>>(&x, &y, label),
604 Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
605 );
606 assert_eq!(
607 ge.eval::<ScalarField<C>>(&x, &y, label),
608 Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
609 );
610 assert_eq!(
611 eq.eval::<ScalarField<C>>(&x, &y, label),
612 Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
613 );
614 }
615
616 #[test]
617 fn test_scalar_boolean_binary_op() {
618 let and = FieldPlaintextBinaryOp::Mul;
619 let or = FieldPlaintextBinaryOp::Or;
620 let xor = FieldPlaintextBinaryOp::Xor;
621 let label = 0;
622 for bool_x in [false, true] {
623 for bool_y in [false, true] {
624 let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
625 let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
626 assert_eq!(
627 and.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
628 Ok((bool_x && bool_y).into())
629 );
630 assert_eq!(
631 or.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
632 Ok((bool_x || bool_y).into())
633 );
634 assert_eq!(
635 xor.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
636 Ok((bool_x ^ bool_y).into())
637 );
638 }
639 }
640 }
641
642 #[test]
643 fn test_bit_ops() {
644 let not = BitPlaintextUnaryOp::Not;
645 for bool_x in [false, true] {
646 let x = Bit::from(bool_x);
647 assert_eq!(not.eval(x), (!bool_x).into());
648 }
649
650 let and = BitPlaintextBinaryOp::And;
651 let or = BitPlaintextBinaryOp::Or;
652 let xor = BitPlaintextBinaryOp::Xor;
653 for bool_x in [false, true] {
654 for bool_y in [false, true] {
655 let x = Bit::from(bool_x);
656 let y = Bit::from(bool_y);
657 assert_eq!(and.eval(x, y), (bool_x && bool_y).into());
658 assert_eq!(or.eval(x, y), (bool_x || bool_y).into());
659 assert_eq!(xor.eval(x, y), (bool_x ^ bool_y).into());
660 }
661 }
662 }
663
664 #[test]
665 fn test_euclidian_division() {
666 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
667 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
668 let label = 0;
669
670 let result = euclidean_division::<ScalarField<C>>(&x, &y, label).unwrap();
671 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
672 }
673
674 #[test]
675 fn test_modulo() {
676 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
677 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
678 let label = 0;
679
680 let result = modulo::<ScalarField<C>>(&x, &y, label).unwrap();
681 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
682 }
683
684 #[test]
685 fn test_signed_bit_extract() {
686 let x = -Scalar::<C>::from(9u32);
687 let label = 0;
688 for i in 0..5 {
689 let op = FieldPlaintextUnaryOp::BitExtract {
690 little_endian_bit_idx: i,
691 signed: true,
692 };
693 let result = op.eval::<ScalarField<C>>(label, &x);
694 assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
695 }
696 }
697
698 #[test]
699 fn test_sqrt() {
700 let mut rng = rand::thread_rng();
701 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
702 let label = 0;
703 let result = FieldPlaintextUnaryOp::Sqrt
704 .eval::<ScalarField<C>>(label, &(x * x))
705 .unwrap();
706
707 assert_eq!(result * result, x * x)
708 }
709
710 #[test]
711 fn test_pow() {
712 let mut rng = rand::thread_rng();
713 let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
714 let label = 0;
715 let five = BoxedUint::from(vec![5u64]);
716 let five_inv = BoxedUint::from(vec![
717 14757395258967641281,
718 14757395258967641292,
719 14757395258967641292,
720 5534023222112865484,
721 ]);
722 let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
723 .eval::<BaseField<C>>(label, &x)
724 .unwrap();
725 let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
726 .eval::<BaseField<C>>(label, &x_pow_5)
727 .unwrap();
728
729 assert_eq!(x_again, x)
730 }
731}