core_utils/circuit/
ops.rs

1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6    algebra::{
7        elliptic_curve::{Curve, Point, Scalar},
8        field::{FieldExtension, SubfieldElement},
9        BoxedUint,
10    },
11    types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17    circuit::{
18        AlgebraicType,
19        BaseFieldPlaintext,
20        BaseFieldPlaintextBatch,
21        BitPlaintext,
22        BitPlaintextBatch,
23        Mersenne107Plaintext,
24        Mersenne107PlaintextBatch,
25        PointPlaintext,
26        PointPlaintextBatch,
27        ScalarPlaintext,
28        ScalarPlaintextBatch,
29    },
30    errors::{AbortError, FaultyPeer},
31    types::Label,
32};
33
34/// Enum representing unary operations on field element plaintexts.
35#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
36pub enum FieldPlaintextUnaryOp {
37    Neg,
38    // Computes the multiplicative inverse. Note: for 0 we return 0.
39    MulInverse,
40    // Extracts a specified bit from a field element
41    BitExtract {
42        little_endian_bit_idx: u16,
43        signed: bool,
44    },
45    Sqrt,
46    Pow {
47        exp: BoxedUint,
48    },
49}
50
51impl FieldPlaintextUnaryOp {
52    // TODO: see if returning CtOption could make more sense
53    pub fn eval<F: FieldExtension>(
54        &self,
55        label: Label,
56        x: SubfieldElement<F>,
57    ) -> Result<SubfieldElement<F>, AbortError> {
58        match self {
59            FieldPlaintextUnaryOp::Neg => Ok(-x),
60            FieldPlaintextUnaryOp::MulInverse => {
61                if x == SubfieldElement::<F>::zero() {
62                    Ok(SubfieldElement::<F>::zero())
63                } else {
64                    Ok(x.invert().unwrap())
65                }
66            }
67            FieldPlaintextUnaryOp::BitExtract {
68                little_endian_bit_idx: idx,
69                signed,
70            } => {
71                let bit = if *signed && x > -x {
72                    !(-SubfieldElement::<F>::one() - x)
73                        .to_biguint()
74                        .bit(*idx as u64)
75                } else {
76                    x.to_biguint().bit(*idx as u64)
77                };
78                Ok(SubfieldElement::<F>::from(bit))
79            }
80            FieldPlaintextUnaryOp::Sqrt => {
81                let (choice, sqrt) =
82                    SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
83                if !bool::from(choice) {
84                    return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
85                }
86                Ok(sqrt)
87            }
88            FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
89        }
90    }
91}
92
93/// Enum representing binary operations on field element plaintexts.
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
95pub enum FieldPlaintextBinaryOp {
96    Add,
97    Mul,
98    EuclDiv,
99    Mod,
100    Gt,
101    Ge,
102    Eq,
103    Xor,
104    Or,
105}
106
107impl FieldPlaintextBinaryOp {
108    pub fn eval<F: FieldExtension>(
109        &self,
110        x: SubfieldElement<F>,
111        y: SubfieldElement<F>,
112        label: Label,
113    ) -> Result<SubfieldElement<F>, AbortError> {
114        match self {
115            FieldPlaintextBinaryOp::Add => Ok(x + y),
116            FieldPlaintextBinaryOp::Mul => Ok(x * y),
117            FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
118            FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
119            FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
120            FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
121            FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
122            FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
123            FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
124        }
125    }
126}
127
128pub(crate) fn euclidean_division<F: FieldExtension>(
129    x: SubfieldElement<F>,
130    y: SubfieldElement<F>,
131    label: Label,
132) -> Result<SubfieldElement<F>, AbortError> {
133    if y == SubfieldElement::<F>::zero() {
134        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
135    }
136
137    // Convert to BigUint
138    let x = x.to_biguint();
139    let y = y.to_biguint();
140
141    let div = (x / y).to_bytes_be();
142    // Pad with zeroes as big-endian
143    let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
144        .chain(div)
145        .collect::<Vec<_>>();
146
147    Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
148}
149
150fn modulo<F: FieldExtension>(
151    x: SubfieldElement<F>,
152    y: SubfieldElement<F>,
153    label: Label,
154) -> Result<SubfieldElement<F>, AbortError> {
155    if y == SubfieldElement::<F>::zero() {
156        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
157    }
158
159    // Convert to BigUint
160    let x = x.to_biguint();
161    let y = y.to_biguint();
162
163    let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
164    // Pad with zeroes as big-endian
165    let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
166        .chain(modulo)
167        .collect::<Vec<_>>();
168
169    Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
170}
171
172/// Enum representing unary operations on a field share.
173#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
174pub enum FieldShareUnaryOp {
175    /// Negation of a field share.
176    Neg,
177    /// Multiplicative inverse of a field share.
178    MulInverse,
179    /// Opens a field share to reveal the underlying value.
180    Open,
181    /// Checks if the field share is zero, returning a plaintext value.
182    IsZero,
183}
184
185/// Enum representing binary operations on field shares. This includes the case where the second
186/// operand is a plaintext value.
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
188pub enum FieldShareBinaryOp {
189    /// Addition of two field shares.
190    Add,
191    /// Multiplication of two field shares.
192    Mul,
193}
194
195/// Enum representing unary operations on binary shares
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
197pub enum BitShareUnaryOp {
198    /// NOT operation
199    Not,
200    /// Opens a bit share to reveal the underlying value.
201    Open,
202}
203
204/// Enum representing binary operations on binary shares.
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
206pub enum BitShareBinaryOp {
207    /// Exclusive OR operation on two bit shares.
208    Xor,
209    /// OR operation on two bit shares.
210    Or,
211    /// AND operation on two bit shares.
212    And,
213}
214
215/// Enum representing unary operations on plaintext points.
216#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
217pub enum PointPlaintextUnaryOp {
218    /// Negation of a point.
219    Neg,
220}
221
222impl PointPlaintextUnaryOp {
223    pub fn eval<C: Curve>(&self, x: Point<C>) -> Result<Point<C>, AbortError> {
224        match self {
225            PointPlaintextUnaryOp::Neg => Ok(-x),
226        }
227    }
228}
229
230/// Enum representing binary operations on plaintext points/scalars.
231#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
232pub enum PointPlaintextBinaryOp {
233    /// Addition of two plaintext points.
234    Add,
235    /// Multiplication of a plaintext point by a plaintext scalar.
236    ScalarMul,
237}
238
239impl PointPlaintextBinaryOp {
240    pub fn eval<C: Curve>(&self, x: Point<C>, y: Point<C>) -> Result<Point<C>, AbortError> {
241        match self {
242            PointPlaintextBinaryOp::Add => Ok(x + y),
243            PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
244                "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
245            )),
246        }
247    }
248}
249
250/// Enum representing unary operations on point shares.
251#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
252pub enum PointShareUnaryOp {
253    /// Negation of a point share.
254    Neg,
255    /// Opens a point share to reveal the underlying value.
256    Open,
257    /// Checks if the point share is zero, returning a plaintext value.
258    IsZero,
259}
260
261/// Enum representing binary operations on point shares.
262#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
263pub enum PointShareBinaryOp {
264    /// Addition of two point shares.
265    Add,
266    /// Multiplication of a point share by a scalar.
267    ScalarMul,
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271#[serde(bound(
272    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
273    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
274))]
275#[derive(Hash)]
276pub enum Input<C: Curve> {
277    SecretPlaintext {
278        inputer: PeerNumber,
279        algebraic_type: AlgebraicType,
280        batched: Batched,
281    },
282    Share {
283        algebraic_type: AlgebraicType,
284        batched: Batched,
285    },
286    RandomShare {
287        algebraic_type: AlgebraicType,
288        batched: Batched,
289    },
290    Scalar(ScalarPlaintext<C>),
291    ScalarBatch(ScalarPlaintextBatch<C>),
292    BaseField(BaseFieldPlaintext<C>),
293    BaseFieldBatch(BaseFieldPlaintextBatch<C>),
294    Mersenne107(Mersenne107Plaintext),
295    Mersenne107Batch(Mersenne107PlaintextBatch),
296    Bit(BitPlaintext),
297    BitBatch(BitPlaintextBatch),
298    Point(PointPlaintext<C>),
299    PointBatch(PointPlaintextBatch<C>),
300    ElGamalCiphertext {
301        c: PointPlaintext<C>,
302        r: PointPlaintext<C>,
303    },
304}
305
306impl<C: Curve> Input<C> {
307    pub fn batched(&self) -> Batched {
308        match self {
309            Input::SecretPlaintext { batched, .. } => *batched,
310            Input::Share { batched, .. } => *batched,
311            Input::RandomShare { batched, .. } => *batched,
312            Input::ScalarBatch(input) => Batched::Yes(input.len()),
313            Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
314            Input::Mersenne107Batch(input) => Batched::Yes(input.len()),
315            Input::BitBatch(input) => Batched::Yes(input.len()),
316            Input::PointBatch(input) => Batched::Yes(input.len()),
317            Input::ElGamalCiphertext { .. } => Batched::No,
318            Input::Scalar { .. } => Batched::No,
319            Input::BaseField { .. } => Batched::No,
320            Input::Mersenne107 { .. } => Batched::No,
321            Input::Bit { .. } => Batched::No,
322            Input::Point { .. } => Batched::No,
323        }
324    }
325}
326
327/// Enum representing whether a gate is batched or not.
328#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
329pub enum Batched {
330    Yes(usize),
331    No,
332}
333
334impl Batched {
335    pub fn count(&self) -> usize {
336        match self {
337            Batched::Yes(count) => *count,
338            Batched::No => 1,
339        }
340    }
341
342    pub fn is_batched(&self) -> bool {
343        match self {
344            Batched::Yes(_) => true,
345            Batched::No => false,
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use primitives::algebra::{
353        elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
354        field::SubfieldElement,
355    };
356
357    use super::*;
358
359    #[test]
360    fn test_scalar_unary_op() {
361        let mut rng = rand::thread_rng();
362        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
363        let label = Label::task(0);
364        let neg = FieldPlaintextUnaryOp::Neg;
365        let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
366
367        assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
368        assert_eq!(
369            mul_inverse.eval::<ScalarField<C>>(label, x),
370            Ok(x.invert().unwrap())
371        );
372    }
373
374    #[test]
375    fn test_scalar_binary_op() {
376        let mut rng = rand::thread_rng();
377        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
378        let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
379        let label = Label::task(0);
380
381        let add = FieldPlaintextBinaryOp::Add;
382        let mul = FieldPlaintextBinaryOp::Mul;
383        let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
384        let modulo_op = FieldPlaintextBinaryOp::Mod;
385        let gt = FieldPlaintextBinaryOp::Gt;
386        let ge = FieldPlaintextBinaryOp::Ge;
387        let eq = FieldPlaintextBinaryOp::Eq;
388
389        assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
390        assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
391        assert_eq!(
392            eucl_div.eval::<ScalarField<C>>(x, y, label),
393            euclidean_division::<ScalarField<C>>(x, y, label)
394        );
395        assert_eq!(
396            modulo_op.eval::<ScalarField<C>>(x, y, label),
397            modulo::<ScalarField<C>>(x, y, label)
398        );
399        assert_eq!(
400            gt.eval::<ScalarField<C>>(x, y, label),
401            Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
402        );
403        assert_eq!(
404            ge.eval::<ScalarField<C>>(x, y, label),
405            Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
406        );
407        assert_eq!(
408            eq.eval::<ScalarField<C>>(x, y, label),
409            Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
410        );
411    }
412
413    #[test]
414    fn test_boolean_binary_op() {
415        let and = FieldPlaintextBinaryOp::Mul;
416        let or = FieldPlaintextBinaryOp::Or;
417        let xor = FieldPlaintextBinaryOp::Xor;
418        let label = Label::task(0);
419        for bool_x in [false, true] {
420            for bool_y in [false, true] {
421                let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
422                let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
423                assert_eq!(
424                    and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
425                    Ok((bool_x && bool_y).into())
426                );
427                assert_eq!(
428                    or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
429                    Ok((bool_x || bool_y).into())
430                );
431                assert_eq!(
432                    xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
433                    Ok((bool_x ^ bool_y).into())
434                );
435            }
436        }
437    }
438
439    #[test]
440    fn test_euclidian_division() {
441        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
442        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
443        let label = Label::task(0);
444
445        let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
446        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
447    }
448
449    #[test]
450    fn test_modulo() {
451        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
452        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
453        let label = Label::task(0);
454
455        let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
456        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
457    }
458
459    #[test]
460    fn test_signed_bit_extract() {
461        let x = -Scalar::<C>::from(9u32);
462        let label = Label::task(0);
463        for i in 0..5 {
464            let op = FieldPlaintextUnaryOp::BitExtract {
465                little_endian_bit_idx: i,
466                signed: true,
467            };
468            let result = op.eval::<ScalarField<C>>(label, x);
469            assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
470        }
471    }
472
473    #[test]
474    fn test_sqrt() {
475        let mut rng = rand::thread_rng();
476        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
477        let label = Label::task(0);
478        let result = FieldPlaintextUnaryOp::Sqrt
479            .eval::<ScalarField<C>>(label, x * x)
480            .unwrap();
481
482        assert_eq!(result * result, x * x)
483    }
484
485    #[test]
486    fn test_pow() {
487        let mut rng = rand::thread_rng();
488        let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
489        let label = Label::task(0);
490        let five = BoxedUint::from(vec![5u64]);
491        let five_inv = BoxedUint::from(vec![
492            14757395258967641281,
493            14757395258967641292,
494            14757395258967641292,
495            5534023222112865484,
496        ]);
497        let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
498            .eval::<BaseField<C>>(label, x)
499            .unwrap();
500        let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
501            .eval::<BaseField<C>>(label, x_pow_5)
502            .unwrap();
503
504        assert_eq!(x_again, x)
505    }
506}