arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        circuits::{
            boolean::{
                boolean_value::BooleanValue,
                utils::{icpot_signed, shift_right, CircuitType},
            },
            traits::arithmetic_circuit::ArithmeticCircuit,
        },
        expressions::expr::EvalFailure,
        global_value::value::FieldValue,
    },
    traits::Select,
    types::DOUBLE_PRECISION_MANTISSA,
    utils::{number::Number, used_field::UsedField},
};

// This is the 52-bit number ln(2) * 2^DOUBLE_PRECISION_MANTISSA.
const LN_2: u64 = 3121657384082679;
// Margin we add to the fractional part in order to reduce rounding noise.
const PRECISION_MARGIN: usize = 4;

// This is the array of the first 21 Chebyshev coefficients for the function log2((z + 3)/2) with z
// in [-1, 1]. If z1 = -3 + 2 * sqrt(2) and z1_tilde = -3 - 2 * sqrt(2) then the coefficients
// are compute as: a_0 = log2(-z1_tilde) - 2, and a_i = -2 * z_1^i / (i * log(2)) for i > 0.
// For the below CHEBYSHEV_COEFFS, we multiply the coefficients by 2^(DOUBLE_PRECISION_MANTISSA +
// PRECISION_MARGIN) and round.
const CHEBYSHEV_COEFFS: [i64; 21] = [
    39134955358043840,
    35672448621869204,
    -3060212288698950,
    350032947506080,
    -45042119427888,
    6182404750215,
    -883944132481,
    129995002606,
    -19515664320,
    2976318791,
    -459590015,
    71684709,
    -11274222,
    1785555,
    -284470,
    45554,
    -7327,
    1183,
    -192,
    31,
    -5,
];

/// A logarithm algorithm. Computes log2(x). The lowest DOUBLE_PRECISION_MANTISSA bits represent the
/// fractional part.
#[derive(Clone, Debug)]
pub struct Log2 {
    // number of bits after the point of the input
    precision_in: usize,
    // number of bits after the point of the output
    precision_out: usize,
}

impl Log2 {
    #[allow(unused)]
    pub const fn new(precision_in: usize) -> Self {
        if precision_in > DOUBLE_PRECISION_MANTISSA {
            panic!("precision_in must be at most 52");
        }
        Log2 {
            precision_in,
            precision_out: DOUBLE_PRECISION_MANTISSA,
        }
    }

    /// Given an input x, this function computes floor_log2_x = floor(log2(abs(x))), icpot =
    /// 2^-floor_log2_x and is_non_pos = x <= 0. If is_non_pos = true then both loor_log2_x and
    /// icpot are 0.
    fn init_log2<F: ActuallyUsedField>(
        &self,
        x: FieldValue<F>,
    ) -> (FieldValue<F>, FieldValue<F>, BooleanValue) {
        let (icpot, icpot_bits, is_non_pos) = icpot_signed(x, CircuitType::default());
        let floor_log2_x = FieldValue::from(F::power_of_two(self.precision_out))
            * icpot_bits.into_iter().rev().enumerate().fold(
                FieldValue::<F>::from(0),
                |acc, (i, bit)| {
                    acc + bit.select(
                        FieldValue::from(F::from(i as i32 - self.precision_in as i32)),
                        FieldValue::<F>::from(0),
                    )
                },
            );

        (floor_log2_x, icpot, is_non_pos)
    }

