arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::FieldBounds,
        circuits::{
            arithmetic::sqrt,
            boolean::{
                boolean_value::{Boolean, BooleanValue},
                byte::Byte,
            },
            traits::arithmetic_circuit::ArithmeticCircuit,
        },
        expressions::expr::EvalFailure,
        global_value::value::FieldValue,
    },
    traits::{Equal, FromLeBytes, GetBit, GreaterEqual, Select},
    utils::{
        elliptic_curve::{
            AffineEdwardsPoint,
            CURVE25519_A,
            EDWARDS25519_D,
            EIGHT_INV_MOD_ELL,
            F25519,
            SQRT_NEG_ONE,
        },
        field::BaseField,
        used_field::UsedField,
    },
};
use ff::Field;

/// Encodes a base field element into a point on Edwards25519 using Elligator 2
/// (see <https://eprint.iacr.org/2013/325.pdf> for the encoding into the Montgomery curve).
/// The encoded point lies in the large prime order subgroup.
/// To make the decoding work we require the base field element to be at most 2^128.
#[derive(Clone, Debug)]
#[allow(non_snake_case, dead_code)]
pub struct Edwards25519ElligatorEncodingCircuit {
    A: BaseField,
    d: BaseField,
    // Corresponds to d of the isomorphic curve x^2 + y^2 = 1 + d_isom * x^2 * y^2.
    // Note: d_isom = -d.
    d_isom: BaseField,
}

impl Default for Edwards25519ElligatorEncodingCircuit {
    fn default() -> Self {
        Self::new()
    }
}

impl Edwards25519ElligatorEncodingCircuit {
    #[allow(dead_code)]
    pub fn new() -> Self {
        Self {
            A: BaseField::from_le_bytes(CURVE25519_A),
            d: BaseField::from_le_bytes(EDWARDS25519_D),
            d_isom: -BaseField::from_le_bytes(EDWARDS25519_D),
        }
    }

    #[cfg(test)]
    pub fn get_d(&self) -> BaseField {
        self.d
    }
}

