arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::{CurveBounds, FieldBounds, IsBounds},
        expressions::{
            expr::EvalFailure,
            macro_uses::{BoundWrap, ConstantBoundSampler},
            InputKind,
        },
    },
    utils::{curve_point::CurvePoint, field::ScalarField, ignore_for_equality::IgnoreForEquality},
};
use arcis_internal_expr_macro::Expr;
use primitives::algebra::{
    elliptic_curve::{Curve as AsyncMPCCurve, Curve25519Ristretto},
    field::subfield_element::SubfieldElement,
};
use serde::{Deserialize, Serialize};
use std::{cell::Cell, marker::PhantomData, rc::Rc};

pub type InputId = usize;

/// The kind of input.

#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct InputInfo {
    pub kind: InputKind,
    pub name: String,
    pub has_already_been_found_unused: IgnoreForEquality<Cell<bool>>,
}

impl InputInfo {
    pub fn to_async_mpc_input(&self) -> core_utils::circuit::Input<Curve25519Ristretto> {
        match self.kind {
            InputKind::Secret => core_utils::circuit::Input::Share {
                algebraic_type: core_utils::circuit::AlgebraicType::Point,
                batched: core_utils::circuit::Batched::No,
            },
            InputKind::SecretFromPlayer(i) => core_utils::circuit::Input::SecretPlaintext {
                inputer: i,
                algebraic_type: core_utils::circuit::AlgebraicType::Point,
                batched: core_utils::circuit::Batched::No,
            },
            InputKind::Plaintext => {
                core_utils::circuit::Input::Point(core_utils::circuit::PointPlaintext::<
                    Curve25519Ristretto,
                >::Input(1))
            }
        }
    }
}

impl Default for InputInfo {
    fn default() -> Self {
        let kind = InputKind::Plaintext;
        let name = "_".to_owned();
        let has_already_been_found_unused = IgnoreForEquality(Cell::new(false));
        Self {
            kind,
            name,
            has_already_been_found_unused,
        }
    }
}

impl InputInfo {
    pub fn is_plaintext(&self) -> bool {
        self.kind.is_plaintext()
    }
}

impl From<InputKind> for InputInfo {
    fn from(value: InputKind) -> Self {
        InputInfo {
            kind: value,
            ..InputInfo::default()
        }
    }
}

/// Expressions for elliptic curve circuits.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Expr)]
pub enum CurveExpr<C: Clone, S: Clone = C> {
    // C = Curve
    // S = Scalar
    /// An input that will produce a curve point.
    Input(InputId, Rc<InputInfo>),
    /// Addition between two curve points.
    Add(C, C),
    /// Negative of a curve point.
    Neg(C),
    /// Multiplication between a scalar and a curve point.
    Mul(S, C),
    /// Reveals a curve point, making it plaintext.
    /// Revealing an already revealed or plaintext curve point will work and not increase circuit
    /// size.
    Reveal(C),
    /// A curve point constant.
    Val(CurvePoint),
}

impl CurveExpr<bool> {
    pub fn is_plaintext(&self) -> bool {
        match self {
            CurveExpr::Input(_, info) => info.is_plaintext(),
            CurveExpr::Reveal(_) => true,
            CurveExpr::Val(_) => true,
            _ => self.get_deps().iter().all(|x| *x),
        }
    }
}

impl<T: Clone, S: Clone> CurveExpr<T, S> {
    pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
        !matches!(self, CurveExpr::Input(_, _))
    }

    pub fn get_input(&self) -> Option<InputId> {
        match self {
            CurveExpr::Input(id, _) => Some(*id),
            _ => None,
        }
    }

    pub fn get_input_name(&self) -> &str {
        match self {
            CurveExpr::Input(_, info) => info.name.as_str(),
            _ => "",
        }
    }

    pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
        match self {
            CurveExpr::Input(_, info) => Some(&info.has_already_been_found_unused.0),
            _ => None,
        }
    }
}

impl CurveExpr<CurvePoint, ScalarField> {
    pub fn eval(self) -> Result<CurvePoint, EvalFailure> {
        let val: CurvePoint = match self {
            CurveExpr::Input(_, _) => EvalFailure::err_imp("Input not evaluable here")?,
            CurveExpr::Add(e1, e2) => e1 + e2,
            CurveExpr::Mul(e1, e2) => {
                let e1 = SubfieldElement::new(
                    <Curve25519Ristretto as AsyncMPCCurve>::Scalar::from_canonical_bytes(
                        e1.to_le_bytes(),
                    )
                    .unwrap(),
                );
                // in async-mpc, multiplication is curve_point * scalar
                e2 * e1
            }
            CurveExpr::Reveal(e) => e,
            CurveExpr::Val(e) => e,
            CurveExpr::Neg(e) => -e,
        };
        Ok(val)
    }
}

impl CurveExpr<CurveBounds, FieldBounds<ScalarField>> {
    pub fn bounds(self) -> CurveBounds {
        match self {
            CurveExpr::Input(_, _) => CurveBounds::All,
            _ => {
                let deps = ApplyCurveExpr::apply(&mut BoundWrap, self.clone()).get_deps();
                if deps.iter().any(|x| x.is_empty()) {
                    CurveBounds::Empty
                } else if deps.iter().all(|x1| x1.as_constant().is_some()) {
                    let sampled = ApplyCurveExpr::apply(&mut ConstantBoundSampler, self);
                    let val = sampled.eval().expect("Can't evaluate constant CurveExpr");
                    CurveBounds::Constant(val)
                } else {
                    CurveBounds::All
                }
            }
        }
    }
}