arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::{FieldBounds, IsBounds},
        circuits::{
            boolean::{
                boolean_value::BooleanValue,
                utils::{cpot_circuit, shift_right, CircuitType},
            },
            traits::arithmetic_circuit::ArithmeticCircuit,
        },
        expressions::expr::EvalFailure,
        global_value::value::FieldValue,
    },
    traits::{FromLeBits, GetBit, Select},
    types::DOUBLE_PRECISION_MANTISSA,
    utils::{number::Number, used_field::UsedField},
};

// Margin we add to the fractional part in order to reduce rounding noise.
const PRECISION_MARGIN: usize = 4;

// This is the array of the first 23 Chebyshev coefficients for the function 1 / ((z + 3) / 2)
// with z in [-1, 1]. If z1 = -3 + 2 * sqrt(2) then the coefficients are compute as:
// a_0 = 1 / sqrt(2), and a_i = sqrt(2) * z_1^i for i > 0.
// For the below CHEBYSHEV_COEFFS, we multiply the coefficients by 2^(DOUBLE_PRECISION_MANTISSA +
// PRECISION_MARGIN) and round.
const CHEBYSHEV_COEFFS: [i64; 23] = [
    50952413380206176,
    -17484104129525320,
    2999798016739667,
    -514683970912700,
    88305808736540,
    -15150881506541,
    2599480302707,
    -446000309701,
    76521555499,
    -13129023295,
    2252584276,
    -386482361,
    66309889,
    -11376978,
    1951980,
    -334906,
    57460,
    -9858,
    1691,
    -290,
    49,
    -8,
    1,
];

/// A division algorithm. Computes a / b. The lowest DOUBLE_PRECISION_MANTISSA bits represent the
/// fractional part.
#[derive(Clone, Debug)]
pub struct Div {
    // number of bits after the point of the input a
    precision_a: usize,
    // number of bits after the point of the input b
    precision_b: usize,
    // number of bits after the point of the output
    precision_out: usize,
}

impl Div {
    pub const fn new(precision_a: usize, precision_b: usize) -> Self {
        if precision_a > DOUBLE_PRECISION_MANTISSA || precision_b > DOUBLE_PRECISION_MANTISSA {
            panic!("input precision must be at most 52",);
        }
        Div {
            precision_a,
            precision_b,
            precision_out: DOUBLE_PRECISION_MANTISSA,
        }
    }

    /// Given an input b, this function computes signed_icpot = (-1)^sign(b) *
    /// 2^-floor(log2(abs(b))) and is_zero = b == 0. Note that if b is a negative power of 2
    /// (i.e., b = -2^e) then the circuit returns -2^(-e + 1) and not -2^(-e). The circuit would
    /// be much slower if we wanted to return the latter, while the division circuit still runs
    /// correctly with the former (b_normalized = 2 for a negative power of 2, while
    /// b_normalized = 1 for a positive power of 2). If is_zero = true then signed_icpot is 0.
    /// Note: the precision of signed_icpot is precision_b.
    fn init_inv<F: ActuallyUsedField>(
        b: FieldValue<F>,
    ) -> (FieldValue<F>, BooleanValue, BooleanValue) {
        let b_bounds = b.bounds();
        let circuit_size = b_bounds.signed_bin_size();
        let bits = (0..circuit_size)
            .map(|i| b.get_bit(i, true))
            .collect::<Vec<BooleanValue>>();
        let sign = *bits.last().unwrap();
        // xored corresponds to b if b >= 0 and to -b - 1 if b < 0
        let xored = bits
            .iter()
            .map(|bit| *bit ^ sign)
            .collect::<Vec<BooleanValue>>();
        let (mut icpot_bits, is_zero) = cpot_circuit(
            xored.into_iter().rev().collect::<Vec<BooleanValue>>(),
            CircuitType::default(),
        );
        // in case b = 1 icpot's leading bit would be 1
        icpot_bits.push(BooleanValue::from(false));
        if b_bounds == FieldBounds::All {
            // since we added a bit we also need to sacrifice one
            // we can still correctly divide by numbers of the order F::modulus() / 2
            icpot_bits.remove(0);
        }
        let neg_icpot_bits = icpot_bits
            .iter()
            .scan(BooleanValue::from(false), |acc, bit| {
                *acc ^= *bit;
                Some(*acc)
            })
            .collect::<Vec<BooleanValue>>();
        let signed_icpot_bits = sign.select(neg_icpot_bits, icpot_bits);
        let signed_icpot = FieldValue::<F>::from_le_bits(signed_icpot_bits, true);

        // the above circuit cannot handle the case where b = -1 (because xored = 0)
        // that's why we return the corresponding bit and handle it in the div function
        let is_neg_one = sign & is_zero;

        (signed_icpot, is_zero, is_neg_one)
    }

