1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6 algebra::{
7 elliptic_curve::{Curve, Point, Scalar},
8 field::{FieldExtension, SubfieldElement},
9 BoxedUint,
10 },
11 types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17 circuit::{
18 AlgebraicType,
19 BaseFieldPlaintext,
20 BaseFieldPlaintextBatch,
21 BitPlaintext,
22 BitPlaintextBatch,
23 PointPlaintext,
24 PointPlaintextBatch,
25 ScalarPlaintext,
26 ScalarPlaintextBatch,
27 },
28 errors::{AbortError, FaultyPeer},
29 types::Label,
30};
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub enum FieldPlaintextUnaryOp {
35 Neg,
36 MulInverse,
38 BitExtract {
40 little_endian_bit_idx: u16,
41 signed: bool,
42 },
43 Sqrt,
44 Pow {
45 exp: BoxedUint,
46 },
47}
48
49impl FieldPlaintextUnaryOp {
50 pub fn eval<F: FieldExtension>(
52 &self,
53 label: Label,
54 x: SubfieldElement<F>,
55 ) -> Result<SubfieldElement<F>, AbortError> {
56 match self {
57 FieldPlaintextUnaryOp::Neg => Ok(-x),
58 FieldPlaintextUnaryOp::MulInverse => {
59 if x == SubfieldElement::<F>::zero() {
60 Ok(SubfieldElement::<F>::zero())
61 } else {
62 Ok(x.invert().unwrap())
63 }
64 }
65 FieldPlaintextUnaryOp::BitExtract {
66 little_endian_bit_idx: idx,
67 signed,
68 } => {
69 let bit = if *signed && x > -x {
70 !(-SubfieldElement::<F>::one() - x)
71 .to_biguint()
72 .bit(*idx as u64)
73 } else {
74 x.to_biguint().bit(*idx as u64)
75 };
76 Ok(SubfieldElement::<F>::from(bit))
77 }
78 FieldPlaintextUnaryOp::Sqrt => {
79 let (choice, sqrt) =
80 SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
81 if !bool::from(choice) {
82 return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
83 }
84 Ok(sqrt)
85 }
86 FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
93pub enum FieldPlaintextBinaryOp {
94 Add,
95 Mul,
96 EuclDiv,
97 Mod,
98 Gt,
99 Ge,
100 Eq,
101 Xor,
102 Or,
103}
104
105impl FieldPlaintextBinaryOp {
106 pub fn eval<F: FieldExtension>(
107 &self,
108 x: SubfieldElement<F>,
109 y: SubfieldElement<F>,
110 label: Label,
111 ) -> Result<SubfieldElement<F>, AbortError> {
112 match self {
113 FieldPlaintextBinaryOp::Add => Ok(x + y),
114 FieldPlaintextBinaryOp::Mul => Ok(x * y),
115 FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
116 FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
117 FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
118 FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
119 FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
120 FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
121 FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
122 }
123 }
124}
125
126pub(crate) fn euclidean_division<F: FieldExtension>(
127 x: SubfieldElement<F>,
128 y: SubfieldElement<F>,
129 label: Label,
130) -> Result<SubfieldElement<F>, AbortError> {
131 if y == SubfieldElement::<F>::zero() {
132 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
133 }
134
135 let x = x.to_biguint();
137 let y = y.to_biguint();
138
139 let div = (x / y).to_bytes_be();
140 let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
142 .chain(div)
143 .collect::<Vec<_>>();
144
145 Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
146}
147
148fn modulo<F: FieldExtension>(
149 x: SubfieldElement<F>,
150 y: SubfieldElement<F>,
151 label: Label,
152) -> Result<SubfieldElement<F>, AbortError> {
153 if y == SubfieldElement::<F>::zero() {
154 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
155 }
156
157 let x = x.to_biguint();
159 let y = y.to_biguint();
160
161 let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
162 let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
164 .chain(modulo)
165 .collect::<Vec<_>>();
166
167 Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
172pub enum FieldShareUnaryOp {
173 Neg,
175 MulInverse,
177 Open,
179 IsZero,
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
186pub enum FieldShareBinaryOp {
187 Add,
189 Mul,
191}
192
193#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
195pub enum BitShareUnaryOp {
196 Not,
198 Open,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
204pub enum BitShareBinaryOp {
205 Xor,
207 Or,
209 And,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
215pub enum PointShareUnaryOp {
216 Neg,
218 Open,
220 IsZero,
222}
223
224#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(bound(
226 serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
227 deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
228))]
229pub enum Input<C: Curve> {
230 SecretPlaintext {
231 inputer: PeerNumber,
232 algebraic_type: AlgebraicType,
233 batched: Batched,
234 },
235 Share {
236 algebraic_type: AlgebraicType,
237 batched: Batched,
238 },
239 RandomShare {
240 algebraic_type: AlgebraicType,
241 batched: Batched,
242 },
243 Scalar(ScalarPlaintext<C>),
244 ScalarBatch(ScalarPlaintextBatch<C>),
245 BaseField(BaseFieldPlaintext<C>),
246 BaseFieldBatch(BaseFieldPlaintextBatch<C>),
247 Bit(BitPlaintext),
248 BitBatch(BitPlaintextBatch),
249 Point(PointPlaintext<C>),
250 PointBatch(PointPlaintextBatch<C>),
251 ElGamalCiphertext {
252 c: PointPlaintext<C>,
253 r: PointPlaintext<C>,
254 },
255}
256
257impl<C: Curve> Input<C> {
258 pub fn batched(&self) -> Batched {
259 match self {
260 Input::SecretPlaintext { batched, .. } => *batched,
261 Input::Share { batched, .. } => *batched,
262 Input::RandomShare { batched, .. } => *batched,
263 Input::ScalarBatch(input) => Batched::Yes(input.len()),
264 Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
265 Input::BitBatch(input) => Batched::Yes(input.len()),
266 Input::PointBatch(input) => Batched::Yes(input.len()),
267 Input::ElGamalCiphertext { .. } => Batched::No,
268 Input::Scalar { .. } => Batched::No,
269 Input::BaseField { .. } => Batched::No,
270 Input::Bit { .. } => Batched::No,
271 Input::Point { .. } => Batched::No,
272 }
273 }
274}
275
276#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
278pub enum PointShareBinaryOp {
279 Add,
281 ScalarMul,
283}
284
285#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
287pub enum Batched {
288 Yes(usize),
289 No,
290}
291
292impl Batched {
293 pub fn count(&self) -> usize {
294 match self {
295 Batched::Yes(count) => *count,
296 Batched::No => 1,
297 }
298 }
299
300 pub fn is_batched(&self) -> bool {
301 match self {
302 Batched::Yes(_) => true,
303 Batched::No => false,
304 }
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use primitives::algebra::{
311 elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
312 field::SubfieldElement,
313 };
314
315 use super::*;
316
317 #[test]
318 fn test_scalar_unary_op() {
319 let mut rng = rand::thread_rng();
320 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
321 let label = Label::task(0);
322 let neg = FieldPlaintextUnaryOp::Neg;
323 let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
324
325 assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
326 assert_eq!(
327 mul_inverse.eval::<ScalarField<C>>(label, x),
328 Ok(x.invert().unwrap())
329 );
330 }
331
332 #[test]
333 fn test_scalar_binary_op() {
334 let mut rng = rand::thread_rng();
335 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
336 let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
337 let label = Label::task(0);
338
339 let add = FieldPlaintextBinaryOp::Add;
340 let mul = FieldPlaintextBinaryOp::Mul;
341 let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
342 let modulo_op = FieldPlaintextBinaryOp::Mod;
343 let gt = FieldPlaintextBinaryOp::Gt;
344 let ge = FieldPlaintextBinaryOp::Ge;
345 let eq = FieldPlaintextBinaryOp::Eq;
346
347 assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
348 assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
349 assert_eq!(
350 eucl_div.eval::<ScalarField<C>>(x, y, label),
351 euclidean_division::<ScalarField<C>>(x, y, label)
352 );
353 assert_eq!(
354 modulo_op.eval::<ScalarField<C>>(x, y, label),
355 modulo::<ScalarField<C>>(x, y, label)
356 );
357 assert_eq!(
358 gt.eval::<ScalarField<C>>(x, y, label),
359 Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
360 );
361 assert_eq!(
362 ge.eval::<ScalarField<C>>(x, y, label),
363 Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
364 );
365 assert_eq!(
366 eq.eval::<ScalarField<C>>(x, y, label),
367 Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
368 );
369 }
370
371 #[test]
372 fn test_boolean_binary_op() {
373 let and = FieldPlaintextBinaryOp::Mul;
374 let or = FieldPlaintextBinaryOp::Or;
375 let xor = FieldPlaintextBinaryOp::Xor;
376 let label = Label::task(0);
377 for bool_x in [false, true] {
378 for bool_y in [false, true] {
379 let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
380 let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
381 assert_eq!(
382 and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
383 Ok((bool_x && bool_y).into())
384 );
385 assert_eq!(
386 or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
387 Ok((bool_x || bool_y).into())
388 );
389 assert_eq!(
390 xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
391 Ok((bool_x ^ bool_y).into())
392 );
393 }
394 }
395 }
396
397 #[test]
398 fn test_euclidian_division() {
399 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
400 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
401 let label = Label::task(0);
402
403 let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
404 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
405 }
406
407 #[test]
408 fn test_modulo() {
409 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
410 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
411 let label = Label::task(0);
412
413 let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
414 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
415 }
416
417 #[test]
418 fn test_signed_bit_extract() {
419 let x = -Scalar::<C>::from(9u32);
420 let label = Label::task(0);
421 for i in 0..5 {
422 let op = FieldPlaintextUnaryOp::BitExtract {
423 little_endian_bit_idx: i,
424 signed: true,
425 };
426 let result = op.eval::<ScalarField<C>>(label, x);
427 assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
428 }
429 }
430
431 #[test]
432 fn test_sqrt() {
433 let mut rng = rand::thread_rng();
434 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
435 let label = Label::task(0);
436 let result = FieldPlaintextUnaryOp::Sqrt
437 .eval::<ScalarField<C>>(label, x * x)
438 .unwrap();
439
440 assert_eq!(result * result, x * x)
441 }
442
443 #[test]
444 fn test_pow() {
445 let mut rng = rand::thread_rng();
446 let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
447 let label = Label::task(0);
448 let five = BoxedUint::from(vec![5u64]);
449 let five_inv = BoxedUint::from(vec![
450 14757395258967641281,
451 14757395258967641292,
452 14757395258967641292,
453 5534023222112865484,
454 ]);
455 let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
456 .eval::<BaseField<C>>(label, x)
457 .unwrap();
458 let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
459 .eval::<BaseField<C>>(label, x_pow_5)
460 .unwrap();
461
462 assert_eq!(x_again, x)
463 }
464}