arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        circuits::boolean::{
            boolean_value::BooleanValue,
            utils::{
                abs,
                equal,
                euclidean_division,
                greater_than,
                is_number_zero,
                keep_ls_bits,
                shift_right,
                subtraction_circuit,
                CircuitType,
            },
        },
        compile_passes::{compilation_pass::LocalCompilationPass, new_eda_bit},
        expressions::{
            conversion_expr::ConversionExpr,
            expr::Expr,
            field_expr::FieldExpr,
            other_expr::OtherExpr,
        },
        global_value::{
            curve_value::CurveValue,
            global_expr_store::with_local_expr_store_as_global,
            value::FieldValue,
        },
        ir_builder::IRBuilder,
    },
    traits::{GetBit, Reveal},
    utils::{elliptic_curve::ProjectiveEdwardsPoint, field::BaseField, number::Number},
    STATISTICAL_SECURITY_FACTOR,
};
use ff::Field;
use rustc_hash::FxHashMap;

/// Information about an argument.
#[derive(Debug, Clone, Copy, Default)]
struct ArgInfo {
    will_have_unsigned_binary_decomposition: bool,
    will_have_signed_binary_decomposition: bool,
}

impl ArgInfo {
    fn make_from_extra<F: ActuallyUsedField>(expr_id: usize, expr_store: &IRBuilder) -> Self {
        let mut res = Self::default();
        if expr_store.get_is_plaintext(expr_id) {
            res.will_have_unsigned_binary_decomposition = true;
            res.will_have_signed_binary_decomposition = true;
        } else if let Some(ConversionExpr::BitToBitNum(v, signed)) =
            F::expr_to_conversion_expr(expr_store.get_expr(expr_id).clone())
        {
            if signed {
                res.will_have_signed_binary_decomposition = true;
            } else {
                res.will_have_unsigned_binary_decomposition = true;
                // A unsigned binary decomposition of a positive number is a signed one too.
                if v.len() < F::CAPACITY as usize {
                    res.will_have_signed_binary_decomposition = true;
                }
            }
        }
        res
    }
    fn find_common_signedness(&self, other: &Self) -> Option<bool> {
        if self.will_have_signed_binary_decomposition && other.will_have_signed_binary_decomposition
        {
            Some(true)
        } else if self.will_have_unsigned_binary_decomposition
            && other.will_have_unsigned_binary_decomposition
        {
            Some(false)
        } else {
            None
        }
    }
}

/// Information about Eq.
#[derive(Debug, Clone, Copy, Default)]
struct EqInfo {
    is_result_plaintext: bool,
    /// Information about the left argument.
    left_info: ArgInfo,
    /// Information about the right argument.
    right_info: ArgInfo,
}

/// Builds unoptimized boolean circuits.
#[derive(Default)]
pub struct BooleanCircuitBuilder {
    expr_store: IRBuilder,
    projective_cache: FxHashMap<usize, [usize; 3]>,
}

