1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6 algebra::{
7 elliptic_curve::{Curve, Point, Scalar},
8 field::{FieldExtension, SubfieldElement},
9 BoxedUint,
10 },
11 types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17 circuit::{
18 AlgebraicType,
19 BaseFieldPlaintext,
20 BaseFieldPlaintextBatch,
21 BitPlaintext,
22 BitPlaintextBatch,
23 PointPlaintext,
24 PointPlaintextBatch,
25 ScalarPlaintext,
26 ScalarPlaintextBatch,
27 },
28 errors::{AbortError, FaultyPeer},
29 types::Label,
30};
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub enum FieldPlaintextUnaryOp {
35 Neg,
36 MulInverse,
38 BitExtract {
40 little_endian_bit_idx: u16,
41 signed: bool,
42 },
43 Sqrt,
44 Pow {
45 exp: BoxedUint,
46 },
47}
48
49impl FieldPlaintextUnaryOp {
50 pub fn eval<F: FieldExtension>(
52 &self,
53 label: Label,
54 x: SubfieldElement<F>,
55 ) -> Result<SubfieldElement<F>, AbortError> {
56 match self {
57 FieldPlaintextUnaryOp::Neg => Ok(-x),
58 FieldPlaintextUnaryOp::MulInverse => {
59 if x == SubfieldElement::<F>::zero() {
60 Ok(SubfieldElement::<F>::zero())
61 } else {
62 Ok(x.invert().unwrap())
63 }
64 }
65 FieldPlaintextUnaryOp::BitExtract {
66 little_endian_bit_idx: idx,
67 signed,
68 } => {
69 let bit = if *signed && x > -x {
70 !(-SubfieldElement::<F>::one() - x)
71 .to_biguint()
72 .bit(*idx as u64)
73 } else {
74 x.to_biguint().bit(*idx as u64)
75 };
76 Ok(SubfieldElement::<F>::from(bit))
77 }
78 FieldPlaintextUnaryOp::Sqrt => {
79 let (choice, sqrt) =
80 SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
81 if !bool::from(choice) {
82 return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
83 }
84 Ok(sqrt)
85 }
86 FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
93pub enum FieldPlaintextBinaryOp {
94 Add,
95 Mul,
96 EuclDiv,
97 Mod,
98 Gt,
99 Ge,
100 Eq,
101 Xor,
102 Or,
103}
104
105impl FieldPlaintextBinaryOp {
106 pub fn eval<F: FieldExtension>(
107 &self,
108 x: SubfieldElement<F>,
109 y: SubfieldElement<F>,
110 label: Label,
111 ) -> Result<SubfieldElement<F>, AbortError> {
112 match self {
113 FieldPlaintextBinaryOp::Add => Ok(x + y),
114 FieldPlaintextBinaryOp::Mul => Ok(x * y),
115 FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
116 FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
117 FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
118 FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
119 FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
120 FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
121 FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
122 }
123 }
124}
125
126pub(crate) fn euclidean_division<F: FieldExtension>(
127 x: SubfieldElement<F>,
128 y: SubfieldElement<F>,
129 label: Label,
130) -> Result<SubfieldElement<F>, AbortError> {
131 if y == SubfieldElement::<F>::zero() {
132 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
133 }
134
135 let x = x.to_biguint();
137 let y = y.to_biguint();
138
139 let div = (x / y).to_bytes_be();
140 let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
142 .chain(div)
143 .collect::<Vec<_>>();
144
145 Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
146}
147
148fn modulo<F: FieldExtension>(
149 x: SubfieldElement<F>,
150 y: SubfieldElement<F>,
151 label: Label,
152) -> Result<SubfieldElement<F>, AbortError> {
153 if y == SubfieldElement::<F>::zero() {
154 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
155 }
156
157 let x = x.to_biguint();
159 let y = y.to_biguint();
160
161 let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
162 let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
164 .chain(modulo)
165 .collect::<Vec<_>>();
166
167 Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
172pub enum FieldShareUnaryOp {
173 Neg,
175 MulInverse,
177 Open,
179 IsZero,
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
186pub enum FieldShareBinaryOp {
187 Add,
189 Mul,
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
195pub enum BitShareUnaryOp {
196 Not,
198 Open,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
204pub enum BitShareBinaryOp {
205 Xor,
207 Or,
209 And,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
215pub enum PointPlaintextUnaryOp {
216 Neg,
218}
219
220impl PointPlaintextUnaryOp {
221 pub fn eval<C: Curve>(&self, x: Point<C>) -> Result<Point<C>, AbortError> {
222 match self {
223 PointPlaintextUnaryOp::Neg => Ok(-x),
224 }
225 }
226}
227
228#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
230pub enum PointPlaintextBinaryOp {
231 Add,
233 ScalarMul,
235}
236
237impl PointPlaintextBinaryOp {
238 pub fn eval<C: Curve>(&self, x: Point<C>, y: Point<C>) -> Result<Point<C>, AbortError> {
239 match self {
240 PointPlaintextBinaryOp::Add => Ok(x + y),
241 PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
242 "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
243 )),
244 }
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
250pub enum PointShareUnaryOp {
251 Neg,
253 Open,
255 IsZero,
257}
258
259#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
261pub enum PointShareBinaryOp {
262 Add,
264 ScalarMul,
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
269#[serde(bound(
270 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
271 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
272))]
273pub enum Input<C: Curve> {
274 SecretPlaintext {
275 inputer: PeerNumber,
276 algebraic_type: AlgebraicType,
277 batched: Batched,
278 },
279 Share {
280 algebraic_type: AlgebraicType,
281 batched: Batched,
282 },
283 RandomShare {
284 algebraic_type: AlgebraicType,
285 batched: Batched,
286 },
287 Scalar(ScalarPlaintext<C>),
288 ScalarBatch(ScalarPlaintextBatch<C>),
289 BaseField(BaseFieldPlaintext<C>),
290 BaseFieldBatch(BaseFieldPlaintextBatch<C>),
291 Bit(BitPlaintext),
292 BitBatch(BitPlaintextBatch),
293 Point(PointPlaintext<C>),
294 PointBatch(PointPlaintextBatch<C>),
295 ElGamalCiphertext {
296 c: PointPlaintext<C>,
297 r: PointPlaintext<C>,
298 },
299}
300
301impl<C: Curve> Input<C> {
302 pub fn batched(&self) -> Batched {
303 match self {
304 Input::SecretPlaintext { batched, .. } => *batched,
305 Input::Share { batched, .. } => *batched,
306 Input::RandomShare { batched, .. } => *batched,
307 Input::ScalarBatch(input) => Batched::Yes(input.len()),
308 Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
309 Input::BitBatch(input) => Batched::Yes(input.len()),
310 Input::PointBatch(input) => Batched::Yes(input.len()),
311 Input::ElGamalCiphertext { .. } => Batched::No,
312 Input::Scalar { .. } => Batched::No,
313 Input::BaseField { .. } => Batched::No,
314 Input::Bit { .. } => Batched::No,
315 Input::Point { .. } => Batched::No,
316 }
317 }
318}
319
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
322pub enum Batched {
323 Yes(usize),
324 No,
325}
326
327impl Batched {
328 pub fn count(&self) -> usize {
329 match self {
330 Batched::Yes(count) => *count,
331 Batched::No => 1,
332 }
333 }
334
335 pub fn is_batched(&self) -> bool {
336 match self {
337 Batched::Yes(_) => true,
338 Batched::No => false,
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use primitives::algebra::{
346 elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
347 field::SubfieldElement,
348 };
349
350 use super::*;
351
352 #[test]
353 fn test_scalar_unary_op() {
354 let mut rng = rand::thread_rng();
355 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
356 let label = Label::task(0);
357 let neg = FieldPlaintextUnaryOp::Neg;
358 let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
359
360 assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
361 assert_eq!(
362 mul_inverse.eval::<ScalarField<C>>(label, x),
363 Ok(x.invert().unwrap())
364 );
365 }
366
367 #[test]
368 fn test_scalar_binary_op() {
369 let mut rng = rand::thread_rng();
370 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
371 let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
372 let label = Label::task(0);
373
374 let add = FieldPlaintextBinaryOp::Add;
375 let mul = FieldPlaintextBinaryOp::Mul;
376 let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
377 let modulo_op = FieldPlaintextBinaryOp::Mod;
378 let gt = FieldPlaintextBinaryOp::Gt;
379 let ge = FieldPlaintextBinaryOp::Ge;
380 let eq = FieldPlaintextBinaryOp::Eq;
381
382 assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
383 assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
384 assert_eq!(
385 eucl_div.eval::<ScalarField<C>>(x, y, label),
386 euclidean_division::<ScalarField<C>>(x, y, label)
387 );
388 assert_eq!(
389 modulo_op.eval::<ScalarField<C>>(x, y, label),
390 modulo::<ScalarField<C>>(x, y, label)
391 );
392 assert_eq!(
393 gt.eval::<ScalarField<C>>(x, y, label),
394 Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
395 );
396 assert_eq!(
397 ge.eval::<ScalarField<C>>(x, y, label),
398 Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
399 );
400 assert_eq!(
401 eq.eval::<ScalarField<C>>(x, y, label),
402 Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
403 );
404 }
405
406 #[test]
407 fn test_boolean_binary_op() {
408 let and = FieldPlaintextBinaryOp::Mul;
409 let or = FieldPlaintextBinaryOp::Or;
410 let xor = FieldPlaintextBinaryOp::Xor;
411 let label = Label::task(0);
412 for bool_x in [false, true] {
413 for bool_y in [false, true] {
414 let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
415 let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
416 assert_eq!(
417 and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
418 Ok((bool_x && bool_y).into())
419 );
420 assert_eq!(
421 or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
422 Ok((bool_x || bool_y).into())
423 );
424 assert_eq!(
425 xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
426 Ok((bool_x ^ bool_y).into())
427 );
428 }
429 }
430 }
431
432 #[test]
433 fn test_euclidian_division() {
434 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
435 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
436 let label = Label::task(0);
437
438 let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
439 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
440 }
441
442 #[test]
443 fn test_modulo() {
444 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
445 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
446 let label = Label::task(0);
447
448 let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
449 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
450 }
451
452 #[test]
453 fn test_signed_bit_extract() {
454 let x = -Scalar::<C>::from(9u32);
455 let label = Label::task(0);
456 for i in 0..5 {
457 let op = FieldPlaintextUnaryOp::BitExtract {
458 little_endian_bit_idx: i,
459 signed: true,
460 };
461 let result = op.eval::<ScalarField<C>>(label, x);
462 assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
463 }
464 }
465
466 #[test]
467 fn test_sqrt() {
468 let mut rng = rand::thread_rng();
469 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
470 let label = Label::task(0);
471 let result = FieldPlaintextUnaryOp::Sqrt
472 .eval::<ScalarField<C>>(label, x * x)
473 .unwrap();
474
475 assert_eq!(result * result, x * x)
476 }
477
478 #[test]
479 fn test_pow() {
480 let mut rng = rand::thread_rng();
481 let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
482 let label = Label::task(0);
483 let five = BoxedUint::from(vec![5u64]);
484 let five_inv = BoxedUint::from(vec![
485 14757395258967641281,
486 14757395258967641292,
487 14757395258967641292,
488 5534023222112865484,
489 ]);
490 let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
491 .eval::<BaseField<C>>(label, x)
492 .unwrap();
493 let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
494 .eval::<BaseField<C>>(label, x_pow_5)
495 .unwrap();
496
497 assert_eq!(x_again, x)
498 }
499}