core_utils/circuit/
ops.rs

1use std::iter::repeat_n;
2
3use ff::Field;
4use num_traits::{One, Zero};
5use primitives::{
6    algebra::{
7        elliptic_curve::{Curve, Point, Scalar},
8        field::{FieldExtension, SubfieldElement},
9        BoxedUint,
10    },
11    types::PeerNumber,
12};
13use serde::{Deserialize, Serialize};
14use typenum::Unsigned;
15
16use crate::{
17    circuit::{
18        AlgebraicType,
19        BaseFieldPlaintext,
20        BaseFieldPlaintextBatch,
21        BitPlaintext,
22        BitPlaintextBatch,
23        PointPlaintext,
24        PointPlaintextBatch,
25        ScalarPlaintext,
26        ScalarPlaintextBatch,
27    },
28    errors::{AbortError, FaultyPeer},
29    types::Label,
30};
31
32/// Enum representing unary operations on field element plaintexts.
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub enum FieldPlaintextUnaryOp {
35    Neg,
36    // Computes the multiplicative inverse. Note: for 0 we return 0.
37    MulInverse,
38    // Extracts a specified bit from a field element
39    BitExtract {
40        little_endian_bit_idx: u16,
41        signed: bool,
42    },
43    Sqrt,
44    Pow {
45        exp: BoxedUint,
46    },
47}
48
49impl FieldPlaintextUnaryOp {
50    // TODO: see if returning CtOption could make more sense
51    pub fn eval<F: FieldExtension>(
52        &self,
53        label: Label,
54        x: SubfieldElement<F>,
55    ) -> Result<SubfieldElement<F>, AbortError> {
56        match self {
57            FieldPlaintextUnaryOp::Neg => Ok(-x),
58            FieldPlaintextUnaryOp::MulInverse => {
59                if x == SubfieldElement::<F>::zero() {
60                    Ok(SubfieldElement::<F>::zero())
61                } else {
62                    Ok(x.invert().unwrap())
63                }
64            }
65            FieldPlaintextUnaryOp::BitExtract {
66                little_endian_bit_idx: idx,
67                signed,
68            } => {
69                let bit = if *signed && x > -x {
70                    !(-SubfieldElement::<F>::one() - x)
71                        .to_biguint()
72                        .bit(*idx as u64)
73                } else {
74                    x.to_biguint().bit(*idx as u64)
75                };
76                Ok(SubfieldElement::<F>::from(bit))
77            }
78            FieldPlaintextUnaryOp::Sqrt => {
79                let (choice, sqrt) =
80                    SubfieldElement::<F>::sqrt_ratio(&x, &SubfieldElement::<F>::one());
81                if !bool::from(choice) {
82                    return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
83                }
84                Ok(sqrt)
85            }
86            FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
87        }
88    }
89}
90
91/// Enum representing binary operations on field element plaintexts.
92#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
93pub enum FieldPlaintextBinaryOp {
94    Add,
95    Mul,
96    EuclDiv,
97    Mod,
98    Gt,
99    Ge,
100    Eq,
101    Xor,
102    Or,
103}
104
105impl FieldPlaintextBinaryOp {
106    pub fn eval<F: FieldExtension>(
107        &self,
108        x: SubfieldElement<F>,
109        y: SubfieldElement<F>,
110        label: Label,
111    ) -> Result<SubfieldElement<F>, AbortError> {
112        match self {
113            FieldPlaintextBinaryOp::Add => Ok(x + y),
114            FieldPlaintextBinaryOp::Mul => Ok(x * y),
115            FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
116            FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
117            FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
118            FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
119            FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
120            FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
121            FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
122        }
123    }
124}
125
126pub(crate) fn euclidean_division<F: FieldExtension>(
127    x: SubfieldElement<F>,
128    y: SubfieldElement<F>,
129    label: Label,
130) -> Result<SubfieldElement<F>, AbortError> {
131    if y == SubfieldElement::<F>::zero() {
132        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
133    }
134
135    // Convert to BigUint
136    let x = x.to_biguint();
137    let y = y.to_biguint();
138
139    let div = (x / y).to_bytes_be();
140    // Pad with zeroes as big-endian
141    let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
142        .chain(div)
143        .collect::<Vec<_>>();
144
145    Ok(SubfieldElement::<F>::from_be_bytes(&div).unwrap())
146}
147
148fn modulo<F: FieldExtension>(
149    x: SubfieldElement<F>,
150    y: SubfieldElement<F>,
151    label: Label,
152) -> Result<SubfieldElement<F>, AbortError> {
153    if y == SubfieldElement::<F>::zero() {
154        return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
155    }
156
157    // Convert to BigUint
158    let x = x.to_biguint();
159    let y = y.to_biguint();
160
161    let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
162    // Pad with zeroes as big-endian
163    let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
164        .chain(modulo)
165        .collect::<Vec<_>>();
166
167    Ok(SubfieldElement::<F>::from_be_bytes(&modulo).unwrap())
168}
169
170/// Enum representing unary operations on a field share.
171#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
172pub enum FieldShareUnaryOp {
173    /// Negation of a field share.
174    Neg,
175    /// Multiplicative inverse of a field share.
176    MulInverse,
177    /// Opens a field share to reveal the underlying value.
178    Open,
179    /// Checks if the field share is zero, returning a plaintext value.
180    IsZero,
181}
182
183/// Enum representing binary operations on field shares. This includes the case where the second
184/// operand is a plaintext value.
185#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
186pub enum FieldShareBinaryOp {
187    /// Addition of two field shares.
188    Add,
189    /// Multiplication of two field shares.
190    Mul,
191}
192
193/// Enum representing unary operations on binary shares
194#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
195pub enum BitShareUnaryOp {
196    /// NOT operation
197    Not,
198    /// Opens a bit share to reveal the underlying value.
199    Open,
200}
201
202/// Enum representing binary operations on binary shares.
203#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
204pub enum BitShareBinaryOp {
205    /// Exclusive OR operation on two bit shares.
206    Xor,
207    /// OR operation on two bit shares.
208    Or,
209    /// AND operation on two bit shares.
210    And,
211}
212
213/// Enum representing unary operations on point shares.
214#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
215pub enum PointShareUnaryOp {
216    /// Negation of a point share.
217    Neg,
218    /// Opens a point share to reveal the underlying value.
219    Open,
220    /// Checks if the point share is zero, returning a plaintext value.
221    IsZero,
222}
223
224#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
225#[serde(bound(
226    serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
227    deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
228))]
229pub enum Input<C: Curve> {
230    SecretPlaintext {
231        inputer: PeerNumber,
232        algebraic_type: AlgebraicType,
233        batched: Batched,
234    },
235    Share {
236        algebraic_type: AlgebraicType,
237        batched: Batched,
238    },
239    RandomShare {
240        algebraic_type: AlgebraicType,
241        batched: Batched,
242    },
243    Scalar(ScalarPlaintext<C>),
244    ScalarBatch(ScalarPlaintextBatch<C>),
245    BaseField(BaseFieldPlaintext<C>),
246    BaseFieldBatch(BaseFieldPlaintextBatch<C>),
247    Bit(BitPlaintext),
248    BitBatch(BitPlaintextBatch),
249    Point(PointPlaintext<C>),
250    PointBatch(PointPlaintextBatch<C>),
251    ElGamalCiphertext {
252        c: PointPlaintext<C>,
253        r: PointPlaintext<C>,
254    },
255}
256
257impl<C: Curve> Input<C> {
258    pub fn batched(&self) -> Batched {
259        match self {
260            Input::SecretPlaintext { batched, .. } => *batched,
261            Input::Share { batched, .. } => *batched,
262            Input::RandomShare { batched, .. } => *batched,
263            Input::ScalarBatch(input) => Batched::Yes(input.len()),
264            Input::BaseFieldBatch(input) => Batched::Yes(input.len()),
265            Input::BitBatch(input) => Batched::Yes(input.len()),
266            Input::PointBatch(input) => Batched::Yes(input.len()),
267            Input::ElGamalCiphertext { .. } => Batched::No,
268            Input::Scalar { .. } => Batched::No,
269            Input::BaseField { .. } => Batched::No,
270            Input::Bit { .. } => Batched::No,
271            Input::Point { .. } => Batched::No,
272        }
273    }
274}
275
276/// Enum representing binary operations on point shares.
277#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
278pub enum PointShareBinaryOp {
279    /// Addition of two point shares.
280    Add,
281    /// Multiplication of a point share by a scalar.
282    ScalarMul,
283}
284
285/// Enum representing whether a gate is batched or not.
286#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
287pub enum Batched {
288    Yes(usize),
289    No,
290}
291
292impl Batched {
293    pub fn count(&self) -> usize {
294        match self {
295            Batched::Yes(count) => *count,
296            Batched::No => 1,
297        }
298    }
299
300    pub fn is_batched(&self) -> bool {
301        match self {
302            Batched::Yes(_) => true,
303            Batched::No => false,
304        }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use primitives::algebra::{
311        elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
312        field::SubfieldElement,
313    };
314
315    use super::*;
316
317    #[test]
318    fn test_scalar_unary_op() {
319        let mut rng = rand::thread_rng();
320        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
321        let label = Label::task(0);
322        let neg = FieldPlaintextUnaryOp::Neg;
323        let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
324
325        assert_eq!(neg.eval::<ScalarField<C>>(label, x), Ok(-x));
326        assert_eq!(
327            mul_inverse.eval::<ScalarField<C>>(label, x),
328            Ok(x.invert().unwrap())
329        );
330    }
331
332    #[test]
333    fn test_scalar_binary_op() {
334        let mut rng = rand::thread_rng();
335        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
336        let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
337        let label = Label::task(0);
338
339        let add = FieldPlaintextBinaryOp::Add;
340        let mul = FieldPlaintextBinaryOp::Mul;
341        let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
342        let modulo_op = FieldPlaintextBinaryOp::Mod;
343        let gt = FieldPlaintextBinaryOp::Gt;
344        let ge = FieldPlaintextBinaryOp::Ge;
345        let eq = FieldPlaintextBinaryOp::Eq;
346
347        assert_eq!(add.eval::<ScalarField<C>>(x, y, label), Ok(x + y));
348        assert_eq!(mul.eval::<ScalarField<C>>(x, y, label), Ok(x * y));
349        assert_eq!(
350            eucl_div.eval::<ScalarField<C>>(x, y, label),
351            euclidean_division::<ScalarField<C>>(x, y, label)
352        );
353        assert_eq!(
354            modulo_op.eval::<ScalarField<C>>(x, y, label),
355            modulo::<ScalarField<C>>(x, y, label)
356        );
357        assert_eq!(
358            gt.eval::<ScalarField<C>>(x, y, label),
359            Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
360        );
361        assert_eq!(
362            ge.eval::<ScalarField<C>>(x, y, label),
363            Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
364        );
365        assert_eq!(
366            eq.eval::<ScalarField<C>>(x, y, label),
367            Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
368        );
369    }
370
371    #[test]
372    fn test_boolean_binary_op() {
373        let and = FieldPlaintextBinaryOp::Mul;
374        let or = FieldPlaintextBinaryOp::Or;
375        let xor = FieldPlaintextBinaryOp::Xor;
376        let label = Label::task(0);
377        for bool_x in [false, true] {
378            for bool_y in [false, true] {
379                let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
380                let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
381                assert_eq!(
382                    and.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
383                    Ok((bool_x && bool_y).into())
384                );
385                assert_eq!(
386                    or.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
387                    Ok((bool_x || bool_y).into())
388                );
389                assert_eq!(
390                    xor.eval::<ScalarField<C>>(scalar_x, scalar_y, label),
391                    Ok((bool_x ^ bool_y).into())
392                );
393            }
394        }
395    }
396
397    #[test]
398    fn test_euclidian_division() {
399        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
400        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
401        let label = Label::task(0);
402
403        let result = euclidean_division::<ScalarField<C>>(x, y, label).unwrap();
404        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
405    }
406
407    #[test]
408    fn test_modulo() {
409        let x = SubfieldElement::<ScalarField<C>>::from(37u32);
410        let y = SubfieldElement::<ScalarField<C>>::from(12u32);
411        let label = Label::task(0);
412
413        let result = modulo::<ScalarField<C>>(x, y, label).unwrap();
414        assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
415    }
416
417    #[test]
418    fn test_signed_bit_extract() {
419        let x = -Scalar::<C>::from(9u32);
420        let label = Label::task(0);
421        for i in 0..5 {
422            let op = FieldPlaintextUnaryOp::BitExtract {
423                little_endian_bit_idx: i,
424                signed: true,
425            };
426            let result = op.eval::<ScalarField<C>>(label, x);
427            assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
428        }
429    }
430
431    #[test]
432    fn test_sqrt() {
433        let mut rng = rand::thread_rng();
434        let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
435        let label = Label::task(0);
436        let result = FieldPlaintextUnaryOp::Sqrt
437            .eval::<ScalarField<C>>(label, x * x)
438            .unwrap();
439
440        assert_eq!(result * result, x * x)
441    }
442
443    #[test]
444    fn test_pow() {
445        let mut rng = rand::thread_rng();
446        let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
447        let label = Label::task(0);
448        let five = BoxedUint::from(vec![5u64]);
449        let five_inv = BoxedUint::from(vec![
450            14757395258967641281,
451            14757395258967641292,
452            14757395258967641292,
453            5534023222112865484,
454        ]);
455        let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
456            .eval::<BaseField<C>>(label, x)
457            .unwrap();
458        let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
459            .eval::<BaseField<C>>(label, x_pow_5)
460            .unwrap();
461
462        assert_eq!(x_again, x)
463    }
464}