arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::{BoolBounds, Bounds, CurveBounds, FieldBounds, IsBounds},
        expressions::{
            circuit::{BaseCircuitId, GeneralCircuitId},
            expr::{EvalFailure, EvalValue},
            field_expr::InputId,
        },
        mxe_input::{ArxInput, MxeInput},
    },
    traits::FromLeBytes,
    utils::{
        curve_point::CurvePoint,
        elliptic_curve::{EDWARDS25519_D, FOUR_INV_MOD_ELL},
        field::{BaseField, ScalarField},
    },
};
use arcis_internal_expr_macro::Expr;
use core::marker::PhantomData;
use curve25519_dalek_arcium_fork::{field::FieldElement, EdwardsPoint, RistrettoPoint, Scalar};
use ff::{Field, PrimeField};
use serde::{Deserialize, Serialize};
use std::cell::Cell;

#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Expr)]
#[allow(clippy::enum_variant_names)]
pub enum OtherExpr<T: Clone, B: Clone = T, C: Clone = T> {
    /// A scalar equal to the ith of the results of applying the circuit to the provided scalars.
    /// If there are i or fewer results, then the result is zero.
    ScalarGeneralCircuit(Vec<T>, Vec<B>, GeneralCircuitId, usize),
    /// A base element equal to one of the results of applying the circuit to the provided scalars.
    /// 0 if the asked result is not present.
    BaseGeneralCircuit(Vec<T>, Vec<B>, GeneralCircuitId, usize),
    BaseArithmeticCircuit(Vec<B>, BaseCircuitId, usize),
    /// Plaintext curve to extended Edwards coordinates conversion.
    /// Returns the coordinates of the prime order representative.
    PlaintextCurveToExtendedEdwards(C, usize),
    /// Curve from plaintext extended Edwards coordinates conversion.
    CurveFromPlaintextExtendedEdwards(B, B, B, B),
    /// Compress (encode) a plaintext point, as explained [here](https://ristretto.group/details/isogeny_encoding.html).
    CompressPlaintextPoint(C, usize),
    /// if usize is 0, x
    /// if usize is 1, y
    /// if usize is 2, z
    /// else 0
    ToProjective(C, usize),
    MxeKey(MxeInput),
}

