arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        circuits::{
            arithmetic::{fast_divide::FastDivide, float_exp::Exp},
            boolean::utils::{neg_abs, sign_bit},
            traits::arithmetic_circuit::ArithmeticCircuit,
        },
        expressions::expr::EvalFailure,
        global_value::value::FieldValue,
    },
    traits::{GreaterEqual, Select},
    types::DOUBLE_PRECISION_MANTISSA,
    utils::{number::Number, used_field::UsedField},
};

/// Corresponds to log(2^DOUBLE_PRECISION_MANTISSA) * 2^DOUBLE_PRECISION_MANTISSA + 13
/// This is the tightes capping limit we can impose so that we avoid triggering another
/// capping in exp2 (the +13 is experimental).
const LIMIT: usize = 162326183972299341;

/// Computs the sigmoid function sigmo(x) = 1/(1 + exp(-x)). The lowest DOUBLE_PRECISION_MANTISSA
/// bits represent the fractional part.
#[derive(Clone, Debug)]
pub struct Sigmoid {
    // number of bits after the point of the input
    precision_in: usize,
    // number of bits after the point of the output
    precision_out: usize,
}

impl Sigmoid {
    #[allow(unused)]
    pub const fn new(precision_in: usize) -> Self {
        if precision_in > DOUBLE_PRECISION_MANTISSA {
            panic!("precision_in must be at most 52");
        }
        Sigmoid {
            precision_in,
            precision_out: DOUBLE_PRECISION_MANTISSA,
        }
    }

    pub fn sigmoid<F: ActuallyUsedField>(&self, x: FieldValue<F>) -> FieldValue<F> {
        let bounds = x.bounds();
        let (min, max) = bounds.min_and_max(true);
        let limit = F::from((LIMIT >> (self.precision_out - self.precision_in)) as u64);
        if min.is_ge_zero() && (min - limit).is_ge_zero() {
            FieldValue::from(F::power_of_two(self.precision_out))
        } else if max.is_lt_zero() && (max + limit).is_le_zero() {
            FieldValue::from(F::ZERO)
        } else {
            // we use the fact that sigmo(x) = sign + sigmo(|x|) - 2 * sign * sigmo(|x|)
            let sign = sign_bit(x);
            let neg_abs_x = neg_abs(x);
            let (min_abs, max_abs) = neg_abs_x.bounds().min_and_max(true);
            let neg_abs_x = if (min_abs + limit).is_ge_zero() {
                neg_abs_x
            } else {
                (neg_abs_x.lt(-limit))
                    .select(FieldValue::from(-limit), neg_abs_x)
                    .with_bounds((-limit, (-limit).max(max_abs, true)))
            };
            let exp_neg_abs_x = Exp::new(self.precision_in).exp(neg_abs_x);
            let one = FieldValue::<F>::from(Number::power_of_two(self.precision_out));
            let sigmo_abs_x = FastDivide::new(0, 0).inv_approx(
                // TODO: make the below work
                // one + exp_neg_abs_x,
                // self.precision_out + 1,
                (one + exp_neg_abs_x) >> 1,
                self.precision_out,
                self.precision_out,
            );
            (FieldValue::<F>::from(sign) * one + sigmo_abs_x
                - sign.select(2 * sigmo_abs_x, FieldValue::<F>::from(0)))
            .with_bounds(Self::sigmoid_bounds(self))
        }
    }

    fn sigmoid_public<F: UsedField>(&self, x: F) -> F {
        let x_float = f64::from(x.to_signed_number()) * 2f64.powi(-(self.precision_in as i32));
        let res_float = if x_float >= 0.0 {
            // sigmo(x) = 1/(1 + exp(-x))
            1.0 / (1.0 + (-x_float).exp())
        } else {
            // for numerical stability reasons, we compute the sigmoid as
            // sigmo(x) = exp(x)/(exp(x) + 1)
            let exp_x = x_float.exp();
            exp_x / (exp_x + 1.0)
        };
        F::from(res_float)
    }

    fn sigmoid_bounds<F: UsedField>(&self) -> FieldBounds<F> {
        FieldBounds::new(F::ZERO, F::power_of_two(self.precision_out))
    }
}

impl<F: UsedField> ArithmeticCircuit<F> for Sigmoid {
    fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        if x.len() != 1 {
            panic!("Sigmoid requires one input")
        }
        Ok(vec![Self::sigmoid_public(self, x[0])])
    }

    fn eval_gap(&self, _x: &[F]) -> F {
        F::from(5)
    }

    fn bounds(&self, _bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        vec![self.sigmoid_bounds()]
    }

    fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField,
    {
        if vals.len() != 1 {
            panic!("Sigmoid requires one input")
        }
        vec![Self::sigmoid(self, vals[0])]
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
        utils::field::ScalarField,
    };
    use rand::Rng;

    impl TestedArithmeticCircuit<ScalarField> for Sigmoid {
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
            let mut precision = 52;
            while rng.gen_bool(0.5) {
                precision -= 1;
            }
            Self::new(precision as usize)
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            1
        }

        fn input_precisions(&self) -> Vec<usize> {
            vec![self.precision_in]
        }
    }

    #[test]
    fn tested_sigmoid() {
        Sigmoid::test(1, 16)
    }
}