core_utils/circuit/
ops.rs

1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::algebra::{
6    field::{FieldExtension, SubfieldElement},
7    BoxedUint,
8};
9use serde::{Deserialize, Serialize};
10use typenum::Unsigned;
11
12use crate::{
13    errors::{AbortError, FaultyPeer},
14    types::Label,
15};
16
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub enum FieldUnaryOp {
19    Neg,
20    // Computes the multiplicative inverse. Note: for 0 we return 0.
21    MulInverse,
22    // Extracts a specified bit from a field element
23    BitExtract {
24        little_endian_bit_idx: u16,
25        signed: bool,
26    },
27    Sqrt,
28    Pow {
29        exp: BoxedUint,
30    },
31}
32
33impl FieldUnaryOp {
34    // TODO: see if returning CtOption could make more sense
35    pub fn eval<F: FieldExtension>(
36        &self,
37        label: Label,
38        x: SubfieldElement<F>,
39    ) -> Result<SubfieldElement<F>, AbortError> {
40        match self {
41            FieldUnaryOp::Neg => Ok(-x),
42            FieldUnaryOp::MulInverse => {
43                if x == SubfieldElement::<F>::zero() {
44                    Ok(SubfieldElement::<F>::zero())
45                } else {
46                    Ok(x.invert().unwrap())
47                }
48            }
49            FieldUnaryOp::BitExtract {
50                little_endian_bit_idx: idx,
51                signed,
52            } => {
53                let bit = if *signed && x > -x {
54                    !(-SubfieldElement::<F>::one() - x)
55                        .to_biguint()
56                        .bit(*idx as u64)
57                } else {
58                    x.to_biguint().bit(*idx as u64)
59                };
60                Ok(SubfieldElement::<F>::from(bit))
61            }
62            FieldUnaryOp::Sqrt => {
63                let (choice, sqrt) =
64                    SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
65                if !bool::from(choice) {
66                    return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
67                }
68                Ok(sqrt)
69            }
70            FieldUnaryOp::Pow { exp } => Ok(x.pow(exp)),
71        }
72    }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
76pub enum FieldBinaryOp {
77    Add,
78    Mul,
79    EuclDiv,
80    Mod,
81    Gt,
82    Ge,
83    Eq,
84    Xor,
85    Or,
86}
87
88impl FieldBinaryOp {
89    pub fn eval<F: FieldExtension>(
90        &self,
91        x: SubfieldElement<F>,
92        y: SubfieldElement<F>,
93        label: Label,
94    ) -> Result<SubfieldElement<F>, AbortError> {
95        match self {
96            FieldBinaryOp::Add => Ok(x + y),
97            FieldBinaryOp::Mul => Ok(x * y),
98            FieldBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
99            FieldBinaryOp::Mod => modulo::<F>(x, y, label),
100            FieldBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
101            FieldBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
102            FieldBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
103            FieldBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
104            FieldBinaryOp::Or => Ok(x + y - x * y),
105        }
106    }
107}
108
109fn euclidean_division<F: FieldExtension>(
110    x: SubfieldElement<F>,
111    y: SubfieldElement<F>,
112    label: Label,
113) -> Result<SubfieldElement<F>, AbortError> {
114    if y == SubfieldElement::<F>::zero() {
115        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
116    }
117
118    // Convert to BigUint
119    let x = x.to_biguint();
120    let y = y.to_biguint();
121
122    let div = (x / y).to_bytes_be();
123    // Pad with zeroes as big-endian
124    let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
125        .chain(div)
126        .collect::<Vec<_>>();
127
128    Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
129}
130
131fn modulo<F: FieldExtension>(
132    x: SubfieldElement<F>,
133    y: SubfieldElement<F>,
134    label: Label,
135) -> Result<SubfieldElement<F>, AbortError> {
136    if y == SubfieldElement::<F>::zero() {
137        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
138    }
139
140    // Convert to BigUint
141    let x = x.to_biguint();
142    let y = y.to_biguint();
143
144    let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
145    // Pad with zeroes as big-endian
146    let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
147        .chain(modulo)
148        .collect::<Vec<_>>();
149
150    Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
151}
152
153#[cfg(test)]
154mod tests {
155    use primitives::algebra::{
156        elliptic_curve::{BaseField, Curve25519Ristretto as C, Scalar, ScalarField},
157        field::SubfieldElement,
158    };
159
160    use super::*;
161
162    #[test]
163    fn test_scalar_unary_op() {
164        let mut rng = primitives::random::test_rng();
165        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
166        let label = Label::task(0);
167        let neg = FieldUnaryOp::Neg;
168        let mul_inverse = FieldUnaryOp::MulInverse;
169
170        assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
171        assert_eq!(
172            mul_inverse.eval::<ScalarField<C>>(label, x),
173            Ok(x.invert().unwrap())
174        );
175    }
176
177    #[test]
178    fn test_scalar_binary_op() {
179        let mut rng = primitives::random::test_rng();
180        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
181        let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
182        let label = Label::task(0);
183
184        let add = FieldBinaryOp::Add;
185        let mul = FieldBinaryOp::Mul;
186        let eucl_div = FieldBinaryOp::EuclDiv;
187        let modulo_op = FieldBinaryOp::Mod;
188        let gt = FieldBinaryOp::Gt;
189        let ge = FieldBinaryOp::Ge;
190        let eq = FieldBinaryOp::Eq;
191
192        assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
193        assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
194        assert_eq!(
195            eucl_div.eval::<ScalarField<C>>(x, y, label),
196            euclidean_division::<ScalarField<C>>(x, y, label)
197        );
198        assert_eq!(
199            modulo_op.eval::<ScalarField<C>>(x, y, label),
200            modulo::<ScalarField<C>>(x, y, label)
201        );
202        assert_eq!(
203            gt.eval::<ScalarField<C>>(x, y, label),
204            Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
205        );
206        assert_eq!(
207            ge.eval::<ScalarField<C>>(x, y, label),
208            Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
209        );
210        assert_eq!(
211            eq.eval::<ScalarField<C>>(x, y, label),
212            Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
213        );
214    }
215
216    #[test]
217    fn test_boolean_binary_op() {
218        let and = FieldBinaryOp::Mul;
219        let or = FieldBinaryOp::Or;
220        let xor = FieldBinaryOp::Xor;
221        let label = Label::task(0);
222        for bool_x in [false, true] {
223            for bool_y in [false, true] {
224                let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
225                let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
226                assert_eq!(
227                    and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
228                    Ok((bool_x && bool_y).into())
229                );
230                assert_eq!(
231                    or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
232                    Ok((bool_x || bool_y).into())
233                );
234                assert_eq!(
235                    xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
236                    Ok((bool_x ^ bool_y).into())
237                );
238            }
239        }
240    }
241
242    #[test]
243    fn test_euclidian_division() {
244        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
245        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
246        let label = Label::task(0);
247
248        let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
249        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
250    }
251
252    #[test]
253    fn test_modulo() {
254        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
255        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
256        let label = Label::task(0);
257
258        let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
259        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
260    }
261
262    #[test]
263    fn test_signed_bit_extract() {
264        let x = -Scalar::<C>::from(9u32);
265        let label = Label::task(0);
266        for i in 0..5 {
267            let op = FieldUnaryOp::BitExtract {
268                little_endian_bit_idx: i,
269                signed: true,
270            };
271            let result = op.eval::<ScalarField<C>>(label, x);
272            assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
273        }
274    }
275
276    #[test]
277    fn test_sqrt() {
278        let mut rng = primitives::random::test_rng();
279        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
280        let label = Label::task(0);
281        let result = FieldUnaryOp::Sqrt
282            .eval::<ScalarField<C>>(label, x * x)
283            .unwrap();
284
285        assert_eq!(result * result, x * x)
286    }
287
288    #[test]
289    fn test_pow() {
290        let mut rng = primitives::random::test_rng();
291        let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
292        let label = Label::task(0);
293        let five = BoxedUint::from(vec![5u64]);
294        let five_inv = BoxedUint::from(vec![
295            14757395258967641281,
296            14757395258967641292,
297            14757395258967641292,
298            5534023222112865484,
299        ]);
300        let x_pow_5 = FieldUnaryOp::Pow { exp: five }
301            .eval::<BaseField<C>>(label, x)
302            .unwrap();
303        let x_again = FieldUnaryOp::Pow { exp: five_inv }
304            .eval::<BaseField<C>>(label, x_pow_5)
305            .unwrap();
306
307        assert_eq!(x_again, x)
308    }
309}