impl OtherExpr<ScalarField, BaseField, CurvePoint> {
    #[allow(non_snake_case)]
    pub fn eval(self) -> Result<EvalValue, EvalFailure> {
        match self {
            OtherExpr::ScalarGeneralCircuit(s, b, c, i) => {
                let (res, _) = c.to_circuit().eval(s, b)?;
                Ok(if i < res.len() {
                    EvalValue::Scalar(res[i])
                } else {
                    EvalValue::Scalar(ScalarField::ZERO)
                })
            }
            OtherExpr::BaseGeneralCircuit(s, b, c, i) => {
                let (_, res) = c.to_circuit().eval(s, b)?;
                Ok(if i < res.len() {
                    EvalValue::Base(res[i])
                } else {
                    EvalValue::Base(BaseField::ZERO)
                })
            }
            OtherExpr::BaseArithmeticCircuit(v, c, i) => Ok(EvalValue::Base(
                c.to_circuit()
                    .eval(v)?
                    .get(i)
                    .cloned()
                    .unwrap_or(BaseField::ZERO),
            )),
            OtherExpr::PlaintextCurveToExtendedEdwards(c, i) => {
                let edwards_point = c.inner().inner();
                let prime_order_repr = Scalar::from_canonical_bytes(FOUR_INV_MOD_ELL).unwrap()
                    * (Scalar::from(4u8) * edwards_point);
                if prime_order_repr.Z() == FieldElement::ZERO {
                    Err(EvalFailure::ub("Point at infinity"))
                } else {
                    // Dehomogenize with respect to Z.
                    let Z_inv = prime_order_repr.Z().invert();
                    Ok(match i {
                        0 => EvalValue::Base(BaseField::from_le_bytes(
                            (&prime_order_repr.X() * &Z_inv).as_bytes(),
                        )),
                        1 => EvalValue::Base(BaseField::from_le_bytes(
                            (&prime_order_repr.Y() * &Z_inv).as_bytes(),
                        )),
                        2 => EvalValue::Base(BaseField::from(1)),
                        3 => EvalValue::Base(BaseField::from_le_bytes(
                            (&prime_order_repr.T() * &Z_inv).as_bytes(),
                        )),
                        _ => EvalValue::Base(BaseField::ZERO),
                    })
                }
            }
            OtherExpr::CurveFromPlaintextExtendedEdwards(X, Y, Z, T) => {
                if -X * X + Y * Y == Z * Z + BaseField::from_le_bytes(EDWARDS25519_D) * T * T
                    && X * Y == Z * T
                    && Z != BaseField::ZERO
                {
                    let edwards_point = EdwardsPoint::new_unchecked(
                        FieldElement::from_bytes(&X.to_le_bytes()),
                        FieldElement::from_bytes(&Y.to_le_bytes()),
                        FieldElement::from_bytes(&Z.to_le_bytes()),
                        FieldElement::from_bytes(&T.to_le_bytes()),
                    );
                    // apply the map E -> [2]E, P \mapsto [2^-1 mod \ell] \circ [2] P
                    let ristretto_point = RistrettoPoint::new_unchecked(
                        (edwards_point * Scalar::from(2u8)) * Scalar::TWO_INV,
                    );
                    Ok(EvalValue::Curve(CurvePoint::new(ristretto_point)))
                } else {
                    Ok(EvalValue::Curve(CurvePoint::identity()))
                }
            }
            OtherExpr::CompressPlaintextPoint(c, i) => {
                assert!(i < 256);
                let compressed = c.to_bytes();
                Ok(EvalValue::Bit(compressed[i / 8] >> (i % 8) & 1u8 == 1u8))
            }
            OtherExpr::ToProjective(c, idx) => {
                if idx < 3 {
                    OtherExpr::PlaintextCurveToExtendedEdwards(c, idx).eval()
                } else {
                    Ok(EvalValue::Base(BaseField::ZERO))
                }
            }
            OtherExpr::MxeKey(k) => Ok(k.mock_eval()),
        }
    }
}

impl OtherExpr<bool> {
    pub fn is_plaintext(&self) -> bool {
        if let OtherExpr::MxeKey(i) = self {
            i.is_plaintext()
        } else {
            self.get_deps().iter().all(|x| *x)
        }
    }
}

impl<T: Clone, B: Clone, C: Clone> OtherExpr<T, B, C> {
    pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
        !matches!(self, OtherExpr::MxeKey(_))
    }
    pub fn get_input(&self) -> Option<InputId> {
        None
    }
    pub fn get_input_name(&self) -> &str {
        ""
    }
    pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
        None
    }
}