impl Edwards25519ElligatorEncodingCircuit {
    /// Encodes the message m into a point on Edwards25519.
    #[allow(dead_code)]
    fn encode<
        B: Boolean + Select<T, T, T>,
        T: F25519 + GetBit<Output = B> + Equal<T, Output = B> + From<B>,
    >(
        &self,
        m: T,
    ) -> AffineEdwardsPoint<T> {
        // TODO: can we add the requirement on the input being at most 2^128 here?

        // we need to add an offset, since the 'usual' Elligator2 map encodes m = 0 as the 2-torsion
        // point (0, 0) on the Montgomery curve, which is a singularity of the birational
        // equivalence
        let m_offset = m + T::from(1);
        let m_offset2 = m_offset * m_offset;
        let m_offset4 = m_offset2 * m_offset2;

        // Note: rather than providing formulas that send m directly to the curve -x^2 + y^2 = 1 + d
        // * x^2 * y^2 we first send it to the isomorphic curve x^2 + y^2 = 1 + d_isom * x^2
        // * y^2 and then multiply the x-coordinate by sqrt(-1) to end up in the target
        // curve. The reason is that we initially used to work with the isomorphic non-twisted
        // Edwards curve and therefore derived the formulas for the non-twisted curve. To avoid
        // deriving formulas for the twisted Edwards curve we just apply the isomorphism at
        // the very end.

        // this is (v-1)/(v+1), where v = -A/(1+2*m_offset^2)
        // note that v+1 cannot be 0, otherwise we'd have m_offset^2 = (A-1)/2
        // and the rhs is a quadratic non-residue
        let y1 = -(T::from(1) + T::from(2) * m_offset2 + T::from(self.A))
            * (T::from(1) + T::from(2) * m_offset2 - T::from(self.A)).invert(true);
        // this is (-(v+A)-1)/(-(v+A)+1), where v is as above
        // note that -(v+A)+1 cannot be 0, otherwise we'd have m_offset^2 = 1/(2*(A-1))
        // and the rhs is a quadratic non-residue
        let y2 = -(T::from(1) + T::from(2) * (T::from(1) + T::from(self.A)) * m_offset2)
            * (T::from(1) + T::from(2) * (T::from(1) - T::from(self.A)) * m_offset2).invert(true);
        // same as g(y1), where g(y) = (1 - y^2)/(1 - d_isom*y^2)
        // TODO: explain why denom (= 1 - d_isom*((v-1)^2 / (v+1)^2)) cannot be 0
        // for the time being we set is_expected_non_zero = false
        let x12 = -T::from(4)
            * T::from(self.A)
            * (T::from(1) + T::from(2) * m_offset2)
            * (T::from(4) * m_offset4 * (T::from(1) - T::from(self.d_isom))
                + T::from(4) * m_offset2 * (T::from(1) - T::from(self.A))
                - T::from(4) * T::from(self.d_isom) * m_offset2 * (T::from(1) + T::from(self.A))
                + (T::from(self.A) - T::from(1)) * (T::from(self.A) - T::from(1))
                - T::from(self.d_isom)
                    * (T::from(self.A) + T::from(1))
                    * (T::from(self.A) + T::from(1)))
            .invert(false);
        // same as g(y2), where g(y) = (1 - y^2)/(1 - d_isom*y^2)
        // TODO: explain why denom (= 1 - d_isom*((-(v+A)-1)^2 / (-(v+A)+1)^2)) cannot be 0
        // for the time being we set is_expected_non_zero = false
        let x22 = -T::from(8)
            * T::from(self.A)
            * (m_offset2 + T::from(2) * m_offset4)
            * (T::from(1)
                + T::from(4) * (T::from(1) - T::from(self.A)) * m_offset2
                + T::from(4)
                    * (T::from(1) - T::from(self.A))
                    * (T::from(1) - T::from(self.A))
                    * m_offset4
                - T::from(self.d_isom)
                    * (T::from(1)
                        + T::from(4) * (T::from(1) + T::from(self.A)) * m_offset2
                        + T::from(4)
                            * (T::from(1) + T::from(self.A))
                            * (T::from(1) + T::from(self.A))
                            * m_offset4))
                .invert(false);

        let (is_square_x12, x1) = sqrt::<BaseField, B, _>(x12, true);
        let (_, x2) = sqrt::<BaseField, B, _>(x22, false);

        let x = is_square_x12.select(x1, x2);
        let y = is_square_x12.select(y1, y2);

        // TODO: here, we will switch from base field secret-sharing to EC point secret-sharing.
        // When generating a random curve point Lambda in the prime order subgroup
        // (both in EC point and base field secret-sharing) one must not forget to add
        // a 8-torsion part to the base field secret-shared mask.

        // multiplying the x-coordinate by sqrt(-1) sends a point on the non-twisted curve to
        // a point on the twisted curve
        AffineEdwardsPoint::new(
            (T::from(BaseField::from_le_bytes(SQRT_NEG_ONE)) * x, y),
            true,
            false,
        )
        .to_projective()
        .mul_str("0001")
        .to_affine()
    }
}

impl ArithmeticCircuit<BaseField> for Edwards25519ElligatorEncodingCircuit {
    fn eval(&self, x: Vec<BaseField>) -> Result<Vec<BaseField>, EvalFailure> {
        if x.len() != 1 {
            panic!("Edwards25519 Elligator encoding requires input of length 1");
        }
        if x[0] > BaseField::power_of_two(128) {
            EvalFailure::err_ub("Input must be at most 2^128")
        } else {
            let (x, y) = Self::encode::<bool, _>(self, x[0]).inner();
            Ok(vec![x, y])
        }
    }

    fn bounds(&self, _bounds: Vec<FieldBounds<BaseField>>) -> Vec<FieldBounds<BaseField>> {
        vec![FieldBounds::All, FieldBounds::All]
    }

