1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6 algebra::{
7 elliptic_curve::{Curve, Point, Scalar},
8 field::{FieldExtension, SubfieldElement},
9 BoxedUint,
10 },
11 types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17 circuit::{
18 AlgebraicType,
19 BaseFieldPlaintext,
20 BaseFieldPlaintextBatch,
21 BitPlaintext,
22 BitPlaintextBatch,
23 Mersenne107Plaintext,
24 Mersenne107PlaintextBatch,
25 PointPlaintext,
26 PointPlaintextBatch,
27 ScalarPlaintext,
28 ScalarPlaintextBatch,
29 },
30 errors::{AbortError, FaultyPeer},
31 types::Label,
32};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub enum FieldPlaintextUnaryOp {
37 Neg,
38 MulInverse,
40 BitExtract {
42 little_endian_bit_idx: u16,
43 signed: bool,
44 },
45 Sqrt,
46 Pow {
47 exp: BoxedUint,
48 },
49}
50
51impl FieldPlaintextUnaryOp {
52 pub fn eval<F: FieldExtension>(
54 &self,
55 label: Label,
56 x: SubfieldElement<F>,
57 ) -> Result<SubfieldElement<F>, AbortError> {
58 match self {
59 FieldPlaintextUnaryOp::Neg => Ok(-x),
60 FieldPlaintextUnaryOp::MulInverse => {
61 if x == SubfieldElement::<F>::zero() {
62 Ok(SubfieldElement::<F>::zero())
63 } else {
64 Ok(x.invert().unwrap())
65 }
66 }
67 FieldPlaintextUnaryOp::BitExtract {
68 little_endian_bit_idx: idx,
69 signed,
70 } => {
71 let bit = if *signed && x > -x {
72 !(-SubfieldElement::<F>::one() - x)
73 .to_biguint()
74 .bit(*idx as u64)
75 } else {
76 x.to_biguint().bit(*idx as u64)
77 };
78 Ok(SubfieldElement::<F>::from(bit))
79 }
80 FieldPlaintextUnaryOp::Sqrt => {
81 let (choice, sqrt) =
82 SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
83 if !bool::from(choice) {
84 return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
85 }
86 Ok(sqrt)
87 }
88 FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
95pub enum FieldPlaintextBinaryOp {
96 Add,
97 Mul,
98 EuclDiv,
99 Mod,
100 Gt,
101 Ge,
102 Eq,
103 Xor,
104 Or,
105}
106
107impl FieldPlaintextBinaryOp {
108 pub fn eval<F: FieldExtension>(
109 &self,
110 x: SubfieldElement<F>,
111 y: SubfieldElement<F>,
112 label: Label,
113 ) -> Result<SubfieldElement<F>, AbortError> {
114 match self {
115 FieldPlaintextBinaryOp::Add => Ok(x + y),
116 FieldPlaintextBinaryOp::Mul => Ok(x * y),
117 FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
118 FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
119 FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
120 FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
121 FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
122 FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
123 FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
124 }
125 }
126}
127
128pub(crate) fn euclidean_division<F: FieldExtension>(
129 x: SubfieldElement<F>,
130 y: SubfieldElement<F>,
131 label: Label,
132) -> Result<SubfieldElement<F>, AbortError> {
133 if y == SubfieldElement::<F>::zero() {
134 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
135 }
136
137 let x = x.to_biguint();
139 let y = y.to_biguint();
140
141 let div = (x / y).to_bytes_be();
142 let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
144 .chain(div)
145 .collect::<Vec<_>>();
146
147 Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
148}
149
150fn modulo<F: FieldExtension>(
151 x: SubfieldElement<F>,
152 y: SubfieldElement<F>,
153 label: Label,
154) -> Result<SubfieldElement<F>, AbortError> {
155 if y == SubfieldElement::<F>::zero() {
156 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
157 }
158
159 let x = x.to_biguint();
161 let y = y.to_biguint();
162
163 let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
164 let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
166 .chain(modulo)
167 .collect::<Vec<_>>();
168
169 Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
170}
171
172#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
174pub enum FieldShareUnaryOp {
175 Neg,
177 MulInverse,
179 Open,
181 IsZero,
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
188pub enum FieldShareBinaryOp {
189 Add,
191 Mul,
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
197pub enum BitShareUnaryOp {
198 Not,
200 Open,
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
206pub enum BitShareBinaryOp {
207 Xor,
209 Or,
211 And,
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
217pub enum PointPlaintextUnaryOp {
218 Neg,
220}
221
222impl PointPlaintextUnaryOp {
223 pub fn eval<C: Curve>(&self, x: Point<C>) -> Result<Point<C>, AbortError> {
224 match self {
225 PointPlaintextUnaryOp::Neg => Ok(-x),
226 }
227 }
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
232pub enum PointPlaintextBinaryOp {
233 Add,
235 ScalarMul,
237}
238
239impl PointPlaintextBinaryOp {
240 pub fn eval<C: Curve>(&self, x: Point<C>, y: Point<C>) -> Result<Point<C>, AbortError> {
241 match self {
242 PointPlaintextBinaryOp::Add => Ok(x + y),
243 PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
244 "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
245 )),
246 }
247 }
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
252pub enum PointShareUnaryOp {
253 Neg,
255 Open,
257 IsZero,
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
263pub enum PointShareBinaryOp {
264 Add,
266 ScalarMul,
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(bound(
272 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
273 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
274))]
275#[derive(Hash)]
276pub enum Input<C: Curve> {
277 SecretPlaintext {
278 inputer: PeerNumber,
279 algebraic_type: AlgebraicType,
280 batched: Batched,
281 },
282 Share {
283 algebraic_type: AlgebraicType,
284 batched: Batched,
285 },
286 RandomShare {
287 algebraic_type: AlgebraicType,
288 batched: Batched,
289 },
290 Scalar(ScalarPlaintext<C>),
291 ScalarBatch(ScalarPlaintextBatch<C>),
292 BaseField(BaseFieldPlaintext<C>),
293 BaseFieldBatch(BaseFieldPlaintextBatch<C>),
294 Mersenne107(Mersenne107Plaintext),
295 Mersenne107Batch(Mersenne107PlaintextBatch),
296 Bit(BitPlaintext),
297 BitBatch(BitPlaintextBatch),
298 Point(PointPlaintext<C>),
299 PointBatch(PointPlaintextBatch<C>),
300}
301
302impl<C: Curve> Input<C> {
303 pub fn batched(&self) -> Batched {
304 match self {
305 Input::SecretPlaintext { batched, .. } => *batched,
306 Input::Share { batched, .. } => *batched,
307 Input::RandomShare { batched, .. } => *batched,
308 Input::ScalarBatch(input) => Batched::Yes(input.len()),
309 Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
310 Input::Mersenne107Batch(input) => Batched::Yes(input.len()),
311 Input::BitBatch(input) => Batched::Yes(input.len()),
312 Input::PointBatch(input) => Batched::Yes(input.len()),
313 Input::Scalar { .. } => Batched::No,
314 Input::BaseField { .. } => Batched::No,
315 Input::Mersenne107 { .. } => Batched::No,
316 Input::Bit { .. } => Batched::No,
317 Input::Point { .. } => Batched::No,
318 }
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
324pub enum Batched {
325 Yes(usize),
326 No,
327}
328
329impl Batched {
330 pub fn count(&self) -> usize {
331 match self {
332 Batched::Yes(count) => *count,
333 Batched::No => 1,
334 }
335 }
336
337 pub fn is_batched(&self) -> bool {
338 match self {
339 Batched::Yes(_) => true,
340 Batched::No => false,
341 }
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use primitives::algebra::{
348 elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
349 field::SubfieldElement,
350 };
351
352 use super::*;
353
354 #[test]
355 fn test_scalar_unary_op() {
356 let mut rng = rand::thread_rng();
357 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
358 let label = Label::task(0);
359 let neg = FieldPlaintextUnaryOp::Neg;
360 let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
361
362 assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
363 assert_eq!(
364 mul_inverse.eval::<ScalarField<C>>(label, x),
365 Ok(x.invert().unwrap())
366 );
367 }
368
369 #[test]
370 fn test_scalar_binary_op() {
371 let mut rng = rand::thread_rng();
372 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
373 let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
374 let label = Label::task(0);
375
376 let add = FieldPlaintextBinaryOp::Add;
377 let mul = FieldPlaintextBinaryOp::Mul;
378 let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
379 let modulo_op = FieldPlaintextBinaryOp::Mod;
380 let gt = FieldPlaintextBinaryOp::Gt;
381 let ge = FieldPlaintextBinaryOp::Ge;
382 let eq = FieldPlaintextBinaryOp::Eq;
383
384 assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
385 assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
386 assert_eq!(
387 eucl_div.eval::<ScalarField<C>>(x, y, label),
388 euclidean_division::<ScalarField<C>>(x, y, label)
389 );
390 assert_eq!(
391 modulo_op.eval::<ScalarField<C>>(x, y, label),
392 modulo::<ScalarField<C>>(x, y, label)
393 );
394 assert_eq!(
395 gt.eval::<ScalarField<C>>(x, y, label),
396 Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
397 );
398 assert_eq!(
399 ge.eval::<ScalarField<C>>(x, y, label),
400 Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
401 );
402 assert_eq!(
403 eq.eval::<ScalarField<C>>(x, y, label),
404 Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
405 );
406 }
407
408 #[test]
409 fn test_boolean_binary_op() {
410 let and = FieldPlaintextBinaryOp::Mul;
411 let or = FieldPlaintextBinaryOp::Or;
412 let xor = FieldPlaintextBinaryOp::Xor;
413 let label = Label::task(0);
414 for bool_x in [false, true] {
415 for bool_y in [false, true] {
416 let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
417 let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
418 assert_eq!(
419 and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
420 Ok((bool_x && bool_y).into())
421 );
422 assert_eq!(
423 or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
424 Ok((bool_x || bool_y).into())
425 );
426 assert_eq!(
427 xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
428 Ok((bool_x ^ bool_y).into())
429 );
430 }
431 }
432 }
433
434 #[test]
435 fn test_euclidian_division() {
436 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
437 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
438 let label = Label::task(0);
439
440 let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
441 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
442 }
443
444 #[test]
445 fn test_modulo() {
446 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
447 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
448 let label = Label::task(0);
449
450 let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
451 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
452 }
453
454 #[test]
455 fn test_signed_bit_extract() {
456 let x = -Scalar::<C>::from(9u32);
457 let label = Label::task(0);
458 for i in 0..5 {
459 let op = FieldPlaintextUnaryOp::BitExtract {
460 little_endian_bit_idx: i,
461 signed: true,
462 };
463 let result = op.eval::<ScalarField<C>>(label, x);
464 assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
465 }
466 }
467
468 #[test]
469 fn test_sqrt() {
470 let mut rng = rand::thread_rng();
471 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
472 let label = Label::task(0);
473 let result = FieldPlaintextUnaryOp::Sqrt
474 .eval::<ScalarField<C>>(label, x * x)
475 .unwrap();
476
477 assert_eq!(result * result, x * x)
478 }
479
480 #[test]
481 fn test_pow() {
482 let mut rng = rand::thread_rng();
483 let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
484 let label = Label::task(0);
485 let five = BoxedUint::from(vec![5u64]);
486 let five_inv = BoxedUint::from(vec![
487 14757395258967641281,
488 14757395258967641292,
489 14757395258967641292,
490 5534023222112865484,
491 ]);
492 let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
493 .eval::<BaseField<C>>(label, x)
494 .unwrap();
495 let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
496 .eval::<BaseField<C>>(label, x_pow_5)
497 .unwrap();
498
499 assert_eq!(x_again, x)
500 }
501}