impl OtherExpr<FieldBounds<ScalarField>, FieldBounds<BaseField>, CurveBounds> {
    #[allow(non_snake_case)]
    pub fn bounds(self) -> Bounds {
        match self {
            OtherExpr::ScalarGeneralCircuit(s, b, c, i) => {
                let (res, _) = c.to_circuit().bounds(s, b);
                Bounds::Scalar(if i < res.len() {
                    res[i]
                } else {
                    FieldBounds::new(ScalarField::ZERO, ScalarField::ZERO)
                })
            }
            OtherExpr::BaseGeneralCircuit(s, b, c, i) => {
                let (_, res) = c.to_circuit().bounds(s, b);
                Bounds::Base(if i < res.len() {
                    res[i]
                } else {
                    FieldBounds::new(BaseField::ZERO, BaseField::ZERO)
                })
            }
            OtherExpr::BaseArithmeticCircuit(v, c, i) => Bounds::Base(
                c.to_circuit()
                    .bounds(v)
                    .get(i)
                    .cloned()
                    .unwrap_or(FieldBounds::new(BaseField::ZERO, BaseField::ZERO)),
            ),
            OtherExpr::PlaintextCurveToExtendedEdwards(c, i) => {
                let c_prime_order_repr = c.as_constant().map(|point| {
                    CurvePoint::new(
                        Scalar::from_canonical_bytes(FOUR_INV_MOD_ELL).unwrap()
                            * (Scalar::from(4u8) * point.inner()),
                    )
                });
                // If c are constant bounds of a point that is not at infinity we dehomogenize with
                // respect to Z.
                match (c_prime_order_repr, i) {
                    (Some(point), _) if point.inner().inner().Z() == FieldElement::ZERO => {
                        Bounds::Base(FieldBounds::from(BaseField::ZERO))
                    }
                    (Some(point), 0) => Bounds::Base(FieldBounds::from(BaseField::from_le_bytes(
                        (&point.inner().inner().X() * &point.inner().inner().Z().invert())
                            .as_bytes(),
                    ))),
                    (Some(point), 1) => Bounds::Base(FieldBounds::from(BaseField::from_le_bytes(
                        (&point.inner().inner().Y() * &point.inner().inner().Z().invert())
                            .as_bytes(),
                    ))),
                    (Some(_), 2) => Bounds::Base(FieldBounds::from(BaseField::from(1))),
                    (Some(point), 3) => Bounds::Base(FieldBounds::from(BaseField::from_le_bytes(
                        (&point.inner().inner().T() * &point.inner().inner().Z().invert())
                            .as_bytes(),
                    ))),
                    (_, t) if t >= 4 => Bounds::Base(FieldBounds::from(BaseField::ZERO)),
                    _ => Bounds::Base(FieldBounds::All),
                }
            }
            OtherExpr::CurveFromPlaintextExtendedEdwards(X, Y, Z, T) => {
                match (
                    X.as_constant(),
                    Y.as_constant(),
                    Z.as_constant(),
                    T.as_constant(),
                ) {
                    (Some(X), Some(Y), Some(Z), Some(T)) => {
                        if -X * X + Y * Y
                            == Z * Z + BaseField::from_le_bytes(EDWARDS25519_D) * T * T
                            && X * Y == Z * T
                            && Z != BaseField::ZERO
                        {
                            let edwards_point = EdwardsPoint::new_unchecked(
                                FieldElement::from_bytes(&X.to_le_bytes()),
                                FieldElement::from_bytes(&Y.to_le_bytes()),
                                FieldElement::from_bytes(&Z.to_le_bytes()),
                                FieldElement::from_bytes(&T.to_le_bytes()),
                            );
                            // apply the map E -> [2]E, P \mapsto [2^-1 mod \ell] \circ [2] P
                            let ristretto_point = RistrettoPoint::new_unchecked(
                                (edwards_point * Scalar::from(2u8)) * Scalar::TWO_INV,
                            );
                            Bounds::Curve(CurveBounds::Constant(CurvePoint::new(ristretto_point)))
                        } else {
                            Bounds::Curve(CurveBounds::Constant(CurvePoint::identity()))
                        }
                    }
                    _ => Bounds::Curve(CurveBounds::All),
                }
            }
            OtherExpr::CompressPlaintextPoint(c, i) => match c.as_constant() {
                Some(point) => Bounds::Bit(BoolBounds::from(
                    point.to_bytes()[i / 8] >> (i % 8) & 1u8 == 1u8,
                )),
                None => Bounds::Bit(BoolBounds::new(true, true)),
            },
            OtherExpr::ToProjective(c, idx) => match c.as_constant() {
                Some(point) => {
                    let v = OtherExpr::ToProjective(point, idx).eval();
                    match v {
                        Ok(EvalValue::Base(v)) => Bounds::Base(FieldBounds::new(v, v)),
                        _ => Bounds::Base(FieldBounds::All),
                    }
                }
                None => Bounds::Base(FieldBounds::All),
            },
            OtherExpr::MxeKey(k) => k.bounds(),
        }
    }
}