core_utils/circuit/
ops.rs

1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6    algebra::{
7        elliptic_curve::{Curve, Point, Scalar},
8        field::{FieldExtension, SubfieldElement},
9        BoxedUint,
10    },
11    types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17    circuit::{
18        AlgebraicType,
19        BaseFieldPlaintext,
20        BaseFieldPlaintextBatch,
21        BitPlaintext,
22        BitPlaintextBatch,
23        PointPlaintext,
24        PointPlaintextBatch,
25        ScalarPlaintext,
26        ScalarPlaintextBatch,
27    },
28    errors::{AbortError, FaultyPeer},
29    types::Label,
30};
31
32/// Enum representing unary operations on field element plaintexts.
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub enum FieldPlaintextUnaryOp {
35    Neg,
36    // Computes the multiplicative inverse. Note: for 0 we return 0.
37    MulInverse,
38    // Extracts a specified bit from a field element
39    BitExtract {
40        little_endian_bit_idx: u16,
41        signed: bool,
42    },
43    Sqrt,
44    Pow {
45        exp: BoxedUint,
46    },
47}
48
49impl FieldPlaintextUnaryOp {
50    // TODO: see if returning CtOption could make more sense
51    pub fn eval<F: FieldExtension>(
52        &self,
53        label: Label,
54        x: SubfieldElement<F>,
55    ) -> Result<SubfieldElement<F>, AbortError> {
56        match self {
57            FieldPlaintextUnaryOp::Neg => Ok(-x),
58            FieldPlaintextUnaryOp::MulInverse => {
59                if x == SubfieldElement::<F>::zero() {
60                    Ok(SubfieldElement::<F>::zero())
61                } else {
62                    Ok(x.invert().unwrap())
63                }
64            }
65            FieldPlaintextUnaryOp::BitExtract {
66                little_endian_bit_idx: idx,
67                signed,
68            } => {
69                let bit = if *signed && x > -x {
70                    !(-SubfieldElement::<F>::one() - x)
71                        .to_biguint()
72                        .bit(*idx as u64)
73                } else {
74                    x.to_biguint().bit(*idx as u64)
75                };
76                Ok(SubfieldElement::<F>::from(bit))
77            }
78            FieldPlaintextUnaryOp::Sqrt => {
79                let (choice, sqrt) =
80                    SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
81                if !bool::from(choice) {
82                    return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
83                }
84                Ok(sqrt)
85            }
86            FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
87        }
88    }
89}
90
91/// Enum representing binary operations on field element plaintexts.
92#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
93pub enum FieldPlaintextBinaryOp {
94    Add,
95    Mul,
96    EuclDiv,
97    Mod,
98    Gt,
99    Ge,
100    Eq,
101    Xor,
102    Or,
103}
104
105impl FieldPlaintextBinaryOp {
106    pub fn eval<F: FieldExtension>(
107        &self,
108        x: SubfieldElement<F>,
109        y: SubfieldElement<F>,
110        label: Label,
111    ) -> Result<SubfieldElement<F>, AbortError> {
112        match self {
113            FieldPlaintextBinaryOp::Add => Ok(x + y),
114            FieldPlaintextBinaryOp::Mul => Ok(x * y),
115            FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
116            FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
117            FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
118            FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
119            FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
120            FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
121            FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
122        }
123    }
124}
125
126pub(crate) fn euclidean_division<F: FieldExtension>(
127    x: SubfieldElement<F>,
128    y: SubfieldElement<F>,
129    label: Label,
130) -> Result<SubfieldElement<F>, AbortError> {
131    if y == SubfieldElement::<F>::zero() {
132        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
133    }
134
135    // Convert to BigUint
136    let x = x.to_biguint();
137    let y = y.to_biguint();
138
139    let div = (x / y).to_bytes_be();
140    // Pad with zeroes as big-endian
141    let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
142        .chain(div)
143        .collect::<Vec<_>>();
144
145    Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
146}
147
148fn modulo<F: FieldExtension>(
149    x: SubfieldElement<F>,
150    y: SubfieldElement<F>,
151    label: Label,
152) -> Result<SubfieldElement<F>, AbortError> {
153    if y == SubfieldElement::<F>::zero() {
154        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
155    }
156
157    // Convert to BigUint
158    let x = x.to_biguint();
159    let y = y.to_biguint();
160
161    let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
162    // Pad with zeroes as big-endian
163    let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
164        .chain(modulo)
165        .collect::<Vec<_>>();
166
167    Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
168}
169
170/// Enum representing unary operations on a field share.
171#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
172pub enum FieldShareUnaryOp {
173    /// Negation of a field share.
174    Neg,
175    /// Multiplicative inverse of a field share.
176    MulInverse,
177    /// Opens a field share to reveal the underlying value.
178    Open,
179    /// Checks if the field share is zero, returning a plaintext value.
180    IsZero,
181}
182
183/// Enum representing binary operations on field shares. This includes the case where the second
184/// operand is a plaintext value.
185#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
186pub enum FieldShareBinaryOp {
187    /// Addition of two field shares.
188    Add,
189    /// Multiplication of two field shares.
190    Mul,
191}
192
193/// Enum representing unary operations on binary shares
194#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
195pub enum BitShareUnaryOp {
196    /// NOT operation
197    Not,
198    /// Opens a bit share to reveal the underlying value.
199    Open,
200}
201
202/// Enum representing binary operations on binary shares.
203#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
204pub enum BitShareBinaryOp {
205    /// Exclusive OR operation on two bit shares.
206    Xor,
207    /// OR operation on two bit shares.
208    Or,
209    /// AND operation on two bit shares.
210    And,
211}
212
213/// Enum representing unary operations on plaintext points.
214#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
215pub enum PointPlaintextUnaryOp {
216    /// Negation of a point.
217    Neg,
218}
219
220impl PointPlaintextUnaryOp {
221    pub fn eval<C: Curve>(&self, x: Point<C>) -> Result<Point<C>, AbortError> {
222        match self {
223            PointPlaintextUnaryOp::Neg => Ok(-x),
224        }
225    }
226}
227
228/// Enum representing binary operations on plaintext points/scalars.
229#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
230pub enum PointPlaintextBinaryOp {
231    /// Addition of two plaintext points.
232    Add,
233    /// Multiplication of a plaintext point by a plaintext scalar.
234    ScalarMul,
235}
236
237impl PointPlaintextBinaryOp {
238    pub fn eval<C: Curve>(&self, x: Point<C>, y: Point<C>) -> Result<Point<C>, AbortError> {
239        match self {
240            PointPlaintextBinaryOp::Add => Ok(x + y),
241            PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
242                "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
243            )),
244        }
245    }
246}
247
248/// Enum representing unary operations on point shares.
249#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
250pub enum PointShareUnaryOp {
251    /// Negation of a point share.
252    Neg,
253    /// Opens a point share to reveal the underlying value.
254    Open,
255    /// Checks if the point share is zero, returning a plaintext value.
256    IsZero,
257}
258
259/// Enum representing binary operations on point shares.
260#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
261pub enum PointShareBinaryOp {
262    /// Addition of two point shares.
263    Add,
264    /// Multiplication of a point share by a scalar.
265    ScalarMul,
266}
267
268#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
269#[serde(bound(
270    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
271    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
272))]
273pub enum Input<C: Curve> {
274    SecretPlaintext {
275        inputer: PeerNumber,
276        algebraic_type: AlgebraicType,
277        batched: Batched,
278    },
279    Share {
280        algebraic_type: AlgebraicType,
281        batched: Batched,
282    },
283    RandomShare {
284        algebraic_type: AlgebraicType,
285        batched: Batched,
286    },
287    Scalar(ScalarPlaintext<C>),
288    ScalarBatch(ScalarPlaintextBatch<C>),
289    BaseField(BaseFieldPlaintext<C>),
290    BaseFieldBatch(BaseFieldPlaintextBatch<C>),
291    Bit(BitPlaintext),
292    BitBatch(BitPlaintextBatch),
293    Point(PointPlaintext<C>),
294    PointBatch(PointPlaintextBatch<C>),
295    ElGamalCiphertext {
296        c: PointPlaintext<C>,
297        r: PointPlaintext<C>,
298    },
299}
300
301impl<C: Curve> Input<C> {
302    pub fn batched(&self) -> Batched {
303        match self {
304            Input::SecretPlaintext { batched, .. } => *batched,
305            Input::Share { batched, .. } => *batched,
306            Input::RandomShare { batched, .. } => *batched,
307            Input::ScalarBatch(input) => Batched::Yes(input.len()),
308            Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
309            Input::BitBatch(input) => Batched::Yes(input.len()),
310            Input::PointBatch(input) => Batched::Yes(input.len()),
311            Input::ElGamalCiphertext { .. } => Batched::No,
312            Input::Scalar { .. } => Batched::No,
313            Input::BaseField { .. } => Batched::No,
314            Input::Bit { .. } => Batched::No,
315            Input::Point { .. } => Batched::No,
316        }
317    }
318}
319
320/// Enum representing whether a gate is batched or not.
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
322pub enum Batched {
323    Yes(usize),
324    No,
325}
326
327impl Batched {
328    pub fn count(&self) -> usize {
329        match self {
330            Batched::Yes(count) => *count,
331            Batched::No => 1,
332        }
333    }
334
335    pub fn is_batched(&self) -> bool {
336        match self {
337            Batched::Yes(_) => true,
338            Batched::No => false,
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use primitives::algebra::{
346        elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
347        field::SubfieldElement,
348    };
349
350    use super::*;
351
352    #[test]
353    fn test_scalar_unary_op() {
354        let mut rng = rand::thread_rng();
355        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
356        let label = Label::task(0);
357        let neg = FieldPlaintextUnaryOp::Neg;
358        let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
359
360        assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
361        assert_eq!(
362            mul_inverse.eval::<ScalarField<C>>(label, x),
363            Ok(x.invert().unwrap())
364        );
365    }
366
367    #[test]
368    fn test_scalar_binary_op() {
369        let mut rng = rand::thread_rng();
370        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
371        let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
372        let label = Label::task(0);
373
374        let add = FieldPlaintextBinaryOp::Add;
375        let mul = FieldPlaintextBinaryOp::Mul;
376        let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
377        let modulo_op = FieldPlaintextBinaryOp::Mod;
378        let gt = FieldPlaintextBinaryOp::Gt;
379        let ge = FieldPlaintextBinaryOp::Ge;
380        let eq = FieldPlaintextBinaryOp::Eq;
381
382        assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
383        assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
384        assert_eq!(
385            eucl_div.eval::<ScalarField<C>>(x, y, label),
386            euclidean_division::<ScalarField<C>>(x, y, label)
387        );
388        assert_eq!(
389            modulo_op.eval::<ScalarField<C>>(x, y, label),
390            modulo::<ScalarField<C>>(x, y, label)
391        );
392        assert_eq!(
393            gt.eval::<ScalarField<C>>(x, y, label),
394            Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
395        );
396        assert_eq!(
397            ge.eval::<ScalarField<C>>(x, y, label),
398            Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
399        );
400        assert_eq!(
401            eq.eval::<ScalarField<C>>(x, y, label),
402            Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
403        );
404    }
405
406    #[test]
407    fn test_boolean_binary_op() {
408        let and = FieldPlaintextBinaryOp::Mul;
409        let or = FieldPlaintextBinaryOp::Or;
410        let xor = FieldPlaintextBinaryOp::Xor;
411        let label = Label::task(0);
412        for bool_x in [false, true] {
413            for bool_y in [false, true] {
414                let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
415                let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
416                assert_eq!(
417                    and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
418                    Ok((bool_x && bool_y).into())
419                );
420                assert_eq!(
421                    or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
422                    Ok((bool_x || bool_y).into())
423                );
424                assert_eq!(
425                    xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
426                    Ok((bool_x ^ bool_y).into())
427                );
428            }
429        }
430    }
431
432    #[test]
433    fn test_euclidian_division() {
434        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
435        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
436        let label = Label::task(0);
437
438        let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
439        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
440    }
441
442    #[test]
443    fn test_modulo() {
444        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
445        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
446        let label = Label::task(0);
447
448        let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
449        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
450    }
451
452    #[test]
453    fn test_signed_bit_extract() {
454        let x = -Scalar::<C>::from(9u32);
455        let label = Label::task(0);
456        for i in 0..5 {
457            let op = FieldPlaintextUnaryOp::BitExtract {
458                little_endian_bit_idx: i,
459                signed: true,
460            };
461            let result = op.eval::<ScalarField<C>>(label, x);
462            assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
463        }
464    }
465
466    #[test]
467    fn test_sqrt() {
468        let mut rng = rand::thread_rng();
469        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
470        let label = Label::task(0);
471        let result = FieldPlaintextUnaryOp::Sqrt
472            .eval::<ScalarField<C>>(label, x * x)
473            .unwrap();
474
475        assert_eq!(result * result, x * x)
476    }
477
478    #[test]
479    fn test_pow() {
480        let mut rng = rand::thread_rng();
481        let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
482        let label = Label::task(0);
483        let five = BoxedUint::from(vec![5u64]);
484        let five_inv = BoxedUint::from(vec![
485            14757395258967641281,
486            14757395258967641292,
487            14757395258967641292,
488            5534023222112865484,
489        ]);
490        let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
491            .eval::<BaseField<C>>(label, x)
492            .unwrap();
493        let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
494            .eval::<BaseField<C>>(label, x_pow_5)
495            .unwrap();
496
497        assert_eq!(x_again, x)
498    }
499}