arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        circuits::boolean::{
            boolean_value::BooleanValue,
            utils::{addition_circuit, subtraction_circuit, CircuitType},
        },
        compile_passes::compilation_pass::LocalCompilationPass,
        expressions::{
            conversion_expr::{
                ConversionExpr::{self, BitFromEdaBit, EdaBit, ScalarFromEdaBit},
                EdaBitId,
            },
            expr::Expr,
        },
        global_value::{
            global_expr_store::{with_global_expr_store_as_local, with_local_expr_store_as_global},
            value::FieldValue,
        },
        ir::IntermediateRepresentation,
        ir_builder::{ExprStore, IRBuilder},
    },
    traits::{GetBit, Reveal, Select},
    utils::number::Number,
    STATISTICAL_SECURITY_FACTOR,
};
use num_bigint::BigInt;
use rustc_hash::FxHashMap;
use std::marker::PhantomData;

/// Generate a new eda bit of specific size.
/// Returns the scalar component and the (reduced) boolean component of the eda bit.
/// If size == F::NUM_BITS and reduced == true we also return
/// - the bits of eda - F::modulus() as a signed (F::NUM_BITS + 1)-bit integer
/// - the bits of eda + (F::modulus()-1)/2 as a unsigned (F::NUM_BITS + 1)-bit integer
/// - the bits of eda - (F::modulus()+1)/2 as a signed (F::NUM_BITS + 1)-bit integer.
///
/// Note:
/// - If size < F::NUM_BITS then the binary expansion of the eda bit is reduced.
/// - If size == F::NUM_BITS then one can choose whether to reduce or not.
/// - If size > F::NUM_BITS we only support unreduced expansion.
#[allow(clippy::type_complexity)]
pub fn new_eda_bit<F: ActuallyUsedField>(
    size: usize,
    reduced: bool,
) -> (
    FieldValue<F>,
    Vec<BooleanValue>,
    Vec<BooleanValue>,
    Vec<BooleanValue>,
    Vec<BooleanValue>,
) {
    assert!(
        size <= F::NUM_BITS as usize || !reduced,
        "Reduced eda bit for size {size} not supported."
    );
    let (eda_bit_scalar_id, eda_bit_bit_ids) = with_global_expr_store_as_local(|expr_store| {
        let eda_bit_expr_id =
            expr_store.push_conversion(EdaBit(EdaBitId::new(), size, PhantomData::<F>));
        let eda_bit_scalar_id =
            expr_store.push_conversion(ScalarFromEdaBit::<F, usize>(eda_bit_expr_id));
        let eda_bit_bit_ids = (0..size)
            .map(|i| expr_store.push_conversion(BitFromEdaBit::<F, usize>(eda_bit_expr_id, i)))
            .collect::<Vec<usize>>();
        (eda_bit_scalar_id, eda_bit_bit_ids)
    });
    let eda_bit_scalar = FieldValue::<F>::from_id(eda_bit_scalar_id);
    let mut eda_bit_bits = eda_bit_bit_ids
        .into_iter()
        .map(BooleanValue::new)
        .collect::<Vec<BooleanValue>>();
    // eda - p
    let mut eda_bit_minus_modulus_bits = vec![BooleanValue::from(false); eda_bit_bits.len() + 1];
    // eda + (p-1)/2
    let mut eda_bit_plus_offset_bits = vec![BooleanValue::from(false); eda_bit_bits.len() + 1];
    // eda - p + (p-1)/2 = eda - (p+1)/2
    let mut eda_bit_minus_modulus_plus_offset_bits =
        vec![BooleanValue::from(false); eda_bit_bits.len() + 1];
    if size == F::NUM_BITS as usize && reduced {
        // Make sure eda_bit_bits represents a number strictly less than the modulus.
        eda_bit_bits.push(BooleanValue::from(false));
        let modulus = BigInt::from(F::modulus());
        let modulus_bits = (0..eda_bit_bits.len())
            .map(|i| {
                BooleanValue::from((modulus.clone() >> i) & BigInt::from(1) == BigInt::from(1))
            })
            .collect::<Vec<BooleanValue>>();
        let mut two_modulus_bits = modulus_bits.clone();
        two_modulus_bits.rotate_right(1);
        let eda_bit_minus_modulus =
            subtraction_circuit(eda_bit_bits.clone(), modulus_bits, CircuitType::default());
        // eda - 2p might not be correctly represented if eda is already reduced
        let eda_bit_minus_two_modulus = subtraction_circuit(
            eda_bit_bits.clone(),
            two_modulus_bits,
            CircuitType::default(),
        );
        let is_already_reduced = *eda_bit_minus_modulus.last().unwrap();
        eda_bit_bits = is_already_reduced.select(eda_bit_bits, eda_bit_minus_modulus.clone());
        eda_bit_minus_modulus_bits =
            is_already_reduced.select(eda_bit_minus_modulus, eda_bit_minus_two_modulus);

        // now compute eda + (p-1)/2 and eda - (p+1)/2
        let offset = FieldValue::from(F::TWO_INV - F::ONE);
        let offset_bits = (0..eda_bit_bits.len())
            .map(|i| offset.get_bit(i, false))
            .collect::<Vec<BooleanValue>>();
        eda_bit_plus_offset_bits = addition_circuit(
            eda_bit_bits.clone(),
            offset_bits.clone(),
            BooleanValue::from(false),
            CircuitType::default(),
        );
        eda_bit_minus_modulus_plus_offset_bits = addition_circuit(
            eda_bit_minus_modulus_bits.clone(),
            offset_bits,
            BooleanValue::from(false),
            CircuitType::default(),
        );
        let _ = eda_bit_bits.pop();
    }
    (
        eda_bit_scalar,
        eda_bit_bits,
        eda_bit_minus_modulus_bits,
        eda_bit_plus_offset_bits,
        eda_bit_minus_modulus_plus_offset_bits,
    )
}