    fn log2_approx<F: ActuallyUsedField>(&self, x_normalized: FieldValue<F>) -> FieldValue<F> {
        // x_normalized is in the interval [2^precision_out, 2^(precision_out + 1) - 1]
        let one =
            FieldValue::<F>::from(Number::power_of_two(self.precision_out + PRECISION_MARGIN));
        // affine change of variables x_normalized -> z in [-2^(precision_out + PRECISION_MARGIN),
        // 2^(precision_out + PRECISION_MARGIN) - 1]
        let z = FieldValue::from(F::power_of_two(1 + PRECISION_MARGIN)) * x_normalized - 3 * one;
        let mut chebyshev_polynomials = vec![one, z];
        for i in 2..21 {
            let last = chebyshev_polynomials[i - 1];
            let second_last = chebyshev_polynomials[i - 2];
            chebyshev_polynomials.push(
                2 * shift_right(z * last, self.precision_out + PRECISION_MARGIN, true)
                    - second_last,
            );
        }

        CHEBYSHEV_COEFFS
            .into_iter()
            .zip(chebyshev_polynomials)
            .fold(FieldValue::<F>::from(0), |acc, (c, p)| {
                acc + FieldValue::from(F::from(Number::from(c))) * p
            })
            >> (self.precision_out + 2 * PRECISION_MARGIN)
    }

    pub fn log2<F: ActuallyUsedField>(&self, x: FieldValue<F>) -> FieldValue<F> {
        let bounds = x.bounds();
        if bounds.signed_max().is_le_zero() {
            FieldValue::<F>::from(0)
        } else {
            // let floor_log2_x = floor(log2(abs(x))) and icpot = 2^-floor_log2_x,
            // so that in case x > 0, log2(x) = floor_log2_x + log2(x * icpot)
            // note that 1 <= x * icpot < 2 and hence 0 <= log2(x * icpot) < 1
            // all that remains is to accurately compute log2(x * icpot)
            let (floor_log2_x, icpot, is_non_pos) = self.init_log2(x);

            let x_icpot = x * icpot;
            let offset = bounds.bin_size(true) as i64 - self.precision_out as i64 - 2;

            let x_normalized = if offset > 0 {
                x_icpot >> (offset as usize)
            } else {
                FieldValue::from(F::power_of_two((-offset) as usize)) * x_icpot
            };
            // x_normalized is in the interval [2^precision_out, 2^(precision_out + 1) - 1],
            // except if x = 0, in which case x_normalized = 0
            let x_normalized_bounds = FieldBounds::new(
                if bounds.signed_min().is_le_zero() {
                    F::ZERO
                } else {
                    F::power_of_two(self.precision_out)
                },
                F::power_of_two(self.precision_out + 1) - F::ONE,
            );
            let log2_x_normalized = self.log2_approx(x_normalized.with_bounds(x_normalized_bounds));

            is_non_pos
                .select(FieldValue::from(F::ZERO), floor_log2_x + log2_x_normalized)
                .with_bounds(self.log2_bounds(bounds))
        }
    }

    fn log2_public<F: UsedField>(&self, x: F) -> F {
        let x_signed = x.to_signed_number();
        if x_signed > 0 {
            let x_float = f64::from(x_signed) * 2f64.powi(-(self.precision_in as i32));
            F::from(x_float.log2())
        } else {
            F::ZERO
        }
    }

    fn log2_bounds<F: UsedField>(&self, bounds: FieldBounds<F>) -> FieldBounds<F> {
        let (min, max) = bounds.min_and_max(true);
        if max.is_le_zero() {
            FieldBounds::new(F::ZERO, F::ZERO)
        } else if min.is_gt_zero() {
            FieldBounds::new(
                self.log2_public(min) - self.eval_gap(&[min]),
                self.log2_public(max) + self.eval_gap(&[max]),
            )
        } else {
            FieldBounds::new(
                self.log2_public(F::ONE),
                (self.log2_public(max) + self.eval_gap(&[max])).max(F::ZERO, true),
            )
        }
    }
}

impl<F: UsedField> ArithmeticCircuit<F> for Log2 {
    fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        if x.len() != 1 {
            panic!("Log2 requires one input")
        }
        let x = x[0];
        Ok(vec![self.log2_public(x)])
    }

    fn eval_gap(&self, x: &[F]) -> F {
        // when x gets close to 1, log2_public(x) gets close to 0, hence we want a gap of at least 4
        (self.log2_public(x[0]).abs() >> 48).max(F::from(4), true)
    }

    fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        vec![self.log2_bounds(bounds[0])]
    }

    fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField,
    {
        if vals.len() != 1 {
            panic!("Log2 requires one input")
        }
        vec![self.log2(vals[0])]
    }
}

