Skip to main content

core_utils/circuit/v2/
ops.rs

1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6    algebra::{
7        elliptic_curve::{BaseFieldElement, Curve, Point, Scalar},
8        field::{subfield_element::Mersenne107Element, Bit, FieldExtension, SubfieldElement},
9        BoxedUint,
10    },
11    types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15use wincode::{SchemaRead, SchemaWrite};
16
17use crate::{
18    circuit::{errors::BatchSizeError, AlgebraicType, BatchSize, GateIndex, ShareOrPlaintext},
19    errors::{AbortError, FaultyPeer},
20};
21
22/// Enum representing unary operations on field element plaintexts.
23#[derive(
24    Debug,
25    Clone,
26    PartialEq,
27    Eq,
28    Hash,
29    Serialize,
30    Deserialize,
31    SchemaRead,
32    SchemaWrite,
33    PartialOrd,
34    Ord,
35)]
36#[repr(C)]
37pub enum FieldPlaintextUnaryOp {
38    Neg,
39    // Computes the multiplicative inverse. Note: for 0 we return 0.
40    MulInverse,
41    // Extracts a specified bit from a field element
42    BitExtract {
43        little_endian_bit_idx: u16,
44        signed: bool,
45    },
46    Sqrt,
47    Pow {
48        exp: BoxedUint,
49    },
50}
51
52impl FieldPlaintextUnaryOp {
53    // TODO: see if returning CtOption could make more sense
54    pub fn eval<F: FieldExtension>(
55        &self,
56        label: GateIndex,
57        x: &SubfieldElement<F>,
58    ) -> Result<SubfieldElement<F>, AbortError> {
59        match self {
60            FieldPlaintextUnaryOp::Neg => Ok(-x),
61            FieldPlaintextUnaryOp::MulInverse => {
62                Ok(x.invert().unwrap_or(SubfieldElement::<F>::zero()))
63            }
64            FieldPlaintextUnaryOp::BitExtract {
65                little_endian_bit_idx: idx,
66                signed,
67            } => {
68                let bit = if *signed && *x > -x {
69                    !(-SubfieldElement::<F>::one() - x)
70                        .to_biguint()
71                        .bit(*idx as u64)
72                } else {
73                    x.to_biguint().bit(*idx as u64)
74                };
75                Ok(SubfieldElement::<F>::from(bit))
76            }
77            FieldPlaintextUnaryOp::Sqrt => {
78                let (choice, sqrt) =
79                    SubfieldElement::<F>::sqrt_ratio(x, &SubfieldElement::<F>::one());
80                if !bool::from(choice) {
81                    return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
82                }
83                Ok(sqrt)
84            }
85            FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
86        }
87    }
88}
89
90/// Enum representing binary operations on field element plaintexts.
91#[derive(
92    Debug,
93    Clone,
94    Copy,
95    PartialEq,
96    Eq,
97    Hash,
98    Serialize,
99    Deserialize,
100    SchemaRead,
101    SchemaWrite,
102    PartialOrd,
103    Ord,
104)]
105#[repr(C)]
106pub enum FieldPlaintextBinaryOp {
107    Add,
108    Mul,
109    EuclDiv,
110    Mod,
111    Gt,
112    Ge,
113    Eq,
114    Xor,
115    Or,
116}
117
118impl FieldPlaintextBinaryOp {
119    pub fn eval<F: FieldExtension>(
120        &self,
121        x: &SubfieldElement<F>,
122        y: &SubfieldElement<F>,
123        label: GateIndex,
124    ) -> Result<SubfieldElement<F>, AbortError> {
125        match self {
126            FieldPlaintextBinaryOp::Add => Ok(x + y),
127            FieldPlaintextBinaryOp::Mul => Ok(x * y),
128            FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
129            FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
130            FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
131            FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
132            FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
133            FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
134            FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
135        }
136    }
137}
138
139pub(crate) fn euclidean_division<F: FieldExtension>(
140    x: &SubfieldElement<F>,
141    y: &SubfieldElement<F>,
142    label: GateIndex,
143) -> Result<SubfieldElement<F>, AbortError> {
144    if *y == SubfieldElement::<F>::zero() {
145        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
146    }
147
148    // Convert to BigUint
149    let x = x.to_biguint();
150    let y = y.to_biguint();
151
152    let div = (x / y).to_bytes_be();
153    // Pad with zeroes as big-endian
154    let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
155        .chain(div)
156        .collect::<Vec<_>>();
157
158    Ok(SubfieldElement::<F>::from_be_bytes(&div)?)
159}
160
161fn modulo<F: FieldExtension>(
162    x: &SubfieldElement<F>,
163    y: &SubfieldElement<F>,
164    label: GateIndex,
165) -> Result<SubfieldElement<F>, AbortError> {
166    if *y == SubfieldElement::<F>::zero() {
167        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
168    }
169
170    // Convert to BigUint
171    let x = x.to_biguint();
172    let y = y.to_biguint();
173
174    let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
175    // Pad with zeroes as big-endian
176    let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
177        .chain(modulo)
178        .collect::<Vec<_>>();
179
180    Ok(SubfieldElement::<F>::from_be_bytes(&modulo)?)
181}
182
183/// Enum representing unary operations on a field share.
184#[derive(
185    Debug,
186    Clone,
187    Copy,
188    PartialEq,
189    Eq,
190    Hash,
191    Serialize,
192    Deserialize,
193    SchemaRead,
194    SchemaWrite,
195    PartialOrd,
196    Ord,
197)]
198#[repr(C)]
199pub enum FieldShareUnaryOp {
200    /// Negation of a field share.
201    Neg,
202    /// Multiplicative inverse of a field share.
203    MulInverse,
204    /// Opens a field share to reveal the underlying value.
205    Open,
206    /// Checks if the field share is zero, returning a plaintext value.
207    IsZero,
208}
209
210/// Enum representing binary operations on field shares. This includes the case where the second
211/// operand is a plaintext value.
212#[derive(
213    Debug,
214    Clone,
215    Copy,
216    PartialEq,
217    Eq,
218    Hash,
219    Serialize,
220    Deserialize,
221    SchemaRead,
222    SchemaWrite,
223    PartialOrd,
224    Ord,
225)]
226#[repr(C)]
227pub enum FieldShareBinaryOp {
228    /// Addition of two field shares.
229    Add,
230    /// Multiplication of two field shares.
231    Mul,
232}
233
234/// Enum representing unary operations on binary shares
235#[derive(
236    Debug,
237    Clone,
238    Copy,
239    PartialEq,
240    Eq,
241    Hash,
242    Serialize,
243    Deserialize,
244    SchemaRead,
245    SchemaWrite,
246    PartialOrd,
247    Ord,
248)]
249#[repr(C)]
250pub enum BitShareUnaryOp {
251    /// NOT operation
252    Not,
253    /// Opens a bit share to reveal the underlying value.
254    Open,
255}
256
257/// Enum representing binary operations on binary shares.
258#[derive(
259    Debug,
260    Clone,
261    Copy,
262    PartialEq,
263    Eq,
264    Hash,
265    Serialize,
266    Deserialize,
267    SchemaRead,
268    SchemaWrite,
269    PartialOrd,
270    Ord,
271)]
272#[repr(C)]
273pub enum BitShareBinaryOp {
274    /// Exclusive OR operation on two bit shares.
275    Xor,
276    /// OR operation on two bit shares.
277    Or,
278    /// AND operation on two bit shares.
279    And,
280}
281
282/// Enum representing unary operations on binary shares
283#[derive(
284    Debug,
285    Clone,
286    Copy,
287    PartialEq,
288    Eq,
289    Hash,
290    Serialize,
291    Deserialize,
292    SchemaRead,
293    SchemaWrite,
294    PartialOrd,
295    Ord,
296)]
297#[repr(C)]
298pub enum BitPlaintextUnaryOp {
299    /// NOT operation
300    Not,
301}
302
303impl BitPlaintextUnaryOp {
304    pub fn eval(&self, x: Bit) -> Bit {
305        match self {
306            BitPlaintextUnaryOp::Not => Bit::ONE - x,
307        }
308    }
309}
310
311/// Enum representing binary operations on binary shares.
312#[derive(
313    Debug,
314    Clone,
315    Copy,
316    PartialEq,
317    Eq,
318    Hash,
319    Serialize,
320    Deserialize,
321    SchemaRead,
322    SchemaWrite,
323    PartialOrd,
324    Ord,
325)]
326#[repr(C)]
327pub enum BitPlaintextBinaryOp {
328    /// Exclusive OR operation on two bits.
329    Xor,
330    /// OR operation on two bits.
331    Or,
332    /// AND operation on two bits.
333    And,
334}
335
336impl BitPlaintextBinaryOp {
337    pub fn eval(&self, x: Bit, y: Bit) -> Bit {
338        match self {
339            BitPlaintextBinaryOp::Xor => x + y,
340            BitPlaintextBinaryOp::Or => x + y - x * y,
341            BitPlaintextBinaryOp::And => x * y,
342        }
343    }
344}
345
346/// Enum representing unary operations on plaintext points.
347#[derive(
348    Debug,
349    Clone,
350    Copy,
351    PartialEq,
352    Eq,
353    Hash,
354    Serialize,
355    Deserialize,
356    SchemaRead,
357    SchemaWrite,
358    PartialOrd,
359    Ord,
360)]
361#[repr(C)]
362pub enum PointPlaintextUnaryOp {
363    /// Negation of a point.
364    Neg,
365}
366
367impl PointPlaintextUnaryOp {
368    pub fn eval<C: Curve>(&self, x: &Point<C>) -> Result<Point<C>, AbortError> {
369        match self {
370            PointPlaintextUnaryOp::Neg => Ok(-x),
371        }
372    }
373}
374
375/// Enum representing binary operations on plaintext points/scalars.
376#[derive(
377    Debug,
378    Clone,
379    Copy,
380    PartialEq,
381    Eq,
382    Hash,
383    Serialize,
384    Deserialize,
385    SchemaRead,
386    SchemaWrite,
387    PartialOrd,
388    Ord,
389)]
390#[repr(C)]
391pub enum PointPlaintextBinaryOp {
392    /// Addition of two plaintext points.
393    Add,
394    /// Multiplication of a plaintext point by a plaintext scalar.
395    ScalarMul,
396}
397
398impl PointPlaintextBinaryOp {
399    pub fn eval<C: Curve>(&self, x: &Point<C>, y: &Point<C>) -> Result<Point<C>, AbortError> {
400        match self {
401            PointPlaintextBinaryOp::Add => Ok(x + y),
402            PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
403                "PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
404            )),
405        }
406    }
407}
408
409/// Enum representing unary operations on point shares.
410#[derive(
411    Debug,
412    Clone,
413    Copy,
414    PartialEq,
415    Eq,
416    Hash,
417    Serialize,
418    Deserialize,
419    SchemaRead,
420    SchemaWrite,
421    PartialOrd,
422    Ord,
423)]
424#[repr(C)]
425pub enum PointShareUnaryOp {
426    /// Negation of a point share.
427    Neg,
428    /// Opens a point share to reveal the underlying value.
429    Open,
430    /// Checks if the point share is zero, returning a plaintext value.
431    IsZero,
432}
433
434/// Enum representing binary operations on point shares.
435#[derive(
436    Debug,
437    Clone,
438    Copy,
439    PartialEq,
440    Eq,
441    Hash,
442    Serialize,
443    Deserialize,
444    SchemaRead,
445    SchemaWrite,
446    PartialOrd,
447    Ord,
448)]
449#[repr(C)]
450pub enum PointShareBinaryOp {
451    /// Addition of two point shares.
452    Add,
453    /// Multiplication of a point share by a scalar.
454    ScalarMul,
455}
456
457/// A circuit input which can be either be a plaintext value, a secret plaintext value, or a secret
458/// share.
459#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
460#[repr(C)]
461pub enum Input {
462    Plaintext {
463        algebraic_type: AlgebraicType,
464        batch_size: BatchSize,
465    },
466    SecretPlaintext {
467        inputer: PeerNumber,
468        algebraic_type: AlgebraicType,
469        batch_size: BatchSize,
470    },
471    Share {
472        algebraic_type: AlgebraicType,
473        batch_size: BatchSize,
474    },
475}
476
477impl Input {
478    pub fn batch_size(&self) -> u32 {
479        match self {
480            Input::Plaintext { batch_size, .. }
481            | Input::SecretPlaintext { batch_size, .. }
482            | Input::Share { batch_size, .. } => *batch_size,
483        }
484    }
485
486    pub fn algebraic_type(&self) -> AlgebraicType {
487        match self {
488            Input::Plaintext { algebraic_type, .. }
489            | Input::Share { algebraic_type, .. }
490            | Input::SecretPlaintext { algebraic_type, .. } => *algebraic_type,
491        }
492    }
493
494    pub fn share_or_plaintext(&self) -> ShareOrPlaintext {
495        match self {
496            Input::SecretPlaintext { .. } | Input::Share { .. } => ShareOrPlaintext::Share,
497            Input::Plaintext { .. } => ShareOrPlaintext::Plaintext,
498        }
499    }
500}
501
502#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
503#[serde(bound(
504    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
505    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
506))]
507#[repr(C)]
508pub enum Constant<C: Curve> {
509    Scalar(Scalar<C>),
510    ScalarBatch(Vec<Scalar<C>>),
511    BaseField(BaseFieldElement<C>),
512    BaseFieldBatch(Vec<BaseFieldElement<C>>),
513    Mersenne107(Mersenne107Element),
514    Mersenne107Batch(Vec<Mersenne107Element>),
515    Bit(Bit),
516    BitBatch(Vec<Bit>),
517    Point(Point<C>),
518    PointBatch(Vec<Point<C>>),
519}
520
521impl<C: Curve> Constant<C> {
522    pub fn batch_size(&self) -> Result<u32, BatchSizeError> {
523        let n = match self {
524            Constant::ScalarBatch(v) => v.len(),
525            Constant::BaseFieldBatch(v) => v.len(),
526            Constant::Mersenne107Batch(v) => v.len(),
527            Constant::BitBatch(v) => v.len(),
528            Constant::PointBatch(v) => v.len(),
529            Constant::Scalar(_)
530            | Constant::BaseField(_)
531            | Constant::Mersenne107(_)
532            | Constant::Bit(_)
533            | Constant::Point(_) => 1,
534        };
535        if let Ok(n) = u32::try_from(n) {
536            Ok(n)
537        } else {
538            Err(BatchSizeError(n))
539        }
540    }
541
542    pub fn algebraic_type(&self) -> AlgebraicType {
543        match self {
544            Constant::Scalar(_) | Constant::ScalarBatch(_) => AlgebraicType::ScalarField,
545            Constant::BaseField(_) | Constant::BaseFieldBatch(_) => AlgebraicType::BaseField,
546            Constant::Mersenne107(_) | Constant::Mersenne107Batch(_) => AlgebraicType::Mersenne107,
547            Constant::Bit(_) | Constant::BitBatch(_) => AlgebraicType::Bit,
548            Constant::Point(_) | Constant::PointBatch(_) => AlgebraicType::Point,
549        }
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use primitives::algebra::{
556        elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
557        field::SubfieldElement,
558    };
559
560    use super::*;
561
562    #[test]
563    fn test_scalar_unary_op() {
564        let mut rng = rand::thread_rng();
565        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
566        let label = 0;
567        let neg = FieldPlaintextUnaryOp::Neg;
568        let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
569
570        assert_eq!(neg.eval::<ScalarField<C>>(label, &x), Ok(-x));
571        assert_eq!(
572            mul_inverse.eval::<ScalarField<C>>(label, &x),
573            Ok(x.invert().unwrap())
574        );
575    }
576
577    #[test]
578    fn test_scalar_binary_op() {
579        let mut rng = rand::thread_rng();
580        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
581        let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
582        let label = 0;
583
584        let add = FieldPlaintextBinaryOp::Add;
585        let mul = FieldPlaintextBinaryOp::Mul;
586        let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
587        let modulo_op = FieldPlaintextBinaryOp::Mod;
588        let gt = FieldPlaintextBinaryOp::Gt;
589        let ge = FieldPlaintextBinaryOp::Ge;
590        let eq = FieldPlaintextBinaryOp::Eq;
591
592        assert_eq!(add.eval::<ScalarField<C>>(&x, &y, label), Ok(x + y));
593        assert_eq!(mul.eval::<ScalarField<C>>(&x, &y, label), Ok(x * y));
594        assert_eq!(
595            eucl_div.eval::<ScalarField<C>>(&x, &y, label),
596            euclidean_division::<ScalarField<C>>(&x, &y, label)
597        );
598        assert_eq!(
599            modulo_op.eval::<ScalarField<C>>(&x, &y, label),
600            modulo::<ScalarField<C>>(&x, &y, label)
601        );
602        assert_eq!(
603            gt.eval::<ScalarField<C>>(&x, &y, label),
604            Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
605        );
606        assert_eq!(
607            ge.eval::<ScalarField<C>>(&x, &y, label),
608            Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
609        );
610        assert_eq!(
611            eq.eval::<ScalarField<C>>(&x, &y, label),
612            Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
613        );
614    }
615
616    #[test]
617    fn test_scalar_boolean_binary_op() {
618        let and = FieldPlaintextBinaryOp::Mul;
619        let or = FieldPlaintextBinaryOp::Or;
620        let xor = FieldPlaintextBinaryOp::Xor;
621        let label = 0;
622        for bool_x in [false, true] {
623            for bool_y in [false, true] {
624                let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
625                let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
626                assert_eq!(
627                    and.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
628                    Ok((bool_x && bool_y).into())
629                );
630                assert_eq!(
631                    or.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
632                    Ok((bool_x || bool_y).into())
633                );
634                assert_eq!(
635                    xor.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
636                    Ok((bool_x ^ bool_y).into())
637                );
638            }
639        }
640    }
641
642    #[test]
643    fn test_bit_ops() {
644        let not = BitPlaintextUnaryOp::Not;
645        for bool_x in [false, true] {
646            let x = Bit::from(bool_x);
647            assert_eq!(not.eval(x), (!bool_x).into());
648        }
649
650        let and = BitPlaintextBinaryOp::And;
651        let or = BitPlaintextBinaryOp::Or;
652        let xor = BitPlaintextBinaryOp::Xor;
653        for bool_x in [false, true] {
654            for bool_y in [false, true] {
655                let x = Bit::from(bool_x);
656                let y = Bit::from(bool_y);
657                assert_eq!(and.eval(x, y), (bool_x && bool_y).into());
658                assert_eq!(or.eval(x, y), (bool_x || bool_y).into());
659                assert_eq!(xor.eval(x, y), (bool_x ^ bool_y).into());
660            }
661        }
662    }
663
664    #[test]
665    fn test_euclidian_division() {
666        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
667        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
668        let label = 0;
669
670        let result = euclidean_division::<ScalarField<C>>(&x, &y, label).unwrap();
671        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
672    }
673
674    #[test]
675    fn test_modulo() {
676        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
677        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
678        let label = 0;
679
680        let result = modulo::<ScalarField<C>>(&x, &y, label).unwrap();
681        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
682    }
683
684    #[test]
685    fn test_signed_bit_extract() {
686        let x = -Scalar::<C>::from(9u32);
687        let label = 0;
688        for i in 0..5 {
689            let op = FieldPlaintextUnaryOp::BitExtract {
690                little_endian_bit_idx: i,
691                signed: true,
692            };
693            let result = op.eval::<ScalarField<C>>(label, &x);
694            assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
695        }
696    }
697
698    #[test]
699    fn test_sqrt() {
700        let mut rng = rand::thread_rng();
701        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
702        let label = 0;
703        let result = FieldPlaintextUnaryOp::Sqrt
704            .eval::<ScalarField<C>>(label, &(x * x))
705            .unwrap();
706
707        assert_eq!(result * result, x * x)
708    }
709
710    #[test]
711    fn test_pow() {
712        let mut rng = rand::thread_rng();
713        let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
714        let label = 0;
715        let five = BoxedUint::from(vec![5u64]);
716        let five_inv = BoxedUint::from(vec![
717            14757395258967641281,
718            14757395258967641292,
719            14757395258967641292,
720            5534023222112865484,
721        ]);
722        let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
723            .eval::<BaseField<C>>(label, &x)
724            .unwrap();
725        let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
726            .eval::<BaseField<C>>(label, &x_pow_5)
727            .unwrap();
728
729        assert_eq!(x_again, x)
730    }
731}