arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        circuits::traits::arithmetic_circuit::ArithmeticCircuit,
        expressions::{expr::EvalFailure, field_expr::FieldExpr},
        global_value::value::FieldValue,
    },
    traits::{Equal, Invert, Pow, Reveal},
    utils::{number::Number, used_field::UsedField},
};
use std::ops::Mul;

/// Pow by a constant.
#[derive(Clone, Debug)]
pub struct PowCircuit {
    pub exponent: Number,
    pub is_expected_non_zero: bool,
}

fn smart_pow<T: Mul<T, Output = T> + Clone + From<i32>>(a: T, b: &Number) -> T {
    assert!(*b > 0);
    let pows: Vec<T> = (0..b.bits())
        .scan(a.clone(), |c, _| {
            let res = Some(c.clone());
            *c = (c.clone()) * (c.clone());
            res
        })
        .collect();
    let mut res = T::from(1);
    for (i, pow) in pows.into_iter().enumerate() {
        if b.bit(i) {
            res = res * pow;
        }
    }
    res
}

impl PowCircuit {
    #[allow(unused)]
    pub fn new(exponent: Number, is_expected_non_zero: bool) -> Self {
        Self {
            exponent,
            is_expected_non_zero,
        }
    }
}

impl<F: UsedField> ArithmeticCircuit<F> for PowCircuit {
    fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        // on plaintext values we simply set is_expected_non_zero = true
        Ok(x.into_iter()
            .map(|val| val.pow(&self.exponent, true))
            .collect())
    }

    fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        bounds
            .into_iter()
            .map(|bound| {
                if self.exponent == 0 {
                    FieldBounds::new(F::ONE, F::ONE)
                } else {
                    smart_pow(bound, &self.exponent)
                }
            })
            .collect()
    }

    fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField,
    {
        assert!(vals.len() == 1);
        let pow = if self.exponent == 0 {
            FieldValue::<F>::from(1)
        } else if vals[0].is_plaintext() {
            FieldValue::<F>::new(FieldExpr::Pow(vals[0], self.exponent.clone(), false))
        } else if self.exponent == F::get_alpha() && self.is_expected_non_zero {
            // We compute x^alpha as follows:
            // 1. generate a random lambda and compute lambda^-alpha
            // 2. compute z = x * lambda and open z
            // 3. compute z^alpha = x^alpha * lambda^alpha
            // 4. return x^alpha = z^alpha * lambda^-alpha

            // 1.
            let lambda = FieldValue::<F>::random();
            // In order to compute lambda^-alpha we first compute lambda^-1,
            // which we then raise to the power alpha.
            let lambda_inv = lambda.invert(true);
            let lambda_pow_neg_alpha = smart_pow(lambda_inv, &F::get_alpha());

            // 2. multiplicatively mask x
            let x = vals[0];
            let z = (x * lambda).reveal();

            // 3. compute pow on the public value z
            let z_pow_alpha = FieldValue::<F>::new(FieldExpr::Pow(z, F::get_alpha(), true));

            // 4. unmask
            z_pow_alpha * lambda_pow_neg_alpha
        } else if self.exponent == F::get_alpha_inverse() {
            // We compute x^alpha_inverse as follows:
            // 1. generate a random lambda and compute lambda^-(alpha_inverse^-1) = lambda^-alpha
            // 2. compute z = x * lambda^-alpha and open z
            // 3. compute z^alpha_inverse = x^alpha_inverse * lambda^-1
            // 4. return x^alpha_inverse = z^alpha_inverse * lambda

            // 1.
            let lambda = FieldValue::<F>::random();
            // In order to compute lambda^-alpha we first compute lambda^-1,
            // which we then raise to the power alpha.
            let lambda_inv = lambda.invert(true);
            let lambda_pow_neg_alpha = smart_pow(lambda_inv, &F::get_alpha());

            // 2. multiplicatively mask x
            let x = vals[0];
            let is_zero_mask = if !self.is_expected_non_zero {
                FieldValue::from(x.eq(FieldValue::from(0)))
            } else {
                FieldValue::from(0)
            };
            let x = x + is_zero_mask;
            let z = (x * lambda_pow_neg_alpha).reveal();

            // 3. compute pow on the public value z
            let z_pow_alpha_inverse =
                FieldValue::<F>::new(FieldExpr::Pow(z, F::get_alpha_inverse(), true));

            // 4. unmask
            z_pow_alpha_inverse * lambda - is_zero_mask
        } else if self.exponent == (F::modulus() - 1) / 2 {
            // We compute x^(p-1)/2 as follows:
            // 1. generate a random lambda and compute lambda^(p-1)/2
            // 2. compute z = x * lambda and open z
            // 3. compute z^(p-1)/2 = x^(p-1)/2 * lambda^(p-1)/2
            // 4. return x^(p-1)/2 = z^(p-1)/2 * lambda^(p-1)/2

            // 1.
            let lambda = FieldValue::<F>::random();
            let lambda_pow_exponent = smart_pow(lambda, &self.exponent);

            // 2. multiplicatively mask x
            let x = vals[0];
            let is_zero_mask = if !self.is_expected_non_zero {
                FieldValue::from(x.eq(FieldValue::from(0)))
            } else {
                FieldValue::from(0)
            };
            let x = x + is_zero_mask;
            let z = (x * lambda).reveal();

            // 3. compute pow on the public value z
            let z_pow_exponent =
                FieldValue::<F>::new(FieldExpr::Pow(z, self.exponent.clone(), true));

            // 4. unmask
            z_pow_exponent * lambda_pow_exponent - is_zero_mask
        } else if F::modulus() % 4 == 3 && self.exponent == (F::modulus() - 3) / 4 {
            // We compute x^(p-3)/4 as follows:
            // 1. generate a random lambda and compute lambda^(p-1 - (p-3)/4) = lambda^(3p-1)/4
            // 2. compute z = x * lambda and open z
            // 3. compute z^(p-3)/4 = x^(p-3)/4 * lambda^(p-3)/4
            // 4. return x^(p-3)/4 = z^(p-3)/4 * lambda^(3p-1)/4

            // 1.
            let lambda = FieldValue::<F>::random();
            let lambda_pow = smart_pow(lambda, &((3 * F::modulus() - 1) / 4));

            // 2. multiplicatively mask x
            let x = vals[0];
            let is_zero_mask = if !self.is_expected_non_zero {
                FieldValue::from(x.eq(FieldValue::from(0)))
            } else {
                FieldValue::from(0)
            };
            let x = x + is_zero_mask;
            let z = (x * lambda).reveal();

            // 3. compute pow on the public value z
            let z_pow_exponent =
                FieldValue::<F>::new(FieldExpr::Pow(z, self.exponent.clone(), true));

            // 4. unmask
            z_pow_exponent * lambda_pow - is_zero_mask
        } else if F::modulus() % 8 == 5 && self.exponent == (F::modulus() - 5) / 8 {
            // We compute x^(p-5)/8 as follows:
            // 1. generate a random lambda and compute lambda^(p-1 - (p-5)/8) = lambda^(7p-3)/8
            // 2. compute z = x * lambda and open z
            // 3. compute z^(p-5)/8 = x^(p-5)/8 * lambda^(p-5)/8
            // 4. return x^(p-5)/8 = z^(p-5)/8 * lambda^(7p-3)/8

            // 1.
            let lambda = FieldValue::<F>::random();
            let lambda_pow = smart_pow(lambda, &((7 * F::modulus() - 3) / 8));

            // 2. multiplicatively mask x
            let x = vals[0];
            let is_zero_mask = if !self.is_expected_non_zero {
                FieldValue::from(x.eq(FieldValue::from(0)))
            } else {
                FieldValue::from(0)
            };
            let x = x + is_zero_mask;
            let z = (x * lambda).reveal();

            // 3. compute pow on the public value z
            let z_pow_exponent =
                FieldValue::<F>::new(FieldExpr::Pow(z, self.exponent.clone(), true));

            // 4. unmask
            z_pow_exponent * lambda_pow - is_zero_mask
        } else {
            smart_pow(vals[0], &self.exponent)
        };
        vec![pow]
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
        ArcisField,
    };
    use num_traits::ToPrimitive;
    use rand::Rng;
    use std::marker::PhantomData;

    #[test]
    fn test_smart_pow() {
        let alpha = 5usize;
        let alpha_number = Number::from(alpha);
        for number in [Number::from(-1), 5.into(), (-3).into()] {
            let tested = smart_pow(number.clone(), &alpha_number);
            let expected = {
                let mut res = Number::from(1);
                for _ in 0..alpha {
                    res = res * &number
                }
                res
            };

            assert_eq!(tested, expected);
        }
    }

    impl<F: ActuallyUsedField> TestedArithmeticCircuit<F> for PowCircuit {
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
            let is_expected_non_zero = rng.gen_bool(0.5);

            // Often try the special case alpha == 0.
            if rng.gen_bool(0.25) {
                return Self::new(0.into(), is_expected_non_zero);
            }
            let r = rng.next_u64() % 5;
            match r {
                0 => Self::new((rng.next_u64() % 256).into(), is_expected_non_zero),
                1 => Self::new(F::get_alpha_inverse(), is_expected_non_zero),
                2 => Self::new((F::modulus() - 1) / 2, is_expected_non_zero),
                3 if F::modulus() % 4 == 3 => {
                    Self::new((F::modulus() - 3) / 4, is_expected_non_zero)
                }
                3 if F::modulus() % 8 == 5 => {
                    Self::new((F::modulus() - 5) / 8, is_expected_non_zero)
                }
                _ => Self::new(rng.next_u64().into(), is_expected_non_zero),
            }
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            1
        }

        fn extra_checks(&self, inputs: Vec<F>, outputs: Vec<F>) {
            assert_eq!(inputs.len(), outputs.len());
            inputs.into_iter().zip(outputs).for_each(|(input, output)| {
                if self.exponent == 0 {
                    assert_eq!(output, F::ONE);
                } else if let Some(n) = self.exponent.to_u8() {
                    let mut res = input;
                    for _ in 1..n {
                        res *= input
                    }
                    assert_eq!(output, res);
                }
            })
        }
    }
    #[test]
    fn tested() {
        PowCircuit::test_with_marker(64, 4, PhantomData::<ArcisField>)
    }
}