arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::{below_power_of_two, BoolBounds, FieldBounds, IsBounds},
        expressions::{expr::EvalFailure, field_expr::InputId},
    },
    traits::FromLeBits,
    utils::{unique_id::UniqueId, used_field::UsedField},
};
use arcis_internal_expr_macro::Expr;
use serde::{Deserialize, Serialize};
use std::{cell::Cell, marker::PhantomData, ops::Add};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EdaBitId(UniqueId);

impl EdaBitId {
    pub fn new() -> Self {
        EdaBitId(UniqueId::new())
    }
}

impl Default for EdaBitId {
    fn default() -> Self {
        EdaBitId::new()
    }
}

/// Expressions for conversions between arithmetic circuits and boolean circuits.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Expr)]
pub enum ConversionExpr<F: UsedField, T: Clone, B: Clone = T, E: Clone = T> {
    // T = Scalar
    // B = Bit, used in boolean circuits
    // E = EdaBit, an edaBit
    /// Converts a scalar to a bit by taking the i-th bit, written in either signed or unsigned
    /// form. The boolean is true iff signed decomposition is used.
    BitNumToBit(T, usize, bool),
    /// Converts a sequence of bits in lsb order into a scalar.
    /// The boolean is true iff signed decomposition is used.
    BitToBitNum(Vec<B>, bool),
    /// An EdaBit, a random integer whose unsigned decomposition has k bits (which can all be 0).
    /// Their bit decomposition can be accessed.
    EdaBit(EdaBitId, usize, PhantomData<F>),
    /// A bit equal to the kth bit of the unsigned decomposition of the edabit.
    BitFromEdaBit(E, usize),
    /// A scalar equal to a plaintext bit.
    ScalarFromPlaintextBit(B),
    /// A scalar equal to the integer (modulo p) from the eda bit.
    ScalarFromEdaBit(E),
}

pub enum ConversionValue<F: UsedField> {
    Bit(bool),
    Scalar(F),
}

impl<F: UsedField + FromLeBits<bool>> ConversionExpr<F, F, bool> {
    pub fn eval(self) -> Result<ConversionValue<F>, EvalFailure> {
        use ConversionExpr::*;
        use ConversionValue::*;
        let val = match self {
            BitNumToBit(e, c, b) => {
                if b {
                    Bit(e.signed_bit(c))
                } else {
                    Bit(e.unsigned_bit(c))
                }
            }
            BitToBitNum(v, signed) => Scalar(<F>::from_le_bits(v, signed)),
            EdaBit(_, _, _) => EvalFailure::err_imp("edaBit not evaluable here")?,
            BitFromEdaBit(e, k) => Bit(e.unsigned_bit(k)),
            ScalarFromEdaBit(e) => Scalar(e),
            ScalarFromPlaintextBit(b) => Scalar(F::from(b)),
        };
        Ok(val)
    }
}
impl<F: UsedField> ConversionExpr<F, bool, bool> {
    pub fn is_plaintext(&self) -> bool {
        !matches!(self, ConversionExpr::EdaBit(..)) && self.get_deps().iter().all(|x| *x)
    }
}
impl<F: UsedField, T: Clone, B: Clone> ConversionExpr<F, T, B> {
    pub fn is_boolean(&self) -> bool {
        use ConversionExpr::*;
        match self {
            BitNumToBit(_, _, _) => true,
            BitToBitNum(_, _) => false,
            EdaBit(_, _, _) => false,
            BitFromEdaBit(_, _) => true,
            ScalarFromPlaintextBit(_) => false,
            ScalarFromEdaBit(_) => false,
        }
    }
    pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
        !matches!(self, ConversionExpr::EdaBit(..))
    }
    pub fn get_input(&self) -> Option<InputId> {
        None
    }
    pub fn get_input_name(&self) -> &str {
        ""
    }
    pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
        None
    }
}

pub enum ConversionBounds<F: UsedField> {
    Bit(BoolBounds),
    Scalar(FieldBounds<F>),
}

fn all_bit_bools<F: UsedField>() -> ConversionBounds<F> {
    ConversionBounds::Bit(BoolBounds::new(true, true))
}
fn all_scalar_bools<F: UsedField>() -> ConversionBounds<F> {
    ConversionBounds::Scalar(FieldBounds::new(F::ZERO, F::ONE))
}

impl<F: UsedField> From<bool> for ConversionBounds<F> {
    fn from(b: bool) -> Self {
        ConversionBounds::Bit(BoolBounds::from(b))
    }
}
impl<F: UsedField> From<F> for ConversionBounds<F> {
    fn from(f: F) -> Self {
        ConversionBounds::Scalar(FieldBounds::from(f))
    }
}

impl<F: UsedField> ConversionExpr<F, FieldBounds<F>, BoolBounds> {
    pub fn bounds(self) -> ConversionBounds<F> {
        use ConversionBounds::*;
        use ConversionExpr::*;
        match self {
            BitNumToBit(b, idx, signed) => {
                let (min, max) = b.min_and_max(signed);
                if max - min >= F::power_of_two(idx) {
                    all_bit_bools()
                } else {
                    let (b1, b2) = b.bits_in_pos(idx, signed);
                    if b1 != b2 {
                        all_bit_bools()
                    } else {
                        b1.into()
                    }
                }
            }
            BitToBitNum(v, signed) => {
                let n = v.len();
                Scalar(
                    v.into_iter()
                        .enumerate()
                        .map(|(i, bounds)| {
                            let is_negative = signed && i + 1 == n;
                            bounds.multiply_by_power_of_two(i, is_negative)
                        })
                        .fold(FieldBounds::from(F::ZERO), Add::add),
                )
            }
            EdaBit(_, u, _) => Scalar(below_power_of_two(u)),
            BitFromEdaBit(b, k) => {
                if b.unsigned_max() >= F::power_of_two(k) {
                    all_bit_bools()
                } else {
                    Bit(false.into())
                }
            }
            ScalarFromPlaintextBit(b) => {
                if let Some(v) = b.as_constant() {
                    ConversionBounds::from(F::from(v))
                } else {
                    all_scalar_bools()
                }
            }
            ScalarFromEdaBit(b) => Scalar(b),
        }
    }
}