/// A logarithm algorithm. Computes ln(x). The lowest DOUBLE_PRECISION_MANTISSA bits represent the
/// fractional part.
#[derive(Clone, Debug)]
pub struct Ln {
    // number of bits after the point of the input
    precision_in: usize,
    // number of bits after the point of the output
    precision_out: usize,
}

impl Ln {
    #[allow(unused)]
    pub const fn new(precision_in: usize) -> Self {
        if precision_in > DOUBLE_PRECISION_MANTISSA {
            panic!("precision_in must be at most 52");
        }
        Ln {
            precision_in,
            precision_out: DOUBLE_PRECISION_MANTISSA,
        }
    }

    pub fn ln<F: ActuallyUsedField>(&self, x: FieldValue<F>) -> FieldValue<F> {
        let bounds = x.bounds();
        shift_right(
            // the precision of LN_2 and of log2(x) is precision_out
            // we want the precision of LN_2 * log2(x) to be precision_out
            FieldValue::from(F::from(LN_2)) * Log2::new(self.precision_in).log2(x),
            self.precision_out,
            true,
        )
        .with_bounds(self.ln_bounds(bounds))
    }

    fn ln_public<F: UsedField>(&self, x: F) -> F {
        let x_signed = x.to_signed_number();
        if x_signed > 0 {
            let x_float = f64::from(x_signed) * 2f64.powi(-(self.precision_in as i32));
            F::from(x_float.ln())
        } else {
            F::ZERO
        }
    }

    fn ln_bounds<F: UsedField>(&self, bounds: FieldBounds<F>) -> FieldBounds<F> {
        let (min, max) = bounds.min_and_max(true);
        if max.is_le_zero() {
            FieldBounds::new(F::ZERO, F::ZERO)
        } else if min.is_gt_zero() {
            FieldBounds::new(
                self.ln_public(min) - self.eval_gap(&[min]),
                self.ln_public(max) + self.eval_gap(&[max]),
            )
        } else {
            FieldBounds::new(
                self.ln_public(F::ONE),
                (self.ln_public(max) + self.eval_gap(&[max])).max(F::ZERO, true),
            )
        }
    }
}

impl<F: UsedField> ArithmeticCircuit<F> for Ln {
    fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        if x.len() != 1 {
            panic!("Ln requires one input")
        }
        let x = x[0];
        Ok(vec![self.ln_public(x)])
    }

    fn eval_gap(&self, x: &[F]) -> F {
        // when x gets close to 1, ln_public(x) gets close to 0, hence we want a gap of at least 4
        (self.ln_public(x[0]).abs() >> 48).max(F::from(4), true)
    }

    fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        vec![self.ln_bounds(bounds[0])]
    }

    fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField,
    {
        if vals.len() != 1 {
            panic!("Ln requires one input")
        }
        vec![self.ln(vals[0])]
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
        utils::field::ScalarField,
    };
    use rand::Rng;

    impl TestedArithmeticCircuit<ScalarField> for Log2 {
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
            let mut precision = 52;
            while rng.gen_bool(0.5) {
                precision -= 1;
            }
            Self::new(precision as usize)
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            1
        }

        fn input_precisions(&self) -> Vec<usize> {
            vec![self.precision_in]
        }
    }

    impl TestedArithmeticCircuit<ScalarField> for Ln {
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
            let mut precision = 52;
            while rng.gen_bool(0.5) {
                precision -= 1;
            }
            Self::new(precision as usize)
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            1
        }

        fn input_precisions(&self) -> Vec<usize> {
            vec![self.precision_in]
        }
    }

    #[test]
    fn tested_log2() {
        Log2::test(16, 4)
    }

    #[test]
    fn tested_ln() {
        Ln::test(16, 4)
    }
}