core_utils/circuit/
ops.rs

1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6    algebra::{
7        elliptic_curve::{Curve, Point, Scalar},
8        field::{FieldExtension, SubfieldElement},
9        BoxedUint,
10    },
11    types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17    circuit::{
18        AlgebraicType,
19        BaseFieldPlaintext,
20        BaseFieldPlaintextBatch,
21        BitPlaintext,
22        BitPlaintextBatch,
23        Mersenne107Plaintext,
24        Mersenne107PlaintextBatch,
25        PointPlaintext,
26        PointPlaintextBatch,
27        ScalarPlaintext,
28        ScalarPlaintextBatch,
29    },
30    errors::{AbortError, FaultyPeer},
31    types::Label,
32};
33
34/// Enum representing unary operations on field element plaintexts.
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub enum FieldPlaintextUnaryOp {
37    Neg,
38    // Computes the multiplicative inverse. Note: for 0 we return 0.
39    MulInverse,
40    // Extracts a specified bit from a field element
41    BitExtract {
42        little_endian_bit_idx: u16,
43        signed: bool,
44    },
45    Sqrt,
46    Pow {
47        exp: BoxedUint,
48    },
49}
50
51impl FieldPlaintextUnaryOp {
52    // TODO: see if returning CtOption could make more sense
53    pub fn eval<F: FieldExtension>(
54        &self,
55        label: Label,
56        x: SubfieldElement<F>,
57    ) -> Result<SubfieldElement<F>, AbortError> {
58        match self {
59            FieldPlaintextUnaryOp::Neg => Ok(-x),
60            FieldPlaintextUnaryOp::MulInverse => {
61                if x == SubfieldElement::<F>::zero() {
62                    Ok(SubfieldElement::<F>::zero())
63                } else {
64                    Ok(x.invert().unwrap())
65                }
66            }
67            FieldPlaintextUnaryOp::BitExtract {
68                little_endian_bit_idx: idx,
69                signed,
70            } => {
71                let bit = if *signed && x > -x {
72                    !(-SubfieldElement::<F>::one() - x)
73                        .to_biguint()
74                        .bit(*idx as u64)
75                } else {
76                    x.to_biguint().bit(*idx as u64)
77                };
78                Ok(SubfieldElement::<F>::from(bit))
79            }
80            FieldPlaintextUnaryOp::Sqrt => {
81                let (choice, sqrt) =
82                    SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
83                if !bool::from(choice) {
84                    return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
85                }
86                Ok(sqrt)
87            }
88            FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
89        }
90    }
91}
92
93/// Enum representing binary operations on field element plaintexts.
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
95pub enum FieldPlaintextBinaryOp {
96    Add,
97    Mul,
98    EuclDiv,
99    Mod,
100    Gt,
101    Ge,
102    Eq,
103    Xor,
104    Or,
105}
106
107impl FieldPlaintextBinaryOp {
108    pub fn eval<F: FieldExtension>(
109        &self,
110        x: SubfieldElement<F>,
111        y: SubfieldElement<F>,
112        label: Label,
113    ) -> Result<SubfieldElement<F>, AbortError> {
114        match self {
115            FieldPlaintextBinaryOp::Add => Ok(x + y),
116            FieldPlaintextBinaryOp::Mul => Ok(x * y),
117            FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
118            FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
119            FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
120            FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
121            FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
122            FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
123            FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
124        }
125    }
126}
127
128pub(crate) fn euclidean_division<F: FieldExtension>(
129    x: SubfieldElement<F>,
130    y: SubfieldElement<F>,
131    label: Label,
132) -> Result<SubfieldElement<F>, AbortError> {
133    if y == SubfieldElement::<F>::zero() {
134        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
135    }
136
137    // Convert to BigUint
138    let x = x.to_biguint();
139    let y = y.to_biguint();
140
141    let div = (x / y).to_bytes_be();
142    // Pad with zeroes as big-endian
143    let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
144        .chain(div)
145        .collect::<Vec<_>>();
146
147    Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
148}
149
150fn modulo<F: FieldExtension>(
151    x: SubfieldElement<F>,
152    y: SubfieldElement<F>,
153    label: Label,
154) -> Result<SubfieldElement<F>, AbortError> {
155    if y == SubfieldElement::<F>::zero() {
156        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
157    }
158
159    // Convert to BigUint
160    let x = x.to_biguint();
161    let y = y.to_biguint();
162
163    let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
164    // Pad with zeroes as big-endian
165    let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
166        .chain(modulo)
167        .collect::<Vec<_>>();
168
169    Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
170}
171
172/// Enum representing unary operations on a field share.
173#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
174pub enum FieldShareUnaryOp {
175    /// Negation of a field share.
176    Neg,
177    /// Multiplicative inverse of a field share.
178    MulInverse,
179    /// Opens a field share to reveal the underlying value.
180    Open,
181    /// Checks if the field share is zero, returning a plaintext value.
182    IsZero,
183}
184
185/// Enum representing binary operations on field shares. This includes the case where the second
186/// operand is a plaintext value.
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
188pub enum FieldShareBinaryOp {
189    /// Addition of two field shares.
190    Add,
191    /// Multiplication of two field shares.
192    Mul,
193}
194
195/// Enum representing unary operations on binary shares
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
197pub enum BitShareUnaryOp {
198    /// NOT operation
199    Not,
200    /// Opens a bit share to reveal the underlying value.
201    Open,
202}
203
204/// Enum representing binary operations on binary shares.
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
206pub enum BitShareBinaryOp {
207    /// Exclusive OR operation on two bit shares.
208    Xor,
209    /// OR operation on two bit shares.
210    Or,
211    /// AND operation on two bit shares.
212    And,
213}
214
215/// Enum representing unary operations on plaintext points.
216#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
217pub enum PointPlaintextUnaryOp {
218    /// Negation of a point.
219    Neg,
220}
221
222impl PointPlaintextUnaryOp {
223    pub fn eval<C: Curve>(&self, x: Point<C>) -> Result<Point<C>, AbortError> {
224        match self {
225            PointPlaintextUnaryOp::Neg => Ok(-x),
226        }
227    }
228}
229
230/// Enum representing binary operations on plaintext points/scalars.
231#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
232pub enum PointPlaintextBinaryOp {
233    /// Addition of two plaintext points.
234    Add,
235    /// Multiplication of a plaintext point by a plaintext scalar.
236    ScalarMul,
237}
238
239impl PointPlaintextBinaryOp {
240    pub fn eval<C: Curve>(&self, x: Point<C>, y: Point<C>) -> Result<Point<C>, AbortError> {
241        match self {
242            PointPlaintextBinaryOp::Add => Ok(x + y),
243            PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
244                "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
245            )),
246        }
247    }
248}
249
250/// Enum representing unary operations on point shares.
251#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
252pub enum PointShareUnaryOp {
253    /// Negation of a point share.
254    Neg,
255    /// Opens a point share to reveal the underlying value.
256    Open,
257    /// Checks if the point share is zero, returning a plaintext value.
258    IsZero,
259}
260
261/// Enum representing binary operations on point shares.
262#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
263pub enum PointShareBinaryOp {
264    /// Addition of two point shares.
265    Add,
266    /// Multiplication of a point share by a scalar.
267    ScalarMul,
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(bound(
272    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
273    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
274))]
275#[derive(Hash)]
276pub enum Input<C: Curve> {
277    SecretPlaintext {
278        inputer: PeerNumber,
279        algebraic_type: AlgebraicType,
280        batched: Batched,
281    },
282    Share {
283        algebraic_type: AlgebraicType,
284        batched: Batched,
285    },
286    RandomShare {
287        algebraic_type: AlgebraicType,
288        batched: Batched,
289    },
290    Scalar(ScalarPlaintext<C>),
291    ScalarBatch(ScalarPlaintextBatch<C>),
292    BaseField(BaseFieldPlaintext<C>),
293    BaseFieldBatch(BaseFieldPlaintextBatch<C>),
294    Mersenne107(Mersenne107Plaintext),
295    Mersenne107Batch(Mersenne107PlaintextBatch),
296    Bit(BitPlaintext),
297    BitBatch(BitPlaintextBatch),
298    Point(PointPlaintext<C>),
299    PointBatch(PointPlaintextBatch<C>),
300}
301
302impl<C: Curve> Input<C> {
303    pub fn batched(&self) -> Batched {
304        match self {
305            Input::SecretPlaintext { batched, .. } => *batched,
306            Input::Share { batched, .. } => *batched,
307            Input::RandomShare { batched, .. } => *batched,
308            Input::ScalarBatch(input) => Batched::Yes(input.len()),
309            Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
310            Input::Mersenne107Batch(input) => Batched::Yes(input.len()),
311            Input::BitBatch(input) => Batched::Yes(input.len()),
312            Input::PointBatch(input) => Batched::Yes(input.len()),
313            Input::Scalar { .. } => Batched::No,
314            Input::BaseField { .. } => Batched::No,
315            Input::Mersenne107 { .. } => Batched::No,
316            Input::Bit { .. } => Batched::No,
317            Input::Point { .. } => Batched::No,
318        }
319    }
320}
321
322/// Enum representing whether a gate is batched or not.
323#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
324pub enum Batched {
325    Yes(usize),
326    No,
327}
328
329impl Batched {
330    pub fn count(&self) -> usize {
331        match self {
332            Batched::Yes(count) => *count,
333            Batched::No => 1,
334        }
335    }
336
337    pub fn is_batched(&self) -> bool {
338        match self {
339            Batched::Yes(_) => true,
340            Batched::No => false,
341        }
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use primitives::algebra::{
348        elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
349        field::SubfieldElement,
350    };
351
352    use super::*;
353
354    #[test]
355    fn test_scalar_unary_op() {
356        let mut rng = rand::thread_rng();
357        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
358        let label = Label::task(0);
359        let neg = FieldPlaintextUnaryOp::Neg;
360        let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
361
362        assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
363        assert_eq!(
364            mul_inverse.eval::<ScalarField<C>>(label, x),
365            Ok(x.invert().unwrap())
366        );
367    }
368
369    #[test]
370    fn test_scalar_binary_op() {
371        let mut rng = rand::thread_rng();
372        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
373        let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
374        let label = Label::task(0);
375
376        let add = FieldPlaintextBinaryOp::Add;
377        let mul = FieldPlaintextBinaryOp::Mul;
378        let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
379        let modulo_op = FieldPlaintextBinaryOp::Mod;
380        let gt = FieldPlaintextBinaryOp::Gt;
381        let ge = FieldPlaintextBinaryOp::Ge;
382        let eq = FieldPlaintextBinaryOp::Eq;
383
384        assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
385        assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
386        assert_eq!(
387            eucl_div.eval::<ScalarField<C>>(x, y, label),
388            euclidean_division::<ScalarField<C>>(x, y, label)
389        );
390        assert_eq!(
391            modulo_op.eval::<ScalarField<C>>(x, y, label),
392            modulo::<ScalarField<C>>(x, y, label)
393        );
394        assert_eq!(
395            gt.eval::<ScalarField<C>>(x, y, label),
396            Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
397        );
398        assert_eq!(
399            ge.eval::<ScalarField<C>>(x, y, label),
400            Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
401        );
402        assert_eq!(
403            eq.eval::<ScalarField<C>>(x, y, label),
404            Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
405        );
406    }
407
408    #[test]
409    fn test_boolean_binary_op() {
410        let and = FieldPlaintextBinaryOp::Mul;
411        let or = FieldPlaintextBinaryOp::Or;
412        let xor = FieldPlaintextBinaryOp::Xor;
413        let label = Label::task(0);
414        for bool_x in [false, true] {
415            for bool_y in [false, true] {
416                let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
417                let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
418                assert_eq!(
419                    and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
420                    Ok((bool_x && bool_y).into())
421                );
422                assert_eq!(
423                    or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
424                    Ok((bool_x || bool_y).into())
425                );
426                assert_eq!(
427                    xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
428                    Ok((bool_x ^ bool_y).into())
429                );
430            }
431        }
432    }
433
434    #[test]
435    fn test_euclidian_division() {
436        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
437        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
438        let label = Label::task(0);
439
440        let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
441        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
442    }
443
444    #[test]
445    fn test_modulo() {
446        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
447        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
448        let label = Label::task(0);
449
450        let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
451        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
452    }
453
454    #[test]
455    fn test_signed_bit_extract() {
456        let x = -Scalar::<C>::from(9u32);
457        let label = Label::task(0);
458        for i in 0..5 {
459            let op = FieldPlaintextUnaryOp::BitExtract {
460                little_endian_bit_idx: i,
461                signed: true,
462            };
463            let result = op.eval::<ScalarField<C>>(label, x);
464            assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
465        }
466    }
467
468    #[test]
469    fn test_sqrt() {
470        let mut rng = rand::thread_rng();
471        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
472        let label = Label::task(0);
473        let result = FieldPlaintextUnaryOp::Sqrt
474            .eval::<ScalarField<C>>(label, x * x)
475            .unwrap();
476
477        assert_eq!(result * result, x * x)
478    }
479
480    #[test]
481    fn test_pow() {
482        let mut rng = rand::thread_rng();
483        let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
484        let label = Label::task(0);
485        let five = BoxedUint::from(vec![5u64]);
486        let five_inv = BoxedUint::from(vec![
487            14757395258967641281,
488            14757395258967641292,
489            14757395258967641292,
490            5534023222112865484,
491        ]);
492        let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
493            .eval::<BaseField<C>>(label, x)
494            .unwrap();
495        let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
496            .eval::<BaseField<C>>(label, x_pow_5)
497            .unwrap();
498
499        assert_eq!(x_again, x)
500    }
501}