impl BooleanCircuitBuilder {
    fn build_circuits<F: ActuallyUsedField>(
        &mut self,
        expr: FieldExpr<F, usize>,
        is_plaintext: bool,
    ) -> Expr<usize> {
        let expr_store = &mut self.expr_store;
        use FieldExpr::*;
        match expr {
            Equal(e1, e2)
                if !expr_store.get_is_plaintext(e1) || !expr_store.get_is_plaintext(e2) =>
            {
                let left_info = ArgInfo::make_from_extra::<F>(e1, expr_store);
                let right_info = ArgInfo::make_from_extra::<F>(e2, expr_store);
                let eq_info = EqInfo {
                    is_result_plaintext: is_plaintext,
                    left_info,
                    right_info,
                };
                with_local_expr_store_as_global(
                    || {
                        let left = FieldValue::<F>::from_id(e1);
                        let right = FieldValue::<F>::from_id(e2);
                        if eq_info.is_result_plaintext {
                            // Here we will end up using
                            // the gate that outputs the result in plaintext.
                            FieldValue::new(Equal(right - left, 0.into()))
                        } else {
                            FieldValue::<F>::from(
                                if let Some(signedness) = eq_info
                                    .left_info
                                    .find_common_signedness(&eq_info.right_info)
                                {
                                    // If both left and right already have or will have
                                    // a boolean decomposition of the same sign,
                                    // we use the boolean circuit.
                                    equal(left, right, signedness, CircuitType::default())
                                } else {
                                    // Otherwise we use the eda-bit.
                                    is_number_zero(right - left, CircuitType::default())
                                },
                            )
                        }
                        .expr()
                    },
                    expr_store,
                )
            }
            Gt(e1, e2, signed)
                if !expr_store.get_is_plaintext(e1) || !expr_store.get_is_plaintext(e2) =>
            {
                let left_info = ArgInfo::make_from_extra::<F>(e1, expr_store);
                let right_info = ArgInfo::make_from_extra::<F>(e2, expr_store);
                with_local_expr_store_as_global(
                    || {
                        let left = FieldValue::<F>::from_id(e1);
                        let right = FieldValue::<F>::from_id(e2);
                        let (l_min, l_max) = left.bounds().min_and_max(signed);
                        let (r_min, r_max) = right.bounds().min_and_max(signed);
                        FieldValue::<F>::from(
                            if left_info.find_common_signedness(&right_info).is_some()
                                || r_min.does_add_signed_overflow(-l_max)
                                || r_max.does_add_signed_overflow(-l_min)
                                || !signed
                                    && (l_min.is_ge_zero() || r_min.is_ge_zero())
                                    && (l_max.is_lt_zero() || r_max.is_lt_zero())
                            {
                                let signed_input = left_info
                                    .find_common_signedness(&right_info)
                                    .unwrap_or(signed);
                                // If both left and right already have or will have a boolean
                                // decomposition, we use the boolean circuit.
                                greater_than(
                                    left,
                                    right,
                                    BooleanValue::from(false),
                                    signed,
                                    signed_input,
                                )
                            } else {
                                // get_bit internally optimizes the depth - it might be less than
                                // log2(F::NUM_BITS)
                                (right - left).get_bit(F::NUM_BITS as usize - 1, true)
                            },
                        )
                        .expr()
                    },
                    expr_store,
                )
            }
            Ge(e1, e2, signed)
                if !expr_store.get_is_plaintext(e1) || !expr_store.get_is_plaintext(e2) =>
            {
                let left_info = ArgInfo::make_from_extra::<F>(e1, expr_store);
                let right_info = ArgInfo::make_from_extra::<F>(e2, expr_store);
                with_local_expr_store_as_global(
                    || {
                        let left = FieldValue::<F>::from_id(e1);
                        let right = FieldValue::<F>::from_id(e2);
                        let (l_min, l_max) = left.bounds().min_and_max(signed);
                        let (r_min, r_max) = right.bounds().min_and_max(signed);
                        FieldValue::<F>::from(
                            if left_info.find_common_signedness(&right_info).is_some()
                                || l_min.does_add_signed_overflow(-r_max)
                                || l_max.does_add_signed_overflow(-r_min)
                                || !signed
                                    && (l_min.is_ge_zero() || r_min.is_ge_zero())
                                    && (l_max.is_lt_zero() || r_max.is_lt_zero())
                            {
                                let signed_input = left_info
                                    .find_common_signedness(&right_info)
                                    .unwrap_or(signed);
                                // If both left and right already have or will have a boolean
                                // decomposition, we use the boolean circuit.
                                greater_than(
                                    left,
                                    right,
                                    BooleanValue::from(true),
                                    signed,
                                    signed_input,
                                )
                            } else {
                                // get_bit internally optimizes the depth - it might be less than
                                // log2(F::NUM_BITS)
                                !(left - right).get_bit(F::NUM_BITS as usize - 1, true)
                            },
                        )
                        .expr()
                    },
                    expr_store,
                )
            }
            Abs(e) => with_local_expr_store_as_global(
                || abs(FieldValue::<F>::from_id(e)).expr(),
                expr_store,
            ),
            Div(e1, e2) if !expr_store.get_is_plaintext(e1) || !expr_store.get_is_plaintext(e2) => {
                with_local_expr_store_as_global(
                    || {
                        euclidean_division(
                            FieldValue::<F>::from_id(e1),
                            FieldValue::<F>::from_id(e2),
                        )
                        .0
                        .expr()
                    },
                    expr_store,
                )
            }
            Rem(e1, e2) if !expr_store.get_is_plaintext(e1) || !expr_store.get_is_plaintext(e2) => {
                with_local_expr_store_as_global(
                    || {
                        euclidean_division(
                            FieldValue::<F>::from_id(e1),
                            FieldValue::<F>::from_id(e2),
                        )
                        .1
                        .expr()
                    },
                    expr_store,
                )
            }
            LogicalRightShift(e, c) => with_local_expr_store_as_global(
                || shift_right(FieldValue::<F>::from_id(e), c, false).expr(),
                expr_store,
            ),
            KeepLsBits(e, c, signed_output) => with_local_expr_store_as_global(
                || keep_ls_bits(FieldValue::<F>::from_id(e), c, signed_output).expr(),
                expr_store,
            ),
            Cap(e, size) => {
                let eda_bit_size = (size + STATISTICAL_SECURITY_FACTOR).min(F::NUM_BITS as usize);
                let res_bits = if !expr_store.get_is_plaintext(e)
                    && Number::power_of_two(size) + Number::power_of_two(eda_bit_size)
                        < F::modulus()
                {
                    with_local_expr_store_as_global(
                        || {
                            let e = FieldValue::<F>::from_id(e);
                            let (eda_bit_scalar, eda_bit_bits, _, _, _) =
                                new_eda_bit::<F>(eda_bit_size, false);
                            let sum = e + eda_bit_scalar;
                            let revealed_sum = sum.reveal();
                            let revealed_sum_bits = (0..size)
                                .map(|i| revealed_sum.get_bit(i, false))
                                .collect::<Vec<BooleanValue>>();
                            subtraction_circuit(
                                revealed_sum_bits,
                                eda_bit_bits
                                    .into_iter()
                                    .take(size)
                                    // if the result is plaintext, we can reveal these bits, but not
                                    // the statistical masking
                                    // ones.
                                    .map(|b| if is_plaintext { b.reveal() } else { b })
                                    .collect::<Vec<BooleanValue>>(),
                                CircuitType::default(),
                            )
                            .into_iter()
                            .map(|bit| bit.get_id())
                            .collect::<Vec<usize>>()
                        },
                        &mut self.expr_store,
                    )
                } else {
                    (0..size)
                        .map(|i| {
                            self.expr_store.new_expr(F::conversion_expr_to_expr(
                                ConversionExpr::BitNumToBit(e, i, false),
                            ))
                        })
                        .collect()
                };
                F::conversion_expr_to_expr(ConversionExpr::<F, _>::BitToBitNum(res_bits, false))
            }

            _ => F::field_expr_to_expr(expr),
        }
    }
    fn build_other_expr(&mut self, expr: OtherExpr<usize>) -> Expr<usize> {
        match expr {
            OtherExpr::ToProjective(c, idx) => {
                let cache_value = self.projective_cache.entry(c).or_insert_with(|| {
                    with_local_expr_store_as_global(
                        || {
                            let c = CurveValue::new(c);
                            let arr = if c.is_plaintext() {
                                let (x, y, z, _) = c.to_extended_public();
                                [x, y, z]
                            } else {
                                let (da_point_curve, da_point_proj) = CurveValue::da_point();
                                let masked = (c - da_point_curve).reveal();
                                let (x, y, z, _) = masked.to_extended_public();
                                let masked_proj =
                                    ProjectiveEdwardsPoint::new((x, y, z), true, true);
                                let res = masked_proj + da_point_proj;
                                [res.X, res.Y, res.Z]
                            };
                            arr.map(|x| x.get_id())
                        },
                        &mut self.expr_store,
                    )
                });
                cache_value
                    .get(idx)
                    .cloned()
                    .map(|expr_id| self.expr_store.get_expr(expr_id).clone())
                    .unwrap_or(Expr::Base(FieldExpr::Val(BaseField::ZERO)))
            }
            _ => Expr::Other(expr),
        }
    }
}

impl LocalCompilationPass for BooleanCircuitBuilder {
    fn expr_store(&mut self) -> &mut IRBuilder {
        &mut self.expr_store
    }

    fn transform(&mut self, expr: Expr<usize>, is_plaintext: bool) -> Expr<usize> {
        match expr {
            Expr::Scalar(e) => self.build_circuits(e, is_plaintext),
            Expr::Base(e) => self.build_circuits(e, is_plaintext),
            Expr::Other(e) => self.build_other_expr(e),
            _ => expr,
        }
    }
}