arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        circuits::{
            boolean::boolean_value::BooleanValue,
            traits::arithmetic_circuit::ArithmeticCircuit,
        },
        expressions::expr::EvalFailure,
        global_value::value::FieldValue,
    },
    traits::{Equal, GetBit, Pow, Select, WithBooleanBounds},
    utils::number::Number,
};
use std::ops::{Add, Mul, Sub};

/// Compute the square-root of the finite field element. Return 0 for quadratic non-residues.
#[derive(Clone, Debug)]
pub struct SqrtCircuit;

/// Compute the square-root of the finite field element.
/// Return (1, a.sqrt()) for quadratic residues.
/// Return (0, 0) for quadratic non-residues.
pub fn sqrt<
    F: ActuallyUsedField,
    B: Select<T, T, T> + Copy,
    T: GetBit<Output = B>
        + Equal<T, Output = B>
        + From<B>
        + Copy
        + Add<T, Output = T>
        + Sub<T, Output = T>
        + Mul<T, Output = T>
        + Pow
        + WithBooleanBounds
        + From<i32>
        + From<Number>
        + From<F>,
>(
    a: T,
    is_expected_non_zero: bool,
) -> (B, T) {
    let p = F::modulus();
    if &p % 8 == 1 {
        panic!("p % 8 == 1 not supported");
    };
    if &p % 4 == 3 {
        // serves to compute all other powers
        let base_pow = a.pow(&((&p - 3) / 4), is_expected_non_zero);
        // pow = a^(p+1)/4 = a^(p-3)/4 * a
        let pow = base_pow * a;
        // is_square_field = a^(p-1)/2 = a^(p+1)/4 * a^(p-3)/4
        let is_square_field = pow * base_pow;
        // is_square_bool is 0 or 1
        let is_square_bool =
            T::from(1) - is_square_field * (is_square_field - T::from(1)) * T::from(F::TWO_INV);
        let is_square = is_square_bool.with_boolean_bounds().get_bit(0, false);

        (is_square, is_square.select(pow, T::from(0)))
    } else {
        let is_zero_mask = if !is_expected_non_zero {
            T::from(a.eq(T::from(0)))
        } else {
            T::from(0)
        };
        let a = a + is_zero_mask;
        // we're in the case p % 8 == 5
        // serves to compute all other powers
        let base_pow = a.pow(&((&p - 5) / 8), true);
        // pow = a^(p+3)/8 = a^(p-5)/8 * a
        let pow = base_pow * a;
        // is_fourth_power_field = a^(p-1)/4 = a^(p+3)/8 * a^(p-5)/8
        let is_fourth_power_field = pow * base_pow;
        // is_square_field = a^(p-1)/2
        let is_square_field = is_fourth_power_field * is_fourth_power_field;
        // is_square_bool is 0 or 1
        let is_square_bool =
            T::from(1) - is_square_field * (is_square_field - T::from(1)) * T::from(F::TWO_INV);
        let is_square = is_square_bool.with_boolean_bounds().get_bit(0, false);
        // is_fourth_power_bool can be different from 0 or 1
        // (though is_square_bool * is_fourth_power_bool is 0 or 1)
        let is_fourth_power_bool = (is_fourth_power_field + T::from(1)) * T::from(F::TWO_INV);
        let is_fourth_power = (is_square_bool * is_fourth_power_bool)
            .with_boolean_bounds()
            .get_bit(0, false);

        (
            is_square,
            is_square.select(
                pow * is_fourth_power
                    .select(T::from(1), T::from(F::from(2).pow(&((&p - 1) / 4), true))),
                T::from(0),
            ) - is_zero_mask,
        )
    }
}

impl<F: ActuallyUsedField> ArithmeticCircuit<F> for SqrtCircuit {
    fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        Ok(x.into_iter()
            .map(|val| sqrt::<F, bool, _>(val, false).1)
            .collect())
    }

    fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        vec![FieldBounds::All; bounds.len()]
    }

    fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>> {
        vals.into_iter()
            .map(|val| sqrt::<F, BooleanValue, _>(val, false).1)
            .collect()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
        utils::{field::ScalarField, used_field::UsedField},
    };
    use ff::Field;
    use rand::Rng;
    use std::marker::PhantomData;

    #[test]
    fn test_sqrt() {
        assert_eq!(
            sqrt::<ScalarField, bool, _>(ScalarField::ZERO, false),
            (true, ScalarField::ZERO)
        );
        assert_eq!(
            sqrt::<ScalarField, bool, _>(ScalarField::ONE, false),
            (true, ScalarField::ONE)
        );
        let rng = &mut crate::utils::test_rng::get();
        for _ in 0..100 {
            let x = ScalarField::random(&mut *rng);
            let computed = sqrt::<ScalarField, bool, _>(x, false);
            let sqrt_computed = (computed.0, computed.1.abs());
            let expected = x.sqrt();
            let is_square_expected = expected.is_some();
            let sqrt_expected = if is_square_expected.into() {
                (true, expected.unwrap().abs())
            } else {
                (false, ScalarField::ZERO)
            };
            assert_eq!(sqrt_computed, sqrt_expected);
        }
    }

    impl<F: ActuallyUsedField> TestedArithmeticCircuit<F> for SqrtCircuit {
        fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
            Self
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            1
        }

        fn extra_checks(&self, inputs: Vec<F>, outputs: Vec<F>) {
            assert_eq!(inputs.len(), outputs.len());
            inputs.into_iter().zip(outputs).for_each(|(val, output)| {
                let p = F::modulus();
                // on plaintext values we simply set is_expected_non_zero = true
                let is_square = val.pow(&((&p - 1) / 2), true);
                let expected = if is_square == F::ONE { val } else { F::ZERO };
                assert_eq!(output * output, expected);
            })
        }
    }
    #[test]
    fn tested() {
        SqrtCircuit::test_with_marker(1, 64, PhantomData::<ScalarField>)
    }
}