1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::algebra::{
6 field::{FieldExtension, SubfieldElement},
7 BoxedUint,
8};
9use serde::{Deserialize, Serialize};
10use typenum::Unsigned;
11
12use crate::{
13 errors::{AbortError, FaultyPeer},
14 types::Label,
15};
16
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub enum FieldUnaryOp {
19 Neg,
20 MulInverse,
22 BitExtract {
24 little_endian_bit_idx: u16,
25 signed: bool,
26 },
27 Sqrt,
28 Pow {
29 exp: BoxedUint,
30 },
31}
32
33impl FieldUnaryOp {
34 pub fn eval<F: FieldExtension>(
36 &self,
37 label: Label,
38 x: SubfieldElement<F>,
39 ) -> Result<SubfieldElement<F>, AbortError> {
40 match self {
41 FieldUnaryOp::Neg => Ok(-x),
42 FieldUnaryOp::MulInverse => {
43 if x == SubfieldElement::<F>::zero() {
44 Ok(SubfieldElement::<F>::zero())
45 } else {
46 Ok(x.invert().unwrap())
47 }
48 }
49 FieldUnaryOp::BitExtract {
50 little_endian_bit_idx: idx,
51 signed,
52 } => {
53 let bit = if *signed && x > -x {
54 !(-SubfieldElement::<F>::one() - x)
55 .to_biguint()
56 .bit(*idx as u64)
57 } else {
58 x.to_biguint().bit(*idx as u64)
59 };
60 Ok(SubfieldElement::<F>::from(bit))
61 }
62 FieldUnaryOp::Sqrt => {
63 let (choice, sqrt) =
64 SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
65 if !bool::from(choice) {
66 return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
67 }
68 Ok(sqrt)
69 }
70 FieldUnaryOp::Pow { exp } => Ok(x.pow(exp)),
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
76pub enum FieldBinaryOp {
77 Add,
78 Mul,
79 EuclDiv,
80 Mod,
81 Gt,
82 Ge,
83 Eq,
84 Xor,
85 Or,
86}
87
88impl FieldBinaryOp {
89 pub fn eval<F: FieldExtension>(
90 &self,
91 x: SubfieldElement<F>,
92 y: SubfieldElement<F>,
93 label: Label,
94 ) -> Result<SubfieldElement<F>, AbortError> {
95 match self {
96 FieldBinaryOp::Add => Ok(x + y),
97 FieldBinaryOp::Mul => Ok(x * y),
98 FieldBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
99 FieldBinaryOp::Mod => modulo::<F>(x, y, label),
100 FieldBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
101 FieldBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
102 FieldBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
103 FieldBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
104 FieldBinaryOp::Or => Ok(x + y - x * y),
105 }
106 }
107}
108
109fn euclidean_division<F: FieldExtension>(
110 x: SubfieldElement<F>,
111 y: SubfieldElement<F>,
112 label: Label,
113) -> Result<SubfieldElement<F>, AbortError> {
114 if y == SubfieldElement::<F>::zero() {
115 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
116 }
117
118 let x = x.to_biguint();
120 let y = y.to_biguint();
121
122 let div = (x / y).to_bytes_be();
123 let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
125 .chain(div)
126 .collect::<Vec<_>>();
127
128 Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
129}
130
131fn modulo<F: FieldExtension>(
132 x: SubfieldElement<F>,
133 y: SubfieldElement<F>,
134 label: Label,
135) -> Result<SubfieldElement<F>, AbortError> {
136 if y == SubfieldElement::<F>::zero() {
137 return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
138 }
139
140 let x = x.to_biguint();
142 let y = y.to_biguint();
143
144 let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
145 let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
147 .chain(modulo)
148 .collect::<Vec<_>>();
149
150 Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
151}
152
153#[cfg(test)]
154mod tests {
155 use primitives::algebra::{
156 elliptic_curve::{BaseField, Curve25519Ristretto as C, Scalar, ScalarField},
157 field::SubfieldElement,
158 };
159
160 use super::*;
161
162 #[test]
163 fn test_scalar_unary_op() {
164 let mut rng = primitives::random::test_rng();
165 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
166 let label = Label::task(0);
167 let neg = FieldUnaryOp::Neg;
168 let mul_inverse = FieldUnaryOp::MulInverse;
169
170 assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
171 assert_eq!(
172 mul_inverse.eval::<ScalarField<C>>(label, x),
173 Ok(x.invert().unwrap())
174 );
175 }
176
177 #[test]
178 fn test_scalar_binary_op() {
179 let mut rng = primitives::random::test_rng();
180 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
181 let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
182 let label = Label::task(0);
183
184 let add = FieldBinaryOp::Add;
185 let mul = FieldBinaryOp::Mul;
186 let eucl_div = FieldBinaryOp::EuclDiv;
187 let modulo_op = FieldBinaryOp::Mod;
188 let gt = FieldBinaryOp::Gt;
189 let ge = FieldBinaryOp::Ge;
190 let eq = FieldBinaryOp::Eq;
191
192 assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
193 assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
194 assert_eq!(
195 eucl_div.eval::<ScalarField<C>>(x, y, label),
196 euclidean_division::<ScalarField<C>>(x, y, label)
197 );
198 assert_eq!(
199 modulo_op.eval::<ScalarField<C>>(x, y, label),
200 modulo::<ScalarField<C>>(x, y, label)
201 );
202 assert_eq!(
203 gt.eval::<ScalarField<C>>(x, y, label),
204 Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
205 );
206 assert_eq!(
207 ge.eval::<ScalarField<C>>(x, y, label),
208 Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
209 );
210 assert_eq!(
211 eq.eval::<ScalarField<C>>(x, y, label),
212 Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
213 );
214 }
215
216 #[test]
217 fn test_boolean_binary_op() {
218 let and = FieldBinaryOp::Mul;
219 let or = FieldBinaryOp::Or;
220 let xor = FieldBinaryOp::Xor;
221 let label = Label::task(0);
222 for bool_x in [false, true] {
223 for bool_y in [false, true] {
224 let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
225 let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
226 assert_eq!(
227 and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
228 Ok((bool_x && bool_y).into())
229 );
230 assert_eq!(
231 or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
232 Ok((bool_x || bool_y).into())
233 );
234 assert_eq!(
235 xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
236 Ok((bool_x ^ bool_y).into())
237 );
238 }
239 }
240 }
241
242 #[test]
243 fn test_euclidian_division() {
244 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
245 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
246 let label = Label::task(0);
247
248 let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
249 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
250 }
251
252 #[test]
253 fn test_modulo() {
254 let x = SubfieldElement::<ScalarField<C>>::from(37u32);
255 let y = SubfieldElement::<ScalarField<C>>::from(12u32);
256 let label = Label::task(0);
257
258 let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
259 assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
260 }
261
262 #[test]
263 fn test_signed_bit_extract() {
264 let x = -Scalar::<C>::from(9u32);
265 let label = Label::task(0);
266 for i in 0..5 {
267 let op = FieldUnaryOp::BitExtract {
268 little_endian_bit_idx: i,
269 signed: true,
270 };
271 let result = op.eval::<ScalarField<C>>(label, x);
272 assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
273 }
274 }
275
276 #[test]
277 fn test_sqrt() {
278 let mut rng = primitives::random::test_rng();
279 let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
280 let label = Label::task(0);
281 let result = FieldUnaryOp::Sqrt
282 .eval::<ScalarField<C>>(label, x * x)
283 .unwrap();
284
285 assert_eq!(result * result, x * x)
286 }
287
288 #[test]
289 fn test_pow() {
290 let mut rng = primitives::random::test_rng();
291 let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
292 let label = Label::task(0);
293 let five = BoxedUint::from(vec![5u64]);
294 let five_inv = BoxedUint::from(vec![
295 14757395258967641281,
296 14757395258967641292,
297 14757395258967641292,
298 5534023222112865484,
299 ]);
300 let x_pow_5 = FieldUnaryOp::Pow { exp: five }
301 .eval::<BaseField<C>>(label, x)
302 .unwrap();
303 let x_again = FieldUnaryOp::Pow { exp: five_inv }
304 .eval::<BaseField<C>>(label, x_pow_5)
305 .unwrap();
306
307 assert_eq!(x_again, x)
308 }
309}