    fn run(&self, vals: Vec<FieldValue<BaseField>>) -> Vec<FieldValue<BaseField>> {
        if vals.len() != 1 {
            panic!("Edwards25519 Elligator encoding requires input of length 1");
        }
        // TODO: add some diagnostics
        // let bounds = BaseField::bounds_to_field_bounds(vals[0].bounds());
        // if bounds.has_negatives() || bounds.unsigned_max() > BaseField::power_of_two(128) {
        //     EvalFailure::ub("Input must be at most 2^128")
        // }
        let (x, y) = Self::encode::<BooleanValue, _>(self, vals[0]).inner();
        vec![x, y]
    }
}

/// Decodes a base field element which was encoded via Edwards25519ElligatorEncodingCircuit.
#[derive(Clone, Debug)]
#[allow(non_snake_case, dead_code)]
pub struct Edwards25519ElligatorDecodingCircuit {
    A: BaseField,
    d: BaseField,
}

impl Default for Edwards25519ElligatorDecodingCircuit {
    fn default() -> Self {
        Self::new()
    }
}

impl Edwards25519ElligatorDecodingCircuit {
    #[allow(dead_code)]
    pub fn new() -> Self {
        Self {
            A: BaseField::from_le_bytes(CURVE25519_A),
            d: BaseField::from_le_bytes(EDWARDS25519_D),
        }
    }
}

#[allow(non_snake_case, dead_code)]
impl Edwards25519ElligatorDecodingCircuit {
    /// Computes the minimum of elligator decode over the 8-torsion coset.
    /// Note that Q is any of the four rational points of order 8.
    fn recover_message<
        B: Boolean + Select<T, T, T>,
        T: F25519 + GetBit<Output = B> + Equal<T, Output = B> + GreaterEqual<T, Output = B> + From<B>,
    >(
        &self,
        P: AffineEdwardsPoint<T>,
        P_plus_Q: AffineEdwardsPoint<T>,
    ) -> T {
        let sqrt_neg_one = T::from(BaseField::from_le_bytes(SQRT_NEG_ONE));
        // Note: the action of the rational 4-torsion on a point (x, y) yields the coset {(x, y),
        // (sqrt(-1)*y, sqrt(-1)*x), (-x, -y), (-sqrt(-1)*y, -sqrt(-1)*x)}. That is, the set of
        // y-coordinates of that coset is precisely {y, -y, sqrt(-1)*x, -sqrt(-1)*x}.
        let y_coordinates = vec![
            P.y,
            -P.y,
            sqrt_neg_one * P.x,
            -sqrt_neg_one * P.x,
            P_plus_Q.y,
            -P_plus_Q.y,
            sqrt_neg_one * P_plus_Q.x,
            -sqrt_neg_one * P_plus_Q.x,
        ];

        let two_inv = T::from((BaseField::modulus() + 1) / 2);

        let decoded_messages = y_coordinates
            .into_iter()
            .flat_map(|y| {
                let c1 = T::from(1) + y;
                let c2 = T::from(1) + T::from(self.A) + y * (T::from(1) - T::from(self.A));
                // c2 cannot be 0 since y = -(1+A)/(1-A) does not correspond to the y-coordinate
                // of a point on the curve
                let c = c1 * c2.invert(true);
                // c1 cannot be 0 since the only point with y = -1 is the 2-torsion point (0, -1),
                // however P is a non-zero element in the large prime order subgroup
                // (we added an offset while encoding)
                let c_inv = c2 * c1.invert(true);
                // is_square tells you whether y is the y-coordinate of a point in the image of the
                // elligator encoding
                let (is_square, m) = sqrt::<BaseField, _, _>(-c * two_inv, true);
                vec![
                    (is_square, m),
                    (is_square, -m),
                    (is_square, m * c_inv),
                    (is_square, -(m * c_inv)),
                ]
            })
            .collect::<Vec<(B, T)>>();

        let max_element = T::from(BaseField::modulus() - 1);

        // TODO: implement in depth 5
        decoded_messages
            .into_iter()
            .fold(max_element, |current, (is_in_image, cand)| {
                (is_in_image & cand.lt(current)).select(cand, current)
            })
            - T::from(1)
    }