#[derive(Default)]
struct ConversionInfo {
    bit_num_to_bit_to_expr_id: FxHashMap<(usize, usize, bool), usize>,
    security_factor: usize, // constant
}

/// Introduces EdaBits (extended doubly authenticated bits),
/// in order to do the conversions from arithmetic to bool and inversely.
#[derive(Default)]
pub struct EdaBitIntroducer {
    expr_store: IRBuilder,
    conversion_info: ConversionInfo,
}

impl EdaBitIntroducer {
    fn conversion_eda_bits<F: ActuallyUsedField>(
        &mut self,
        expr: ConversionExpr<F, usize>,
        is_plaintext: bool,
    ) -> Expr<usize> {
        use ConversionExpr::*;
        match expr {
            BitNumToBit(e, i, signed) if !self.expr_store.get_is_plaintext(e) => {
                match self
                    .conversion_info
                    .bit_num_to_bit_to_expr_id
                    .get(&(e, i, signed))
                {
                    Some(n) => self.expr_store.get_expr(*n).clone(),
                    None => {
                        // To convert a scalar x to binary, we:
                        // 1. Have an edaBit r, a random number whose unsigned decomposition can be
                        //    accessed.
                        // 2. We compute and reveal s = x + r.
                        // 3. We compute a binary decomposition of s.
                        // 4. We compute s - r in binary.
                        with_local_expr_store_as_global(
                            || {
                                let e_id = e;
                                let e = FieldValue::<F>::from_id(e_id);
                                let bounds = e.bounds();
                                let num_size = bounds.bin_size(signed);
                                if num_size > 0 {
                                    // Maybe we already computed this.
                                    if let Some(n) = self
                                        .conversion_info
                                        .bit_num_to_bit_to_expr_id
                                        .get(&(e_id, num_size - 1, signed))
                                    {
                                        return if signed {
                                            BooleanValue::new(*n).expr()
                                        } else {
                                            BooleanValue::from(false).expr()
                                        };
                                    }
                                }

                                let eda_bit_size = (num_size
                                    + self.conversion_info.security_factor)
                                    .min(F::NUM_BITS as usize);
                                // could_underflow would work here as a name too.
                                // If there is overflow when we add eda_bit,
                                // then there will be underflow when we subtract it.
                                let could_overflow = Number::power_of_two(eda_bit_size)
                                    + Number::power_of_two(num_size)
                                    >= F::modulus();
                                let (
                                    eda_bit_scalar,
                                    mut eda_bit_bits,
                                    eda_bit_minus_modulus_bits,
                                    eda_bit_plus_offset_bits,
                                    eda_bit_minus_modulus_plus_offset_bits,
                                ) = new_eda_bit::<F>(eda_bit_size, could_overflow);

                                let sum = e + eda_bit_scalar;
                                let mut revealed_sum = sum.reveal();
                                if could_overflow && signed {
                                    // If it could overflow and it is signed,
                                    // then we do an unsigned decomposition of x + (p-1)/2,
                                    // from which we can extract a signed decomposition of x.
                                    let offset = FieldValue::from(F::TWO_INV - F::ONE);
                                    revealed_sum += offset;
                                }

                                let used_num_size = if could_overflow {
                                    F::NUM_BITS as usize
                                } else {
                                    num_size
                                };
                                eda_bit_bits = eda_bit_bits
                                    .into_iter()
                                    .take(used_num_size)
                                    .collect::<Vec<BooleanValue>>();

                                let mut revealed_sum_bits = (0..used_num_size)
                                    .map(|i| {
                                        revealed_sum.get_bit(
                                            i,
                                            // With signed decomposition, it is hard to handle
                                            // underflow.
                                            // So we only do signed decomposition when no
                                            // underflow.
                                            (!could_overflow) && signed,
                                        )
                                    })
                                    .collect::<Vec<BooleanValue>>();

                                let sub_bits = if could_overflow {
                                    revealed_sum_bits.push(BooleanValue::from(false));
                                    eda_bit_bits.push(BooleanValue::from(false));
                                    debug_assert!(
                                        eda_bit_bits.len() == eda_bit_minus_modulus_bits.len()
                                    );
                                    // sum - eda
                                    // if sub_bits1 is non-negative it is the expansion of
                                    //   x + (p-1)/2 mod p if signed
                                    //   x                 otherwise
                                    let mut sub_bits1 = subtraction_circuit(
                                        revealed_sum_bits.clone(),
                                        eda_bit_bits,
                                        CircuitType::default(),
                                    );
                                    let has_sum_overflowed = sub_bits1.pop().unwrap();
                                    if signed {
                                        // sum - (eda + (p-1)/2)
                                        let mut sub_bits2 = subtraction_circuit(
                                            revealed_sum_bits.clone(),
                                            eda_bit_plus_offset_bits,
                                            CircuitType::default(),
                                        );
                                        let _ = sub_bits2.pop();
                                        // sum - (eda - (p+1)/2)
                                        let mut sub_bits3 = subtraction_circuit(
                                            revealed_sum_bits,
                                            eda_bit_minus_modulus_plus_offset_bits,
                                            CircuitType::default(),
                                        );
                                        let _ = sub_bits3.pop();
                                        // The correct signed decomposition of x is either
                                        // sub_bits2 or sub_bits3 depending on whether sum has
                                        // overflowed or not.
                                        has_sum_overflowed.select(sub_bits3, sub_bits2)
                                    } else {
                                        // sum - (eda - p)
                                        let mut sub_bits2 = subtraction_circuit(
                                            revealed_sum_bits,
                                            eda_bit_minus_modulus_bits,
                                            CircuitType::default(),
                                        );
                                        let _ = sub_bits2.pop();
                                        // The correct unsigned decomposition of x is either
                                        // sub_bits1 or sub_bits2 depending on whether sum has
                                        // overflowed or not.
                                        has_sum_overflowed.select(sub_bits2, sub_bits1)
                                    }
                                } else {
                                    subtraction_circuit(
                                        revealed_sum_bits,
                                        eda_bit_bits,
                                        CircuitType::default(),
                                    )
                                };

                                sub_bits[0..num_size]
                                    .iter()
                                    .enumerate()
                                    .for_each(|(i, bit)| {
                                        self.conversion_info
                                            .bit_num_to_bit_to_expr_id
                                            .insert((e_id, i, signed), bit.get_id());
                                    });
                                if signed {
                                    sub_bits[i.min(num_size - 1)].expr()
                                } else if i >= num_size {
                                    BooleanValue::from(false).expr()
                                } else {
                                    sub_bits[i].expr()
                                }
                            },
                            &mut self.expr_store,
                        )
                    }
                }
            }
            BitToBitNum(v, signed) => {
                let n = v.len();
                let are_bits_revealed = is_plaintext && n < F::NUM_BITS as usize;
                with_local_expr_store_as_global(
                    || {
                        let mut res = FieldValue::<F>::from(0);
                        v.into_iter()
                            .map(BooleanValue::new)
                            .enumerate()
                            .for_each(|(i, bit)| {
                                let c = if signed && i + 1 == n {
                                    FieldValue::from(F::negative_power_of_two(i))
                                } else {
                                    FieldValue::from(F::power_of_two(i))
                                };
                                if bit.is_plaintext() {
                                    res += c * FieldValue::<F>::from(bit)
                                } else if are_bits_revealed {
                                    res += c * FieldValue::<F>::from(bit.reveal())
                                } else {
                                    let (eda_bit_scalar, eda_bit_bits, _, _, _) =
                                        new_eda_bit::<F>(1, false);
                                    let eda_bit_bit = eda_bit_bits[0];
                                    let xor = bit ^ eda_bit_bit;
                                    let revealed = xor.reveal();
                                    let revealed_scalar = FieldValue::<F>::from(revealed);
                                    let prod = revealed_scalar * eda_bit_scalar;
                                    let scalar_xor = eda_bit_scalar + revealed_scalar
                                        - FieldValue::<F>::from(2) * prod;
                                    res += c * scalar_xor
                                }
                            });
                        res.expr()
                    },
                    &mut self.expr_store,
                )
            }
            _ => F::conversion_expr_to_expr(expr),
        }
    }
}

impl LocalCompilationPass for EdaBitIntroducer {
    fn expr_store(&mut self) -> &mut IRBuilder {
        &mut self.expr_store
    }

    fn setup(&mut self, _old_ir: &IntermediateRepresentation) {
        self.conversion_info.security_factor = STATISTICAL_SECURITY_FACTOR;
    }

    fn transform(&mut self, expr: Expr<usize>, is_plaintext: bool) -> Expr<usize> {
        match expr {
            Expr::ScalarConversion(e) => self.conversion_eda_bits(e, is_plaintext),
            Expr::BaseConversion(e) => self.conversion_eda_bits(e, is_plaintext),
            _ => expr,
        }
    }
}