    fn inv_approx<F: ActuallyUsedField>(&self, b_normalized: FieldValue<F>) -> FieldValue<F> {
        // b_normalized is in the interval [2^precision_out, 2^(precision_out + 1) - 1]
        let one =
            FieldValue::<F>::from(Number::power_of_two(self.precision_out + PRECISION_MARGIN));
        // affine change of variables b_normalized -> z in [-2^(precision_out + PRECISION_MARGIN),
        // 2^(precision_out + PRECISION_MARGIN) - 1]
        let z = FieldValue::from(F::power_of_two(1 + PRECISION_MARGIN)) * b_normalized - 3 * one;
        let mut chebyshev_polynomials = vec![one, z];
        for i in 2..23 {
            let last = chebyshev_polynomials[i - 1];
            let second_last = chebyshev_polynomials[i - 2];
            chebyshev_polynomials.push(
                2 * shift_right(z * last, self.precision_out + PRECISION_MARGIN, true)
                    - second_last,
            );
        }

        CHEBYSHEV_COEFFS
            .into_iter()
            .zip(chebyshev_polynomials)
            .fold(FieldValue::<F>::from(0), |acc, (c, p)| {
                acc + FieldValue::from(F::from(Number::from(c))) * p
            })
            >> (self.precision_out + 2 * PRECISION_MARGIN)
    }

    pub fn div<F: ActuallyUsedField>(&self, a: FieldValue<F>, b: FieldValue<F>) -> FieldValue<F> {
        let a_bounds = a.bounds();
        let b_bounds = b.bounds();
        if b_bounds.signed_max().eq(&F::ZERO) && b_bounds.signed_min().eq(&F::ZERO) {
            FieldValue::<F>::from(0)
        } else {
            // let icpot = (-1)^sign(b) * 2^-floor(log2(abs(b)))
            // (there is the exceptional case of a negative power of 2, see
            // the function doc for init_inv, and in particular b = -1)
            // in case b != 0 we have 1 / b = icpot * inv(b * icpot)
            // note that 1 <= b * icpot <= 2 and hence 1/2 <= inv(b * icpot) <= 1
            // all that remains is to accurately compute inv(b * icpot)
            let (icpot, _, is_neg_one) = Self::init_inv(b);
            let b_icpot = b * icpot;
            let offset_icpot = b_bounds.bin_size(true) as i32
                - self.precision_out as i32
                - 1
                - ((b_bounds == FieldBounds::All) as i32);
            let b_normalized = if offset_icpot > 0 {
                b_icpot >> (offset_icpot as usize)
            } else {
                FieldValue::from(F::power_of_two((-offset_icpot) as usize)) * b_icpot
            };
            // b_normalized is in the interval [2^precision_out, 2^(precision_out + 1)],
            // except if b = 0, in which case b_normalized = 0
            let b_normalized_bounds = FieldBounds::new(
                if b_bounds.signed_min().is_le_zero() {
                    F::ZERO
                } else {
                    F::power_of_two(self.precision_out)
                },
                F::power_of_two(self.precision_out + 1),
            );
            let inv_b_normalized = self.inv_approx(b_normalized.with_bounds(b_normalized_bounds));

            let offset_a_icpot = offset_icpot + self.precision_a as i32 - self.precision_b as i32;
            let a_normalized = if offset_a_icpot > 0 {
                if b_bounds == FieldBounds::All {
                    // this is to handle cases where b_bounds = FieldBounds::All while b being
                    // small
                    a * shift_right(icpot, offset_a_icpot as usize, true)
                } else {
                    shift_right(a * icpot, offset_a_icpot as usize, true)
                }
            } else {
                FieldValue::from(F::power_of_two((-offset_a_icpot) as usize)) * a * icpot
            };

            let res = if a_normalized.bounds().bin_size(true) + self.precision_out
                >= F::NUM_BITS as usize
            {
                shift_right(a_normalized, self.precision_out, true) * inv_b_normalized
            } else {
                shift_right(a_normalized * inv_b_normalized, self.precision_out, true)
            };

            // in case b = -1 we have icpot = 0 (i.e., res = 0) and we can correct res by hand
            (res + is_neg_one.select(
                FieldValue::from(F::negative_power_of_two(
                    self.precision_out + self.precision_b - self.precision_a,
                )) * a,
                FieldValue::<F>::from(0),
            ))
            .with_bounds(self.div_bounds(a_bounds, b_bounds))
        }
    }