    /// Decodes a base field element which was encoded via Edwards25519ElligatorEncodingCircuit.
    fn decode<
        B: Boolean + Select<T, T, T>,
        T: F25519 + GetBit<Output = B> + Equal<T, Output = B> + GreaterEqual<T, Output = B> + From<B>,
    >(
        &self,
        P: AffineEdwardsPoint<T>,
    ) -> T {
        let P_prime = P
            .to_projective()
            .mul_bits(
                EIGHT_INV_MOD_ELL
                    .into_iter()
                    .flat_map(|byte| Byte::<bool>::from(byte).to_vec())
                    .collect::<Vec<bool>>(),
            )
            .to_affine();
        let P_prime_plus_Q = P_prime + AffineEdwardsPoint::eight_torsion_point();
        Self::recover_message(self, P_prime, P_prime_plus_Q)
    }
}

impl ArithmeticCircuit<BaseField> for Edwards25519ElligatorDecodingCircuit {
    #[allow(non_snake_case)]
    fn eval(&self, x: Vec<BaseField>) -> Result<Vec<BaseField>, EvalFailure> {
        if x.len() != 2 {
            panic!("EC Point to Base field decoding requires input of length 2");
        }
        let P = AffineEdwardsPoint::new((x[0], x[1]), false, false);
        let m = Self::decode::<bool, _>(self, P);
        if m > BaseField::power_of_two(128) {
            EvalFailure::err_ub("Circuit only supported for decoded messages of at most 2^128")
        } else {
            Ok(vec![m])
        }
    }

    fn bounds(&self, _bounds: Vec<FieldBounds<BaseField>>) -> Vec<FieldBounds<BaseField>> {
        vec![FieldBounds::new(
            BaseField::ZERO,
            BaseField::power_of_two(128),
        )]
    }

    #[allow(non_snake_case)]
    fn run(&self, vals: Vec<FieldValue<BaseField>>) -> Vec<FieldValue<BaseField>> {
        if vals.len() != 2 {
            panic!("EC Point to Base field decoding requires input of length 2");
        }

        let P = AffineEdwardsPoint::new((vals[0], vals[1]), false, false);
        vec![Self::decode::<BooleanValue, _>(self, P)]
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit;
    use rand::Rng;

    #[test]
    #[allow(non_snake_case)]
    fn test_encoding_decoding() {
        let rng = &mut crate::utils::test_rng::get();
        let encoding_circuit = Edwards25519ElligatorEncodingCircuit::new();
        let decoding_circuit = Edwards25519ElligatorDecodingCircuit::new();
        for _ in 0..4 {
            for magnitude in [8, 16, 32, 64, 128] {
                let m = BaseField::gen_inclusive_range(
                    &mut *rng,
                    BaseField::ZERO,
                    BaseField::power_of_two(magnitude),
                );
                let P = encoding_circuit.encode::<bool, _>(m);
                let (x, y) = P.inner();
                let x2 = x * x;
                let y2 = y * y;
                assert_eq!(
                    -x2 + y2,
                    BaseField::ONE + encoding_circuit.get_d() * x2 * y2
                );
                let m_dec = decoding_circuit.decode::<bool, _>(P);
                assert_eq!(m_dec, m);
            }
        }
    }

    impl TestedArithmeticCircuit<BaseField> for Edwards25519ElligatorEncodingCircuit {
        fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
            Self::new()
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            1
        }

        fn extra_checks(&self, inputs: Vec<BaseField>, outputs: Vec<BaseField>) {
            let decoding_circuit = Edwards25519ElligatorDecodingCircuit::new();
            let decoded = decoding_circuit.eval(outputs).unwrap();
            assert_eq!(inputs, decoded);
        }
    }

    #[test]
    fn tested_elligator_encoding() {
        Edwards25519ElligatorEncodingCircuit::test(1, 4)
    }
}