arcis-compiler 0.9.2

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        circuits::{
            arithmetic::{
                abs::AbsCircuit,
                bitwise_and::BitwiseAnd,
                float_div::Div,
                float_exp::{Exp, Exp2},
                float_log::{Ln, Log2},
                float_sqrt::{DivSqrt, Sqrt},
                lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
                max::Max,
                min::Min,
                sigmoid::Sigmoid,
                zero::ZeroCircuit,
            },
            boolean::{
                ed25519::{
                    Ed25519MXESign,
                    Ed25519Sign,
                    Ed25519Verify,
                    Ed25519VerifyingKeyFromSecretKey,
                },
                sha3::{SHA3_256, SHA3_512},
            },
            general::{conversion::ConversionCircuit, identity::IdentityCircuit},
            traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
        },
        expressions::{
            expr::{EvalValue, Expr},
            field_expr::FieldExpr,
            other_expr::OtherExpr,
        },
        global_value::{
            field_array::FieldArray,
            global_expr_store::with_global_expr_store_as_local,
            value::FieldValue,
        },
    },
    types::DOUBLE_PRECISION_MANTISSA,
    utils::field::BaseField,
};
use serde::{Deserialize, Serialize};

// inline_const is unavailable,
// so instead I follow the desugaring example on
// https://github.com/rust-lang/rust/pull/104087 :
struct StaticCircuits;
impl StaticCircuits {
    const MIN: Min = Min::new(true);
    const MAX: Max = Max::new(true);
    const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
    const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
    const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
    const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
    const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
    const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
    const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
    const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
    const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ArithmeticCircuitId {
    Min,
    Max,
    Log2,
    Ln,
    Exp2,
    Exp,
    Sqrt,
    DivSqrt,
    Div,
    BitwiseAnd,
    LowestBiggerPowerOfTwoMinusOne,
    Sigmoid,
    Zero,
    Abs,
    Identity,
}

impl ArithmeticCircuitId {
    pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
        match self {
            ArithmeticCircuitId::Min => &StaticCircuits::MIN,
            ArithmeticCircuitId::Max => &StaticCircuits::MAX,
            ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
            ArithmeticCircuitId::Ln => &StaticCircuits::LN,
            ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
            ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
            ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
            ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
            ArithmeticCircuitId::Div => &StaticCircuits::DIV,
            ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
            ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
            ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
            ArithmeticCircuitId::Zero => &ZeroCircuit,
            ArithmeticCircuitId::Abs => &AbsCircuit,
            ArithmeticCircuitId::Identity => &IdentityCircuit,
        }
    }
    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
        T::apply_arithmetic_circuit_id(args, self)
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum GeneralCircuitId {
    Conversion,
}

impl GeneralCircuitId {
    pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
        match self {
            GeneralCircuitId::Conversion => &ConversionCircuit,
        }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BaseCircuitId {
    Ed25519Sign,
    Ed25519MXESign,
    Ed25519Verify,
    Ed25519VerifyingKeyFromSecretKey,
    Sha3_256,
    Sha3_512,
    Arith(ArithmeticCircuitId),
}

impl BaseCircuitId {
    pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
        match self {
            BaseCircuitId::Ed25519Sign => &Ed25519Sign,
            BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
            BaseCircuitId::Ed25519Verify => &Ed25519Verify,
            BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
            BaseCircuitId::Sha3_256 => &SHA3_256,
            BaseCircuitId::Sha3_512 => &SHA3_512,
            BaseCircuitId::Arith(a) => a.to_circuit(),
        }
    }
    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
        T::apply_base_circuit_id(args, self)
    }
    pub fn new_from_str(str: &str) -> Option<Self> {
        let res = match str {
            // Please keep alphabetical order.
            "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
            "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
            "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
            "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
            "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
            "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
            "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
            "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
            "identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
            "lowest_bigger_power_of_two_minus_one" => {
                BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
            }
            "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
            "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
            "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
            "sign" => BaseCircuitId::Ed25519Sign,
            "mxe-sign" => BaseCircuitId::Ed25519MXESign,
            "verify" => BaseCircuitId::Ed25519Verify,
            "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
            "sha3-256" => BaseCircuitId::Sha3_256,
            "sha3-512" => BaseCircuitId::Sha3_512,
            "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
            _ => return None,
        };
        Some(res)
    }
}

pub trait CircuitArg: Sized {
    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
}

impl CircuitArg for BaseField {
    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
        c.to_circuit().eval(v).unwrap()
    }

    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
        c.to_circuit().eval(v).unwrap()
    }
}

impl CircuitArg for EvalValue {
    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
        CircuitArg::apply_arithmetic_circuit_id(
            v.into_iter()
                .map(|x| BaseField::from(x.to_signed_number()))
                .collect(),
            c,
        )
        .into_iter()
        .map(EvalValue::Base)
        .collect()
    }

    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
        CircuitArg::apply_base_circuit_id(
            v.into_iter()
                .map(|x| BaseField::from(x.to_signed_number()))
                .collect(),
            c,
        )
        .into_iter()
        .map(EvalValue::Base)
        .collect()
    }
}

impl CircuitArg for FieldValue<BaseField> {
    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
        let all_bounds =
            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
        let n = c.to_circuit().bounds(all_bounds).len();
        (0..n)
            .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
            .collect()
    }

    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
        let n = c.to_circuit().bounds(all_bounds).len();
        (0..n)
            .map(|i| {
                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
                    expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
                        v.iter().map(FieldValue::get_id).collect(),
                        c,
                        i,
                    )))
                }))
            })
            .collect()
    }
}

impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
        let all_bounds =
            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
        let n = c.to_circuit().bounds(all_bounds).len();
        (0..n)
            .map(|i| {
                FieldArray::from(
                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
                        (0..N)
                            .map(|j| {
                                FieldValue::new(FieldExpr::SubCircuit(
                                    v.iter()
                                        .copied()
                                        .map(|x| x[j])
                                        .collect::<Vec<FieldValue<BaseField>>>(),
                                    c,
                                    i,
                                ))
                            })
                            .collect::<Vec<FieldValue<BaseField>>>(),
                    )
                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
                        panic!("Expected a Vec of length {} (found {})", N, v.len())
                    }),
                )
            })
            .collect::<Vec<Self>>()
    }

    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
        let n = c.to_circuit().bounds(all_bounds).len();
        (0..n)
            .map(|i| {
                FieldArray::from(
                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
                        (0..N)
                            .map(|j| {
                                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
                                    expr_store.new_expr(Expr::Other(
                                        OtherExpr::BaseArithmeticCircuit(
                                            v.iter()
                                                .copied()
                                                .map(|x| x[j].get_id())
                                                .collect::<Vec<usize>>(),
                                            c,
                                            i,
                                        ),
                                    ))
                                }))
                            })
                            .collect::<Vec<FieldValue<BaseField>>>(),
                    )
                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
                        panic!("Expected a Vec of length {} (found {})", N, v.len())
                    }),
                )
            })
            .collect::<Vec<Self>>()
    }
}