    fn div_public<F: UsedField>(&self, a: F, b: F) -> F {
        let b_signed = b.to_signed_number();
        if b_signed != 0 {
            let a_signed = a.to_signed_number();
            let a_float = f64::from(a_signed) * 2f64.powi(-(self.precision_a as i32));
            let b_float = f64::from(b_signed) * 2f64.powi(-(self.precision_b as i32));
            F::from(a_float / b_float)
        } else {
            F::ZERO
        }
    }

    fn div_bounds<F: UsedField>(
        &self,
        a_bounds: FieldBounds<F>,
        b_bounds: FieldBounds<F>,
    ) -> FieldBounds<F> {
        if a_bounds.bin_size(true) + self.precision_out + self.precision_b - self.precision_a
            >= F::NUM_BITS as usize
            || b_bounds.bin_size(true) + self.precision_out > F::NUM_BITS as usize
        {
            FieldBounds::All
        } else {
            let (a_min, a_max) = a_bounds.min_and_max(true);
            let (b_min, b_max) = b_bounds.min_and_max(true);
            let (min, max) = if b_min.is_ge_zero() {
                (
                    // the lhs accounts for the case a_min > 0 and the rhs for the case a_min < 0
                    (self.div_public(a_min, b_max) - self.eval_gap(&[a_min, b_max])).min(
                        self.div_public(a_min, b_min.max(F::ONE, true))
                            - self.eval_gap(&[a_min, b_min.max(F::ONE, true)]),
                        true,
                    ),
                    // the lhs accounts for the case a_max > 0 and the rhs for the case a_max < 0
                    (self.div_public(a_max, b_min.max(F::ONE, true))
                        + self.eval_gap(&[a_max, b_min.max(F::ONE, true)]))
                    .max(
                        self.div_public(a_max, b_max) + self.eval_gap(&[a_max, b_max]),
                        true,
                    ),
                )
            } else if b_max.is_le_zero() {
                (
                    // the lhs accounts for the case a_max > 0 and the rhs for the case a_max < 0
                    (self.div_public(a_max, b_max.min(-F::ONE, true))
                        - self.eval_gap(&[a_max, b_max.min(-F::ONE, true)]))
                    .min(
                        self.div_public(a_max, b_min) - self.eval_gap(&[a_max, b_min]),
                        true,
                    ),
                    // the lhs accounts for the case a_min > 0 and the rhs for the case a_min < 0
                    (self.div_public(a_min, b_min) + self.eval_gap(&[a_min, b_min])).max(
                        self.div_public(a_min, b_max.min(-F::ONE, true))
                            + self.eval_gap(&[a_min, b_max.min(-F::ONE, true)]),
                        true,
                    ),
                )
            } else {
                // [b_min, b_max] contains both -1 and 1
                let extr = (self.div_public(a_max.abs(), F::ONE)
                    + self.eval_gap(&[a_max.abs(), F::ONE]))
                .max(
                    self.div_public(a_min.abs(), F::ONE) + self.eval_gap(&[a_min.abs(), F::ONE]),
                    true,
                );
                (-extr, extr)
            };
            if b_bounds.contains(F::ZERO) {
                FieldBounds::new(min.min(F::ZERO, true), max.max(F::ZERO, true))
            } else {
                FieldBounds::new(min, max)
            }
        }
    }
}

impl<F: UsedField> ArithmeticCircuit<F> for Div {
    fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        if x.len() != 2 {
            panic!("div requires two inputs")
        }
        let a = x[0];
        let b = x[1];
        if a.signed_bits() + self.precision_out + self.precision_b - self.precision_a
            >= F::NUM_BITS as usize
            || b.signed_bits() + self.precision_out > F::NUM_BITS as usize
        {
            return EvalFailure::err_ub("input out of range");
        }
        Ok(vec![self.div_public(a, b)])
    }

    fn eval_gap(&self, x: &[F]) -> F {
        (self.div_public(x[0], x[1]).abs() >> 42).max(F::from(2), true)
    }

    fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        vec![self.div_bounds(bounds[0], bounds[1])]
    }

    fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField,
    {
        if vals.len() != 2 {
            panic!("div requires two input2")
        }
        vec![self.div(vals[0], vals[1])]
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
        utils::field::ScalarField,
    };
    use rand::Rng;

    impl TestedArithmeticCircuit<ScalarField> for Div {
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
            let mut precision_a = 52;
            let mut precision_b = 52;
            while rng.gen_bool(0.5) {
                precision_a -= 1;
            }
            while rng.gen_bool(0.5) {
                precision_b -= 1;
            }
            Self::new(precision_a as usize, precision_b as usize)
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            2
        }

        fn input_precisions(&self) -> Vec<usize> {
            vec![self.precision_a, self.precision_b]
        }
    }

    #[test]
    fn tested_div() {
        Div::test(16, 4)
    }
}